simota_assigner.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # The code is based on:
  15. # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/bbox/assigners/sim_ota_assigner.py
  16. import paddle
  17. import numpy as np
  18. import paddle.nn.functional as F
  19. from ppdet.modeling.losses.varifocal_loss import varifocal_loss
  20. from ppdet.modeling.bbox_utils import batch_bbox_overlaps
  21. from ppdet.core.workspace import register
  22. @register
  23. class SimOTAAssigner(object):
  24. """Computes matching between predictions and ground truth.
  25. Args:
  26. center_radius (int | float, optional): Ground truth center size
  27. to judge whether a prior is in center. Default 2.5.
  28. candidate_topk (int, optional): The candidate top-k which used to
  29. get top-k ious to calculate dynamic-k. Default 10.
  30. iou_weight (int | float, optional): The scale factor for regression
  31. iou cost. Default 3.0.
  32. cls_weight (int | float, optional): The scale factor for classification
  33. cost. Default 1.0.
  34. num_classes (int): The num_classes of dataset.
  35. use_vfl (int): Whether to use varifocal_loss when calculating the cost matrix.
  36. """
  37. __shared__ = ['num_classes']
  38. def __init__(self,
  39. center_radius=2.5,
  40. candidate_topk=10,
  41. iou_weight=3.0,
  42. cls_weight=1.0,
  43. num_classes=80,
  44. use_vfl=True):
  45. self.center_radius = center_radius
  46. self.candidate_topk = candidate_topk
  47. self.iou_weight = iou_weight
  48. self.cls_weight = cls_weight
  49. self.num_classes = num_classes
  50. self.use_vfl = use_vfl
  51. def get_in_gt_and_in_center_info(self, flatten_center_and_stride,
  52. gt_bboxes):
  53. num_gt = gt_bboxes.shape[0]
  54. flatten_x = flatten_center_and_stride[:, 0].unsqueeze(1).tile(
  55. [1, num_gt])
  56. flatten_y = flatten_center_and_stride[:, 1].unsqueeze(1).tile(
  57. [1, num_gt])
  58. flatten_stride_x = flatten_center_and_stride[:, 2].unsqueeze(1).tile(
  59. [1, num_gt])
  60. flatten_stride_y = flatten_center_and_stride[:, 3].unsqueeze(1).tile(
  61. [1, num_gt])
  62. # is prior centers in gt bboxes, shape: [n_center, n_gt]
  63. l_ = flatten_x - gt_bboxes[:, 0]
  64. t_ = flatten_y - gt_bboxes[:, 1]
  65. r_ = gt_bboxes[:, 2] - flatten_x
  66. b_ = gt_bboxes[:, 3] - flatten_y
  67. deltas = paddle.stack([l_, t_, r_, b_], axis=1)
  68. is_in_gts = deltas.min(axis=1) > 0
  69. is_in_gts_all = is_in_gts.sum(axis=1) > 0
  70. # is prior centers in gt centers
  71. gt_center_xs = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
  72. gt_center_ys = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
  73. ct_bound_l = gt_center_xs - self.center_radius * flatten_stride_x
  74. ct_bound_t = gt_center_ys - self.center_radius * flatten_stride_y
  75. ct_bound_r = gt_center_xs + self.center_radius * flatten_stride_x
  76. ct_bound_b = gt_center_ys + self.center_radius * flatten_stride_y
  77. cl_ = flatten_x - ct_bound_l
  78. ct_ = flatten_y - ct_bound_t
  79. cr_ = ct_bound_r - flatten_x
  80. cb_ = ct_bound_b - flatten_y
  81. ct_deltas = paddle.stack([cl_, ct_, cr_, cb_], axis=1)
  82. is_in_cts = ct_deltas.min(axis=1) > 0
  83. is_in_cts_all = is_in_cts.sum(axis=1) > 0
  84. # in any of gts or gt centers, shape: [n_center]
  85. is_in_gts_or_centers_all = paddle.logical_or(is_in_gts_all,
  86. is_in_cts_all)
  87. is_in_gts_or_centers_all_inds = paddle.nonzero(
  88. is_in_gts_or_centers_all).squeeze(1)
  89. # both in gts and gt centers, shape: [num_fg, num_gt]
  90. is_in_gts_and_centers = paddle.logical_and(
  91. paddle.gather(
  92. is_in_gts.cast('int'), is_in_gts_or_centers_all_inds,
  93. axis=0).cast('bool'),
  94. paddle.gather(
  95. is_in_cts.cast('int'), is_in_gts_or_centers_all_inds,
  96. axis=0).cast('bool'))
  97. return is_in_gts_or_centers_all, is_in_gts_or_centers_all_inds, is_in_gts_and_centers
  98. def dynamic_k_matching(self, cost_matrix, pairwise_ious, num_gt):
  99. match_matrix = np.zeros_like(cost_matrix.numpy())
  100. # select candidate topk ious for dynamic-k calculation
  101. topk_ious, _ = paddle.topk(
  102. pairwise_ious,
  103. min(self.candidate_topk, pairwise_ious.shape[0]),
  104. axis=0)
  105. # calculate dynamic k for each gt
  106. dynamic_ks = paddle.clip(topk_ious.sum(0).cast('int'), min=1)
  107. for gt_idx in range(num_gt):
  108. _, pos_idx = paddle.topk(
  109. cost_matrix[:, gt_idx], k=dynamic_ks[gt_idx], largest=False)
  110. match_matrix[:, gt_idx][pos_idx.numpy()] = 1.0
  111. del topk_ious, dynamic_ks, pos_idx
  112. # match points more than two gts
  113. extra_match_gts_mask = match_matrix.sum(1) > 1
  114. if extra_match_gts_mask.sum() > 0:
  115. cost_matrix = cost_matrix.numpy()
  116. cost_argmin = np.argmin(
  117. cost_matrix[extra_match_gts_mask, :], axis=1)
  118. match_matrix[extra_match_gts_mask, :] *= 0.0
  119. match_matrix[extra_match_gts_mask, cost_argmin] = 1.0
  120. # get foreground mask
  121. match_fg_mask_inmatrix = match_matrix.sum(1) > 0
  122. match_gt_inds_to_fg = match_matrix[match_fg_mask_inmatrix, :].argmax(1)
  123. return match_gt_inds_to_fg, match_fg_mask_inmatrix
  124. def get_sample(self, assign_gt_inds, gt_bboxes):
  125. pos_inds = np.unique(np.nonzero(assign_gt_inds > 0)[0])
  126. neg_inds = np.unique(np.nonzero(assign_gt_inds == 0)[0])
  127. pos_assigned_gt_inds = assign_gt_inds[pos_inds] - 1
  128. if gt_bboxes.size == 0:
  129. # hack for index error case
  130. assert pos_assigned_gt_inds.size == 0
  131. pos_gt_bboxes = np.empty_like(gt_bboxes).reshape(-1, 4)
  132. else:
  133. if len(gt_bboxes.shape) < 2:
  134. gt_bboxes = gt_bboxes.resize(-1, 4)
  135. pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
  136. return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds
  137. def __call__(self,
  138. flatten_cls_pred_scores,
  139. flatten_center_and_stride,
  140. flatten_bboxes,
  141. gt_bboxes,
  142. gt_labels,
  143. eps=1e-7):
  144. """Assign gt to priors using SimOTA.
  145. TODO: add comment.
  146. Returns:
  147. assign_result: The assigned result.
  148. """
  149. num_gt = gt_bboxes.shape[0]
  150. num_bboxes = flatten_bboxes.shape[0]
  151. if num_gt == 0 or num_bboxes == 0:
  152. # No ground truth or boxes
  153. label = np.ones([num_bboxes], dtype=np.int64) * self.num_classes
  154. label_weight = np.ones([num_bboxes], dtype=np.float32)
  155. bbox_target = np.zeros_like(flatten_center_and_stride)
  156. return 0, label, label_weight, bbox_target
  157. is_in_gts_or_centers_all, is_in_gts_or_centers_all_inds, is_in_boxes_and_center = self.get_in_gt_and_in_center_info(
  158. flatten_center_and_stride, gt_bboxes)
  159. # bboxes and scores to calculate matrix
  160. valid_flatten_bboxes = flatten_bboxes[is_in_gts_or_centers_all_inds]
  161. valid_cls_pred_scores = flatten_cls_pred_scores[
  162. is_in_gts_or_centers_all_inds]
  163. num_valid_bboxes = valid_flatten_bboxes.shape[0]
  164. pairwise_ious = batch_bbox_overlaps(valid_flatten_bboxes,
  165. gt_bboxes) # [num_points,num_gts]
  166. if self.use_vfl:
  167. gt_vfl_labels = gt_labels.squeeze(-1).unsqueeze(0).tile(
  168. [num_valid_bboxes, 1]).reshape([-1])
  169. valid_pred_scores = valid_cls_pred_scores.unsqueeze(1).tile(
  170. [1, num_gt, 1]).reshape([-1, self.num_classes])
  171. vfl_score = np.zeros(valid_pred_scores.shape)
  172. vfl_score[np.arange(0, vfl_score.shape[0]), gt_vfl_labels.numpy(
  173. )] = pairwise_ious.reshape([-1])
  174. vfl_score = paddle.to_tensor(vfl_score)
  175. losses_vfl = varifocal_loss(
  176. valid_pred_scores, vfl_score,
  177. use_sigmoid=False).reshape([num_valid_bboxes, num_gt])
  178. losses_giou = batch_bbox_overlaps(
  179. valid_flatten_bboxes, gt_bboxes, mode='giou')
  180. cost_matrix = (
  181. losses_vfl * self.cls_weight + losses_giou * self.iou_weight +
  182. paddle.logical_not(is_in_boxes_and_center).cast('float32') *
  183. 100000000)
  184. else:
  185. iou_cost = -paddle.log(pairwise_ious + eps)
  186. gt_onehot_label = (F.one_hot(
  187. gt_labels.squeeze(-1).cast(paddle.int64),
  188. flatten_cls_pred_scores.shape[-1]).cast('float32').unsqueeze(0)
  189. .tile([num_valid_bboxes, 1, 1]))
  190. valid_pred_scores = valid_cls_pred_scores.unsqueeze(1).tile(
  191. [1, num_gt, 1])
  192. cls_cost = F.binary_cross_entropy(
  193. valid_pred_scores, gt_onehot_label, reduction='none').sum(-1)
  194. cost_matrix = (
  195. cls_cost * self.cls_weight + iou_cost * self.iou_weight +
  196. paddle.logical_not(is_in_boxes_and_center).cast('float32') *
  197. 100000000)
  198. match_gt_inds_to_fg, match_fg_mask_inmatrix = \
  199. self.dynamic_k_matching(
  200. cost_matrix, pairwise_ious, num_gt)
  201. # sample and assign results
  202. assigned_gt_inds = np.zeros([num_bboxes], dtype=np.int64)
  203. match_fg_mask_inall = np.zeros_like(assigned_gt_inds)
  204. match_fg_mask_inall[is_in_gts_or_centers_all.numpy(
  205. )] = match_fg_mask_inmatrix
  206. assigned_gt_inds[match_fg_mask_inall.astype(
  207. np.bool_)] = match_gt_inds_to_fg + 1
  208. pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds \
  209. = self.get_sample(assigned_gt_inds, gt_bboxes.numpy())
  210. bbox_target = np.zeros_like(flatten_bboxes)
  211. bbox_weight = np.zeros_like(flatten_bboxes)
  212. label = np.ones([num_bboxes], dtype=np.int64) * self.num_classes
  213. label_weight = np.zeros([num_bboxes], dtype=np.float32)
  214. if len(pos_inds) > 0:
  215. gt_labels = gt_labels.numpy()
  216. pos_bbox_targets = pos_gt_bboxes
  217. bbox_target[pos_inds, :] = pos_bbox_targets
  218. bbox_weight[pos_inds, :] = 1.0
  219. if not np.any(gt_labels):
  220. label[pos_inds] = 0
  221. else:
  222. label[pos_inds] = gt_labels.squeeze(-1)[pos_assigned_gt_inds]
  223. label_weight[pos_inds] = 1.0
  224. if len(neg_inds) > 0:
  225. label_weight[neg_inds] = 1.0
  226. pos_num = max(pos_inds.size, 1)
  227. return pos_num, label, label_weight, bbox_target