test_matched_rbox_iou.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import numpy as np
  2. import sys
  3. import time
  4. from shapely.geometry import Polygon
  5. import paddle
  6. import unittest
  7. from ext_op import matched_rbox_iou
  8. def rbox2poly_single(rrect, get_best_begin_point=False):
  9. """
  10. rrect:[x_ctr,y_ctr,w,h,angle]
  11. to
  12. poly:[x0,y0,x1,y1,x2,y2,x3,y3]
  13. """
  14. x_ctr, y_ctr, width, height, angle = rrect[:5]
  15. tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2
  16. # rect 2x4
  17. rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]])
  18. R = np.array([[np.cos(angle), -np.sin(angle)],
  19. [np.sin(angle), np.cos(angle)]])
  20. # poly
  21. poly = R.dot(rect)
  22. x0, x1, x2, x3 = poly[0, :4] + x_ctr
  23. y0, y1, y2, y3 = poly[1, :4] + y_ctr
  24. poly = np.array([x0, y0, x1, y1, x2, y2, x3, y3], dtype=np.float64)
  25. return poly
  26. def intersection(g, p):
  27. """
  28. Intersection.
  29. """
  30. g = g[:8].reshape((4, 2))
  31. p = p[:8].reshape((4, 2))
  32. a = g
  33. b = p
  34. use_filter = True
  35. if use_filter:
  36. # step1:
  37. inter_x1 = np.maximum(np.min(a[:, 0]), np.min(b[:, 0]))
  38. inter_x2 = np.minimum(np.max(a[:, 0]), np.max(b[:, 0]))
  39. inter_y1 = np.maximum(np.min(a[:, 1]), np.min(b[:, 1]))
  40. inter_y2 = np.minimum(np.max(a[:, 1]), np.max(b[:, 1]))
  41. if inter_x1 >= inter_x2 or inter_y1 >= inter_y2:
  42. return 0.
  43. x1 = np.minimum(np.min(a[:, 0]), np.min(b[:, 0]))
  44. x2 = np.maximum(np.max(a[:, 0]), np.max(b[:, 0]))
  45. y1 = np.minimum(np.min(a[:, 1]), np.min(b[:, 1]))
  46. y2 = np.maximum(np.max(a[:, 1]), np.max(b[:, 1]))
  47. if x1 >= x2 or y1 >= y2 or (x2 - x1) < 2 or (y2 - y1) < 2:
  48. return 0.
  49. g = Polygon(g)
  50. p = Polygon(p)
  51. if not g.is_valid or not p.is_valid:
  52. return 0
  53. inter = Polygon(g).intersection(Polygon(p)).area
  54. union = g.area + p.area - inter
  55. if union == 0:
  56. return 0
  57. else:
  58. return inter / union
  59. def matched_rbox_overlaps(anchors, gt_bboxes, use_cv2=False):
  60. """
  61. Args:
  62. anchors: [M, 5] x1,y1,x2,y2,angle
  63. gt_bboxes: [M, 5] x1,y1,x2,y2,angle
  64. Returns:
  65. macthed_iou: [M]
  66. """
  67. assert anchors.shape[1] == 5
  68. assert gt_bboxes.shape[1] == 5
  69. gt_bboxes_ploy = [rbox2poly_single(e) for e in gt_bboxes]
  70. anchors_ploy = [rbox2poly_single(e) for e in anchors]
  71. num = len(anchors_ploy)
  72. iou = np.zeros((num, ), dtype=np.float64)
  73. start_time = time.time()
  74. for i in range(num):
  75. try:
  76. iou[i] = intersection(gt_bboxes_ploy[i], anchors_ploy[i])
  77. except Exception as e:
  78. print('cur gt_bboxes_ploy[i]', gt_bboxes_ploy[i],
  79. 'anchors_ploy[j]', anchors_ploy[i], e)
  80. return iou
  81. def gen_sample(n):
  82. rbox = np.random.rand(n, 5)
  83. rbox[:, 0:4] = rbox[:, 0:4] * 0.45 + 0.001
  84. rbox[:, 4] = rbox[:, 4] - 0.5
  85. return rbox
  86. class MatchedRBoxIoUTest(unittest.TestCase):
  87. def setUp(self):
  88. self.initTestCase()
  89. self.rbox1 = gen_sample(self.n)
  90. self.rbox2 = gen_sample(self.n)
  91. def initTestCase(self):
  92. self.n = 1000
  93. def assertAllClose(self, x, y, msg, atol=5e-1, rtol=1e-2):
  94. self.assertTrue(np.allclose(x, y, atol=atol, rtol=rtol), msg=msg)
  95. def get_places(self):
  96. places = [paddle.CPUPlace()]
  97. if paddle.device.is_compiled_with_cuda():
  98. places.append(paddle.CUDAPlace(0))
  99. return places
  100. def check_output(self, place):
  101. paddle.disable_static()
  102. pd_rbox1 = paddle.to_tensor(self.rbox1, place=place)
  103. pd_rbox2 = paddle.to_tensor(self.rbox2, place=place)
  104. actual_t = matched_rbox_iou(pd_rbox1, pd_rbox2).numpy()
  105. poly_rbox1 = self.rbox1
  106. poly_rbox2 = self.rbox2
  107. poly_rbox1[:, 0:4] = self.rbox1[:, 0:4] * 1024
  108. poly_rbox2[:, 0:4] = self.rbox2[:, 0:4] * 1024
  109. expect_t = matched_rbox_overlaps(poly_rbox1, poly_rbox2, use_cv2=False)
  110. self.assertAllClose(
  111. actual_t,
  112. expect_t,
  113. msg="rbox_iou has diff at {} \nExpect {}\nBut got {}".format(
  114. str(place), str(expect_t), str(actual_t)))
  115. def test_output(self):
  116. places = self.get_places()
  117. for place in places:
  118. self.check_output(place)
  119. if __name__ == "__main__":
  120. unittest.main()