rotated_task_aligned_assigner.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from ppdet.core.workspace import register
  21. from ..rbox_utils import rotated_iou_similarity, check_points_in_rotated_boxes
  22. from .utils import gather_topk_anchors, compute_max_iou_anchor
  23. __all__ = ['RotatedTaskAlignedAssigner']
  24. @register
  25. class RotatedTaskAlignedAssigner(nn.Layer):
  26. """TOOD: Task-aligned One-stage Object Detection
  27. """
  28. def __init__(self, topk=13, alpha=1.0, beta=6.0, eps=1e-9):
  29. super(RotatedTaskAlignedAssigner, self).__init__()
  30. self.topk = topk
  31. self.alpha = alpha
  32. self.beta = beta
  33. self.eps = eps
  34. @paddle.no_grad()
  35. def forward(self,
  36. pred_scores,
  37. pred_bboxes,
  38. anchor_points,
  39. num_anchors_list,
  40. gt_labels,
  41. gt_bboxes,
  42. pad_gt_mask,
  43. bg_index,
  44. gt_scores=None):
  45. r"""This code is based on
  46. https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/task_aligned_assigner.py
  47. The assignment is done in following steps
  48. 1. compute alignment metric between all bbox (bbox of all pyramid levels) and gt
  49. 2. select top-k bbox as candidates for each gt
  50. 3. limit the positive sample's center in gt (because the anchor-free detector
  51. only can predict positive distance)
  52. 4. if an anchor box is assigned to multiple gts, the one with the
  53. highest iou will be selected.
  54. Args:
  55. pred_scores (Tensor, float32): predicted class probability, shape(B, L, C)
  56. pred_bboxes (Tensor, float32): predicted bounding boxes, shape(B, L, 5)
  57. anchor_points (Tensor, float32): pre-defined anchors, shape(1, L, 2), "cxcy" format
  58. num_anchors_list (List): num of anchors in each level, shape(L)
  59. gt_labels (Tensor, int64|int32): Label of gt_bboxes, shape(B, n, 1)
  60. gt_bboxes (Tensor, float32): Ground truth bboxes, shape(B, n, 5)
  61. pad_gt_mask (Tensor, float32): 1 means bbox, 0 means no bbox, shape(B, n, 1)
  62. bg_index (int): background index
  63. gt_scores (Tensor|None, float32) Score of gt_bboxes, shape(B, n, 1)
  64. Returns:
  65. assigned_labels (Tensor): (B, L)
  66. assigned_bboxes (Tensor): (B, L, 5)
  67. assigned_scores (Tensor): (B, L, C)
  68. """
  69. assert pred_scores.ndim == pred_bboxes.ndim
  70. assert gt_labels.ndim == gt_bboxes.ndim and \
  71. gt_bboxes.ndim == 3
  72. batch_size, num_anchors, num_classes = pred_scores.shape
  73. _, num_max_boxes, _ = gt_bboxes.shape
  74. # negative batch
  75. if num_max_boxes == 0:
  76. assigned_labels = paddle.full(
  77. [batch_size, num_anchors], bg_index, dtype=gt_labels.dtype)
  78. assigned_bboxes = paddle.zeros([batch_size, num_anchors, 5])
  79. assigned_scores = paddle.zeros(
  80. [batch_size, num_anchors, num_classes])
  81. return assigned_labels, assigned_bboxes, assigned_scores
  82. # compute iou between gt and pred bbox, [B, n, L]
  83. ious = rotated_iou_similarity(gt_bboxes, pred_bboxes)
  84. ious = paddle.where(ious > 1 + self.eps, paddle.zeros_like(ious), ious)
  85. ious.stop_gradient = True
  86. # gather pred bboxes class score
  87. pred_scores = pred_scores.transpose([0, 2, 1])
  88. batch_ind = paddle.arange(
  89. end=batch_size, dtype=gt_labels.dtype).unsqueeze(-1)
  90. gt_labels_ind = paddle.stack(
  91. [batch_ind.tile([1, num_max_boxes]), gt_labels.squeeze(-1)],
  92. axis=-1)
  93. bbox_cls_scores = paddle.gather_nd(pred_scores, gt_labels_ind)
  94. # compute alignment metrics, [B, n, L]
  95. alignment_metrics = bbox_cls_scores.pow(self.alpha) * ious.pow(
  96. self.beta)
  97. # check the positive sample's center in gt, [B, n, L]
  98. is_in_gts = check_points_in_rotated_boxes(anchor_points, gt_bboxes)
  99. # select topk largest alignment metrics pred bbox as candidates
  100. # for each gt, [B, n, L]
  101. is_in_topk = gather_topk_anchors(
  102. alignment_metrics * is_in_gts, self.topk, topk_mask=pad_gt_mask)
  103. # select positive sample, [B, n, L]
  104. mask_positive = is_in_topk * is_in_gts * pad_gt_mask
  105. # if an anchor box is assigned to multiple gts,
  106. # the one with the highest iou will be selected, [B, n, L]
  107. mask_positive_sum = mask_positive.sum(axis=-2)
  108. if mask_positive_sum.max() > 1:
  109. mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).tile(
  110. [1, num_max_boxes, 1])
  111. is_max_iou = compute_max_iou_anchor(ious)
  112. mask_positive = paddle.where(mask_multiple_gts, is_max_iou,
  113. mask_positive)
  114. mask_positive_sum = mask_positive.sum(axis=-2)
  115. assigned_gt_index = mask_positive.argmax(axis=-2)
  116. # assigned target
  117. assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes
  118. assigned_labels = paddle.gather(
  119. gt_labels.flatten(), assigned_gt_index.flatten(), axis=0)
  120. assigned_labels = assigned_labels.reshape([batch_size, num_anchors])
  121. assigned_labels = paddle.where(
  122. mask_positive_sum > 0, assigned_labels,
  123. paddle.full_like(assigned_labels, bg_index))
  124. assigned_bboxes = paddle.gather(
  125. gt_bboxes.reshape([-1, 5]), assigned_gt_index.flatten(), axis=0)
  126. assigned_bboxes = assigned_bboxes.reshape([batch_size, num_anchors, 5])
  127. assigned_scores = F.one_hot(assigned_labels, num_classes + 1)
  128. ind = list(range(num_classes + 1))
  129. ind.remove(bg_index)
  130. assigned_scores = paddle.index_select(
  131. assigned_scores, paddle.to_tensor(ind), axis=-1)
  132. # rescale alignment metrics
  133. alignment_metrics *= mask_positive
  134. max_metrics_per_instance = alignment_metrics.max(axis=-1, keepdim=True)
  135. max_ious_per_instance = (ious * mask_positive).max(axis=-1,
  136. keepdim=True)
  137. alignment_metrics = alignment_metrics / (
  138. max_metrics_per_instance + self.eps) * max_ious_per_instance
  139. alignment_metrics = alignment_metrics.max(-2).unsqueeze(-1)
  140. assigned_scores = assigned_scores * alignment_metrics
  141. assigned_bboxes.stop_gradient = True
  142. assigned_scores.stop_gradient = True
  143. assigned_labels.stop_gradient = True
  144. return assigned_labels, assigned_bboxes, assigned_scores