mot_centertrack_infer.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import copy
  16. import math
  17. import time
  18. import yaml
  19. import cv2
  20. import numpy as np
  21. from collections import defaultdict
  22. import paddle
  23. from benchmark_utils import PaddleInferBenchmark
  24. from utils import gaussian_radius, gaussian2D, draw_umich_gaussian
  25. from preprocess import preprocess, decode_image, WarpAffine, NormalizeImage, Permute
  26. from utils import argsparser, Timer, get_current_memory_mb
  27. from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig
  28. from keypoint_preprocess import get_affine_transform
  29. # add python path
  30. import sys
  31. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  32. sys.path.insert(0, parent_path)
  33. from pptracking.python.mot import CenterTracker
  34. from pptracking.python.mot.utils import MOTTimer, write_mot_results
  35. from pptracking.python.mot.visualize import plot_tracking
  36. def transform_preds_with_trans(coords, trans):
  37. target_coords = np.ones((coords.shape[0], 3), np.float32)
  38. target_coords[:, :2] = coords
  39. target_coords = np.dot(trans, target_coords.transpose()).transpose()
  40. return target_coords[:, :2]
  41. def affine_transform(pt, t):
  42. new_pt = np.array([pt[0], pt[1], 1.]).T
  43. new_pt = np.dot(t, new_pt)
  44. return new_pt[:2]
  45. def affine_transform_bbox(bbox, trans, width, height):
  46. bbox = np.array(copy.deepcopy(bbox), dtype=np.float32)
  47. bbox[:2] = affine_transform(bbox[:2], trans)
  48. bbox[2:] = affine_transform(bbox[2:], trans)
  49. bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, width - 1)
  50. bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, height - 1)
  51. return bbox
  52. class CenterTrack(Detector):
  53. """
  54. Args:
  55. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  56. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  57. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  58. batch_size (int): size of pre batch in inference
  59. trt_min_shape (int): min shape for dynamic shape in trt
  60. trt_max_shape (int): max shape for dynamic shape in trt
  61. trt_opt_shape (int): opt shape for dynamic shape in trt
  62. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  63. calibration, trt_calib_mode need to set True
  64. cpu_threads (int): cpu threads
  65. enable_mkldnn (bool): whether to open MKLDNN
  66. output_dir (string): The path of output, default as 'output'
  67. threshold (float): Score threshold of the detected bbox, default as 0.5
  68. save_images (bool): Whether to save visualization image results, default as False
  69. save_mot_txts (bool): Whether to save tracking results (txt), default as False
  70. """
  71. def __init__(
  72. self,
  73. model_dir,
  74. tracker_config=None,
  75. device='CPU',
  76. run_mode='paddle',
  77. batch_size=1,
  78. trt_min_shape=1,
  79. trt_max_shape=960,
  80. trt_opt_shape=544,
  81. trt_calib_mode=False,
  82. cpu_threads=1,
  83. enable_mkldnn=False,
  84. output_dir='output',
  85. threshold=0.5,
  86. save_images=False,
  87. save_mot_txts=False, ):
  88. super(CenterTrack, self).__init__(
  89. model_dir=model_dir,
  90. device=device,
  91. run_mode=run_mode,
  92. batch_size=batch_size,
  93. trt_min_shape=trt_min_shape,
  94. trt_max_shape=trt_max_shape,
  95. trt_opt_shape=trt_opt_shape,
  96. trt_calib_mode=trt_calib_mode,
  97. cpu_threads=cpu_threads,
  98. enable_mkldnn=enable_mkldnn,
  99. output_dir=output_dir,
  100. threshold=threshold, )
  101. self.save_images = save_images
  102. self.save_mot_txts = save_mot_txts
  103. assert batch_size == 1, "MOT model only supports batch_size=1."
  104. self.det_times = Timer(with_tracker=True)
  105. self.num_classes = len(self.pred_config.labels)
  106. # tracker config
  107. cfg = self.pred_config.tracker
  108. min_box_area = cfg.get('min_box_area', -1)
  109. vertical_ratio = cfg.get('vertical_ratio', -1)
  110. track_thresh = cfg.get('track_thresh', 0.4)
  111. pre_thresh = cfg.get('pre_thresh', 0.5)
  112. self.tracker = CenterTracker(
  113. num_classes=self.num_classes,
  114. min_box_area=min_box_area,
  115. vertical_ratio=vertical_ratio,
  116. track_thresh=track_thresh,
  117. pre_thresh=pre_thresh)
  118. self.pre_image = None
  119. def get_additional_inputs(self, dets, meta, with_hm=True):
  120. # Render input heatmap from previous trackings.
  121. trans_input = meta['trans_input']
  122. inp_width, inp_height = int(meta['inp_width']), int(meta['inp_height'])
  123. input_hm = np.zeros((1, inp_height, inp_width), dtype=np.float32)
  124. for det in dets:
  125. if det['score'] < self.tracker.pre_thresh:
  126. continue
  127. bbox = affine_transform_bbox(det['bbox'], trans_input, inp_width,
  128. inp_height)
  129. h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
  130. if (h > 0 and w > 0):
  131. radius = gaussian_radius(
  132. (math.ceil(h), math.ceil(w)), min_overlap=0.7)
  133. radius = max(0, int(radius))
  134. ct = np.array(
  135. [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2],
  136. dtype=np.float32)
  137. ct_int = ct.astype(np.int32)
  138. if with_hm:
  139. input_hm[0] = draw_umich_gaussian(input_hm[0], ct_int,
  140. radius)
  141. if with_hm:
  142. input_hm = input_hm[np.newaxis]
  143. return input_hm
  144. def preprocess(self, image_list):
  145. preprocess_ops = []
  146. for op_info in self.pred_config.preprocess_infos:
  147. new_op_info = op_info.copy()
  148. op_type = new_op_info.pop('type')
  149. preprocess_ops.append(eval(op_type)(**new_op_info))
  150. assert len(image_list) == 1, 'MOT only support bs=1'
  151. im_path = image_list[0]
  152. im, im_info = preprocess(im_path, preprocess_ops)
  153. #inputs = create_inputs(im, im_info)
  154. inputs = {}
  155. inputs['image'] = np.array((im, )).astype('float32')
  156. inputs['im_shape'] = np.array(
  157. (im_info['im_shape'], )).astype('float32')
  158. inputs['scale_factor'] = np.array(
  159. (im_info['scale_factor'], )).astype('float32')
  160. inputs['trans_input'] = im_info['trans_input']
  161. inputs['inp_width'] = im_info['inp_width']
  162. inputs['inp_height'] = im_info['inp_height']
  163. inputs['center'] = im_info['center']
  164. inputs['scale'] = im_info['scale']
  165. inputs['out_height'] = im_info['out_height']
  166. inputs['out_width'] = im_info['out_width']
  167. if self.pre_image is None:
  168. self.pre_image = inputs['image']
  169. # initializing tracker for the first frame
  170. self.tracker.init_track([])
  171. inputs['pre_image'] = self.pre_image
  172. self.pre_image = inputs['image'] # Note: update for next image
  173. # render input heatmap from tracker status
  174. pre_hm = self.get_additional_inputs(
  175. self.tracker.tracks, inputs, with_hm=True)
  176. inputs['pre_hm'] = pre_hm #.to_tensor(pre_hm)
  177. input_names = self.predictor.get_input_names()
  178. for i in range(len(input_names)):
  179. input_tensor = self.predictor.get_input_handle(input_names[i])
  180. if input_names[i] == 'x':
  181. input_tensor.copy_from_cpu(inputs['image'])
  182. else:
  183. input_tensor.copy_from_cpu(inputs[input_names[i]])
  184. return inputs
  185. def postprocess(self, inputs, result):
  186. # postprocess output of predictor
  187. np_bboxes = result['bboxes']
  188. if np_bboxes.shape[0] <= 0:
  189. print('[WARNNING] No object detected and tracked.')
  190. result = {'bboxes': np.zeros([0, 6]), 'cts': None, 'tracking': None}
  191. return result
  192. result = {k: v for k, v in result.items() if v is not None}
  193. return result
  194. def centertrack_post_process(self, dets, meta, out_thresh):
  195. if not ('bboxes' in dets):
  196. return [{}]
  197. preds = []
  198. c, s = meta['center'], meta['scale']
  199. h, w = meta['out_height'], meta['out_width']
  200. trans = get_affine_transform(
  201. center=c,
  202. input_size=s,
  203. rot=0,
  204. output_size=[w, h],
  205. shift=(0., 0.),
  206. inv=True).astype(np.float32)
  207. for i, dets_bbox in enumerate(dets['bboxes']):
  208. if dets_bbox[1] < out_thresh:
  209. break
  210. item = {}
  211. item['score'] = dets_bbox[1]
  212. item['class'] = int(dets_bbox[0]) + 1
  213. item['ct'] = transform_preds_with_trans(
  214. dets['cts'][i].reshape([1, 2]), trans).reshape(2)
  215. if 'tracking' in dets:
  216. tracking = transform_preds_with_trans(
  217. (dets['tracking'][i] + dets['cts'][i]).reshape([1, 2]),
  218. trans).reshape(2)
  219. item['tracking'] = tracking - item['ct']
  220. if 'bboxes' in dets:
  221. bbox = transform_preds_with_trans(
  222. dets_bbox[2:6].reshape([2, 2]), trans).reshape(4)
  223. item['bbox'] = bbox
  224. preds.append(item)
  225. return preds
  226. def tracking(self, inputs, det_results):
  227. result = self.centertrack_post_process(
  228. det_results, inputs, self.tracker.out_thresh)
  229. online_targets = self.tracker.update(result)
  230. online_tlwhs, online_scores, online_ids = [], [], []
  231. for t in online_targets:
  232. bbox = t['bbox']
  233. tlwh = [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]]
  234. tscore = float(t['score'])
  235. tid = int(t['tracking_id'])
  236. if tlwh[2] * tlwh[3] > 0:
  237. online_tlwhs.append(tlwh)
  238. online_ids.append(tid)
  239. online_scores.append(tscore)
  240. return online_tlwhs, online_scores, online_ids
  241. def predict(self, repeats=1):
  242. '''
  243. Args:
  244. repeats (int): repeats number for prediction
  245. Returns:
  246. result (dict): include 'bboxes', 'cts' and 'tracking':
  247. np.ndarray: shape:[N,6],[N,2] and [N,2], N: number of box
  248. '''
  249. # model prediction
  250. np_bboxes, np_cts, np_tracking = None, None, None
  251. for i in range(repeats):
  252. self.predictor.run()
  253. output_names = self.predictor.get_output_names()
  254. bboxes_tensor = self.predictor.get_output_handle(output_names[0])
  255. np_bboxes = bboxes_tensor.copy_to_cpu()
  256. cts_tensor = self.predictor.get_output_handle(output_names[1])
  257. np_cts = cts_tensor.copy_to_cpu()
  258. tracking_tensor = self.predictor.get_output_handle(output_names[2])
  259. np_tracking = tracking_tensor.copy_to_cpu()
  260. result = dict(
  261. bboxes=np_bboxes,
  262. cts=np_cts,
  263. tracking=np_tracking)
  264. return result
  265. def predict_image(self,
  266. image_list,
  267. run_benchmark=False,
  268. repeats=1,
  269. visual=True,
  270. seq_name=None):
  271. mot_results = []
  272. num_classes = self.num_classes
  273. image_list.sort()
  274. ids2names = self.pred_config.labels
  275. data_type = 'mcmot' if num_classes > 1 else 'mot'
  276. for frame_id, img_file in enumerate(image_list):
  277. batch_image_list = [img_file] # bs=1 in MOT model
  278. if run_benchmark:
  279. # preprocess
  280. inputs = self.preprocess(batch_image_list) # warmup
  281. self.det_times.preprocess_time_s.start()
  282. inputs = self.preprocess(batch_image_list)
  283. self.det_times.preprocess_time_s.end()
  284. # model prediction
  285. result_warmup = self.predict(repeats=repeats) # warmup
  286. self.det_times.inference_time_s.start()
  287. result = self.predict(repeats=repeats)
  288. self.det_times.inference_time_s.end(repeats=repeats)
  289. # postprocess
  290. result_warmup = self.postprocess(inputs, result) # warmup
  291. self.det_times.postprocess_time_s.start()
  292. det_result = self.postprocess(inputs, result)
  293. self.det_times.postprocess_time_s.end()
  294. # tracking
  295. result_warmup = self.tracking(inputs, det_result)
  296. self.det_times.tracking_time_s.start()
  297. online_tlwhs, online_scores, online_ids = self.tracking(inputs,
  298. det_result)
  299. self.det_times.tracking_time_s.end()
  300. self.det_times.img_num += 1
  301. cm, gm, gu = get_current_memory_mb()
  302. self.cpu_mem += cm
  303. self.gpu_mem += gm
  304. self.gpu_util += gu
  305. else:
  306. self.det_times.preprocess_time_s.start()
  307. inputs = self.preprocess(batch_image_list)
  308. self.det_times.preprocess_time_s.end()
  309. self.det_times.inference_time_s.start()
  310. result = self.predict()
  311. self.det_times.inference_time_s.end()
  312. self.det_times.postprocess_time_s.start()
  313. det_result = self.postprocess(inputs, result)
  314. self.det_times.postprocess_time_s.end()
  315. # tracking process
  316. self.det_times.tracking_time_s.start()
  317. online_tlwhs, online_scores, online_ids = self.tracking(inputs,
  318. det_result)
  319. self.det_times.tracking_time_s.end()
  320. self.det_times.img_num += 1
  321. if visual:
  322. if len(image_list) > 1 and frame_id % 10 == 0:
  323. print('Tracking frame {}'.format(frame_id))
  324. frame, _ = decode_image(img_file, {})
  325. im = plot_tracking(
  326. frame,
  327. online_tlwhs,
  328. online_ids,
  329. online_scores,
  330. frame_id=frame_id,
  331. ids2names=ids2names)
  332. if seq_name is None:
  333. seq_name = image_list[0].split('/')[-2]
  334. save_dir = os.path.join(self.output_dir, seq_name)
  335. if not os.path.exists(save_dir):
  336. os.makedirs(save_dir)
  337. cv2.imwrite(
  338. os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
  339. mot_results.append([online_tlwhs, online_scores, online_ids])
  340. return mot_results
  341. def predict_video(self, video_file, camera_id):
  342. video_out_name = 'mot_output.mp4'
  343. if camera_id != -1:
  344. capture = cv2.VideoCapture(camera_id)
  345. else:
  346. capture = cv2.VideoCapture(video_file)
  347. video_out_name = os.path.split(video_file)[-1]
  348. # Get Video info : resolution, fps, frame count
  349. width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  350. height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  351. fps = int(capture.get(cv2.CAP_PROP_FPS))
  352. frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
  353. print("fps: %d, frame_count: %d" % (fps, frame_count))
  354. if not os.path.exists(self.output_dir):
  355. os.makedirs(self.output_dir)
  356. out_path = os.path.join(self.output_dir, video_out_name)
  357. video_format = 'mp4v'
  358. fourcc = cv2.VideoWriter_fourcc(*video_format)
  359. writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
  360. frame_id = 1
  361. timer = MOTTimer()
  362. results = defaultdict(list) # centertrack onpy support single class
  363. num_classes = self.num_classes
  364. data_type = 'mcmot' if num_classes > 1 else 'mot'
  365. ids2names = self.pred_config.labels
  366. while (1):
  367. ret, frame = capture.read()
  368. if not ret:
  369. break
  370. if frame_id % 10 == 0:
  371. print('Tracking frame: %d' % (frame_id))
  372. frame_id += 1
  373. timer.tic()
  374. seq_name = video_out_name.split('.')[0]
  375. mot_results = self.predict_image(
  376. [frame[:, :, ::-1]], visual=False, seq_name=seq_name)
  377. timer.toc()
  378. fps = 1. / timer.duration
  379. online_tlwhs, online_scores, online_ids = mot_results[0]
  380. results[0].append(
  381. (frame_id + 1, online_tlwhs, online_scores, online_ids))
  382. im = plot_tracking(
  383. frame,
  384. online_tlwhs,
  385. online_ids,
  386. online_scores,
  387. frame_id=frame_id,
  388. fps=fps,
  389. ids2names=ids2names)
  390. writer.write(im)
  391. if camera_id != -1:
  392. cv2.imshow('Mask Detection', im)
  393. if cv2.waitKey(1) & 0xFF == ord('q'):
  394. break
  395. if self.save_mot_txts:
  396. result_filename = os.path.join(
  397. self.output_dir, video_out_name.split('.')[-2] + '.txt')
  398. write_mot_results(result_filename, results, data_type, num_classes)
  399. writer.release()
  400. def main():
  401. detector = CenterTrack(
  402. FLAGS.model_dir,
  403. tracker_config=None,
  404. device=FLAGS.device,
  405. run_mode=FLAGS.run_mode,
  406. batch_size=1,
  407. trt_min_shape=FLAGS.trt_min_shape,
  408. trt_max_shape=FLAGS.trt_max_shape,
  409. trt_opt_shape=FLAGS.trt_opt_shape,
  410. trt_calib_mode=FLAGS.trt_calib_mode,
  411. cpu_threads=FLAGS.cpu_threads,
  412. enable_mkldnn=FLAGS.enable_mkldnn,
  413. output_dir=FLAGS.output_dir,
  414. threshold=FLAGS.threshold,
  415. save_images=FLAGS.save_images,
  416. save_mot_txts=FLAGS.save_mot_txts)
  417. # predict from video file or camera video stream
  418. if FLAGS.video_file is not None or FLAGS.camera_id != -1:
  419. detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
  420. else:
  421. # predict from image
  422. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  423. detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
  424. if not FLAGS.run_benchmark:
  425. detector.det_times.info(average=True)
  426. else:
  427. mode = FLAGS.run_mode
  428. model_dir = FLAGS.model_dir
  429. model_info = {
  430. 'model_name': model_dir.strip('/').split('/')[-1],
  431. 'precision': mode.split('_')[-1]
  432. }
  433. bench_log(detector, img_list, model_info, name='MOT')
  434. if __name__ == '__main__':
  435. paddle.enable_static()
  436. parser = argsparser()
  437. FLAGS = parser.parse_args()
  438. print_arguments(FLAGS)
  439. FLAGS.device = FLAGS.device.upper()
  440. assert FLAGS.device in ['CPU', 'GPU', 'XPU'
  441. ], "device should be CPU, GPU or XPU"
  442. main()