post_process.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from ppdet.core.workspace import register
  19. from ppdet.modeling.bbox_utils import nonempty_bbox
  20. from .transformers import bbox_cxcywh_to_xyxy
  21. try:
  22. from collections.abc import Sequence
  23. except Exception:
  24. from collections import Sequence
  25. __all__ = [
  26. 'BBoxPostProcess', 'MaskPostProcess', 'JDEBBoxPostProcess',
  27. 'CenterNetPostProcess', 'DETRBBoxPostProcess', 'SparsePostProcess'
  28. ]
  29. @register
  30. class BBoxPostProcess(object):
  31. __shared__ = ['num_classes', 'export_onnx', 'export_eb']
  32. __inject__ = ['decode', 'nms']
  33. def __init__(self,
  34. num_classes=80,
  35. decode=None,
  36. nms=None,
  37. export_onnx=False,
  38. export_eb=False):
  39. super(BBoxPostProcess, self).__init__()
  40. self.num_classes = num_classes
  41. self.decode = decode
  42. self.nms = nms
  43. self.export_onnx = export_onnx
  44. self.export_eb = export_eb
  45. def __call__(self, head_out, rois, im_shape, scale_factor):
  46. """
  47. Decode the bbox and do NMS if needed.
  48. Args:
  49. head_out (tuple): bbox_pred and cls_prob of bbox_head output.
  50. rois (tuple): roi and rois_num of rpn_head output.
  51. im_shape (Tensor): The shape of the input image.
  52. scale_factor (Tensor): The scale factor of the input image.
  53. export_onnx (bool): whether export model to onnx
  54. Returns:
  55. bbox_pred (Tensor): The output prediction with shape [N, 6], including
  56. labels, scores and bboxes. The size of bboxes are corresponding
  57. to the input image, the bboxes may be used in other branch.
  58. bbox_num (Tensor): The number of prediction boxes of each batch with
  59. shape [1], and is N.
  60. """
  61. if self.nms is not None:
  62. bboxes, score = self.decode(head_out, rois, im_shape, scale_factor)
  63. bbox_pred, bbox_num, _ = self.nms(bboxes, score, self.num_classes)
  64. else:
  65. bbox_pred, bbox_num = self.decode(head_out, rois, im_shape,
  66. scale_factor)
  67. if self.export_onnx:
  68. # add fake box after postprocess when exporting onnx
  69. fake_bboxes = paddle.to_tensor(
  70. np.array(
  71. [[0., 0.0, 0.0, 0.0, 1.0, 1.0]], dtype='float32'))
  72. bbox_pred = paddle.concat([bbox_pred, fake_bboxes])
  73. bbox_num = bbox_num + 1
  74. return bbox_pred, bbox_num
  75. def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
  76. """
  77. Rescale, clip and filter the bbox from the output of NMS to
  78. get final prediction.
  79. Notes:
  80. Currently only support bs = 1.
  81. Args:
  82. bboxes (Tensor): The output bboxes with shape [N, 6] after decode
  83. and NMS, including labels, scores and bboxes.
  84. bbox_num (Tensor): The number of prediction boxes of each batch with
  85. shape [1], and is N.
  86. im_shape (Tensor): The shape of the input image.
  87. scale_factor (Tensor): The scale factor of the input image.
  88. Returns:
  89. pred_result (Tensor): The final prediction results with shape [N, 6]
  90. including labels, scores and bboxes.
  91. """
  92. if self.export_eb:
  93. # enable rcnn models for edgeboard hw to skip the following postprocess.
  94. return bboxes, bboxes, bbox_num
  95. if not self.export_onnx:
  96. bboxes_list = []
  97. bbox_num_list = []
  98. id_start = 0
  99. fake_bboxes = paddle.to_tensor(
  100. np.array(
  101. [[0., 0.0, 0.0, 0.0, 1.0, 1.0]], dtype='float32'))
  102. fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  103. # add fake bbox when output is empty for each batch
  104. for i in range(bbox_num.shape[0]):
  105. if bbox_num[i] == 0:
  106. bboxes_i = fake_bboxes
  107. bbox_num_i = fake_bbox_num
  108. else:
  109. bboxes_i = bboxes[id_start:id_start + bbox_num[i], :]
  110. bbox_num_i = bbox_num[i]
  111. id_start += bbox_num[i]
  112. bboxes_list.append(bboxes_i)
  113. bbox_num_list.append(bbox_num_i)
  114. bboxes = paddle.concat(bboxes_list)
  115. bbox_num = paddle.concat(bbox_num_list)
  116. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  117. if not self.export_onnx:
  118. origin_shape_list = []
  119. scale_factor_list = []
  120. # scale_factor: scale_y, scale_x
  121. for i in range(bbox_num.shape[0]):
  122. expand_shape = paddle.expand(origin_shape[i:i + 1, :],
  123. [bbox_num[i], 2])
  124. scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
  125. scale = paddle.concat([scale_x, scale_y, scale_x, scale_y])
  126. expand_scale = paddle.expand(scale, [bbox_num[i], 4])
  127. origin_shape_list.append(expand_shape)
  128. scale_factor_list.append(expand_scale)
  129. self.origin_shape_list = paddle.concat(origin_shape_list)
  130. scale_factor_list = paddle.concat(scale_factor_list)
  131. else:
  132. # simplify the computation for bs=1 when exporting onnx
  133. scale_y, scale_x = scale_factor[0][0], scale_factor[0][1]
  134. scale = paddle.concat(
  135. [scale_x, scale_y, scale_x, scale_y]).unsqueeze(0)
  136. self.origin_shape_list = paddle.expand(origin_shape,
  137. [bbox_num[0], 2])
  138. scale_factor_list = paddle.expand(scale, [bbox_num[0], 4])
  139. # bboxes: [N, 6], label, score, bbox
  140. pred_label = bboxes[:, 0:1]
  141. pred_score = bboxes[:, 1:2]
  142. pred_bbox = bboxes[:, 2:]
  143. # rescale bbox to original image
  144. scaled_bbox = pred_bbox / scale_factor_list
  145. origin_h = self.origin_shape_list[:, 0]
  146. origin_w = self.origin_shape_list[:, 1]
  147. zeros = paddle.zeros_like(origin_h)
  148. # clip bbox to [0, original_size]
  149. x1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 0], origin_w), zeros)
  150. y1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 1], origin_h), zeros)
  151. x2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 2], origin_w), zeros)
  152. y2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 3], origin_h), zeros)
  153. pred_bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
  154. # filter empty bbox
  155. keep_mask = nonempty_bbox(pred_bbox, return_mask=True)
  156. keep_mask = paddle.unsqueeze(keep_mask, [1])
  157. pred_label = paddle.where(keep_mask, pred_label,
  158. paddle.ones_like(pred_label) * -1)
  159. pred_result = paddle.concat([pred_label, pred_score, pred_bbox], axis=1)
  160. return bboxes, pred_result, bbox_num
  161. def get_origin_shape(self, ):
  162. return self.origin_shape_list
  163. @register
  164. class MaskPostProcess(object):
  165. __shared__ = ['export_onnx', 'assign_on_cpu']
  166. """
  167. refer to:
  168. https://github.com/facebookresearch/detectron2/layers/mask_ops.py
  169. Get Mask output according to the output from model
  170. """
  171. def __init__(self,
  172. binary_thresh=0.5,
  173. export_onnx=False,
  174. assign_on_cpu=False):
  175. super(MaskPostProcess, self).__init__()
  176. self.binary_thresh = binary_thresh
  177. self.export_onnx = export_onnx
  178. self.assign_on_cpu = assign_on_cpu
  179. def paste_mask(self, masks, boxes, im_h, im_w):
  180. """
  181. Paste the mask prediction to the original image.
  182. """
  183. x0_int, y0_int = 0, 0
  184. x1_int, y1_int = im_w, im_h
  185. x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1)
  186. N = masks.shape[0]
  187. img_y = paddle.arange(y0_int, y1_int) + 0.5
  188. img_x = paddle.arange(x0_int, x1_int) + 0.5
  189. img_y = (img_y - y0) / (y1 - y0) * 2 - 1
  190. img_x = (img_x - x0) / (x1 - x0) * 2 - 1
  191. # img_x, img_y have shapes (N, w), (N, h)
  192. if self.assign_on_cpu:
  193. paddle.set_device('cpu')
  194. gx = img_x[:, None, :].expand(
  195. [N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]])
  196. gy = img_y[:, :, None].expand(
  197. [N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]])
  198. grid = paddle.stack([gx, gy], axis=3)
  199. img_masks = F.grid_sample(masks, grid, align_corners=False)
  200. return img_masks[:, 0]
  201. def __call__(self, mask_out, bboxes, bbox_num, origin_shape):
  202. """
  203. Decode the mask_out and paste the mask to the origin image.
  204. Args:
  205. mask_out (Tensor): mask_head output with shape [N, 28, 28].
  206. bbox_pred (Tensor): The output bboxes with shape [N, 6] after decode
  207. and NMS, including labels, scores and bboxes.
  208. bbox_num (Tensor): The number of prediction boxes of each batch with
  209. shape [1], and is N.
  210. origin_shape (Tensor): The origin shape of the input image, the tensor
  211. shape is [N, 2], and each row is [h, w].
  212. Returns:
  213. pred_result (Tensor): The final prediction mask results with shape
  214. [N, h, w] in binary mask style.
  215. """
  216. num_mask = mask_out.shape[0]
  217. origin_shape = paddle.cast(origin_shape, 'int32')
  218. device = paddle.device.get_device()
  219. if self.export_onnx:
  220. h, w = origin_shape[0][0], origin_shape[0][1]
  221. mask_onnx = self.paste_mask(mask_out[:, None, :, :], bboxes[:, 2:],
  222. h, w)
  223. mask_onnx = mask_onnx >= self.binary_thresh
  224. pred_result = paddle.cast(mask_onnx, 'int32')
  225. else:
  226. max_h = paddle.max(origin_shape[:, 0])
  227. max_w = paddle.max(origin_shape[:, 1])
  228. pred_result = paddle.zeros(
  229. [num_mask, max_h, max_w], dtype='int32') - 1
  230. id_start = 0
  231. for i in range(paddle.shape(bbox_num)[0]):
  232. bboxes_i = bboxes[id_start:id_start + bbox_num[i], :]
  233. mask_out_i = mask_out[id_start:id_start + bbox_num[i], :, :]
  234. im_h = origin_shape[i, 0]
  235. im_w = origin_shape[i, 1]
  236. bbox_num_i = bbox_num[id_start]
  237. pred_mask = self.paste_mask(mask_out_i[:, None, :, :],
  238. bboxes_i[:, 2:], im_h, im_w)
  239. pred_mask = paddle.cast(pred_mask >= self.binary_thresh,
  240. 'int32')
  241. pred_result[id_start:id_start + bbox_num[i], :im_h, :
  242. im_w] = pred_mask
  243. id_start += bbox_num[i]
  244. if self.assign_on_cpu:
  245. paddle.set_device(device)
  246. return pred_result
  247. @register
  248. class JDEBBoxPostProcess(nn.Layer):
  249. __shared__ = ['num_classes']
  250. __inject__ = ['decode', 'nms']
  251. def __init__(self, num_classes=1, decode=None, nms=None, return_idx=True):
  252. super(JDEBBoxPostProcess, self).__init__()
  253. self.num_classes = num_classes
  254. self.decode = decode
  255. self.nms = nms
  256. self.return_idx = return_idx
  257. self.fake_bbox_pred = paddle.to_tensor(
  258. np.array(
  259. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32'))
  260. self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  261. self.fake_nms_keep_idx = paddle.to_tensor(
  262. np.array(
  263. [[0]], dtype='int32'))
  264. self.fake_yolo_boxes_out = paddle.to_tensor(
  265. np.array(
  266. [[[0.0, 0.0, 0.0, 0.0]]], dtype='float32'))
  267. self.fake_yolo_scores_out = paddle.to_tensor(
  268. np.array(
  269. [[[0.0]]], dtype='float32'))
  270. self.fake_boxes_idx = paddle.to_tensor(np.array([[0]], dtype='int64'))
  271. def forward(self, head_out, anchors):
  272. """
  273. Decode the bbox and do NMS for JDE model.
  274. Args:
  275. head_out (list): Bbox_pred and cls_prob of bbox_head output.
  276. anchors (list): Anchors of JDE model.
  277. Returns:
  278. boxes_idx (Tensor): The index of kept bboxes after decode 'JDEBox'.
  279. bbox_pred (Tensor): The output is the prediction with shape [N, 6]
  280. including labels, scores and bboxes.
  281. bbox_num (Tensor): The number of prediction of each batch with shape [N].
  282. nms_keep_idx (Tensor): The index of kept bboxes after NMS.
  283. """
  284. boxes_idx, yolo_boxes_scores = self.decode(head_out, anchors)
  285. if len(boxes_idx) == 0:
  286. boxes_idx = self.fake_boxes_idx
  287. yolo_boxes_out = self.fake_yolo_boxes_out
  288. yolo_scores_out = self.fake_yolo_scores_out
  289. else:
  290. yolo_boxes = paddle.gather_nd(yolo_boxes_scores, boxes_idx)
  291. # TODO: only support bs=1 now
  292. yolo_boxes_out = paddle.reshape(
  293. yolo_boxes[:, :4], shape=[1, len(boxes_idx), 4])
  294. yolo_scores_out = paddle.reshape(
  295. yolo_boxes[:, 4:5], shape=[1, 1, len(boxes_idx)])
  296. boxes_idx = boxes_idx[:, 1:]
  297. if self.return_idx:
  298. bbox_pred, bbox_num, nms_keep_idx = self.nms(
  299. yolo_boxes_out, yolo_scores_out, self.num_classes)
  300. if bbox_pred.shape[0] == 0:
  301. bbox_pred = self.fake_bbox_pred
  302. bbox_num = self.fake_bbox_num
  303. nms_keep_idx = self.fake_nms_keep_idx
  304. return boxes_idx, bbox_pred, bbox_num, nms_keep_idx
  305. else:
  306. bbox_pred, bbox_num, _ = self.nms(yolo_boxes_out, yolo_scores_out,
  307. self.num_classes)
  308. if bbox_pred.shape[0] == 0:
  309. bbox_pred = self.fake_bbox_pred
  310. bbox_num = self.fake_bbox_num
  311. return _, bbox_pred, bbox_num, _
  312. @register
  313. class CenterNetPostProcess(object):
  314. """
  315. Postprocess the model outputs to get final prediction:
  316. 1. Do NMS for heatmap to get top `max_per_img` bboxes.
  317. 2. Decode bboxes using center offset and box size.
  318. 3. Rescale decoded bboxes reference to the origin image shape.
  319. Args:
  320. max_per_img(int): the maximum number of predicted objects in a image,
  321. 500 by default.
  322. down_ratio(int): the down ratio from images to heatmap, 4 by default.
  323. regress_ltrb (bool): whether to regress left/top/right/bottom or
  324. width/height for a box, true by default.
  325. """
  326. __shared__ = ['down_ratio']
  327. def __init__(self, max_per_img=500, down_ratio=4, regress_ltrb=True):
  328. super(CenterNetPostProcess, self).__init__()
  329. self.max_per_img = max_per_img
  330. self.down_ratio = down_ratio
  331. self.regress_ltrb = regress_ltrb
  332. # _simple_nms() _topk() are same as TTFBox in ppdet/modeling/layers.py
  333. def _simple_nms(self, heat, kernel=3):
  334. """ Use maxpool to filter the max score, get local peaks. """
  335. pad = (kernel - 1) // 2
  336. hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad)
  337. keep = paddle.cast(hmax == heat, 'float32')
  338. return heat * keep
  339. def _topk(self, scores):
  340. """ Select top k scores and decode to get xy coordinates. """
  341. k = self.max_per_img
  342. shape_fm = paddle.shape(scores)
  343. shape_fm.stop_gradient = True
  344. cat, height, width = shape_fm[1], shape_fm[2], shape_fm[3]
  345. # batch size is 1
  346. scores_r = paddle.reshape(scores, [cat, -1])
  347. topk_scores, topk_inds = paddle.topk(scores_r, k)
  348. topk_ys = topk_inds // width
  349. topk_xs = topk_inds % width
  350. topk_score_r = paddle.reshape(topk_scores, [-1])
  351. topk_score, topk_ind = paddle.topk(topk_score_r, k)
  352. k_t = paddle.full(paddle.shape(topk_ind), k, dtype='int64')
  353. topk_clses = paddle.cast(paddle.floor_divide(topk_ind, k_t), 'float32')
  354. topk_inds = paddle.reshape(topk_inds, [-1])
  355. topk_ys = paddle.reshape(topk_ys, [-1, 1])
  356. topk_xs = paddle.reshape(topk_xs, [-1, 1])
  357. topk_inds = paddle.gather(topk_inds, topk_ind)
  358. topk_ys = paddle.gather(topk_ys, topk_ind)
  359. topk_xs = paddle.gather(topk_xs, topk_ind)
  360. return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
  361. def __call__(self, hm, wh, reg, im_shape, scale_factor):
  362. # 1.get clses and scores, note that hm had been done sigmoid
  363. heat = self._simple_nms(hm)
  364. scores, inds, topk_clses, ys, xs = self._topk(heat)
  365. clses = topk_clses.unsqueeze(1)
  366. scores = scores.unsqueeze(1)
  367. # 2.get bboxes, note only support batch_size=1 now
  368. reg_t = paddle.transpose(reg, [0, 2, 3, 1])
  369. reg = paddle.reshape(reg_t, [-1, reg_t.shape[-1]])
  370. reg = paddle.gather(reg, inds)
  371. xs = paddle.cast(xs, 'float32')
  372. ys = paddle.cast(ys, 'float32')
  373. xs = xs + reg[:, 0:1]
  374. ys = ys + reg[:, 1:2]
  375. wh_t = paddle.transpose(wh, [0, 2, 3, 1])
  376. wh = paddle.reshape(wh_t, [-1, wh_t.shape[-1]])
  377. wh = paddle.gather(wh, inds)
  378. if self.regress_ltrb:
  379. x1 = xs - wh[:, 0:1]
  380. y1 = ys - wh[:, 1:2]
  381. x2 = xs + wh[:, 2:3]
  382. y2 = ys + wh[:, 3:4]
  383. else:
  384. x1 = xs - wh[:, 0:1] / 2
  385. y1 = ys - wh[:, 1:2] / 2
  386. x2 = xs + wh[:, 0:1] / 2
  387. y2 = ys + wh[:, 1:2] / 2
  388. n, c, feat_h, feat_w = paddle.shape(hm)
  389. padw = (feat_w * self.down_ratio - im_shape[0, 1]) / 2
  390. padh = (feat_h * self.down_ratio - im_shape[0, 0]) / 2
  391. x1 = x1 * self.down_ratio
  392. y1 = y1 * self.down_ratio
  393. x2 = x2 * self.down_ratio
  394. y2 = y2 * self.down_ratio
  395. x1 = x1 - padw
  396. y1 = y1 - padh
  397. x2 = x2 - padw
  398. y2 = y2 - padh
  399. bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
  400. scale_y = scale_factor[:, 0:1]
  401. scale_x = scale_factor[:, 1:2]
  402. scale_expand = paddle.concat(
  403. [scale_x, scale_y, scale_x, scale_y], axis=1)
  404. boxes_shape = bboxes.shape[:]
  405. scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
  406. bboxes = paddle.divide(bboxes, scale_expand)
  407. results = paddle.concat([clses, scores, bboxes], axis=1)
  408. return results, paddle.shape(results)[0:1], inds, topk_clses, ys, xs
  409. @register
  410. class DETRBBoxPostProcess(object):
  411. __shared__ = ['num_classes', 'use_focal_loss']
  412. __inject__ = []
  413. def __init__(self,
  414. num_classes=80,
  415. num_top_queries=100,
  416. use_focal_loss=False):
  417. super(DETRBBoxPostProcess, self).__init__()
  418. self.num_classes = num_classes
  419. self.num_top_queries = num_top_queries
  420. self.use_focal_loss = use_focal_loss
  421. def __call__(self, head_out, im_shape, scale_factor):
  422. """
  423. Decode the bbox.
  424. Args:
  425. head_out (tuple): bbox_pred, cls_logit and masks of bbox_head output.
  426. im_shape (Tensor): The shape of the input image.
  427. scale_factor (Tensor): The scale factor of the input image.
  428. Returns:
  429. bbox_pred (Tensor): The output prediction with shape [N, 6], including
  430. labels, scores and bboxes. The size of bboxes are corresponding
  431. to the input image, the bboxes may be used in other branch.
  432. bbox_num (Tensor): The number of prediction boxes of each batch with
  433. shape [bs], and is N.
  434. """
  435. bboxes, logits, masks = head_out
  436. bbox_pred = bbox_cxcywh_to_xyxy(bboxes)
  437. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  438. img_h, img_w = paddle.split(origin_shape, 2, axis=-1)
  439. origin_shape = paddle.concat(
  440. [img_w, img_h, img_w, img_h], axis=-1).reshape([-1, 1, 4])
  441. bbox_pred *= origin_shape
  442. scores = F.sigmoid(logits) if self.use_focal_loss else F.softmax(
  443. logits)[:, :, :-1]
  444. if not self.use_focal_loss:
  445. scores, labels = scores.max(-1), scores.argmax(-1)
  446. if scores.shape[1] > self.num_top_queries:
  447. scores, index = paddle.topk(
  448. scores, self.num_top_queries, axis=-1)
  449. batch_ind = paddle.arange(
  450. end=scores.shape[0]).unsqueeze(-1).tile(
  451. [1, self.num_top_queries])
  452. index = paddle.stack([batch_ind, index], axis=-1)
  453. labels = paddle.gather_nd(labels, index)
  454. bbox_pred = paddle.gather_nd(bbox_pred, index)
  455. else:
  456. scores, index = paddle.topk(
  457. scores.flatten(1), self.num_top_queries, axis=-1)
  458. labels = index % self.num_classes
  459. index = index // self.num_classes
  460. batch_ind = paddle.arange(end=scores.shape[0]).unsqueeze(-1).tile(
  461. [1, self.num_top_queries])
  462. index = paddle.stack([batch_ind, index], axis=-1)
  463. bbox_pred = paddle.gather_nd(bbox_pred, index)
  464. bbox_pred = paddle.concat(
  465. [
  466. labels.unsqueeze(-1).astype('float32'), scores.unsqueeze(-1),
  467. bbox_pred
  468. ],
  469. axis=-1)
  470. bbox_num = paddle.to_tensor(
  471. bbox_pred.shape[1], dtype='int32').tile([bbox_pred.shape[0]])
  472. bbox_pred = bbox_pred.reshape([-1, 6])
  473. return bbox_pred, bbox_num
  474. @register
  475. class SparsePostProcess(object):
  476. __shared__ = ['num_classes']
  477. def __init__(self, num_proposals, num_classes=80):
  478. super(SparsePostProcess, self).__init__()
  479. self.num_classes = num_classes
  480. self.num_proposals = num_proposals
  481. def __call__(self, box_cls, box_pred, scale_factor_wh, img_whwh):
  482. """
  483. Arguments:
  484. box_cls (Tensor): tensor of shape (batch_size, num_proposals, K).
  485. The tensor predicts the classification probability for each proposal.
  486. box_pred (Tensor): tensors of shape (batch_size, num_proposals, 4).
  487. The tensor predicts 4-vector (x,y,w,h) box
  488. regression values for every proposal
  489. scale_factor_wh (Tensor): tensors of shape [batch_size, 2] the scalor of per img
  490. img_whwh (Tensor): tensors of shape [batch_size, 4]
  491. Returns:
  492. bbox_pred (Tensor): tensors of shape [num_boxes, 6] Each row has 6 values:
  493. [label, confidence, xmin, ymin, xmax, ymax]
  494. bbox_num (Tensor): tensors of shape [batch_size] the number of RoIs in each image.
  495. """
  496. assert len(box_cls) == len(scale_factor_wh) == len(img_whwh)
  497. img_wh = img_whwh[:, :2]
  498. scores = F.sigmoid(box_cls)
  499. labels = paddle.arange(0, self.num_classes). \
  500. unsqueeze(0).tile([self.num_proposals, 1]).flatten(start_axis=0, stop_axis=1)
  501. classes_all = []
  502. scores_all = []
  503. boxes_all = []
  504. for i, (scores_per_image,
  505. box_pred_per_image) in enumerate(zip(scores, box_pred)):
  506. scores_per_image, topk_indices = scores_per_image.flatten(
  507. 0, 1).topk(
  508. self.num_proposals, sorted=False)
  509. labels_per_image = paddle.gather(labels, topk_indices, axis=0)
  510. box_pred_per_image = box_pred_per_image.reshape([-1, 1, 4]).tile(
  511. [1, self.num_classes, 1]).reshape([-1, 4])
  512. box_pred_per_image = paddle.gather(
  513. box_pred_per_image, topk_indices, axis=0)
  514. classes_all.append(labels_per_image)
  515. scores_all.append(scores_per_image)
  516. boxes_all.append(box_pred_per_image)
  517. bbox_num = paddle.zeros([len(scale_factor_wh)], dtype="int32")
  518. boxes_final = []
  519. for i in range(len(scale_factor_wh)):
  520. classes = classes_all[i]
  521. boxes = boxes_all[i]
  522. scores = scores_all[i]
  523. boxes[:, 0::2] = paddle.clip(
  524. boxes[:, 0::2], min=0, max=img_wh[i][0]) / scale_factor_wh[i][0]
  525. boxes[:, 1::2] = paddle.clip(
  526. boxes[:, 1::2], min=0, max=img_wh[i][1]) / scale_factor_wh[i][1]
  527. boxes_w, boxes_h = (boxes[:, 2] - boxes[:, 0]).numpy(), (
  528. boxes[:, 3] - boxes[:, 1]).numpy()
  529. keep = (boxes_w > 1.) & (boxes_h > 1.)
  530. if (keep.sum() == 0):
  531. bboxes = paddle.zeros([1, 6]).astype("float32")
  532. else:
  533. boxes = paddle.to_tensor(boxes.numpy()[keep]).astype("float32")
  534. classes = paddle.to_tensor(classes.numpy()[keep]).astype(
  535. "float32").unsqueeze(-1)
  536. scores = paddle.to_tensor(scores.numpy()[keep]).astype(
  537. "float32").unsqueeze(-1)
  538. bboxes = paddle.concat([classes, scores, boxes], axis=-1)
  539. boxes_final.append(bboxes)
  540. bbox_num[i] = bboxes.shape[0]
  541. bbox_pred = paddle.concat(boxes_final)
  542. return bbox_pred, bbox_num
  543. def multiclass_nms(bboxs, num_classes, match_threshold=0.6, match_metric='iou'):
  544. final_boxes = []
  545. for c in range(num_classes):
  546. idxs = bboxs[:, 0] == c
  547. if np.count_nonzero(idxs) == 0: continue
  548. r = nms(bboxs[idxs, 1:], match_threshold, match_metric)
  549. final_boxes.append(np.concatenate([np.full((r.shape[0], 1), c), r], 1))
  550. return final_boxes
  551. def nms(dets, match_threshold=0.6, match_metric='iou'):
  552. """ Apply NMS to avoid detecting too many overlapping bounding boxes.
  553. Args:
  554. dets: shape [N, 5], [score, x1, y1, x2, y2]
  555. match_metric: 'iou' or 'ios'
  556. match_threshold: overlap thresh for match metric.
  557. """
  558. if dets.shape[0] == 0:
  559. return dets[[], :]
  560. scores = dets[:, 0]
  561. x1 = dets[:, 1]
  562. y1 = dets[:, 2]
  563. x2 = dets[:, 3]
  564. y2 = dets[:, 4]
  565. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  566. order = scores.argsort()[::-1]
  567. ndets = dets.shape[0]
  568. suppressed = np.zeros((ndets), dtype=np.int32)
  569. for _i in range(ndets):
  570. i = order[_i]
  571. if suppressed[i] == 1:
  572. continue
  573. ix1 = x1[i]
  574. iy1 = y1[i]
  575. ix2 = x2[i]
  576. iy2 = y2[i]
  577. iarea = areas[i]
  578. for _j in range(_i + 1, ndets):
  579. j = order[_j]
  580. if suppressed[j] == 1:
  581. continue
  582. xx1 = max(ix1, x1[j])
  583. yy1 = max(iy1, y1[j])
  584. xx2 = min(ix2, x2[j])
  585. yy2 = min(iy2, y2[j])
  586. w = max(0.0, xx2 - xx1 + 1)
  587. h = max(0.0, yy2 - yy1 + 1)
  588. inter = w * h
  589. if match_metric == 'iou':
  590. union = iarea + areas[j] - inter
  591. match_value = inter / union
  592. elif match_metric == 'ios':
  593. smaller = min(iarea, areas[j])
  594. match_value = inter / smaller
  595. else:
  596. raise ValueError()
  597. if match_value >= match_threshold:
  598. suppressed[j] = 1
  599. keep = np.where(suppressed == 0)[0]
  600. dets = dets[keep, :]
  601. return dets