bbox_utils.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import paddle
  16. import numpy as np
  17. def bbox2delta(src_boxes, tgt_boxes, weights=[1.0, 1.0, 1.0, 1.0]):
  18. """Encode bboxes to deltas.
  19. """
  20. src_w = src_boxes[:, 2] - src_boxes[:, 0]
  21. src_h = src_boxes[:, 3] - src_boxes[:, 1]
  22. src_ctr_x = src_boxes[:, 0] + 0.5 * src_w
  23. src_ctr_y = src_boxes[:, 1] + 0.5 * src_h
  24. tgt_w = tgt_boxes[:, 2] - tgt_boxes[:, 0]
  25. tgt_h = tgt_boxes[:, 3] - tgt_boxes[:, 1]
  26. tgt_ctr_x = tgt_boxes[:, 0] + 0.5 * tgt_w
  27. tgt_ctr_y = tgt_boxes[:, 1] + 0.5 * tgt_h
  28. wx, wy, ww, wh = weights
  29. dx = wx * (tgt_ctr_x - src_ctr_x) / src_w
  30. dy = wy * (tgt_ctr_y - src_ctr_y) / src_h
  31. dw = ww * paddle.log(tgt_w / src_w)
  32. dh = wh * paddle.log(tgt_h / src_h)
  33. deltas = paddle.stack((dx, dy, dw, dh), axis=1)
  34. return deltas
  35. def delta2bbox(deltas, boxes, weights=[1.0, 1.0, 1.0, 1.0], max_shape=None):
  36. """Decode deltas to boxes. Used in RCNNBox,CascadeHead,RCNNHead,RetinaHead.
  37. Note: return tensor shape [n,1,4]
  38. If you want to add a reshape, please add after the calling code instead of here.
  39. """
  40. clip_scale = math.log(1000.0 / 16)
  41. widths = boxes[:, 2] - boxes[:, 0]
  42. heights = boxes[:, 3] - boxes[:, 1]
  43. ctr_x = boxes[:, 0] + 0.5 * widths
  44. ctr_y = boxes[:, 1] + 0.5 * heights
  45. wx, wy, ww, wh = weights
  46. dx = deltas[:, 0::4] / wx
  47. dy = deltas[:, 1::4] / wy
  48. dw = deltas[:, 2::4] / ww
  49. dh = deltas[:, 3::4] / wh
  50. # Prevent sending too large values into paddle.exp()
  51. dw = paddle.clip(dw, max=clip_scale)
  52. dh = paddle.clip(dh, max=clip_scale)
  53. pred_ctr_x = dx * widths.unsqueeze(1) + ctr_x.unsqueeze(1)
  54. pred_ctr_y = dy * heights.unsqueeze(1) + ctr_y.unsqueeze(1)
  55. pred_w = paddle.exp(dw) * widths.unsqueeze(1)
  56. pred_h = paddle.exp(dh) * heights.unsqueeze(1)
  57. pred_boxes = []
  58. pred_boxes.append(pred_ctr_x - 0.5 * pred_w)
  59. pred_boxes.append(pred_ctr_y - 0.5 * pred_h)
  60. pred_boxes.append(pred_ctr_x + 0.5 * pred_w)
  61. pred_boxes.append(pred_ctr_y + 0.5 * pred_h)
  62. pred_boxes = paddle.stack(pred_boxes, axis=-1)
  63. if max_shape is not None:
  64. pred_boxes[..., 0::2] = pred_boxes[..., 0::2].clip(
  65. min=0, max=max_shape[1])
  66. pred_boxes[..., 1::2] = pred_boxes[..., 1::2].clip(
  67. min=0, max=max_shape[0])
  68. return pred_boxes
  69. def bbox2delta_v2(src_boxes,
  70. tgt_boxes,
  71. delta_mean=[0.0, 0.0, 0.0, 0.0],
  72. delta_std=[1.0, 1.0, 1.0, 1.0]):
  73. """Encode bboxes to deltas.
  74. Modified from bbox2delta() which just use weight parameters to multiply deltas.
  75. """
  76. src_w = src_boxes[:, 2] - src_boxes[:, 0]
  77. src_h = src_boxes[:, 3] - src_boxes[:, 1]
  78. src_ctr_x = src_boxes[:, 0] + 0.5 * src_w
  79. src_ctr_y = src_boxes[:, 1] + 0.5 * src_h
  80. tgt_w = tgt_boxes[:, 2] - tgt_boxes[:, 0]
  81. tgt_h = tgt_boxes[:, 3] - tgt_boxes[:, 1]
  82. tgt_ctr_x = tgt_boxes[:, 0] + 0.5 * tgt_w
  83. tgt_ctr_y = tgt_boxes[:, 1] + 0.5 * tgt_h
  84. dx = (tgt_ctr_x - src_ctr_x) / src_w
  85. dy = (tgt_ctr_y - src_ctr_y) / src_h
  86. dw = paddle.log(tgt_w / src_w)
  87. dh = paddle.log(tgt_h / src_h)
  88. deltas = paddle.stack((dx, dy, dw, dh), axis=1)
  89. deltas = (
  90. deltas - paddle.to_tensor(delta_mean)) / paddle.to_tensor(delta_std)
  91. return deltas
  92. def delta2bbox_v2(deltas,
  93. boxes,
  94. delta_mean=[0.0, 0.0, 0.0, 0.0],
  95. delta_std=[1.0, 1.0, 1.0, 1.0],
  96. max_shape=None,
  97. ctr_clip=32.0):
  98. """Decode deltas to bboxes.
  99. Modified from delta2bbox() which just use weight parameters to be divided by deltas.
  100. Used in YOLOFHead.
  101. Note: return tensor shape [n,1,4]
  102. If you want to add a reshape, please add after the calling code instead of here.
  103. """
  104. clip_scale = math.log(1000.0 / 16)
  105. widths = boxes[:, 2] - boxes[:, 0]
  106. heights = boxes[:, 3] - boxes[:, 1]
  107. ctr_x = boxes[:, 0] + 0.5 * widths
  108. ctr_y = boxes[:, 1] + 0.5 * heights
  109. deltas = deltas * paddle.to_tensor(delta_std) + paddle.to_tensor(delta_mean)
  110. dx = deltas[:, 0::4]
  111. dy = deltas[:, 1::4]
  112. dw = deltas[:, 2::4]
  113. dh = deltas[:, 3::4]
  114. # Prevent sending too large values into paddle.exp()
  115. dx = dx * widths.unsqueeze(1)
  116. dy = dy * heights.unsqueeze(1)
  117. if ctr_clip is not None:
  118. dx = paddle.clip(dx, max=ctr_clip, min=-ctr_clip)
  119. dy = paddle.clip(dy, max=ctr_clip, min=-ctr_clip)
  120. dw = paddle.clip(dw, max=clip_scale)
  121. dh = paddle.clip(dh, max=clip_scale)
  122. else:
  123. dw = dw.clip(min=-ctr_clip, max=ctr_clip)
  124. dh = dh.clip(min=-ctr_clip, max=ctr_clip)
  125. pred_ctr_x = dx + ctr_x.unsqueeze(1)
  126. pred_ctr_y = dy + ctr_y.unsqueeze(1)
  127. pred_w = paddle.exp(dw) * widths.unsqueeze(1)
  128. pred_h = paddle.exp(dh) * heights.unsqueeze(1)
  129. pred_boxes = []
  130. pred_boxes.append(pred_ctr_x - 0.5 * pred_w)
  131. pred_boxes.append(pred_ctr_y - 0.5 * pred_h)
  132. pred_boxes.append(pred_ctr_x + 0.5 * pred_w)
  133. pred_boxes.append(pred_ctr_y + 0.5 * pred_h)
  134. pred_boxes = paddle.stack(pred_boxes, axis=-1)
  135. if max_shape is not None:
  136. pred_boxes[..., 0::2] = pred_boxes[..., 0::2].clip(
  137. min=0, max=max_shape[1])
  138. pred_boxes[..., 1::2] = pred_boxes[..., 1::2].clip(
  139. min=0, max=max_shape[0])
  140. return pred_boxes
  141. def expand_bbox(bboxes, scale):
  142. w_half = (bboxes[:, 2] - bboxes[:, 0]) * .5
  143. h_half = (bboxes[:, 3] - bboxes[:, 1]) * .5
  144. x_c = (bboxes[:, 2] + bboxes[:, 0]) * .5
  145. y_c = (bboxes[:, 3] + bboxes[:, 1]) * .5
  146. w_half *= scale
  147. h_half *= scale
  148. bboxes_exp = np.zeros(bboxes.shape, dtype=np.float32)
  149. bboxes_exp[:, 0] = x_c - w_half
  150. bboxes_exp[:, 2] = x_c + w_half
  151. bboxes_exp[:, 1] = y_c - h_half
  152. bboxes_exp[:, 3] = y_c + h_half
  153. return bboxes_exp
  154. def clip_bbox(boxes, im_shape):
  155. h, w = im_shape[0], im_shape[1]
  156. x1 = boxes[:, 0].clip(0, w)
  157. y1 = boxes[:, 1].clip(0, h)
  158. x2 = boxes[:, 2].clip(0, w)
  159. y2 = boxes[:, 3].clip(0, h)
  160. return paddle.stack([x1, y1, x2, y2], axis=1)
  161. def nonempty_bbox(boxes, min_size=0, return_mask=False):
  162. w = boxes[:, 2] - boxes[:, 0]
  163. h = boxes[:, 3] - boxes[:, 1]
  164. mask = paddle.logical_and(h > min_size, w > min_size)
  165. if return_mask:
  166. return mask
  167. keep = paddle.nonzero(mask).flatten()
  168. return keep
  169. def bbox_area(boxes):
  170. return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
  171. def bbox_overlaps(boxes1, boxes2):
  172. """
  173. Calculate overlaps between boxes1 and boxes2
  174. Args:
  175. boxes1 (Tensor): boxes with shape [M, 4]
  176. boxes2 (Tensor): boxes with shape [N, 4]
  177. Return:
  178. overlaps (Tensor): overlaps between boxes1 and boxes2 with shape [M, N]
  179. """
  180. M = boxes1.shape[0]
  181. N = boxes2.shape[0]
  182. if M * N == 0:
  183. return paddle.zeros([M, N], dtype='float32')
  184. area1 = bbox_area(boxes1)
  185. area2 = bbox_area(boxes2)
  186. xy_max = paddle.minimum(
  187. paddle.unsqueeze(boxes1, 1)[:, :, 2:], boxes2[:, 2:])
  188. xy_min = paddle.maximum(
  189. paddle.unsqueeze(boxes1, 1)[:, :, :2], boxes2[:, :2])
  190. width_height = xy_max - xy_min
  191. width_height = width_height.clip(min=0)
  192. inter = width_height.prod(axis=2)
  193. overlaps = paddle.where(inter > 0, inter /
  194. (paddle.unsqueeze(area1, 1) + area2 - inter),
  195. paddle.zeros_like(inter))
  196. return overlaps
  197. def batch_bbox_overlaps(bboxes1,
  198. bboxes2,
  199. mode='iou',
  200. is_aligned=False,
  201. eps=1e-6):
  202. """Calculate overlap between two set of bboxes.
  203. If ``is_aligned `` is ``False``, then calculate the overlaps between each
  204. bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned
  205. pair of bboxes1 and bboxes2.
  206. Args:
  207. bboxes1 (Tensor): shape (B, m, 4) in <x1, y1, x2, y2> format or empty.
  208. bboxes2 (Tensor): shape (B, n, 4) in <x1, y1, x2, y2> format or empty.
  209. B indicates the batch dim, in shape (B1, B2, ..., Bn).
  210. If ``is_aligned `` is ``True``, then m and n must be equal.
  211. mode (str): "iou" (intersection over union) or "iof" (intersection over
  212. foreground).
  213. is_aligned (bool, optional): If True, then m and n must be equal.
  214. Default False.
  215. eps (float, optional): A value added to the denominator for numerical
  216. stability. Default 1e-6.
  217. Returns:
  218. Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,)
  219. """
  220. assert mode in ['iou', 'iof', 'giou'], 'Unsupported mode {}'.format(mode)
  221. # Either the boxes are empty or the length of boxes's last dimenstion is 4
  222. assert (bboxes1.shape[-1] == 4 or bboxes1.shape[0] == 0)
  223. assert (bboxes2.shape[-1] == 4 or bboxes2.shape[0] == 0)
  224. # Batch dim must be the same
  225. # Batch dim: (B1, B2, ... Bn)
  226. assert bboxes1.shape[:-2] == bboxes2.shape[:-2]
  227. batch_shape = bboxes1.shape[:-2]
  228. rows = bboxes1.shape[-2] if bboxes1.shape[0] > 0 else 0
  229. cols = bboxes2.shape[-2] if bboxes2.shape[0] > 0 else 0
  230. if is_aligned:
  231. assert rows == cols
  232. if rows * cols == 0:
  233. if is_aligned:
  234. return paddle.full(batch_shape + (rows, ), 1)
  235. else:
  236. return paddle.full(batch_shape + (rows, cols), 1)
  237. area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1])
  238. area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1])
  239. if is_aligned:
  240. lt = paddle.maximum(bboxes1[:, :2], bboxes2[:, :2]) # [B, rows, 2]
  241. rb = paddle.minimum(bboxes1[:, 2:], bboxes2[:, 2:]) # [B, rows, 2]
  242. wh = (rb - lt).clip(min=0) # [B, rows, 2]
  243. overlap = wh[:, 0] * wh[:, 1]
  244. if mode in ['iou', 'giou']:
  245. union = area1 + area2 - overlap
  246. else:
  247. union = area1
  248. if mode == 'giou':
  249. enclosed_lt = paddle.minimum(bboxes1[:, :2], bboxes2[:, :2])
  250. enclosed_rb = paddle.maximum(bboxes1[:, 2:], bboxes2[:, 2:])
  251. else:
  252. lt = paddle.maximum(bboxes1[:, :2].reshape([rows, 1, 2]),
  253. bboxes2[:, :2]) # [B, rows, cols, 2]
  254. rb = paddle.minimum(bboxes1[:, 2:].reshape([rows, 1, 2]),
  255. bboxes2[:, 2:]) # [B, rows, cols, 2]
  256. wh = (rb - lt).clip(min=0) # [B, rows, cols, 2]
  257. overlap = wh[:, :, 0] * wh[:, :, 1]
  258. if mode in ['iou', 'giou']:
  259. union = area1.reshape([rows,1]) \
  260. + area2.reshape([1,cols]) - overlap
  261. else:
  262. union = area1[:, None]
  263. if mode == 'giou':
  264. enclosed_lt = paddle.minimum(bboxes1[:, :2].reshape([rows, 1, 2]),
  265. bboxes2[:, :2])
  266. enclosed_rb = paddle.maximum(bboxes1[:, 2:].reshape([rows, 1, 2]),
  267. bboxes2[:, 2:])
  268. eps = paddle.to_tensor([eps])
  269. union = paddle.maximum(union, eps)
  270. ious = overlap / union
  271. if mode in ['iou', 'iof']:
  272. return ious
  273. # calculate gious
  274. enclose_wh = (enclosed_rb - enclosed_lt).clip(min=0)
  275. enclose_area = enclose_wh[:, :, 0] * enclose_wh[:, :, 1]
  276. enclose_area = paddle.maximum(enclose_area, eps)
  277. gious = ious - (enclose_area - union) / enclose_area
  278. return 1 - gious
  279. def xywh2xyxy(box):
  280. x, y, w, h = box
  281. x1 = x - w * 0.5
  282. y1 = y - h * 0.5
  283. x2 = x + w * 0.5
  284. y2 = y + h * 0.5
  285. return [x1, y1, x2, y2]
  286. def make_grid(h, w, dtype):
  287. yv, xv = paddle.meshgrid([paddle.arange(h), paddle.arange(w)])
  288. return paddle.stack((xv, yv), 2).cast(dtype=dtype)
  289. def decode_yolo(box, anchor, downsample_ratio):
  290. """decode yolo box
  291. Args:
  292. box (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
  293. anchor (list): anchor with the shape [na, 2]
  294. downsample_ratio (int): downsample ratio, default 32
  295. scale (float): scale, default 1.
  296. Return:
  297. box (list): decoded box, [x, y, w, h], all have the shape [b, na, h, w, 1]
  298. """
  299. x, y, w, h = box
  300. na, grid_h, grid_w = x.shape[1:4]
  301. grid = make_grid(grid_h, grid_w, x.dtype).reshape((1, 1, grid_h, grid_w, 2))
  302. x1 = (x + grid[:, :, :, :, 0:1]) / grid_w
  303. y1 = (y + grid[:, :, :, :, 1:2]) / grid_h
  304. anchor = paddle.to_tensor(anchor, dtype=x.dtype)
  305. anchor = anchor.reshape((1, na, 1, 1, 2))
  306. w1 = paddle.exp(w) * anchor[:, :, :, :, 0:1] / (downsample_ratio * grid_w)
  307. h1 = paddle.exp(h) * anchor[:, :, :, :, 1:2] / (downsample_ratio * grid_h)
  308. return [x1, y1, w1, h1]
  309. def batch_iou_similarity(box1, box2, eps=1e-9):
  310. """Calculate iou of box1 and box2 in batch
  311. Args:
  312. box1 (Tensor): box with the shape [N, M1, 4]
  313. box2 (Tensor): box with the shape [N, M2, 4]
  314. Return:
  315. iou (Tensor): iou between box1 and box2 with the shape [N, M1, M2]
  316. """
  317. box1 = box1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4]
  318. box2 = box2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4]
  319. px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
  320. gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
  321. x1y1 = paddle.maximum(px1y1, gx1y1)
  322. x2y2 = paddle.minimum(px2y2, gx2y2)
  323. overlap = (x2y2 - x1y1).clip(0).prod(-1)
  324. area1 = (px2y2 - px1y1).clip(0).prod(-1)
  325. area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
  326. union = area1 + area2 - overlap + eps
  327. return overlap / union
  328. def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9):
  329. """calculate the iou of box1 and box2
  330. Args:
  331. box1 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
  332. box2 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
  333. giou (bool): whether use giou or not, default False
  334. diou (bool): whether use diou or not, default False
  335. ciou (bool): whether use ciou or not, default False
  336. eps (float): epsilon to avoid divide by zero
  337. Return:
  338. iou (Tensor): iou of box1 and box1, with the shape [b, na, h, w, 1]
  339. """
  340. px1, py1, px2, py2 = box1
  341. gx1, gy1, gx2, gy2 = box2
  342. x1 = paddle.maximum(px1, gx1)
  343. y1 = paddle.maximum(py1, gy1)
  344. x2 = paddle.minimum(px2, gx2)
  345. y2 = paddle.minimum(py2, gy2)
  346. overlap = ((x2 - x1).clip(0)) * ((y2 - y1).clip(0))
  347. area1 = (px2 - px1) * (py2 - py1)
  348. area1 = area1.clip(0)
  349. area2 = (gx2 - gx1) * (gy2 - gy1)
  350. area2 = area2.clip(0)
  351. union = area1 + area2 - overlap + eps
  352. iou = overlap / union
  353. if giou or ciou or diou:
  354. # convex w, h
  355. cw = paddle.maximum(px2, gx2) - paddle.minimum(px1, gx1)
  356. ch = paddle.maximum(py2, gy2) - paddle.minimum(py1, gy1)
  357. if giou:
  358. c_area = cw * ch + eps
  359. return iou - (c_area - union) / c_area
  360. else:
  361. # convex diagonal squared
  362. c2 = cw**2 + ch**2 + eps
  363. # center distance
  364. rho2 = ((px1 + px2 - gx1 - gx2)**2 + (py1 + py2 - gy1 - gy2)**2) / 4
  365. if diou:
  366. return iou - rho2 / c2
  367. else:
  368. w1, h1 = px2 - px1, py2 - py1 + eps
  369. w2, h2 = gx2 - gx1, gy2 - gy1 + eps
  370. delta = paddle.atan(w1 / h1) - paddle.atan(w2 / h2)
  371. v = (4 / math.pi**2) * paddle.pow(delta, 2)
  372. alpha = v / (1 + eps - iou + v)
  373. alpha.stop_gradient = True
  374. return iou - (rho2 / c2 + v * alpha)
  375. else:
  376. return iou
  377. def bbox_iou_np_expand(box1, box2, x1y1x2y2=True, eps=1e-16):
  378. """
  379. Calculate the iou of box1 and box2 with numpy.
  380. Args:
  381. box1 (ndarray): [N, 4]
  382. box2 (ndarray): [M, 4], usually N != M
  383. x1y1x2y2 (bool): whether in x1y1x2y2 stype, default True
  384. eps (float): epsilon to avoid divide by zero
  385. Return:
  386. iou (ndarray): iou of box1 and box2, [N, M]
  387. """
  388. N, M = len(box1), len(box2) # usually N != M
  389. if x1y1x2y2:
  390. b1_x1, b1_y1 = box1[:, 0], box1[:, 1]
  391. b1_x2, b1_y2 = box1[:, 2], box1[:, 3]
  392. b2_x1, b2_y1 = box2[:, 0], box2[:, 1]
  393. b2_x2, b2_y2 = box2[:, 2], box2[:, 3]
  394. else:
  395. # cxcywh style
  396. # Transform from center and width to exact coordinates
  397. b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
  398. b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
  399. b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
  400. b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
  401. # get the coordinates of the intersection rectangle
  402. inter_rect_x1 = np.zeros((N, M), dtype=np.float32)
  403. inter_rect_y1 = np.zeros((N, M), dtype=np.float32)
  404. inter_rect_x2 = np.zeros((N, M), dtype=np.float32)
  405. inter_rect_y2 = np.zeros((N, M), dtype=np.float32)
  406. for i in range(len(box2)):
  407. inter_rect_x1[:, i] = np.maximum(b1_x1, b2_x1[i])
  408. inter_rect_y1[:, i] = np.maximum(b1_y1, b2_y1[i])
  409. inter_rect_x2[:, i] = np.minimum(b1_x2, b2_x2[i])
  410. inter_rect_y2[:, i] = np.minimum(b1_y2, b2_y2[i])
  411. # Intersection area
  412. inter_area = np.maximum(inter_rect_x2 - inter_rect_x1, 0) * np.maximum(
  413. inter_rect_y2 - inter_rect_y1, 0)
  414. # Union Area
  415. b1_area = np.repeat(
  416. ((b1_x2 - b1_x1) * (b1_y2 - b1_y1)).reshape(-1, 1), M, axis=-1)
  417. b2_area = np.repeat(
  418. ((b2_x2 - b2_x1) * (b2_y2 - b2_y1)).reshape(1, -1), N, axis=0)
  419. ious = inter_area / (b1_area + b2_area - inter_area + eps)
  420. return ious
  421. def bbox2distance(points, bbox, max_dis=None, eps=0.1):
  422. """Decode bounding box based on distances.
  423. Args:
  424. points (Tensor): Shape (n, 2), [x, y].
  425. bbox (Tensor): Shape (n, 4), "xyxy" format
  426. max_dis (float): Upper bound of the distance.
  427. eps (float): a small value to ensure target < max_dis, instead <=
  428. Returns:
  429. Tensor: Decoded distances.
  430. """
  431. left = points[:, 0] - bbox[:, 0]
  432. top = points[:, 1] - bbox[:, 1]
  433. right = bbox[:, 2] - points[:, 0]
  434. bottom = bbox[:, 3] - points[:, 1]
  435. if max_dis is not None:
  436. left = left.clip(min=0, max=max_dis - eps)
  437. top = top.clip(min=0, max=max_dis - eps)
  438. right = right.clip(min=0, max=max_dis - eps)
  439. bottom = bottom.clip(min=0, max=max_dis - eps)
  440. return paddle.stack([left, top, right, bottom], -1)
  441. def distance2bbox(points, distance, max_shape=None):
  442. """Decode distance prediction to bounding box.
  443. Args:
  444. points (Tensor): Shape (n, 2), [x, y].
  445. distance (Tensor): Distance from the given point to 4
  446. boundaries (left, top, right, bottom).
  447. max_shape (tuple): Shape of the image.
  448. Returns:
  449. Tensor: Decoded bboxes.
  450. """
  451. x1 = points[:, 0] - distance[:, 0]
  452. y1 = points[:, 1] - distance[:, 1]
  453. x2 = points[:, 0] + distance[:, 2]
  454. y2 = points[:, 1] + distance[:, 3]
  455. if max_shape is not None:
  456. x1 = x1.clip(min=0, max=max_shape[1])
  457. y1 = y1.clip(min=0, max=max_shape[0])
  458. x2 = x2.clip(min=0, max=max_shape[1])
  459. y2 = y2.clip(min=0, max=max_shape[0])
  460. return paddle.stack([x1, y1, x2, y2], -1)
  461. def bbox_center(boxes):
  462. """Get bbox centers from boxes.
  463. Args:
  464. boxes (Tensor): boxes with shape (..., 4), "xmin, ymin, xmax, ymax" format.
  465. Returns:
  466. Tensor: boxes centers with shape (..., 2), "cx, cy" format.
  467. """
  468. boxes_cx = (boxes[..., 0] + boxes[..., 2]) / 2
  469. boxes_cy = (boxes[..., 1] + boxes[..., 3]) / 2
  470. return paddle.stack([boxes_cx, boxes_cy], axis=-1)
  471. def batch_distance2bbox(points, distance, max_shapes=None):
  472. """Decode distance prediction to bounding box for batch.
  473. Args:
  474. points (Tensor): [B, ..., 2], "xy" format
  475. distance (Tensor): [B, ..., 4], "ltrb" format
  476. max_shapes (Tensor): [B, 2], "h,w" format, Shape of the image.
  477. Returns:
  478. Tensor: Decoded bboxes, "x1y1x2y2" format.
  479. """
  480. lt, rb = paddle.split(distance, 2, -1)
  481. # while tensor add parameters, parameters should be better placed on the second place
  482. x1y1 = -lt + points
  483. x2y2 = rb + points
  484. out_bbox = paddle.concat([x1y1, x2y2], -1)
  485. if max_shapes is not None:
  486. max_shapes = max_shapes.flip(-1).tile([1, 2])
  487. delta_dim = out_bbox.ndim - max_shapes.ndim
  488. for _ in range(delta_dim):
  489. max_shapes.unsqueeze_(1)
  490. out_bbox = paddle.where(out_bbox < max_shapes, out_bbox, max_shapes)
  491. out_bbox = paddle.where(out_bbox > 0, out_bbox,
  492. paddle.zeros_like(out_bbox))
  493. return out_bbox
  494. def iou_similarity(box1, box2, eps=1e-10):
  495. """Calculate iou of box1 and box2
  496. Args:
  497. box1 (Tensor): box with the shape [M1, 4]
  498. box2 (Tensor): box with the shape [M2, 4]
  499. Return:
  500. iou (Tensor): iou between box1 and box2 with the shape [M1, M2]
  501. """
  502. box1 = box1.unsqueeze(1) # [M1, 4] -> [M1, 1, 4]
  503. box2 = box2.unsqueeze(0) # [M2, 4] -> [1, M2, 4]
  504. px1y1, px2y2 = box1[:, :, 0:2], box1[:, :, 2:4]
  505. gx1y1, gx2y2 = box2[:, :, 0:2], box2[:, :, 2:4]
  506. x1y1 = paddle.maximum(px1y1, gx1y1)
  507. x2y2 = paddle.minimum(px2y2, gx2y2)
  508. overlap = (x2y2 - x1y1).clip(0).prod(-1)
  509. area1 = (px2y2 - px1y1).clip(0).prod(-1)
  510. area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
  511. union = area1 + area2 - overlap + eps
  512. return overlap / union