task_aligned_assigner.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from ppdet.core.workspace import register
  21. from ..bbox_utils import batch_iou_similarity
  22. from .utils import (gather_topk_anchors, check_points_inside_bboxes,
  23. compute_max_iou_anchor)
  24. __all__ = ['TaskAlignedAssigner']
  25. def is_close_gt(anchor, gt, stride_lst, max_dist=2.0, alpha=2.):
  26. """Calculate distance ratio of box1 and box2 in batch for larger stride
  27. anchors dist/stride to promote the survive of large distance match
  28. Args:
  29. anchor (Tensor): box with the shape [L, 2]
  30. gt (Tensor): box with the shape [N, M2, 4]
  31. Return:
  32. dist (Tensor): dist ratio between box1 and box2 with the shape [N, M1, M2]
  33. """
  34. center1 = anchor.unsqueeze(0)
  35. center2 = (gt[..., :2] + gt[..., -2:]) / 2.
  36. center1 = center1.unsqueeze(1) # [N, M1, 2] -> [N, 1, M1, 2]
  37. center2 = center2.unsqueeze(2) # [N, M2, 2] -> [N, M2, 1, 2]
  38. stride = paddle.concat([
  39. paddle.full([x], 32 / pow(2, idx)) for idx, x in enumerate(stride_lst)
  40. ]).unsqueeze(0).unsqueeze(0)
  41. dist = paddle.linalg.norm(center1 - center2, p=2, axis=-1) / stride
  42. dist_ratio = dist
  43. dist_ratio[dist < max_dist] = 1.
  44. dist_ratio[dist >= max_dist] = 0.
  45. return dist_ratio
  46. @register
  47. class TaskAlignedAssigner(nn.Layer):
  48. """TOOD: Task-aligned One-stage Object Detection
  49. """
  50. def __init__(self,
  51. topk=13,
  52. alpha=1.0,
  53. beta=6.0,
  54. eps=1e-9,
  55. is_close_gt=False):
  56. super(TaskAlignedAssigner, self).__init__()
  57. self.topk = topk
  58. self.alpha = alpha
  59. self.beta = beta
  60. self.eps = eps
  61. self.is_close_gt = is_close_gt
  62. @paddle.no_grad()
  63. def forward(self,
  64. pred_scores,
  65. pred_bboxes,
  66. anchor_points,
  67. num_anchors_list,
  68. gt_labels,
  69. gt_bboxes,
  70. pad_gt_mask,
  71. bg_index,
  72. gt_scores=None):
  73. r"""This code is based on
  74. https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/task_aligned_assigner.py
  75. The assignment is done in following steps
  76. 1. compute alignment metric between all bbox (bbox of all pyramid levels) and gt
  77. 2. select top-k bbox as candidates for each gt
  78. 3. limit the positive sample's center in gt (because the anchor-free detector
  79. only can predict positive distance)
  80. 4. if an anchor box is assigned to multiple gts, the one with the
  81. highest iou will be selected.
  82. Args:
  83. pred_scores (Tensor, float32): predicted class probability, shape(B, L, C)
  84. pred_bboxes (Tensor, float32): predicted bounding boxes, shape(B, L, 4)
  85. anchor_points (Tensor, float32): pre-defined anchors, shape(L, 2), "cxcy" format
  86. num_anchors_list (List): num of anchors in each level, shape(L)
  87. gt_labels (Tensor, int64|int32): Label of gt_bboxes, shape(B, n, 1)
  88. gt_bboxes (Tensor, float32): Ground truth bboxes, shape(B, n, 4)
  89. pad_gt_mask (Tensor, float32): 1 means bbox, 0 means no bbox, shape(B, n, 1)
  90. bg_index (int): background index
  91. gt_scores (Tensor|None, float32) Score of gt_bboxes, shape(B, n, 1)
  92. Returns:
  93. assigned_labels (Tensor): (B, L)
  94. assigned_bboxes (Tensor): (B, L, 4)
  95. assigned_scores (Tensor): (B, L, C)
  96. """
  97. assert pred_scores.ndim == pred_bboxes.ndim
  98. assert gt_labels.ndim == gt_bboxes.ndim and \
  99. gt_bboxes.ndim == 3
  100. batch_size, num_anchors, num_classes = pred_scores.shape
  101. _, num_max_boxes, _ = gt_bboxes.shape
  102. # negative batch
  103. if num_max_boxes == 0:
  104. assigned_labels = paddle.full(
  105. [batch_size, num_anchors], bg_index, dtype='int32')
  106. assigned_bboxes = paddle.zeros([batch_size, num_anchors, 4])
  107. assigned_scores = paddle.zeros(
  108. [batch_size, num_anchors, num_classes])
  109. return assigned_labels, assigned_bboxes, assigned_scores
  110. # compute iou between gt and pred bbox, [B, n, L]
  111. ious = batch_iou_similarity(gt_bboxes, pred_bboxes)
  112. # gather pred bboxes class score
  113. pred_scores = pred_scores.transpose([0, 2, 1])
  114. batch_ind = paddle.arange(
  115. end=batch_size, dtype=gt_labels.dtype).unsqueeze(-1)
  116. gt_labels_ind = paddle.stack(
  117. [batch_ind.tile([1, num_max_boxes]), gt_labels.squeeze(-1)],
  118. axis=-1)
  119. bbox_cls_scores = paddle.gather_nd(pred_scores, gt_labels_ind)
  120. # compute alignment metrics, [B, n, L]
  121. alignment_metrics = bbox_cls_scores.pow(self.alpha) * ious.pow(
  122. self.beta)
  123. # check the positive sample's center in gt, [B, n, L]
  124. if self.is_close_gt:
  125. is_in_gts = is_close_gt(anchor_points, gt_bboxes, num_anchors_list)
  126. else:
  127. is_in_gts = check_points_inside_bboxes(anchor_points, gt_bboxes)
  128. # select topk largest alignment metrics pred bbox as candidates
  129. # for each gt, [B, n, L]
  130. is_in_topk = gather_topk_anchors(
  131. alignment_metrics * is_in_gts, self.topk, topk_mask=pad_gt_mask)
  132. # select positive sample, [B, n, L]
  133. mask_positive = is_in_topk * is_in_gts * pad_gt_mask
  134. # if an anchor box is assigned to multiple gts,
  135. # the one with the highest iou will be selected, [B, n, L]
  136. mask_positive_sum = mask_positive.sum(axis=-2)
  137. if mask_positive_sum.max() > 1:
  138. mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).tile(
  139. [1, num_max_boxes, 1])
  140. is_max_iou = compute_max_iou_anchor(ious)
  141. mask_positive = paddle.where(mask_multiple_gts, is_max_iou,
  142. mask_positive)
  143. mask_positive_sum = mask_positive.sum(axis=-2)
  144. assigned_gt_index = mask_positive.argmax(axis=-2)
  145. # assigned target
  146. assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes
  147. assigned_labels = paddle.gather(
  148. gt_labels.flatten(), assigned_gt_index.flatten(), axis=0)
  149. assigned_labels = assigned_labels.reshape([batch_size, num_anchors])
  150. assigned_labels = paddle.where(
  151. mask_positive_sum > 0, assigned_labels,
  152. paddle.full_like(assigned_labels, bg_index))
  153. assigned_bboxes = paddle.gather(
  154. gt_bboxes.reshape([-1, 4]), assigned_gt_index.flatten(), axis=0)
  155. assigned_bboxes = assigned_bboxes.reshape([batch_size, num_anchors, 4])
  156. assigned_scores = F.one_hot(assigned_labels, num_classes + 1)
  157. ind = list(range(num_classes + 1))
  158. ind.remove(bg_index)
  159. assigned_scores = paddle.index_select(
  160. assigned_scores, paddle.to_tensor(ind), axis=-1)
  161. # rescale alignment metrics
  162. alignment_metrics *= mask_positive
  163. max_metrics_per_instance = alignment_metrics.max(axis=-1, keepdim=True)
  164. max_ious_per_instance = (ious * mask_positive).max(axis=-1,
  165. keepdim=True)
  166. alignment_metrics = alignment_metrics / (
  167. max_metrics_per_instance + self.eps) * max_ious_per_instance
  168. alignment_metrics = alignment_metrics.max(-2).unsqueeze(-1)
  169. assigned_scores = assigned_scores * alignment_metrics
  170. return assigned_labels, assigned_bboxes, assigned_scores