mot_jde_infer.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import time
  16. import yaml
  17. import cv2
  18. import numpy as np
  19. from collections import defaultdict
  20. import paddle
  21. from benchmark_utils import PaddleInferBenchmark
  22. from preprocess import decode_image
  23. from mot_utils import argsparser, Timer, get_current_memory_mb
  24. from det_infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig
  25. # add python path
  26. import sys
  27. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  28. sys.path.insert(0, parent_path)
  29. from mot import JDETracker
  30. from mot.utils import MOTTimer, write_mot_results, flow_statistic
  31. from mot.visualize import plot_tracking, plot_tracking_dict
  32. # Global dictionary
  33. MOT_JDE_SUPPORT_MODELS = {
  34. 'JDE',
  35. 'FairMOT',
  36. }
  37. class JDE_Detector(Detector):
  38. """
  39. Args:
  40. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  41. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  42. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  43. batch_size (int): size of pre batch in inference
  44. trt_min_shape (int): min shape for dynamic shape in trt
  45. trt_max_shape (int): max shape for dynamic shape in trt
  46. trt_opt_shape (int): opt shape for dynamic shape in trt
  47. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  48. calibration, trt_calib_mode need to set True
  49. cpu_threads (int): cpu threads
  50. enable_mkldnn (bool): whether to open MKLDNN
  51. output_dir (string): The path of output, default as 'output'
  52. threshold (float): Score threshold of the detected bbox, default as 0.5
  53. save_images (bool): Whether to save visualization image results, default as False
  54. save_mot_txts (bool): Whether to save tracking results (txt), default as False
  55. draw_center_traj (bool): Whether drawing the trajectory of center, default as False
  56. secs_interval (int): The seconds interval to count after tracking, default as 10
  57. skip_frame_num (int): Skip frame num to get faster MOT results, default as -1
  58. do_entrance_counting(bool): Whether counting the numbers of identifiers entering
  59. or getting out from the entrance, default as False,only support single class
  60. counting in MOT.
  61. do_break_in_counting(bool): Whether counting the numbers of identifiers break in
  62. the area, default as False,only support single class counting in MOT,
  63. and the video should be taken by a static camera.
  64. region_type (str): Area type for entrance counting or break in counting, 'horizontal'
  65. and 'vertical' used when do entrance counting. 'custom' used when do break in counting.
  66. Note that only support single-class MOT, and the video should be taken by a static camera.
  67. region_polygon (list): Clockwise point coords (x0,y0,x1,y1...) of polygon of area when
  68. do_break_in_counting. Note that only support single-class MOT and
  69. the video should be taken by a static camera.
  70. """
  71. def __init__(self,
  72. model_dir,
  73. tracker_config=None,
  74. device='CPU',
  75. run_mode='paddle',
  76. batch_size=1,
  77. trt_min_shape=1,
  78. trt_max_shape=1088,
  79. trt_opt_shape=608,
  80. trt_calib_mode=False,
  81. cpu_threads=1,
  82. enable_mkldnn=False,
  83. output_dir='output',
  84. threshold=0.5,
  85. save_images=False,
  86. save_mot_txts=False,
  87. draw_center_traj=False,
  88. secs_interval=10,
  89. skip_frame_num=-1,
  90. do_entrance_counting=False,
  91. do_break_in_counting=False,
  92. region_type='horizontal',
  93. region_polygon=[]):
  94. super(JDE_Detector, self).__init__(
  95. model_dir=model_dir,
  96. device=device,
  97. run_mode=run_mode,
  98. batch_size=batch_size,
  99. trt_min_shape=trt_min_shape,
  100. trt_max_shape=trt_max_shape,
  101. trt_opt_shape=trt_opt_shape,
  102. trt_calib_mode=trt_calib_mode,
  103. cpu_threads=cpu_threads,
  104. enable_mkldnn=enable_mkldnn,
  105. output_dir=output_dir,
  106. threshold=threshold, )
  107. self.save_images = save_images
  108. self.save_mot_txts = save_mot_txts
  109. self.draw_center_traj = draw_center_traj
  110. self.secs_interval = secs_interval
  111. self.skip_frame_num = skip_frame_num
  112. self.do_entrance_counting = do_entrance_counting
  113. self.do_break_in_counting = do_break_in_counting
  114. self.region_type = region_type
  115. self.region_polygon = region_polygon
  116. if self.region_type == 'custom':
  117. assert len(
  118. self.region_polygon
  119. ) > 6, 'region_type is custom, region_polygon should be at least 3 pairs of point coords.'
  120. assert batch_size == 1, "MOT model only supports batch_size=1."
  121. self.det_times = Timer(with_tracker=True)
  122. self.num_classes = len(self.pred_config.labels)
  123. if self.skip_frame_num > 1:
  124. self.previous_det_result = None
  125. # tracker config
  126. assert self.pred_config.tracker, "The exported JDE Detector model should have tracker."
  127. cfg = self.pred_config.tracker
  128. min_box_area = cfg.get('min_box_area', 0.0)
  129. vertical_ratio = cfg.get('vertical_ratio', 0.0)
  130. conf_thres = cfg.get('conf_thres', 0.0)
  131. tracked_thresh = cfg.get('tracked_thresh', 0.7)
  132. metric_type = cfg.get('metric_type', 'euclidean')
  133. self.tracker = JDETracker(
  134. num_classes=self.num_classes,
  135. min_box_area=min_box_area,
  136. vertical_ratio=vertical_ratio,
  137. conf_thres=conf_thres,
  138. tracked_thresh=tracked_thresh,
  139. metric_type=metric_type)
  140. def postprocess(self, inputs, result):
  141. # postprocess output of predictor
  142. np_boxes = result['pred_dets']
  143. if np_boxes.shape[0] <= 0:
  144. print('[WARNNING] No object detected.')
  145. result = {'pred_dets': np.zeros([0, 6]), 'pred_embs': None}
  146. result = {k: v for k, v in result.items() if v is not None}
  147. return result
  148. def tracking(self, det_results):
  149. pred_dets = det_results['pred_dets'] # cls_id, score, x0, y0, x1, y1
  150. pred_embs = det_results['pred_embs']
  151. online_targets_dict = self.tracker.update(pred_dets, pred_embs)
  152. online_tlwhs = defaultdict(list)
  153. online_scores = defaultdict(list)
  154. online_ids = defaultdict(list)
  155. for cls_id in range(self.num_classes):
  156. online_targets = online_targets_dict[cls_id]
  157. for t in online_targets:
  158. tlwh = t.tlwh
  159. tid = t.track_id
  160. tscore = t.score
  161. if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: continue
  162. if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
  163. 3] > self.tracker.vertical_ratio:
  164. continue
  165. online_tlwhs[cls_id].append(tlwh)
  166. online_ids[cls_id].append(tid)
  167. online_scores[cls_id].append(tscore)
  168. return online_tlwhs, online_scores, online_ids
  169. def predict(self, repeats=1):
  170. '''
  171. Args:
  172. repeats (int): repeats number for prediction
  173. Returns:
  174. result (dict): include 'pred_dets': np.ndarray: shape:[N,6], N: number of box,
  175. matix element:[class, score, x_min, y_min, x_max, y_max]
  176. FairMOT(JDE)'s result include 'pred_embs': np.ndarray:
  177. shape: [N, 128]
  178. '''
  179. # model prediction
  180. np_pred_dets, np_pred_embs = None, None
  181. for i in range(repeats):
  182. self.predictor.run()
  183. output_names = self.predictor.get_output_names()
  184. boxes_tensor = self.predictor.get_output_handle(output_names[0])
  185. np_pred_dets = boxes_tensor.copy_to_cpu()
  186. embs_tensor = self.predictor.get_output_handle(output_names[1])
  187. np_pred_embs = embs_tensor.copy_to_cpu()
  188. result = dict(pred_dets=np_pred_dets, pred_embs=np_pred_embs)
  189. return result
  190. def predict_image(self,
  191. image_list,
  192. run_benchmark=False,
  193. repeats=1,
  194. visual=True,
  195. seq_name=None,
  196. reuse_det_result=False):
  197. mot_results = []
  198. num_classes = self.num_classes
  199. image_list.sort()
  200. ids2names = self.pred_config.labels
  201. data_type = 'mcmot' if num_classes > 1 else 'mot'
  202. for frame_id, img_file in enumerate(image_list):
  203. batch_image_list = [img_file] # bs=1 in MOT model
  204. if run_benchmark:
  205. # preprocess
  206. inputs = self.preprocess(batch_image_list) # warmup
  207. self.det_times.preprocess_time_s.start()
  208. inputs = self.preprocess(batch_image_list)
  209. self.det_times.preprocess_time_s.end()
  210. # model prediction
  211. result_warmup = self.predict(repeats=repeats) # warmup
  212. self.det_times.inference_time_s.start()
  213. result = self.predict(repeats=repeats)
  214. self.det_times.inference_time_s.end(repeats=repeats)
  215. # postprocess
  216. result_warmup = self.postprocess(inputs, result) # warmup
  217. self.det_times.postprocess_time_s.start()
  218. det_result = self.postprocess(inputs, result)
  219. self.det_times.postprocess_time_s.end()
  220. # tracking
  221. result_warmup = self.tracking(det_result)
  222. self.det_times.tracking_time_s.start()
  223. online_tlwhs, online_scores, online_ids = self.tracking(
  224. det_result)
  225. self.det_times.tracking_time_s.end()
  226. self.det_times.img_num += 1
  227. cm, gm, gu = get_current_memory_mb()
  228. self.cpu_mem += cm
  229. self.gpu_mem += gm
  230. self.gpu_util += gu
  231. else:
  232. self.det_times.preprocess_time_s.start()
  233. if not reuse_det_result:
  234. inputs = self.preprocess(batch_image_list)
  235. self.det_times.preprocess_time_s.end()
  236. self.det_times.inference_time_s.start()
  237. if not reuse_det_result:
  238. result = self.predict()
  239. self.det_times.inference_time_s.end()
  240. self.det_times.postprocess_time_s.start()
  241. if not reuse_det_result:
  242. det_result = self.postprocess(inputs, result)
  243. self.previous_det_result = det_result
  244. else:
  245. assert self.previous_det_result is not None
  246. det_result = self.previous_det_result
  247. self.det_times.postprocess_time_s.end()
  248. # tracking process
  249. self.det_times.tracking_time_s.start()
  250. online_tlwhs, online_scores, online_ids = self.tracking(
  251. det_result)
  252. self.det_times.tracking_time_s.end()
  253. self.det_times.img_num += 1
  254. if visual:
  255. if len(image_list) > 1 and frame_id % 10 == 0:
  256. print('Tracking frame {}'.format(frame_id))
  257. frame, _ = decode_image(img_file, {})
  258. im = plot_tracking_dict(
  259. frame,
  260. num_classes,
  261. online_tlwhs,
  262. online_ids,
  263. online_scores,
  264. frame_id=frame_id,
  265. ids2names=ids2names)
  266. if seq_name is None:
  267. seq_name = image_list[0].split('/')[-2]
  268. save_dir = os.path.join(self.output_dir, seq_name)
  269. if not os.path.exists(save_dir):
  270. os.makedirs(save_dir)
  271. cv2.imwrite(
  272. os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
  273. mot_results.append([online_tlwhs, online_scores, online_ids])
  274. return mot_results
  275. def predict_video(self, video_file, camera_id):
  276. video_out_name = 'mot_output.mp4'
  277. if camera_id != -1:
  278. capture = cv2.VideoCapture(camera_id)
  279. else:
  280. capture = cv2.VideoCapture(video_file)
  281. video_out_name = os.path.split(video_file)[-1]
  282. # Get Video info : resolution, fps, frame count
  283. width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  284. height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  285. fps = int(capture.get(cv2.CAP_PROP_FPS))
  286. frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
  287. print("fps: %d, frame_count: %d" % (fps, frame_count))
  288. if not os.path.exists(self.output_dir):
  289. os.makedirs(self.output_dir)
  290. out_path = os.path.join(self.output_dir, video_out_name)
  291. video_format = 'mp4v'
  292. fourcc = cv2.VideoWriter_fourcc(*video_format)
  293. writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
  294. frame_id = 0
  295. timer = MOTTimer()
  296. results = defaultdict(list) # support single class and multi classes
  297. num_classes = self.num_classes
  298. data_type = 'mcmot' if num_classes > 1 else 'mot'
  299. ids2names = self.pred_config.labels
  300. center_traj = None
  301. entrance = None
  302. records = None
  303. if self.draw_center_traj:
  304. center_traj = [{} for i in range(num_classes)]
  305. if num_classes == 1:
  306. id_set = set()
  307. interval_id_set = set()
  308. in_id_list = list()
  309. out_id_list = list()
  310. prev_center = dict()
  311. records = list()
  312. if self.do_entrance_counting or self.do_break_in_counting:
  313. if self.region_type == 'horizontal':
  314. entrance = [0, height / 2., width, height / 2.]
  315. elif self.region_type == 'vertical':
  316. entrance = [width / 2, 0., width / 2, height]
  317. elif self.region_type == 'custom':
  318. entrance = []
  319. assert len(
  320. self.region_polygon
  321. ) % 2 == 0, "region_polygon should be pairs of coords points when do break_in counting."
  322. for i in range(0, len(self.region_polygon), 2):
  323. entrance.append([
  324. self.region_polygon[i], self.region_polygon[i + 1]
  325. ])
  326. entrance.append([width, height])
  327. else:
  328. raise ValueError("region_type:{} is not supported.".format(
  329. self.region_type))
  330. video_fps = fps
  331. while (1):
  332. ret, frame = capture.read()
  333. if not ret:
  334. break
  335. if frame_id % 10 == 0:
  336. print('Tracking frame: %d' % (frame_id))
  337. timer.tic()
  338. mot_skip_frame_num = self.skip_frame_num
  339. reuse_det_result = False
  340. if mot_skip_frame_num > 1 and frame_id > 0 and frame_id % mot_skip_frame_num > 0:
  341. reuse_det_result = True
  342. seq_name = video_out_name.split('.')[0]
  343. mot_results = self.predict_image(
  344. [frame],
  345. visual=False,
  346. seq_name=seq_name,
  347. reuse_det_result=reuse_det_result)
  348. timer.toc()
  349. online_tlwhs, online_scores, online_ids = mot_results[0]
  350. for cls_id in range(num_classes):
  351. results[cls_id].append(
  352. (frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
  353. online_ids[cls_id]))
  354. # NOTE: just implement flow statistic for single class
  355. if num_classes == 1:
  356. result = (frame_id + 1, online_tlwhs[0], online_scores[0],
  357. online_ids[0])
  358. statistic = flow_statistic(
  359. result,
  360. self.secs_interval,
  361. self.do_entrance_counting,
  362. self.do_break_in_counting,
  363. self.region_type,
  364. video_fps,
  365. entrance,
  366. id_set,
  367. interval_id_set,
  368. in_id_list,
  369. out_id_list,
  370. prev_center,
  371. records,
  372. data_type,
  373. ids2names=self.pred_config.labels)
  374. records = statistic['records']
  375. fps = 1. / timer.duration
  376. im = plot_tracking_dict(
  377. frame,
  378. num_classes,
  379. online_tlwhs,
  380. online_ids,
  381. online_scores,
  382. frame_id=frame_id,
  383. fps=fps,
  384. ids2names=ids2names,
  385. do_entrance_counting=self.do_entrance_counting,
  386. entrance=entrance,
  387. records=records,
  388. center_traj=center_traj)
  389. writer.write(im)
  390. if camera_id != -1:
  391. cv2.imshow('Mask Detection', im)
  392. if cv2.waitKey(1) & 0xFF == ord('q'):
  393. break
  394. frame_id += 1
  395. if self.save_mot_txts:
  396. result_filename = os.path.join(
  397. self.output_dir, video_out_name.split('.')[-2] + '.txt')
  398. write_mot_results(result_filename, results, data_type, num_classes)
  399. if num_classes == 1:
  400. result_filename = os.path.join(
  401. self.output_dir,
  402. video_out_name.split('.')[-2] + '_flow_statistic.txt')
  403. f = open(result_filename, 'w')
  404. for line in records:
  405. f.write(line)
  406. print('Flow statistic save in {}'.format(result_filename))
  407. f.close()
  408. writer.release()
  409. def main():
  410. detector = JDE_Detector(
  411. FLAGS.model_dir,
  412. tracker_config=None,
  413. device=FLAGS.device,
  414. run_mode=FLAGS.run_mode,
  415. batch_size=1,
  416. trt_min_shape=FLAGS.trt_min_shape,
  417. trt_max_shape=FLAGS.trt_max_shape,
  418. trt_opt_shape=FLAGS.trt_opt_shape,
  419. trt_calib_mode=FLAGS.trt_calib_mode,
  420. cpu_threads=FLAGS.cpu_threads,
  421. enable_mkldnn=FLAGS.enable_mkldnn,
  422. output_dir=FLAGS.output_dir,
  423. threshold=FLAGS.threshold,
  424. save_images=FLAGS.save_images,
  425. save_mot_txts=FLAGS.save_mot_txts,
  426. draw_center_traj=FLAGS.draw_center_traj,
  427. secs_interval=FLAGS.secs_interval,
  428. skip_frame_num=FLAGS.skip_frame_num,
  429. do_entrance_counting=FLAGS.do_entrance_counting,
  430. do_break_in_counting=FLAGS.do_break_in_counting,
  431. region_type=FLAGS.region_type,
  432. region_polygon=FLAGS.region_polygon)
  433. # predict from video file or camera video stream
  434. if FLAGS.video_file is not None or FLAGS.camera_id != -1:
  435. detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
  436. else:
  437. # predict from image
  438. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  439. detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
  440. if not FLAGS.run_benchmark:
  441. detector.det_times.info(average=True)
  442. else:
  443. mode = FLAGS.run_mode
  444. model_dir = FLAGS.model_dir
  445. model_info = {
  446. 'model_name': model_dir.strip('/').split('/')[-1],
  447. 'precision': mode.split('_')[-1]
  448. }
  449. bench_log(detector, img_list, model_info, name='MOT')
  450. if __name__ == '__main__':
  451. paddle.enable_static()
  452. parser = argsparser()
  453. FLAGS = parser.parse_args()
  454. print_arguments(FLAGS)
  455. FLAGS.device = FLAGS.device.upper()
  456. assert FLAGS.device in ['CPU', 'GPU', 'XPU'
  457. ], "device should be CPU, GPU or XPU"
  458. main()