pose3d_loss.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from itertools import cycle, islice
  18. from collections import abc
  19. import paddle
  20. import paddle.nn as nn
  21. from ppdet.core.workspace import register, serializable
  22. __all__ = ['Pose3DLoss']
  23. @register
  24. @serializable
  25. class Pose3DLoss(nn.Layer):
  26. def __init__(self, weight_3d=1.0, weight_2d=0.0, reduction='none'):
  27. """
  28. KeyPointMSELoss layer
  29. Args:
  30. weight_3d (float): weight of 3d loss
  31. weight_2d (float): weight of 2d loss
  32. reduction (bool): whether use reduction to loss
  33. """
  34. super(Pose3DLoss, self).__init__()
  35. self.weight_3d = weight_3d
  36. self.weight_2d = weight_2d
  37. self.criterion_2dpose = nn.MSELoss(reduction=reduction)
  38. self.criterion_3dpose = nn.MSELoss(reduction=reduction)
  39. self.criterion_smoothl1 = nn.SmoothL1Loss(
  40. reduction=reduction, delta=1.0)
  41. self.criterion_vertices = nn.L1Loss()
  42. def forward(self, pred3d, pred2d, inputs):
  43. """
  44. mpjpe: mpjpe loss between 3d joints
  45. keypoint_2d_loss: 2d joints loss compute by criterion_2dpose
  46. """
  47. gt_3d_joints = inputs['joints_3d']
  48. gt_2d_joints = inputs['joints_2d']
  49. has_3d_joints = inputs['has_3d_joints']
  50. has_2d_joints = inputs['has_2d_joints']
  51. loss_3d = mpjpe(pred3d, gt_3d_joints, has_3d_joints)
  52. loss_2d = keypoint_2d_loss(self.criterion_2dpose, pred2d, gt_2d_joints,
  53. has_2d_joints)
  54. return self.weight_3d * loss_3d + self.weight_2d * loss_2d
  55. def filter_3d_joints(pred, gt, has_3d_joints):
  56. """
  57. filter 3d joints
  58. """
  59. gt = gt[has_3d_joints == 1]
  60. gt = gt[:, :, :3]
  61. pred = pred[has_3d_joints == 1]
  62. gt_pelvis = (gt[:, 2, :] + gt[:, 3, :]) / 2
  63. gt = gt - gt_pelvis[:, None, :]
  64. pred_pelvis = (pred[:, 2, :] + pred[:, 3, :]) / 2
  65. pred = pred - pred_pelvis[:, None, :]
  66. return pred, gt
  67. @register
  68. @serializable
  69. def mpjpe(pred, gt, has_3d_joints):
  70. """
  71. mPJPE loss
  72. """
  73. pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
  74. error = paddle.sqrt(((pred - gt)**2).sum(axis=-1)).mean()
  75. return error
  76. @register
  77. @serializable
  78. def mpjpe_criterion(pred, gt, has_3d_joints, criterion_pose3d):
  79. """
  80. mPJPE loss of self define criterion
  81. """
  82. pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
  83. error = paddle.sqrt(criterion_pose3d(pred, gt).sum(axis=-1)).mean()
  84. return error
  85. @register
  86. @serializable
  87. def weighted_mpjpe(pred, gt, has_3d_joints):
  88. """
  89. Weighted_mPJPE
  90. """
  91. pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
  92. weight = paddle.linalg.norm(pred, p=2, axis=-1)
  93. weight = paddle.to_tensor(
  94. [1.5, 1.3, 1.2, 1.2, 1.3, 1.5, 1.5, 1.3, 1.2, 1.2, 1.3, 1.5, 1., 1.])
  95. error = (weight * paddle.linalg.norm(pred - gt, p=2, axis=-1)).mean()
  96. return error
  97. @register
  98. @serializable
  99. def normed_mpjpe(pred, gt, has_3d_joints):
  100. """
  101. Normalized MPJPE (scale only), adapted from:
  102. https://github.com/hrhodin/UnsupervisedGeometryAwareRepresentationLearning/blob/master/losses/poses.py
  103. """
  104. assert pred.shape == gt.shape
  105. pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
  106. norm_predicted = paddle.mean(
  107. paddle.sum(pred**2, axis=3, keepdim=True), axis=2, keepdim=True)
  108. norm_target = paddle.mean(
  109. paddle.sum(gt * pred, axis=3, keepdim=True), axis=2, keepdim=True)
  110. scale = norm_target / norm_predicted
  111. return mpjpe(scale * pred, gt)
  112. @register
  113. @serializable
  114. def mpjpe_np(pred, gt, has_3d_joints):
  115. """
  116. mPJPE_NP
  117. """
  118. pred, gt = filter_3d_joints(pred, gt, has_3d_joints)
  119. error = np.sqrt(((pred - gt)**2).sum(axis=-1)).mean()
  120. return error
  121. @register
  122. @serializable
  123. def mean_per_vertex_error(pred, gt, has_smpl):
  124. """
  125. Compute mPVE
  126. """
  127. pred = pred[has_smpl == 1]
  128. gt = gt[has_smpl == 1]
  129. with paddle.no_grad():
  130. error = paddle.sqrt(((pred - gt)**2).sum(axis=-1)).mean()
  131. return error
  132. @register
  133. @serializable
  134. def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d,
  135. has_pose_2d):
  136. """
  137. Compute 2D reprojection loss if 2D keypoint annotations are available.
  138. The confidence (conf) is binary and indicates whether the keypoints exist or not.
  139. """
  140. conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
  141. loss = (conf * criterion_keypoints(pred_keypoints_2d,
  142. gt_keypoints_2d[:, :, :-1])).mean()
  143. return loss
  144. @register
  145. @serializable
  146. def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d,
  147. has_pose_3d):
  148. """
  149. Compute 3D keypoint loss if 3D keypoint annotations are available.
  150. """
  151. conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
  152. gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone()
  153. gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1]
  154. conf = conf[has_pose_3d == 1]
  155. pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1]
  156. if len(gt_keypoints_3d) > 0:
  157. gt_pelvis = (gt_keypoints_3d[:, 2, :] + gt_keypoints_3d[:, 3, :]) / 2
  158. gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :]
  159. pred_pelvis = (
  160. pred_keypoints_3d[:, 2, :] + pred_keypoints_3d[:, 3, :]) / 2
  161. pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :]
  162. return (conf * criterion_keypoints(pred_keypoints_3d,
  163. gt_keypoints_3d)).mean()
  164. else:
  165. return paddle.to_tensor([1.]).fill_(0.)
  166. @register
  167. @serializable
  168. def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl):
  169. """
  170. Compute per-vertex loss if vertex annotations are available.
  171. """
  172. pred_vertices_with_shape = pred_vertices[has_smpl == 1]
  173. gt_vertices_with_shape = gt_vertices[has_smpl == 1]
  174. if len(gt_vertices_with_shape) > 0:
  175. return criterion_vertices(pred_vertices_with_shape,
  176. gt_vertices_with_shape)
  177. else:
  178. return paddle.to_tensor([1.]).fill_(0.)
  179. @register
  180. @serializable
  181. def rectify_pose(pose):
  182. pose = pose.copy()
  183. R_mod = cv2.Rodrigues(np.array([np.pi, 0, 0]))[0]
  184. R_root = cv2.Rodrigues(pose[:3])[0]
  185. new_root = R_root.dot(R_mod)
  186. pose[:3] = cv2.Rodrigues(new_root)[0].reshape(3)
  187. return pose