mot_operators.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. try:
  18. from collections.abc import Sequence
  19. except Exception:
  20. from collections import Sequence
  21. from numbers import Integral
  22. import cv2
  23. import copy
  24. import numpy as np
  25. import random
  26. import math
  27. from .operators import BaseOperator, register_op
  28. from .batch_operators import Gt2TTFTarget
  29. from ppdet.modeling.bbox_utils import bbox_iou_np_expand
  30. from ppdet.utils.logger import setup_logger
  31. from .op_helper import gaussian_radius
  32. logger = setup_logger(__name__)
  33. __all__ = [
  34. 'RGBReverse', 'LetterBoxResize', 'MOTRandomAffine', 'Gt2JDETargetThres',
  35. 'Gt2JDETargetMax', 'Gt2FairMOTTarget'
  36. ]
  37. @register_op
  38. class RGBReverse(BaseOperator):
  39. """RGB to BGR, or BGR to RGB, sensitive to MOTRandomAffine
  40. """
  41. def __init__(self):
  42. super(RGBReverse, self).__init__()
  43. def apply(self, sample, context=None):
  44. im = sample['image']
  45. sample['image'] = np.ascontiguousarray(im[:, :, ::-1])
  46. return sample
  47. @register_op
  48. class LetterBoxResize(BaseOperator):
  49. def __init__(self, target_size):
  50. """
  51. Resize image to target size, convert normalized xywh to pixel xyxy
  52. format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
  53. Args:
  54. target_size (int|list): image target size.
  55. """
  56. super(LetterBoxResize, self).__init__()
  57. if not isinstance(target_size, (Integral, Sequence)):
  58. raise TypeError(
  59. "Type of target_size is invalid. Must be Integer or List or Tuple, now is {}".
  60. format(type(target_size)))
  61. if isinstance(target_size, Integral):
  62. target_size = [target_size, target_size]
  63. self.target_size = target_size
  64. def apply_image(self, img, height, width, color=(127.5, 127.5, 127.5)):
  65. # letterbox: resize a rectangular image to a padded rectangular
  66. shape = img.shape[:2] # [height, width]
  67. ratio_h = float(height) / shape[0]
  68. ratio_w = float(width) / shape[1]
  69. ratio = min(ratio_h, ratio_w)
  70. new_shape = (round(shape[1] * ratio),
  71. round(shape[0] * ratio)) # [width, height]
  72. padw = (width - new_shape[0]) / 2
  73. padh = (height - new_shape[1]) / 2
  74. top, bottom = round(padh - 0.1), round(padh + 0.1)
  75. left, right = round(padw - 0.1), round(padw + 0.1)
  76. img = cv2.resize(
  77. img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
  78. img = cv2.copyMakeBorder(
  79. img, top, bottom, left, right, cv2.BORDER_CONSTANT,
  80. value=color) # padded rectangular
  81. return img, ratio, padw, padh
  82. def apply_bbox(self, bbox0, h, w, ratio, padw, padh):
  83. bboxes = bbox0.copy()
  84. bboxes[:, 0] = ratio * w * (bbox0[:, 0] - bbox0[:, 2] / 2) + padw
  85. bboxes[:, 1] = ratio * h * (bbox0[:, 1] - bbox0[:, 3] / 2) + padh
  86. bboxes[:, 2] = ratio * w * (bbox0[:, 0] + bbox0[:, 2] / 2) + padw
  87. bboxes[:, 3] = ratio * h * (bbox0[:, 1] + bbox0[:, 3] / 2) + padh
  88. return bboxes
  89. def apply(self, sample, context=None):
  90. """ Resize the image numpy.
  91. """
  92. im = sample['image']
  93. h, w = sample['im_shape']
  94. if not isinstance(im, np.ndarray):
  95. raise TypeError("{}: image type is not numpy.".format(self))
  96. if len(im.shape) != 3:
  97. from PIL import UnidentifiedImageError
  98. raise UnidentifiedImageError(
  99. '{}: image is not 3-dimensional.'.format(self))
  100. # apply image
  101. height, width = self.target_size
  102. img, ratio, padw, padh = self.apply_image(
  103. im, height=height, width=width)
  104. sample['image'] = img
  105. new_shape = (round(h * ratio), round(w * ratio))
  106. sample['im_shape'] = np.asarray(new_shape, dtype=np.float32)
  107. sample['scale_factor'] = np.asarray([ratio, ratio], dtype=np.float32)
  108. # apply bbox
  109. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  110. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], h, w, ratio,
  111. padw, padh)
  112. return sample
  113. @register_op
  114. class MOTRandomAffine(BaseOperator):
  115. """
  116. Affine transform to image and coords to achieve the rotate, scale and
  117. shift effect for training image.
  118. Args:
  119. degrees (list[2]): the rotate range to apply, transform range is [min, max]
  120. translate (list[2]): the translate range to apply, transform range is [min, max]
  121. scale (list[2]): the scale range to apply, transform range is [min, max]
  122. shear (list[2]): the shear range to apply, transform range is [min, max]
  123. borderValue (list[3]): value used in case of a constant border when appling
  124. the perspective transformation
  125. reject_outside (bool): reject warped bounding bboxes outside of image
  126. Returns:
  127. records(dict): contain the image and coords after tranformed
  128. """
  129. def __init__(self,
  130. degrees=(-5, 5),
  131. translate=(0.10, 0.10),
  132. scale=(0.50, 1.20),
  133. shear=(-2, 2),
  134. borderValue=(127.5, 127.5, 127.5),
  135. reject_outside=True):
  136. super(MOTRandomAffine, self).__init__()
  137. self.degrees = degrees
  138. self.translate = translate
  139. self.scale = scale
  140. self.shear = shear
  141. self.borderValue = borderValue
  142. self.reject_outside = reject_outside
  143. def apply(self, sample, context=None):
  144. # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
  145. border = 0 # width of added border (optional)
  146. img = sample['image']
  147. height, width = img.shape[0], img.shape[1]
  148. # Rotation and Scale
  149. R = np.eye(3)
  150. a = random.random() * (self.degrees[1] - self.degrees[0]
  151. ) + self.degrees[0]
  152. s = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0]
  153. R[:2] = cv2.getRotationMatrix2D(
  154. angle=a, center=(width / 2, height / 2), scale=s)
  155. # Translation
  156. T = np.eye(3)
  157. T[0, 2] = (
  158. random.random() * 2 - 1
  159. ) * self.translate[0] * height + border # x translation (pixels)
  160. T[1, 2] = (
  161. random.random() * 2 - 1
  162. ) * self.translate[1] * width + border # y translation (pixels)
  163. # Shear
  164. S = np.eye(3)
  165. S[0, 1] = math.tan((random.random() *
  166. (self.shear[1] - self.shear[0]) + self.shear[0]) *
  167. math.pi / 180) # x shear (deg)
  168. S[1, 0] = math.tan((random.random() *
  169. (self.shear[1] - self.shear[0]) + self.shear[0]) *
  170. math.pi / 180) # y shear (deg)
  171. M = S @T @R # Combined rotation matrix. ORDER IS IMPORTANT HERE!!
  172. imw = cv2.warpPerspective(
  173. img,
  174. M,
  175. dsize=(width, height),
  176. flags=cv2.INTER_LINEAR,
  177. borderValue=self.borderValue) # BGR order borderValue
  178. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  179. targets = sample['gt_bbox']
  180. n = targets.shape[0]
  181. points = targets.copy()
  182. area0 = (points[:, 2] - points[:, 0]) * (
  183. points[:, 3] - points[:, 1])
  184. # warp points
  185. xy = np.ones((n * 4, 3))
  186. xy[:, :2] = points[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
  187. n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  188. xy = (xy @M.T)[:, :2].reshape(n, 8)
  189. # create new boxes
  190. x = xy[:, [0, 2, 4, 6]]
  191. y = xy[:, [1, 3, 5, 7]]
  192. xy = np.concatenate(
  193. (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  194. # apply angle-based reduction
  195. radians = a * math.pi / 180
  196. reduction = max(abs(math.sin(radians)), abs(math.cos(radians)))**0.5
  197. x = (xy[:, 2] + xy[:, 0]) / 2
  198. y = (xy[:, 3] + xy[:, 1]) / 2
  199. w = (xy[:, 2] - xy[:, 0]) * reduction
  200. h = (xy[:, 3] - xy[:, 1]) * reduction
  201. xy = np.concatenate(
  202. (x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
  203. # reject warped points outside of image
  204. if self.reject_outside:
  205. np.clip(xy[:, 0], 0, width, out=xy[:, 0])
  206. np.clip(xy[:, 2], 0, width, out=xy[:, 2])
  207. np.clip(xy[:, 1], 0, height, out=xy[:, 1])
  208. np.clip(xy[:, 3], 0, height, out=xy[:, 3])
  209. w = xy[:, 2] - xy[:, 0]
  210. h = xy[:, 3] - xy[:, 1]
  211. area = w * h
  212. ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16))
  213. i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.1) & (ar < 10)
  214. if sum(i) > 0:
  215. sample['gt_bbox'] = xy[i].astype(sample['gt_bbox'].dtype)
  216. sample['gt_class'] = sample['gt_class'][i]
  217. if 'difficult' in sample:
  218. sample['difficult'] = sample['difficult'][i]
  219. if 'gt_ide' in sample:
  220. sample['gt_ide'] = sample['gt_ide'][i]
  221. if 'is_crowd' in sample:
  222. sample['is_crowd'] = sample['is_crowd'][i]
  223. sample['image'] = imw
  224. return sample
  225. else:
  226. return sample
  227. @register_op
  228. class Gt2JDETargetThres(BaseOperator):
  229. __shared__ = ['num_classes']
  230. """
  231. Generate JDE targets by groud truth data when training
  232. Args:
  233. anchors (list): anchors of JDE model
  234. anchor_masks (list): anchor_masks of JDE model
  235. downsample_ratios (list): downsample ratios of JDE model
  236. ide_thresh (float): thresh of identity, higher is groud truth
  237. fg_thresh (float): thresh of foreground, higher is foreground
  238. bg_thresh (float): thresh of background, lower is background
  239. num_classes (int): number of classes
  240. """
  241. def __init__(self,
  242. anchors,
  243. anchor_masks,
  244. downsample_ratios,
  245. ide_thresh=0.5,
  246. fg_thresh=0.5,
  247. bg_thresh=0.4,
  248. num_classes=1):
  249. super(Gt2JDETargetThres, self).__init__()
  250. self.anchors = anchors
  251. self.anchor_masks = anchor_masks
  252. self.downsample_ratios = downsample_ratios
  253. self.ide_thresh = ide_thresh
  254. self.fg_thresh = fg_thresh
  255. self.bg_thresh = bg_thresh
  256. self.num_classes = num_classes
  257. def generate_anchor(self, nGh, nGw, anchor_hw):
  258. nA = len(anchor_hw)
  259. yy, xx = np.meshgrid(np.arange(nGh), np.arange(nGw))
  260. mesh = np.stack([xx.T, yy.T], axis=0) # [2, nGh, nGw]
  261. mesh = np.repeat(mesh[None, :], nA, axis=0) # [nA, 2, nGh, nGw]
  262. anchor_offset_mesh = anchor_hw[:, :, None][:, :, :, None]
  263. anchor_offset_mesh = np.repeat(anchor_offset_mesh, nGh, axis=-2)
  264. anchor_offset_mesh = np.repeat(anchor_offset_mesh, nGw, axis=-1)
  265. anchor_mesh = np.concatenate(
  266. [mesh, anchor_offset_mesh], axis=1) # [nA, 4, nGh, nGw]
  267. return anchor_mesh
  268. def encode_delta(self, gt_box_list, fg_anchor_list):
  269. px, py, pw, ph = fg_anchor_list[:, 0], fg_anchor_list[:,1], \
  270. fg_anchor_list[:, 2], fg_anchor_list[:,3]
  271. gx, gy, gw, gh = gt_box_list[:, 0], gt_box_list[:, 1], \
  272. gt_box_list[:, 2], gt_box_list[:, 3]
  273. dx = (gx - px) / pw
  274. dy = (gy - py) / ph
  275. dw = np.log(gw / pw)
  276. dh = np.log(gh / ph)
  277. return np.stack([dx, dy, dw, dh], axis=1)
  278. def pad_box(self, sample, num_max):
  279. assert 'gt_bbox' in sample
  280. bbox = sample['gt_bbox']
  281. gt_num = len(bbox)
  282. pad_bbox = np.zeros((num_max, 4), dtype=np.float32)
  283. if gt_num > 0:
  284. pad_bbox[:gt_num, :] = bbox[:gt_num, :]
  285. sample['gt_bbox'] = pad_bbox
  286. if 'gt_score' in sample:
  287. pad_score = np.zeros((num_max, ), dtype=np.float32)
  288. if gt_num > 0:
  289. pad_score[:gt_num] = sample['gt_score'][:gt_num, 0]
  290. sample['gt_score'] = pad_score
  291. if 'difficult' in sample:
  292. pad_diff = np.zeros((num_max, ), dtype=np.int32)
  293. if gt_num > 0:
  294. pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
  295. sample['difficult'] = pad_diff
  296. if 'is_crowd' in sample:
  297. pad_crowd = np.zeros((num_max, ), dtype=np.int32)
  298. if gt_num > 0:
  299. pad_crowd[:gt_num] = sample['is_crowd'][:gt_num, 0]
  300. sample['is_crowd'] = pad_crowd
  301. if 'gt_ide' in sample:
  302. pad_ide = np.zeros((num_max, ), dtype=np.int32)
  303. if gt_num > 0:
  304. pad_ide[:gt_num] = sample['gt_ide'][:gt_num, 0]
  305. sample['gt_ide'] = pad_ide
  306. return sample
  307. def __call__(self, samples, context=None):
  308. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  309. "anchor_masks', and 'downsample_ratios' should have same length."
  310. h, w = samples[0]['image'].shape[1:3]
  311. num_max = 0
  312. for sample in samples:
  313. num_max = max(num_max, len(sample['gt_bbox']))
  314. for sample in samples:
  315. gt_bbox = sample['gt_bbox']
  316. gt_ide = sample['gt_ide']
  317. for i, (anchor_hw, downsample_ratio
  318. ) in enumerate(zip(self.anchors, self.downsample_ratios)):
  319. anchor_hw = np.array(
  320. anchor_hw, dtype=np.float32) / downsample_ratio
  321. nA = len(anchor_hw)
  322. nGh, nGw = int(h / downsample_ratio), int(w / downsample_ratio)
  323. tbox = np.zeros((nA, nGh, nGw, 4), dtype=np.float32)
  324. tconf = np.zeros((nA, nGh, nGw), dtype=np.float32)
  325. tid = -np.ones((nA, nGh, nGw, 1), dtype=np.float32)
  326. gxy, gwh = gt_bbox[:, 0:2].copy(), gt_bbox[:, 2:4].copy()
  327. gxy[:, 0] = gxy[:, 0] * nGw
  328. gxy[:, 1] = gxy[:, 1] * nGh
  329. gwh[:, 0] = gwh[:, 0] * nGw
  330. gwh[:, 1] = gwh[:, 1] * nGh
  331. gxy[:, 0] = np.clip(gxy[:, 0], 0, nGw - 1)
  332. gxy[:, 1] = np.clip(gxy[:, 1], 0, nGh - 1)
  333. tboxes = np.concatenate([gxy, gwh], axis=1)
  334. anchor_mesh = self.generate_anchor(nGh, nGw, anchor_hw)
  335. anchor_list = np.transpose(anchor_mesh,
  336. (0, 2, 3, 1)).reshape(-1, 4)
  337. iou_pdist = bbox_iou_np_expand(
  338. anchor_list, tboxes, x1y1x2y2=False)
  339. iou_max = np.max(iou_pdist, axis=1)
  340. max_gt_index = np.argmax(iou_pdist, axis=1)
  341. iou_map = iou_max.reshape(nA, nGh, nGw)
  342. gt_index_map = max_gt_index.reshape(nA, nGh, nGw)
  343. id_index = iou_map > self.ide_thresh
  344. fg_index = iou_map > self.fg_thresh
  345. bg_index = iou_map < self.bg_thresh
  346. ign_index = (iou_map < self.fg_thresh) * (
  347. iou_map > self.bg_thresh)
  348. tconf[fg_index] = 1
  349. tconf[bg_index] = 0
  350. tconf[ign_index] = -1
  351. gt_index = gt_index_map[fg_index]
  352. gt_box_list = tboxes[gt_index]
  353. gt_id_list = gt_ide[gt_index_map[id_index]]
  354. if np.sum(fg_index) > 0:
  355. tid[id_index] = gt_id_list
  356. fg_anchor_list = anchor_list.reshape(nA, nGh, nGw,
  357. 4)[fg_index]
  358. delta_target = self.encode_delta(gt_box_list,
  359. fg_anchor_list)
  360. tbox[fg_index] = delta_target
  361. sample['tbox{}'.format(i)] = tbox
  362. sample['tconf{}'.format(i)] = tconf
  363. sample['tide{}'.format(i)] = tid
  364. sample.pop('gt_class')
  365. sample = self.pad_box(sample, num_max)
  366. return samples
  367. @register_op
  368. class Gt2JDETargetMax(BaseOperator):
  369. __shared__ = ['num_classes']
  370. """
  371. Generate JDE targets by groud truth data when evaluating
  372. Args:
  373. anchors (list): anchors of JDE model
  374. anchor_masks (list): anchor_masks of JDE model
  375. downsample_ratios (list): downsample ratios of JDE model
  376. max_iou_thresh (float): iou thresh for high quality anchor
  377. num_classes (int): number of classes
  378. """
  379. def __init__(self,
  380. anchors,
  381. anchor_masks,
  382. downsample_ratios,
  383. max_iou_thresh=0.60,
  384. num_classes=1):
  385. super(Gt2JDETargetMax, self).__init__()
  386. self.anchors = anchors
  387. self.anchor_masks = anchor_masks
  388. self.downsample_ratios = downsample_ratios
  389. self.max_iou_thresh = max_iou_thresh
  390. self.num_classes = num_classes
  391. def __call__(self, samples, context=None):
  392. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  393. "anchor_masks', and 'downsample_ratios' should have same length."
  394. h, w = samples[0]['image'].shape[1:3]
  395. for sample in samples:
  396. gt_bbox = sample['gt_bbox']
  397. gt_ide = sample['gt_ide']
  398. for i, (anchor_hw, downsample_ratio
  399. ) in enumerate(zip(self.anchors, self.downsample_ratios)):
  400. anchor_hw = np.array(
  401. anchor_hw, dtype=np.float32) / downsample_ratio
  402. nA = len(anchor_hw)
  403. nGh, nGw = int(h / downsample_ratio), int(w / downsample_ratio)
  404. tbox = np.zeros((nA, nGh, nGw, 4), dtype=np.float32)
  405. tconf = np.zeros((nA, nGh, nGw), dtype=np.float32)
  406. tid = -np.ones((nA, nGh, nGw, 1), dtype=np.float32)
  407. gxy, gwh = gt_bbox[:, 0:2].copy(), gt_bbox[:, 2:4].copy()
  408. gxy[:, 0] = gxy[:, 0] * nGw
  409. gxy[:, 1] = gxy[:, 1] * nGh
  410. gwh[:, 0] = gwh[:, 0] * nGw
  411. gwh[:, 1] = gwh[:, 1] * nGh
  412. gi = np.clip(gxy[:, 0], 0, nGw - 1).astype(int)
  413. gj = np.clip(gxy[:, 1], 0, nGh - 1).astype(int)
  414. # iou of targets-anchors (using wh only)
  415. box1 = gwh
  416. box2 = anchor_hw[:, None, :]
  417. inter_area = np.minimum(box1, box2).prod(2)
  418. iou = inter_area / (
  419. box1.prod(1) + box2.prod(2) - inter_area + 1e-16)
  420. # Select best iou_pred and anchor
  421. iou_best = iou.max(0) # best anchor [0-2] for each target
  422. a = np.argmax(iou, axis=0)
  423. # Select best unique target-anchor combinations
  424. iou_order = np.argsort(-iou_best) # best to worst
  425. # Unique anchor selection
  426. u = np.stack((gi, gj, a), 0)[:, iou_order]
  427. _, first_unique = np.unique(u, axis=1, return_index=True)
  428. mask = iou_order[first_unique]
  429. # best anchor must share significant commonality (iou) with target
  430. # TODO: examine arbitrary threshold
  431. idx = mask[iou_best[mask] > self.max_iou_thresh]
  432. if len(idx) > 0:
  433. a_i, gj_i, gi_i = a[idx], gj[idx], gi[idx]
  434. t_box = gt_bbox[idx]
  435. t_id = gt_ide[idx]
  436. if len(t_box.shape) == 1:
  437. t_box = t_box.reshape(1, 4)
  438. gxy, gwh = t_box[:, 0:2].copy(), t_box[:, 2:4].copy()
  439. gxy[:, 0] = gxy[:, 0] * nGw
  440. gxy[:, 1] = gxy[:, 1] * nGh
  441. gwh[:, 0] = gwh[:, 0] * nGw
  442. gwh[:, 1] = gwh[:, 1] * nGh
  443. # XY coordinates
  444. tbox[:, :, :, 0:2][a_i, gj_i, gi_i] = gxy - gxy.astype(int)
  445. # Width and height in yolo method
  446. tbox[:, :, :, 2:4][a_i, gj_i, gi_i] = np.log(gwh /
  447. anchor_hw[a_i])
  448. tconf[a_i, gj_i, gi_i] = 1
  449. tid[a_i, gj_i, gi_i] = t_id
  450. sample['tbox{}'.format(i)] = tbox
  451. sample['tconf{}'.format(i)] = tconf
  452. sample['tide{}'.format(i)] = tid
  453. class Gt2FairMOTTarget(Gt2TTFTarget):
  454. __shared__ = ['num_classes']
  455. """
  456. Generate FairMOT targets by ground truth data.
  457. Difference between Gt2FairMOTTarget and Gt2TTFTarget are:
  458. 1. the gaussian kernal radius to generate a heatmap.
  459. 2. the targets needed during training.
  460. Args:
  461. num_classes(int): the number of classes.
  462. down_ratio(int): the down ratio from images to heatmap, 4 by default.
  463. max_objs(int): the maximum number of ground truth objects in a image, 500 by default.
  464. """
  465. def __init__(self, num_classes=1, down_ratio=4, max_objs=500):
  466. super(Gt2TTFTarget, self).__init__()
  467. self.down_ratio = down_ratio
  468. self.num_classes = num_classes
  469. self.max_objs = max_objs
  470. def __call__(self, samples, context=None):
  471. for b_id, sample in enumerate(samples):
  472. output_h = sample['image'].shape[1] // self.down_ratio
  473. output_w = sample['image'].shape[2] // self.down_ratio
  474. heatmap = np.zeros(
  475. (self.num_classes, output_h, output_w), dtype='float32')
  476. bbox_size = np.zeros((self.max_objs, 4), dtype=np.float32)
  477. center_offset = np.zeros((self.max_objs, 2), dtype=np.float32)
  478. index = np.zeros((self.max_objs, ), dtype=np.int64)
  479. index_mask = np.zeros((self.max_objs, ), dtype=np.int32)
  480. reid = np.zeros((self.max_objs, ), dtype=np.int64)
  481. bbox_xys = np.zeros((self.max_objs, 4), dtype=np.float32)
  482. if self.num_classes > 1:
  483. # each category corresponds to a set of track ids
  484. cls_tr_ids = np.zeros(
  485. (self.num_classes, output_h, output_w), dtype=np.int64)
  486. cls_id_map = np.full((output_h, output_w), -1, dtype=np.int64)
  487. gt_bbox = sample['gt_bbox']
  488. gt_class = sample['gt_class']
  489. gt_ide = sample['gt_ide']
  490. for k in range(len(gt_bbox)):
  491. cls_id = gt_class[k][0]
  492. bbox = gt_bbox[k]
  493. ide = gt_ide[k][0]
  494. bbox[[0, 2]] = bbox[[0, 2]] * output_w
  495. bbox[[1, 3]] = bbox[[1, 3]] * output_h
  496. bbox_amodal = copy.deepcopy(bbox)
  497. bbox_amodal[0] = bbox_amodal[0] - bbox_amodal[2] / 2.
  498. bbox_amodal[1] = bbox_amodal[1] - bbox_amodal[3] / 2.
  499. bbox_amodal[2] = bbox_amodal[0] + bbox_amodal[2]
  500. bbox_amodal[3] = bbox_amodal[1] + bbox_amodal[3]
  501. bbox[0] = np.clip(bbox[0], 0, output_w - 1)
  502. bbox[1] = np.clip(bbox[1], 0, output_h - 1)
  503. h = bbox[3]
  504. w = bbox[2]
  505. bbox_xy = copy.deepcopy(bbox)
  506. bbox_xy[0] = bbox_xy[0] - bbox_xy[2] / 2
  507. bbox_xy[1] = bbox_xy[1] - bbox_xy[3] / 2
  508. bbox_xy[2] = bbox_xy[0] + bbox_xy[2]
  509. bbox_xy[3] = bbox_xy[1] + bbox_xy[3]
  510. if h > 0 and w > 0:
  511. radius = gaussian_radius((math.ceil(h), math.ceil(w)), 0.7)
  512. radius = max(0, int(radius))
  513. ct = np.array([bbox[0], bbox[1]], dtype=np.float32)
  514. ct_int = ct.astype(np.int32)
  515. self.draw_truncate_gaussian(heatmap[cls_id], ct_int, radius,
  516. radius)
  517. bbox_size[k] = ct[0] - bbox_amodal[0], ct[1] - bbox_amodal[1], \
  518. bbox_amodal[2] - ct[0], bbox_amodal[3] - ct[1]
  519. index[k] = ct_int[1] * output_w + ct_int[0]
  520. center_offset[k] = ct - ct_int
  521. index_mask[k] = 1
  522. reid[k] = ide
  523. bbox_xys[k] = bbox_xy
  524. if self.num_classes > 1:
  525. cls_id_map[ct_int[1], ct_int[0]] = cls_id
  526. cls_tr_ids[cls_id][ct_int[1]][ct_int[0]] = ide - 1
  527. # track id start from 0
  528. sample['heatmap'] = heatmap
  529. sample['index'] = index
  530. sample['offset'] = center_offset
  531. sample['size'] = bbox_size
  532. sample['index_mask'] = index_mask
  533. sample['reid'] = reid
  534. if self.num_classes > 1:
  535. sample['cls_id_map'] = cls_id_map
  536. sample['cls_tr_ids'] = cls_tr_ids
  537. sample['bbox_xys'] = bbox_xys
  538. sample.pop('is_crowd', None)
  539. sample.pop('difficult', None)
  540. sample.pop('gt_class', None)
  541. sample.pop('gt_bbox', None)
  542. sample.pop('gt_score', None)
  543. sample.pop('gt_ide', None)
  544. return samples