rbox_iou_utils.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // The code is based on
  16. // https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/csrc/box_iou_rotated/
  17. #pragma once
  18. #include <cassert>
  19. #include <cmath>
  20. #include <vector>
  21. #ifdef __CUDACC__
  22. // Designates functions callable from the host (CPU) and the device (GPU)
  23. #define HOST_DEVICE __host__ __device__
  24. #define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
  25. #else
  26. #include <algorithm>
  27. #define HOST_DEVICE
  28. #define HOST_DEVICE_INLINE HOST_DEVICE inline
  29. #endif
  30. namespace {
  31. template <typename T> struct RotatedBox { T x_ctr, y_ctr, w, h, a; };
  32. template <typename T> struct Point {
  33. T x, y;
  34. HOST_DEVICE_INLINE Point(const T &px = 0, const T &py = 0) : x(px), y(py) {}
  35. HOST_DEVICE_INLINE Point operator+(const Point &p) const {
  36. return Point(x + p.x, y + p.y);
  37. }
  38. HOST_DEVICE_INLINE Point &operator+=(const Point &p) {
  39. x += p.x;
  40. y += p.y;
  41. return *this;
  42. }
  43. HOST_DEVICE_INLINE Point operator-(const Point &p) const {
  44. return Point(x - p.x, y - p.y);
  45. }
  46. HOST_DEVICE_INLINE Point operator*(const T coeff) const {
  47. return Point(x * coeff, y * coeff);
  48. }
  49. };
  50. template <typename T>
  51. HOST_DEVICE_INLINE T dot_2d(const Point<T> &A, const Point<T> &B) {
  52. return A.x * B.x + A.y * B.y;
  53. }
  54. template <typename T>
  55. HOST_DEVICE_INLINE T cross_2d(const Point<T> &A, const Point<T> &B) {
  56. return A.x * B.y - B.x * A.y;
  57. }
  58. template <typename T>
  59. HOST_DEVICE_INLINE void get_rotated_vertices(const RotatedBox<T> &box,
  60. Point<T> (&pts)[4]) {
  61. // M_PI / 180. == 0.01745329251
  62. // double theta = box.a * 0.01745329251;
  63. // MODIFIED
  64. double theta = box.a;
  65. T cosTheta2 = (T)cos(theta) * 0.5f;
  66. T sinTheta2 = (T)sin(theta) * 0.5f;
  67. // y: top --> down; x: left --> right
  68. pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w;
  69. pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;
  70. pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w;
  71. pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;
  72. pts[2].x = 2 * box.x_ctr - pts[0].x;
  73. pts[2].y = 2 * box.y_ctr - pts[0].y;
  74. pts[3].x = 2 * box.x_ctr - pts[1].x;
  75. pts[3].y = 2 * box.y_ctr - pts[1].y;
  76. }
  77. template <typename T>
  78. HOST_DEVICE_INLINE int get_intersection_points(const Point<T> (&pts1)[4],
  79. const Point<T> (&pts2)[4],
  80. Point<T> (&intersections)[24]) {
  81. // Line vector
  82. // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
  83. Point<T> vec1[4], vec2[4];
  84. for (int i = 0; i < 4; i++) {
  85. vec1[i] = pts1[(i + 1) % 4] - pts1[i];
  86. vec2[i] = pts2[(i + 1) % 4] - pts2[i];
  87. }
  88. // Line test - test all line combos for intersection
  89. int num = 0; // number of intersections
  90. for (int i = 0; i < 4; i++) {
  91. for (int j = 0; j < 4; j++) {
  92. // Solve for 2x2 Ax=b
  93. T det = cross_2d<T>(vec2[j], vec1[i]);
  94. // This takes care of parallel lines
  95. if (fabs(det) <= 1e-14) {
  96. continue;
  97. }
  98. auto vec12 = pts2[j] - pts1[i];
  99. T t1 = cross_2d<T>(vec2[j], vec12) / det;
  100. T t2 = cross_2d<T>(vec1[i], vec12) / det;
  101. if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
  102. intersections[num++] = pts1[i] + vec1[i] * t1;
  103. }
  104. }
  105. }
  106. // Check for vertices of rect1 inside rect2
  107. {
  108. const auto &AB = vec2[0];
  109. const auto &DA = vec2[3];
  110. auto ABdotAB = dot_2d<T>(AB, AB);
  111. auto ADdotAD = dot_2d<T>(DA, DA);
  112. for (int i = 0; i < 4; i++) {
  113. // assume ABCD is the rectangle, and P is the point to be judged
  114. // P is inside ABCD iff. P's projection on AB lies within AB
  115. // and P's projection on AD lies within AD
  116. auto AP = pts1[i] - pts2[0];
  117. auto APdotAB = dot_2d<T>(AP, AB);
  118. auto APdotAD = -dot_2d<T>(AP, DA);
  119. if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
  120. (APdotAD <= ADdotAD)) {
  121. intersections[num++] = pts1[i];
  122. }
  123. }
  124. }
  125. // Reverse the check - check for vertices of rect2 inside rect1
  126. {
  127. const auto &AB = vec1[0];
  128. const auto &DA = vec1[3];
  129. auto ABdotAB = dot_2d<T>(AB, AB);
  130. auto ADdotAD = dot_2d<T>(DA, DA);
  131. for (int i = 0; i < 4; i++) {
  132. auto AP = pts2[i] - pts1[0];
  133. auto APdotAB = dot_2d<T>(AP, AB);
  134. auto APdotAD = -dot_2d<T>(AP, DA);
  135. if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
  136. (APdotAD <= ADdotAD)) {
  137. intersections[num++] = pts2[i];
  138. }
  139. }
  140. }
  141. return num;
  142. }
  143. template <typename T>
  144. HOST_DEVICE_INLINE int convex_hull_graham(const Point<T> (&p)[24],
  145. const int &num_in, Point<T> (&q)[24],
  146. bool shift_to_zero = false) {
  147. assert(num_in >= 2);
  148. // Step 1:
  149. // Find point with minimum y
  150. // if more than 1 points have the same minimum y,
  151. // pick the one with the minimum x.
  152. int t = 0;
  153. for (int i = 1; i < num_in; i++) {
  154. if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {
  155. t = i;
  156. }
  157. }
  158. auto &start = p[t]; // starting point
  159. // Step 2:
  160. // Subtract starting point from every points (for sorting in the next step)
  161. for (int i = 0; i < num_in; i++) {
  162. q[i] = p[i] - start;
  163. }
  164. // Swap the starting point to position 0
  165. auto tmp = q[0];
  166. q[0] = q[t];
  167. q[t] = tmp;
  168. // Step 3:
  169. // Sort point 1 ~ num_in according to their relative cross-product values
  170. // (essentially sorting according to angles)
  171. // If the angles are the same, sort according to their distance to origin
  172. T dist[24];
  173. for (int i = 0; i < num_in; i++) {
  174. dist[i] = dot_2d<T>(q[i], q[i]);
  175. }
  176. #ifdef __CUDACC__
  177. // CUDA version
  178. // In the future, we can potentially use thrust
  179. // for sorting here to improve speed (though not guaranteed)
  180. for (int i = 1; i < num_in - 1; i++) {
  181. for (int j = i + 1; j < num_in; j++) {
  182. T crossProduct = cross_2d<T>(q[i], q[j]);
  183. if ((crossProduct < -1e-6) ||
  184. (fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) {
  185. auto q_tmp = q[i];
  186. q[i] = q[j];
  187. q[j] = q_tmp;
  188. auto dist_tmp = dist[i];
  189. dist[i] = dist[j];
  190. dist[j] = dist_tmp;
  191. }
  192. }
  193. }
  194. #else
  195. // CPU version
  196. std::sort(q + 1, q + num_in,
  197. [](const Point<T> &A, const Point<T> &B) -> bool {
  198. T temp = cross_2d<T>(A, B);
  199. if (fabs(temp) < 1e-6) {
  200. return dot_2d<T>(A, A) < dot_2d<T>(B, B);
  201. } else {
  202. return temp > 0;
  203. }
  204. });
  205. #endif
  206. // Step 4:
  207. // Make sure there are at least 2 points (that don't overlap with each other)
  208. // in the stack
  209. int k; // index of the non-overlapped second point
  210. for (k = 1; k < num_in; k++) {
  211. if (dist[k] > 1e-8) {
  212. break;
  213. }
  214. }
  215. if (k == num_in) {
  216. // We reach the end, which means the convex hull is just one point
  217. q[0] = p[t];
  218. return 1;
  219. }
  220. q[1] = q[k];
  221. int m = 2; // 2 points in the stack
  222. // Step 5:
  223. // Finally we can start the scanning process.
  224. // When a non-convex relationship between the 3 points is found
  225. // (either concave shape or duplicated points),
  226. // we pop the previous point from the stack
  227. // until the 3-point relationship is convex again, or
  228. // until the stack only contains two points
  229. for (int i = k + 1; i < num_in; i++) {
  230. while (m > 1 && cross_2d<T>(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) {
  231. m--;
  232. }
  233. q[m++] = q[i];
  234. }
  235. // Step 6 (Optional):
  236. // In general sense we need the original coordinates, so we
  237. // need to shift the points back (reverting Step 2)
  238. // But if we're only interested in getting the area/perimeter of the shape
  239. // We can simply return.
  240. if (!shift_to_zero) {
  241. for (int i = 0; i < m; i++) {
  242. q[i] += start;
  243. }
  244. }
  245. return m;
  246. }
  247. template <typename T>
  248. HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int &m) {
  249. if (m <= 2) {
  250. return 0;
  251. }
  252. T area = 0;
  253. for (int i = 1; i < m - 1; i++) {
  254. area += fabs(cross_2d<T>(q[i] - q[0], q[i + 1] - q[0]));
  255. }
  256. return area / 2.0;
  257. }
  258. template <typename T>
  259. HOST_DEVICE_INLINE T rboxes_intersection(const RotatedBox<T> &box1,
  260. const RotatedBox<T> &box2) {
  261. // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
  262. // from rotated_rect_intersection_pts
  263. Point<T> intersectPts[24], orderedPts[24];
  264. Point<T> pts1[4];
  265. Point<T> pts2[4];
  266. get_rotated_vertices<T>(box1, pts1);
  267. get_rotated_vertices<T>(box2, pts2);
  268. int num = get_intersection_points<T>(pts1, pts2, intersectPts);
  269. if (num <= 2) {
  270. return 0.0;
  271. }
  272. // Convex Hull to order the intersection points in clockwise order and find
  273. // the contour area.
  274. int num_convex = convex_hull_graham<T>(intersectPts, num, orderedPts, true);
  275. return polygon_area<T>(orderedPts, num_convex);
  276. }
  277. } // namespace
  278. template <typename T>
  279. HOST_DEVICE_INLINE T rbox_iou_single(T const *const box1_raw,
  280. T const *const box2_raw) {
  281. // shift center to the middle point to achieve higher precision in result
  282. RotatedBox<T> box1, box2;
  283. auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0;
  284. auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0;
  285. box1.x_ctr = box1_raw[0] - center_shift_x;
  286. box1.y_ctr = box1_raw[1] - center_shift_y;
  287. box1.w = box1_raw[2];
  288. box1.h = box1_raw[3];
  289. box1.a = box1_raw[4];
  290. box2.x_ctr = box2_raw[0] - center_shift_x;
  291. box2.y_ctr = box2_raw[1] - center_shift_y;
  292. box2.w = box2_raw[2];
  293. box2.h = box2_raw[3];
  294. box2.a = box2_raw[4];
  295. if (box1.w < 1e-2 || box1.h < 1e-2 || box2.w < 1e-2 || box2.h < 1e-2) {
  296. return 0.f;
  297. }
  298. const T area1 = box1.w * box1.h;
  299. const T area2 = box2.w * box2.h;
  300. const T intersection = rboxes_intersection<T>(box1, box2);
  301. const T iou = intersection / (area1 + area2 - intersection);
  302. return iou;
  303. }
  304. /**
  305. Computes ceil(a / b)
  306. */
  307. HOST_DEVICE inline int CeilDiv(const int a, const int b) {
  308. return (a + b - 1) / b;
  309. }