proposal_generator.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. from ppdet.core.workspace import register, serializable
  16. from .. import ops
  17. @register
  18. @serializable
  19. class ProposalGenerator(object):
  20. """
  21. Proposal generation module
  22. For more details, please refer to the document of generate_proposals
  23. in ppdet/modeing/ops.py
  24. Args:
  25. pre_nms_top_n (int): Number of total bboxes to be kept per
  26. image before NMS. default 6000
  27. post_nms_top_n (int): Number of total bboxes to be kept per
  28. image after NMS. default 1000
  29. nms_thresh (float): Threshold in NMS. default 0.5
  30. min_size (flaot): Remove predicted boxes with either height or
  31. width < min_size. default 0.1
  32. eta (float): Apply in adaptive NMS, if adaptive `threshold > 0.5`,
  33. `adaptive_threshold = adaptive_threshold * eta` in each iteration.
  34. default 1.
  35. topk_after_collect (bool): whether to adopt topk after batch
  36. collection. If topk_after_collect is true, box filter will not be
  37. used after NMS at each image in proposal generation. default false
  38. """
  39. def __init__(self,
  40. pre_nms_top_n=12000,
  41. post_nms_top_n=2000,
  42. nms_thresh=.5,
  43. min_size=.1,
  44. eta=1.,
  45. topk_after_collect=False):
  46. super(ProposalGenerator, self).__init__()
  47. self.pre_nms_top_n = pre_nms_top_n
  48. self.post_nms_top_n = post_nms_top_n
  49. self.nms_thresh = nms_thresh
  50. self.min_size = min_size
  51. self.eta = eta
  52. self.topk_after_collect = topk_after_collect
  53. def __call__(self, scores, bbox_deltas, anchors, im_shape):
  54. top_n = self.pre_nms_top_n if self.topk_after_collect else self.post_nms_top_n
  55. variances = paddle.ones_like(anchors)
  56. if hasattr(paddle.vision.ops, "generate_proposals"):
  57. rpn_rois, rpn_rois_prob, rpn_rois_num = paddle.vision.ops.generate_proposals(
  58. scores,
  59. bbox_deltas,
  60. im_shape,
  61. anchors,
  62. variances,
  63. pre_nms_top_n=self.pre_nms_top_n,
  64. post_nms_top_n=top_n,
  65. nms_thresh=self.nms_thresh,
  66. min_size=self.min_size,
  67. eta=self.eta,
  68. return_rois_num=True)
  69. else:
  70. rpn_rois, rpn_rois_prob, rpn_rois_num = ops.generate_proposals(
  71. scores,
  72. bbox_deltas,
  73. im_shape,
  74. anchors,
  75. variances,
  76. pre_nms_top_n=self.pre_nms_top_n,
  77. post_nms_top_n=top_n,
  78. nms_thresh=self.nms_thresh,
  79. min_size=self.min_size,
  80. eta=self.eta,
  81. return_rois_num=True)
  82. return rpn_rois, rpn_rois_prob, rpn_rois_num, self.post_nms_top_n