uniform_assigner.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from ppdet.core.workspace import register
  18. from ppdet.modeling.bbox_utils import batch_bbox_overlaps
  19. from ppdet.modeling.transformers import bbox_xyxy_to_cxcywh
  20. __all__ = ['UniformAssigner']
  21. def batch_p_dist(x, y, p=2):
  22. """
  23. calculate pairwise p_dist, the first index of x and y are batch
  24. return [x.shape[0], y.shape[0]]
  25. """
  26. x = x.unsqueeze(1)
  27. diff = x - y
  28. return paddle.norm(diff, p=p, axis=list(range(2, diff.dim())))
  29. @register
  30. class UniformAssigner(nn.Layer):
  31. def __init__(self, pos_ignore_thr, neg_ignore_thr, match_times=4):
  32. super(UniformAssigner, self).__init__()
  33. self.pos_ignore_thr = pos_ignore_thr
  34. self.neg_ignore_thr = neg_ignore_thr
  35. self.match_times = match_times
  36. def forward(self, bbox_pred, anchor, gt_bboxes, gt_labels=None):
  37. num_bboxes = bbox_pred.shape[0]
  38. num_gts = gt_bboxes.shape[0]
  39. match_labels = paddle.full([num_bboxes], -1, dtype=paddle.int32)
  40. pred_ious = batch_bbox_overlaps(bbox_pred, gt_bboxes)
  41. pred_max_iou = pred_ious.max(axis=1)
  42. neg_ignore = pred_max_iou > self.neg_ignore_thr
  43. # exclude potential ignored neg samples first, deal with pos samples later
  44. #match_labels: -2(ignore), -1(neg) or >=0(pos_inds)
  45. match_labels = paddle.where(neg_ignore,
  46. paddle.full_like(match_labels, -2),
  47. match_labels)
  48. bbox_pred_c = bbox_xyxy_to_cxcywh(bbox_pred)
  49. anchor_c = bbox_xyxy_to_cxcywh(anchor)
  50. gt_bboxes_c = bbox_xyxy_to_cxcywh(gt_bboxes)
  51. bbox_pred_dist = batch_p_dist(bbox_pred_c, gt_bboxes_c, p=1)
  52. anchor_dist = batch_p_dist(anchor_c, gt_bboxes_c, p=1)
  53. top_pred = bbox_pred_dist.topk(
  54. k=self.match_times, axis=0, largest=False)[1]
  55. top_anchor = anchor_dist.topk(
  56. k=self.match_times, axis=0, largest=False)[1]
  57. tar_pred = paddle.arange(num_gts).expand([self.match_times, num_gts])
  58. tar_anchor = paddle.arange(num_gts).expand([self.match_times, num_gts])
  59. pos_places = paddle.concat([top_pred, top_anchor]).reshape([-1])
  60. pos_inds = paddle.concat([tar_pred, tar_anchor]).reshape([-1])
  61. pos_anchor = anchor[pos_places]
  62. pos_tar_bbox = gt_bboxes[pos_inds]
  63. pos_ious = batch_bbox_overlaps(
  64. pos_anchor, pos_tar_bbox, is_aligned=True)
  65. pos_ignore = pos_ious < self.pos_ignore_thr
  66. pos_inds = paddle.where(pos_ignore,
  67. paddle.full_like(pos_inds, -2), pos_inds)
  68. match_labels[pos_places] = pos_inds
  69. match_labels.stop_gradient = True
  70. pos_keep = ~pos_ignore
  71. if pos_keep.sum() > 0:
  72. pos_places_keep = pos_places[pos_keep]
  73. pos_bbox_pred = bbox_pred[pos_places_keep].reshape([-1, 4])
  74. pos_bbox_tar = pos_tar_bbox[pos_keep].reshape([-1, 4]).detach()
  75. else:
  76. pos_bbox_pred = None
  77. pos_bbox_tar = None
  78. return match_labels, pos_bbox_pred, pos_bbox_tar