test_rbox_iou.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import numpy as np
  2. import sys
  3. import time
  4. from shapely.geometry import Polygon
  5. import paddle
  6. import unittest
  7. from ext_op import rbox_iou
  8. def rbox2poly_single(rrect, get_best_begin_point=False):
  9. """
  10. rrect:[x_ctr,y_ctr,w,h,angle]
  11. to
  12. poly:[x0,y0,x1,y1,x2,y2,x3,y3]
  13. """
  14. x_ctr, y_ctr, width, height, angle = rrect[:5]
  15. tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2
  16. # rect 2x4
  17. rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]])
  18. R = np.array([[np.cos(angle), -np.sin(angle)],
  19. [np.sin(angle), np.cos(angle)]])
  20. # poly
  21. poly = R.dot(rect)
  22. x0, x1, x2, x3 = poly[0, :4] + x_ctr
  23. y0, y1, y2, y3 = poly[1, :4] + y_ctr
  24. poly = np.array([x0, y0, x1, y1, x2, y2, x3, y3], dtype=np.float64)
  25. return poly
  26. def intersection(g, p):
  27. """
  28. Intersection.
  29. """
  30. g = g[:8].reshape((4, 2))
  31. p = p[:8].reshape((4, 2))
  32. a = g
  33. b = p
  34. use_filter = True
  35. if use_filter:
  36. # step1:
  37. inter_x1 = np.maximum(np.min(a[:, 0]), np.min(b[:, 0]))
  38. inter_x2 = np.minimum(np.max(a[:, 0]), np.max(b[:, 0]))
  39. inter_y1 = np.maximum(np.min(a[:, 1]), np.min(b[:, 1]))
  40. inter_y2 = np.minimum(np.max(a[:, 1]), np.max(b[:, 1]))
  41. if inter_x1 >= inter_x2 or inter_y1 >= inter_y2:
  42. return 0.
  43. x1 = np.minimum(np.min(a[:, 0]), np.min(b[:, 0]))
  44. x2 = np.maximum(np.max(a[:, 0]), np.max(b[:, 0]))
  45. y1 = np.minimum(np.min(a[:, 1]), np.min(b[:, 1]))
  46. y2 = np.maximum(np.max(a[:, 1]), np.max(b[:, 1]))
  47. if x1 >= x2 or y1 >= y2 or (x2 - x1) < 2 or (y2 - y1) < 2:
  48. return 0.
  49. g = Polygon(g)
  50. p = Polygon(p)
  51. if not g.is_valid or not p.is_valid:
  52. return 0
  53. inter = Polygon(g).intersection(Polygon(p)).area
  54. union = g.area + p.area - inter
  55. if union == 0:
  56. return 0
  57. else:
  58. return inter / union
  59. def rbox_overlaps(anchors, gt_bboxes, use_cv2=False):
  60. """
  61. Args:
  62. anchors: [NA, 5] x1,y1,x2,y2,angle
  63. gt_bboxes: [M, 5] x1,y1,x2,y2,angle
  64. Returns:
  65. iou: [NA, M]
  66. """
  67. assert anchors.shape[1] == 5
  68. assert gt_bboxes.shape[1] == 5
  69. gt_bboxes_ploy = [rbox2poly_single(e) for e in gt_bboxes]
  70. anchors_ploy = [rbox2poly_single(e) for e in anchors]
  71. num_gt, num_anchors = len(gt_bboxes_ploy), len(anchors_ploy)
  72. iou = np.zeros((num_anchors, num_gt), dtype=np.float64)
  73. start_time = time.time()
  74. for i in range(num_anchors):
  75. for j in range(num_gt):
  76. try:
  77. iou[i, j] = intersection(anchors_ploy[i], gt_bboxes_ploy[j])
  78. except Exception as e:
  79. print('cur anchors_ploy[i]', anchors_ploy[i],
  80. 'gt_bboxes_ploy[j]', gt_bboxes_ploy[j], e)
  81. return iou
  82. def gen_sample(n):
  83. rbox = np.random.rand(n, 5)
  84. rbox[:, 0:4] = rbox[:, 0:4] * 0.45 + 0.001
  85. rbox[:, 4] = rbox[:, 4] - 0.5
  86. return rbox
  87. class RBoxIoUTest(unittest.TestCase):
  88. def setUp(self):
  89. self.initTestCase()
  90. self.rbox1 = gen_sample(self.n)
  91. self.rbox2 = gen_sample(self.m)
  92. def initTestCase(self):
  93. self.n = 13000
  94. self.m = 7
  95. def assertAllClose(self, x, y, msg, atol=5e-1, rtol=1e-2):
  96. self.assertTrue(np.allclose(x, y, atol=atol, rtol=rtol), msg=msg)
  97. def get_places(self):
  98. places = [paddle.CPUPlace()]
  99. if paddle.device.is_compiled_with_cuda():
  100. places.append(paddle.CUDAPlace(0))
  101. return places
  102. def check_output(self, place):
  103. paddle.disable_static()
  104. pd_rbox1 = paddle.to_tensor(self.rbox1, place=place)
  105. pd_rbox2 = paddle.to_tensor(self.rbox2, place=place)
  106. actual_t = rbox_iou(pd_rbox1, pd_rbox2).numpy()
  107. poly_rbox1 = self.rbox1
  108. poly_rbox2 = self.rbox2
  109. poly_rbox1[:, 0:4] = self.rbox1[:, 0:4] * 1024
  110. poly_rbox2[:, 0:4] = self.rbox2[:, 0:4] * 1024
  111. expect_t = rbox_overlaps(poly_rbox1, poly_rbox2, use_cv2=False)
  112. self.assertAllClose(
  113. actual_t,
  114. expect_t,
  115. msg="rbox_iou has diff at {} \nExpect {}\nBut got {}".format(
  116. str(place), str(expect_t), str(actual_t)))
  117. def test_output(self):
  118. places = self.get_places()
  119. for place in places:
  120. self.check_output(place)
  121. if __name__ == "__main__":
  122. unittest.main()