roi_align_rotated.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/roi_align_rotated.py
  17. """
  18. import paddle
  19. import paddle.nn as nn
  20. from paddle.utils.cpp_extension import load
  21. custom_ops = load(
  22. name="custom_jit_ops",
  23. sources=[
  24. "ppocr/ext_op/roi_align_rotated/roi_align_rotated.cc",
  25. "ppocr/ext_op/roi_align_rotated/roi_align_rotated.cu"
  26. ])
  27. roi_align_rotated = custom_ops.roi_align_rotated
  28. class RoIAlignRotated(nn.Layer):
  29. """RoI align pooling layer for rotated proposals.
  30. """
  31. def __init__(self,
  32. out_size,
  33. spatial_scale,
  34. sample_num=0,
  35. aligned=True,
  36. clockwise=False):
  37. super(RoIAlignRotated, self).__init__()
  38. if isinstance(out_size, int):
  39. self.out_h = out_size
  40. self.out_w = out_size
  41. elif isinstance(out_size, tuple):
  42. assert len(out_size) == 2
  43. assert isinstance(out_size[0], int)
  44. assert isinstance(out_size[1], int)
  45. self.out_h, self.out_w = out_size
  46. else:
  47. raise TypeError(
  48. '"out_size" must be an integer or tuple of integers')
  49. self.spatial_scale = float(spatial_scale)
  50. self.sample_num = int(sample_num)
  51. self.aligned = aligned
  52. self.clockwise = clockwise
  53. def forward(self, feats, rois):
  54. output = roi_align_rotated(feats, rois, self.out_h, self.out_w,
  55. self.spatial_scale, self.sample_num,
  56. self.aligned, self.clockwise)
  57. return output