fcosr_assigner.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from ppdet.core.workspace import register
  22. from ppdet.modeling.rbox_utils import box2corners, check_points_in_polys, paddle_gather
  23. __all__ = ['FCOSRAssigner']
  24. EPS = 1e-9
  25. @register
  26. class FCOSRAssigner(nn.Layer):
  27. """ FCOSR Assigner, refer to https://arxiv.org/abs/2111.10780 for details
  28. 1. compute normalized gaussian distribution score and refined gaussian distribution score
  29. 2. refer to ellipse center sampling, sample points whose normalized gaussian distribution score is greater than threshold
  30. 3. refer to multi-level sampling, assign ground truth to feature map which follows two conditions.
  31. i). first, the ratio between the short edge of the target and the stride of the feature map is less than 2.
  32. ii). second, the long edge of minimum bounding rectangle of the target is larger than the acceptance range of feature map
  33. 4. refer to fuzzy sample label assignment, the points satisfying 2 and 3 will be assigned to the ground truth according to gaussian distribution score
  34. """
  35. __shared__ = ['num_classes']
  36. def __init__(self,
  37. num_classes=80,
  38. factor=12,
  39. threshold=0.23,
  40. boundary=[[-1, 128], [128, 320], [320, 10000]],
  41. score_type='iou'):
  42. super(FCOSRAssigner, self).__init__()
  43. self.num_classes = num_classes
  44. self.factor = factor
  45. self.threshold = threshold
  46. self.boundary = [
  47. paddle.to_tensor(
  48. l, dtype=paddle.float32).reshape([1, 1, 2]) for l in boundary
  49. ]
  50. self.score_type = score_type
  51. def get_gaussian_distribution_score(self, points, gt_rboxes, gt_polys):
  52. # projecting points to coordinate system defined by each rbox
  53. # [B, N, 4, 2] -> 4 * [B, N, 1, 2]
  54. a, b, c, d = gt_polys.split(4, axis=2)
  55. # [1, L, 2] -> [1, 1, L, 2]
  56. points = points.unsqueeze(0)
  57. ab = b - a
  58. ad = d - a
  59. # [B, N, 5] -> [B, N, 2], [B, N, 2], [B, N, 1]
  60. xy, wh, angle = gt_rboxes.split([2, 2, 1], axis=-1)
  61. # [B, N, 2] -> [B, N, 1, 2]
  62. xy = xy.unsqueeze(2)
  63. # vector of points to center [B, N, L, 2]
  64. vec = points - xy
  65. # <ab, vec> = |ab| * |vec| * cos(theta) [B, N, L]
  66. vec_dot_ab = paddle.sum(vec * ab, axis=-1)
  67. # <ad, vec> = |ad| * |vec| * cos(theta) [B, N, L]
  68. vec_dot_ad = paddle.sum(vec * ad, axis=-1)
  69. # norm_ab [B, N, L]
  70. norm_ab = paddle.sum(ab * ab, axis=-1).sqrt()
  71. # norm_ad [B, N, L]
  72. norm_ad = paddle.sum(ad * ad, axis=-1).sqrt()
  73. # min(h, w), [B, N, 1]
  74. min_edge = paddle.min(wh, axis=-1, keepdim=True)
  75. # delta_x, delta_y [B, N, L]
  76. delta_x = vec_dot_ab.pow(2) / (norm_ab.pow(3) * min_edge + EPS)
  77. delta_y = vec_dot_ad.pow(2) / (norm_ad.pow(3) * min_edge + EPS)
  78. # score [B, N, L]
  79. norm_score = paddle.exp(-0.5 * self.factor * (delta_x + delta_y))
  80. # simplified calculation
  81. sigma = min_edge / self.factor
  82. refined_score = norm_score / (2 * np.pi * sigma + EPS)
  83. return norm_score, refined_score
  84. def get_rotated_inside_mask(self, points, gt_polys, scores):
  85. inside_mask = check_points_in_polys(points, gt_polys)
  86. center_mask = scores >= self.threshold
  87. return (inside_mask & center_mask).cast(paddle.float32)
  88. def get_inside_range_mask(self, points, gt_bboxes, gt_rboxes, stride_tensor,
  89. regress_range):
  90. # [1, L, 2] -> [1, 1, L, 2]
  91. points = points.unsqueeze(0)
  92. # [B, n, 4] -> [B, n, 1, 4]
  93. x1y1, x2y2 = gt_bboxes.unsqueeze(2).split(2, axis=-1)
  94. # [B, n, L, 2]
  95. lt = points - x1y1
  96. rb = x2y2 - points
  97. # [B, n, L, 4]
  98. ltrb = paddle.concat([lt, rb], axis=-1)
  99. # [B, n, L, 4] -> [B, n, L]
  100. inside_mask = paddle.min(ltrb, axis=-1) > EPS
  101. # regress_range [1, L, 2] -> [1, 1, L, 2]
  102. regress_range = regress_range.unsqueeze(0)
  103. # stride_tensor [1, L, 1] -> [1, 1, L]
  104. stride_tensor = stride_tensor.transpose((0, 2, 1))
  105. # fcos range
  106. # [B, n, L, 4] -> [B, n, L]
  107. ltrb_max = paddle.max(ltrb, axis=-1)
  108. # [1, 1, L, 2] -> [1, 1, L]
  109. low, high = regress_range[..., 0], regress_range[..., 1]
  110. # [B, n, L]
  111. regress_mask = (ltrb_max >= low) & (ltrb_max <= high)
  112. # mask for rotated
  113. # [B, n, 1]
  114. min_edge = paddle.min(gt_rboxes[..., 2:4], axis=-1, keepdim=True)
  115. # [B, n , L]
  116. rotated_mask = ((min_edge / stride_tensor) < 2.0) & (ltrb_max > high)
  117. mask = inside_mask & (regress_mask | rotated_mask)
  118. return mask.cast(paddle.float32)
  119. @paddle.no_grad()
  120. def forward(self,
  121. anchor_points,
  122. stride_tensor,
  123. num_anchors_list,
  124. gt_labels,
  125. gt_bboxes,
  126. gt_rboxes,
  127. pad_gt_mask,
  128. bg_index,
  129. pred_rboxes=None):
  130. r"""
  131. Args:
  132. anchor_points (Tensor, float32): pre-defined anchor points, shape(1, L, 2),
  133. "x, y" format
  134. stride_tensor (Tensor, float32): stride tensor, shape (1, L, 1)
  135. num_anchors_list (List): num of anchors in each level
  136. gt_labels (Tensor, int64|int32): Label of gt_bboxes, shape(B, n, 1)
  137. gt_bboxes (Tensor, float32): Ground truth bboxes, shape(B, n, 4)
  138. gt_rboxes (Tensor, float32): Ground truth bboxes, shape(B, n, 5)
  139. pad_gt_mask (Tensor, float32): 1 means bbox, 0 means no bbox, shape(B, n, 1)
  140. bg_index (int): background index
  141. pred_rboxes (Tensor, float32, optional): predicted bounding boxes, shape(B, L, 5)
  142. Returns:
  143. assigned_labels (Tensor): (B, L)
  144. assigned_rboxes (Tensor): (B, L, 5)
  145. assigned_scores (Tensor): (B, L, C), if pred_rboxes is not None, then output ious
  146. """
  147. _, num_anchors, _ = anchor_points.shape
  148. batch_size, num_max_boxes, _ = gt_rboxes.shape
  149. if num_max_boxes == 0:
  150. assigned_labels = paddle.full(
  151. [batch_size, num_anchors], bg_index, dtype=gt_labels.dtype)
  152. assigned_rboxes = paddle.zeros([batch_size, num_anchors, 5])
  153. assigned_scores = paddle.zeros(
  154. [batch_size, num_anchors, self.num_classes])
  155. return assigned_labels, assigned_rboxes, assigned_scores
  156. # get normalized gaussian distribution score and refined distribution score
  157. gt_polys = box2corners(gt_rboxes)
  158. score, refined_score = self.get_gaussian_distribution_score(
  159. anchor_points, gt_rboxes, gt_polys)
  160. inside_mask = self.get_rotated_inside_mask(anchor_points, gt_polys,
  161. score)
  162. regress_ranges = []
  163. for num, bound in zip(num_anchors_list, self.boundary):
  164. regress_ranges.append(bound.tile((1, num, 1)))
  165. regress_ranges = paddle.concat(regress_ranges, axis=1)
  166. regress_mask = self.get_inside_range_mask(
  167. anchor_points, gt_bboxes, gt_rboxes, stride_tensor, regress_ranges)
  168. # [B, n, L]
  169. mask_positive = inside_mask * regress_mask * pad_gt_mask
  170. refined_score = refined_score * mask_positive - (1. - mask_positive)
  171. argmax_refined_score = refined_score.argmax(axis=-2)
  172. max_refined_score = refined_score.max(axis=-2)
  173. assigned_gt_index = argmax_refined_score
  174. # assigned target
  175. batch_ind = paddle.arange(
  176. end=batch_size, dtype=gt_labels.dtype).unsqueeze(-1)
  177. assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes
  178. assigned_labels = paddle.gather(
  179. gt_labels.flatten(), assigned_gt_index.flatten(), axis=0)
  180. assigned_labels = assigned_labels.reshape([batch_size, num_anchors])
  181. assigned_labels = paddle.where(
  182. max_refined_score > 0, assigned_labels,
  183. paddle.full_like(assigned_labels, bg_index))
  184. assigned_rboxes = paddle.gather(
  185. gt_rboxes.reshape([-1, 5]), assigned_gt_index.flatten(), axis=0)
  186. assigned_rboxes = assigned_rboxes.reshape([batch_size, num_anchors, 5])
  187. assigned_scores = F.one_hot(assigned_labels, self.num_classes + 1)
  188. ind = list(range(self.num_classes + 1))
  189. ind.remove(bg_index)
  190. assigned_scores = paddle.index_select(
  191. assigned_scores, paddle.to_tensor(ind), axis=-1)
  192. if self.score_type == 'gaussian':
  193. selected_scores = paddle_gather(
  194. score, 1, argmax_refined_score.unsqueeze(-2)).squeeze(-2)
  195. assigned_scores = assigned_scores * selected_scores.unsqueeze(-1)
  196. elif self.score_type == 'iou':
  197. assert pred_rboxes is not None, 'If score type is iou, pred_rboxes should not be None'
  198. from ext_op import matched_rbox_iou
  199. b, l = pred_rboxes.shape[:2]
  200. iou_score = matched_rbox_iou(
  201. pred_rboxes.reshape((-1, 5)), assigned_rboxes.reshape(
  202. (-1, 5))).reshape((b, l, 1))
  203. assigned_scores = assigned_scores * iou_score
  204. return assigned_labels, assigned_rboxes, assigned_scores