atss_assigner.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from ppdet.core.workspace import register
  22. from ..bbox_utils import iou_similarity, batch_iou_similarity
  23. from ..bbox_utils import bbox_center
  24. from .utils import (check_points_inside_bboxes, compute_max_iou_anchor,
  25. compute_max_iou_gt)
  26. __all__ = ['ATSSAssigner']
  27. @register
  28. class ATSSAssigner(nn.Layer):
  29. """Bridging the Gap Between Anchor-based and Anchor-free Detection
  30. via Adaptive Training Sample Selection
  31. """
  32. __shared__ = ['num_classes']
  33. def __init__(self,
  34. topk=9,
  35. num_classes=80,
  36. force_gt_matching=False,
  37. eps=1e-9,
  38. sm_use=False):
  39. super(ATSSAssigner, self).__init__()
  40. self.topk = topk
  41. self.num_classes = num_classes
  42. self.force_gt_matching = force_gt_matching
  43. self.eps = eps
  44. self.sm_use = sm_use
  45. def _gather_topk_pyramid(self, gt2anchor_distances, num_anchors_list,
  46. pad_gt_mask):
  47. gt2anchor_distances_list = paddle.split(
  48. gt2anchor_distances, num_anchors_list, axis=-1)
  49. num_anchors_index = np.cumsum(num_anchors_list).tolist()
  50. num_anchors_index = [0, ] + num_anchors_index[:-1]
  51. is_in_topk_list = []
  52. topk_idxs_list = []
  53. for distances, anchors_index in zip(gt2anchor_distances_list,
  54. num_anchors_index):
  55. num_anchors = distances.shape[-1]
  56. _, topk_idxs = paddle.topk(
  57. distances, self.topk, axis=-1, largest=False)
  58. topk_idxs_list.append(topk_idxs + anchors_index)
  59. is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(
  60. axis=-2).astype(gt2anchor_distances.dtype)
  61. is_in_topk_list.append(is_in_topk * pad_gt_mask)
  62. is_in_topk_list = paddle.concat(is_in_topk_list, axis=-1)
  63. topk_idxs_list = paddle.concat(topk_idxs_list, axis=-1)
  64. return is_in_topk_list, topk_idxs_list
  65. @paddle.no_grad()
  66. def forward(self,
  67. anchor_bboxes,
  68. num_anchors_list,
  69. gt_labels,
  70. gt_bboxes,
  71. pad_gt_mask,
  72. bg_index,
  73. gt_scores=None,
  74. pred_bboxes=None):
  75. r"""This code is based on
  76. https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/atss_assigner.py
  77. The assignment is done in following steps
  78. 1. compute iou between all bbox (bbox of all pyramid levels) and gt
  79. 2. compute center distance between all bbox and gt
  80. 3. on each pyramid level, for each gt, select k bbox whose center
  81. are closest to the gt center, so we total select k*l bbox as
  82. candidates for each gt
  83. 4. get corresponding iou for the these candidates, and compute the
  84. mean and std, set mean + std as the iou threshold
  85. 5. select these candidates whose iou are greater than or equal to
  86. the threshold as positive
  87. 6. limit the positive sample's center in gt
  88. 7. if an anchor box is assigned to multiple gts, the one with the
  89. highest iou will be selected.
  90. Args:
  91. anchor_bboxes (Tensor, float32): pre-defined anchors, shape(L, 4),
  92. "xmin, xmax, ymin, ymax" format
  93. num_anchors_list (List): num of anchors in each level
  94. gt_labels (Tensor, int64|int32): Label of gt_bboxes, shape(B, n, 1)
  95. gt_bboxes (Tensor, float32): Ground truth bboxes, shape(B, n, 4)
  96. pad_gt_mask (Tensor, float32): 1 means bbox, 0 means no bbox, shape(B, n, 1)
  97. bg_index (int): background index
  98. gt_scores (Tensor|None, float32) Score of gt_bboxes,
  99. shape(B, n, 1), if None, then it will initialize with one_hot label
  100. pred_bboxes (Tensor, float32, optional): predicted bounding boxes, shape(B, L, 4)
  101. Returns:
  102. assigned_labels (Tensor): (B, L)
  103. assigned_bboxes (Tensor): (B, L, 4)
  104. assigned_scores (Tensor): (B, L, C), if pred_bboxes is not None, then output ious
  105. """
  106. assert gt_labels.ndim == gt_bboxes.ndim and \
  107. gt_bboxes.ndim == 3
  108. num_anchors, _ = anchor_bboxes.shape
  109. batch_size, num_max_boxes, _ = gt_bboxes.shape
  110. # negative batch
  111. if num_max_boxes == 0:
  112. assigned_labels = paddle.full(
  113. [batch_size, num_anchors], bg_index, dtype='int32')
  114. assigned_bboxes = paddle.zeros([batch_size, num_anchors, 4])
  115. assigned_scores = paddle.zeros(
  116. [batch_size, num_anchors, self.num_classes])
  117. return assigned_labels, assigned_bboxes, assigned_scores
  118. # 1. compute iou between gt and anchor bbox, [B, n, L]
  119. ious = iou_similarity(gt_bboxes.reshape([-1, 4]), anchor_bboxes)
  120. ious = ious.reshape([batch_size, -1, num_anchors])
  121. # 2. compute center distance between all anchors and gt, [B, n, L]
  122. gt_centers = bbox_center(gt_bboxes.reshape([-1, 4])).unsqueeze(1)
  123. anchor_centers = bbox_center(anchor_bboxes)
  124. gt2anchor_distances = (gt_centers - anchor_centers.unsqueeze(0)) \
  125. .norm(2, axis=-1).reshape([batch_size, -1, num_anchors])
  126. # 3. on each pyramid level, selecting topk closest candidates
  127. # based on the center distance, [B, n, L]
  128. is_in_topk, topk_idxs = self._gather_topk_pyramid(
  129. gt2anchor_distances, num_anchors_list, pad_gt_mask)
  130. # 4. get corresponding iou for the these candidates, and compute the
  131. # mean and std, 5. set mean + std as the iou threshold
  132. iou_candidates = ious * is_in_topk
  133. iou_threshold = paddle.index_sample(
  134. iou_candidates.flatten(stop_axis=-2),
  135. topk_idxs.flatten(stop_axis=-2))
  136. iou_threshold = iou_threshold.reshape([batch_size, num_max_boxes, -1])
  137. iou_threshold = iou_threshold.mean(axis=-1, keepdim=True) + \
  138. iou_threshold.std(axis=-1, keepdim=True)
  139. is_in_topk = paddle.where(iou_candidates > iou_threshold, is_in_topk,
  140. paddle.zeros_like(is_in_topk))
  141. # 6. check the positive sample's center in gt, [B, n, L]
  142. if self.sm_use:
  143. is_in_gts = check_points_inside_bboxes(
  144. anchor_centers, gt_bboxes, sm_use=True)
  145. else:
  146. is_in_gts = check_points_inside_bboxes(anchor_centers, gt_bboxes)
  147. # select positive sample, [B, n, L]
  148. mask_positive = is_in_topk * is_in_gts * pad_gt_mask
  149. # 7. if an anchor box is assigned to multiple gts,
  150. # the one with the highest iou will be selected.
  151. mask_positive_sum = mask_positive.sum(axis=-2)
  152. if mask_positive_sum.max() > 1:
  153. mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).tile(
  154. [1, num_max_boxes, 1])
  155. if self.sm_use:
  156. is_max_iou = compute_max_iou_anchor(ious * mask_positive)
  157. else:
  158. is_max_iou = compute_max_iou_anchor(ious)
  159. mask_positive = paddle.where(mask_multiple_gts, is_max_iou,
  160. mask_positive)
  161. mask_positive_sum = mask_positive.sum(axis=-2)
  162. # 8. make sure every gt_bbox matches the anchor
  163. if self.force_gt_matching:
  164. is_max_iou = compute_max_iou_gt(ious) * pad_gt_mask
  165. mask_max_iou = (is_max_iou.sum(-2, keepdim=True) == 1).tile(
  166. [1, num_max_boxes, 1])
  167. mask_positive = paddle.where(mask_max_iou, is_max_iou,
  168. mask_positive)
  169. mask_positive_sum = mask_positive.sum(axis=-2)
  170. assigned_gt_index = mask_positive.argmax(axis=-2)
  171. # assigned target
  172. batch_ind = paddle.arange(
  173. end=batch_size, dtype=gt_labels.dtype).unsqueeze(-1)
  174. assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes
  175. assigned_labels = paddle.gather(
  176. gt_labels.flatten(), assigned_gt_index.flatten(), axis=0)
  177. assigned_labels = assigned_labels.reshape([batch_size, num_anchors])
  178. assigned_labels = paddle.where(
  179. mask_positive_sum > 0, assigned_labels,
  180. paddle.full_like(assigned_labels, bg_index))
  181. assigned_bboxes = paddle.gather(
  182. gt_bboxes.reshape([-1, 4]), assigned_gt_index.flatten(), axis=0)
  183. assigned_bboxes = assigned_bboxes.reshape([batch_size, num_anchors, 4])
  184. assigned_scores = F.one_hot(assigned_labels, self.num_classes + 1)
  185. ind = list(range(self.num_classes + 1))
  186. ind.remove(bg_index)
  187. assigned_scores = paddle.index_select(
  188. assigned_scores, paddle.to_tensor(ind), axis=-1)
  189. if pred_bboxes is not None:
  190. # assigned iou
  191. ious = batch_iou_similarity(gt_bboxes, pred_bboxes) * mask_positive
  192. ious = ious.max(axis=-2).unsqueeze(-1)
  193. assigned_scores *= ious
  194. elif gt_scores is not None:
  195. gather_scores = paddle.gather(
  196. gt_scores.flatten(), assigned_gt_index.flatten(), axis=0)
  197. gather_scores = gather_scores.reshape([batch_size, num_anchors])
  198. gather_scores = paddle.where(mask_positive_sum > 0, gather_scores,
  199. paddle.zeros_like(gather_scores))
  200. assigned_scores *= gather_scores.unsqueeze(-1)
  201. return assigned_labels, assigned_bboxes, assigned_scores