mot_sde_infer.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import time
  16. import yaml
  17. import cv2
  18. import re
  19. import glob
  20. import numpy as np
  21. from collections import defaultdict
  22. import paddle
  23. from benchmark_utils import PaddleInferBenchmark
  24. from preprocess import decode_image
  25. # add python path
  26. import sys
  27. parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
  28. sys.path.insert(0, parent_path)
  29. from det_infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor
  30. from mot_utils import argsparser, Timer, get_current_memory_mb, video2frames, _is_valid_video
  31. from mot.tracker import JDETracker, DeepSORTTracker, OCSORTTracker, BOTSORTTracker
  32. from mot.utils import MOTTimer, write_mot_results, get_crops, clip_box, flow_statistic
  33. from mot.visualize import plot_tracking, plot_tracking_dict
  34. from mot.mtmct.utils import parse_bias
  35. from mot.mtmct.postprocess import trajectory_fusion, sub_cluster, gen_res, print_mtmct_result
  36. from mot.mtmct.postprocess import get_mtmct_matching_results, save_mtmct_crops, save_mtmct_vis_results
  37. class SDE_Detector(Detector):
  38. """
  39. Args:
  40. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  41. tracker_config (str): tracker config path
  42. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  43. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  44. batch_size (int): size of pre batch in inference
  45. trt_min_shape (int): min shape for dynamic shape in trt
  46. trt_max_shape (int): max shape for dynamic shape in trt
  47. trt_opt_shape (int): opt shape for dynamic shape in trt
  48. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  49. calibration, trt_calib_mode need to set True
  50. cpu_threads (int): cpu threads
  51. enable_mkldnn (bool): whether to open MKLDNN
  52. output_dir (string): The path of output, default as 'output'
  53. threshold (float): Score threshold of the detected bbox, default as 0.5
  54. save_images (bool): Whether to save visualization image results, default as False
  55. save_mot_txts (bool): Whether to save tracking results (txt), default as False
  56. draw_center_traj (bool): Whether drawing the trajectory of center, default as False
  57. secs_interval (int): The seconds interval to count after tracking, default as 10
  58. skip_frame_num (int): Skip frame num to get faster MOT results, default as -1
  59. warmup_frame (int):Warmup frame num to test speed of MOT,default as 50
  60. do_entrance_counting(bool): Whether counting the numbers of identifiers entering
  61. or getting out from the entrance, default as False,only support single class
  62. counting in MOT, and the video should be taken by a static camera.
  63. do_break_in_counting(bool): Whether counting the numbers of identifiers break in
  64. the area, default as False,only support single class counting in MOT,
  65. and the video should be taken by a static camera.
  66. region_type (str): Area type for entrance counting or break in counting, 'horizontal'
  67. and 'vertical' used when do entrance counting. 'custom' used when do break in counting.
  68. Note that only support single-class MOT, and the video should be taken by a static camera.
  69. region_polygon (list): Clockwise point coords (x0,y0,x1,y1...) of polygon of area when
  70. do_break_in_counting. Note that only support single-class MOT and
  71. the video should be taken by a static camera.
  72. reid_model_dir (str): reid model dir, default None for ByteTrack, but set for DeepSORT
  73. mtmct_dir (str): MTMCT dir, default None, set for doing MTMCT
  74. """
  75. def __init__(self,
  76. model_dir,
  77. tracker_config,
  78. device='CPU',
  79. run_mode='paddle',
  80. batch_size=1,
  81. trt_min_shape=1,
  82. trt_max_shape=1280,
  83. trt_opt_shape=640,
  84. trt_calib_mode=False,
  85. cpu_threads=1,
  86. enable_mkldnn=False,
  87. output_dir='output',
  88. threshold=0.5,
  89. save_images=False,
  90. save_mot_txts=False,
  91. draw_center_traj=False,
  92. secs_interval=10,
  93. skip_frame_num=-1,
  94. warmup_frame=50,
  95. do_entrance_counting=False,
  96. do_break_in_counting=False,
  97. region_type='horizontal',
  98. region_polygon=[],
  99. reid_model_dir=None,
  100. mtmct_dir=None):
  101. super(SDE_Detector, self).__init__(
  102. model_dir=model_dir,
  103. device=device,
  104. run_mode=run_mode,
  105. batch_size=batch_size,
  106. trt_min_shape=trt_min_shape,
  107. trt_max_shape=trt_max_shape,
  108. trt_opt_shape=trt_opt_shape,
  109. trt_calib_mode=trt_calib_mode,
  110. cpu_threads=cpu_threads,
  111. enable_mkldnn=enable_mkldnn,
  112. output_dir=output_dir,
  113. threshold=threshold, )
  114. self.save_images = save_images
  115. self.save_mot_txts = save_mot_txts
  116. self.draw_center_traj = draw_center_traj
  117. self.secs_interval = secs_interval
  118. self.skip_frame_num = skip_frame_num
  119. self.warmup_frame = warmup_frame
  120. self.do_entrance_counting = do_entrance_counting
  121. self.do_break_in_counting = do_break_in_counting
  122. self.region_type = region_type
  123. self.region_polygon = region_polygon
  124. if self.region_type == 'custom':
  125. assert len(
  126. self.region_polygon
  127. ) > 6, 'region_type is custom, region_polygon should be at least 3 pairs of point coords.'
  128. assert batch_size == 1, "MOT model only supports batch_size=1."
  129. self.det_times = Timer(with_tracker=True)
  130. self.num_classes = len(self.pred_config.labels)
  131. if self.skip_frame_num > 1:
  132. self.previous_det_result = None
  133. # reid config
  134. self.use_reid = False if reid_model_dir is None else True
  135. if self.use_reid:
  136. self.reid_pred_config = self.set_config(reid_model_dir)
  137. self.reid_predictor, self.config = load_predictor(
  138. reid_model_dir,
  139. run_mode=run_mode,
  140. batch_size=50, # reid_batch_size
  141. min_subgraph_size=self.reid_pred_config.min_subgraph_size,
  142. device=device,
  143. use_dynamic_shape=self.reid_pred_config.use_dynamic_shape,
  144. trt_min_shape=trt_min_shape,
  145. trt_max_shape=trt_max_shape,
  146. trt_opt_shape=trt_opt_shape,
  147. trt_calib_mode=trt_calib_mode,
  148. cpu_threads=cpu_threads,
  149. enable_mkldnn=enable_mkldnn)
  150. else:
  151. self.reid_pred_config = None
  152. self.reid_predictor = None
  153. assert tracker_config is not None, 'Note that tracker_config should be set.'
  154. self.tracker_config = tracker_config
  155. tracker_cfg = yaml.safe_load(open(self.tracker_config))
  156. cfg = tracker_cfg[tracker_cfg['type']]
  157. # tracker config
  158. self.use_deepsort_tracker = True if tracker_cfg[
  159. 'type'] == 'DeepSORTTracker' else False
  160. self.use_ocsort_tracker = True if tracker_cfg[
  161. 'type'] == 'OCSORTTracker' else False
  162. self.use_botsort_tracker = True if tracker_cfg[
  163. 'type'] == 'BOTSORTTracker' else False
  164. if self.use_deepsort_tracker:
  165. if self.reid_pred_config is not None and hasattr(
  166. self.reid_pred_config, 'tracker'):
  167. cfg = self.reid_pred_config.tracker
  168. budget = cfg.get('budget', 100)
  169. max_age = cfg.get('max_age', 30)
  170. max_iou_distance = cfg.get('max_iou_distance', 0.7)
  171. matching_threshold = cfg.get('matching_threshold', 0.2)
  172. min_box_area = cfg.get('min_box_area', 0)
  173. vertical_ratio = cfg.get('vertical_ratio', 0)
  174. self.tracker = DeepSORTTracker(
  175. budget=budget,
  176. max_age=max_age,
  177. max_iou_distance=max_iou_distance,
  178. matching_threshold=matching_threshold,
  179. min_box_area=min_box_area,
  180. vertical_ratio=vertical_ratio, )
  181. elif self.use_ocsort_tracker:
  182. det_thresh = cfg.get('det_thresh', 0.4)
  183. max_age = cfg.get('max_age', 30)
  184. min_hits = cfg.get('min_hits', 3)
  185. iou_threshold = cfg.get('iou_threshold', 0.3)
  186. delta_t = cfg.get('delta_t', 3)
  187. inertia = cfg.get('inertia', 0.2)
  188. min_box_area = cfg.get('min_box_area', 0)
  189. vertical_ratio = cfg.get('vertical_ratio', 0)
  190. use_byte = cfg.get('use_byte', False)
  191. use_angle_cost = cfg.get('use_angle_cost', False)
  192. self.tracker = OCSORTTracker(
  193. det_thresh=det_thresh,
  194. max_age=max_age,
  195. min_hits=min_hits,
  196. iou_threshold=iou_threshold,
  197. delta_t=delta_t,
  198. inertia=inertia,
  199. min_box_area=min_box_area,
  200. vertical_ratio=vertical_ratio,
  201. use_byte=use_byte,
  202. use_angle_cost=use_angle_cost)
  203. elif self.use_botsort_tracker:
  204. track_high_thresh = cfg.get('track_high_thresh', 0.3)
  205. track_low_thresh = cfg.get('track_low_thresh', 0.2)
  206. new_track_thresh = cfg.get('new_track_thresh', 0.4)
  207. match_thresh = cfg.get('match_thresh', 0.7)
  208. track_buffer = cfg.get('track_buffer', 30)
  209. camera_motion = cfg.get('camera_motion', False)
  210. cmc_method = cfg.get('cmc_method', 'sparseOptFlow')
  211. self.tracker = BOTSORTTracker(
  212. track_high_thresh=track_high_thresh,
  213. track_low_thresh=track_low_thresh,
  214. new_track_thresh=new_track_thresh,
  215. match_thresh=match_thresh,
  216. track_buffer=track_buffer,
  217. camera_motion=camera_motion,
  218. cmc_method=cmc_method)
  219. else:
  220. # use ByteTracker
  221. use_byte = cfg.get('use_byte', False)
  222. det_thresh = cfg.get('det_thresh', 0.3)
  223. min_box_area = cfg.get('min_box_area', 0)
  224. vertical_ratio = cfg.get('vertical_ratio', 0)
  225. match_thres = cfg.get('match_thres', 0.9)
  226. conf_thres = cfg.get('conf_thres', 0.6)
  227. low_conf_thres = cfg.get('low_conf_thres', 0.1)
  228. self.tracker = JDETracker(
  229. use_byte=use_byte,
  230. det_thresh=det_thresh,
  231. num_classes=self.num_classes,
  232. min_box_area=min_box_area,
  233. vertical_ratio=vertical_ratio,
  234. match_thres=match_thres,
  235. conf_thres=conf_thres,
  236. low_conf_thres=low_conf_thres, )
  237. self.do_mtmct = False if mtmct_dir is None else True
  238. self.mtmct_dir = mtmct_dir
  239. def postprocess(self, inputs, result):
  240. # postprocess output of predictor
  241. keep_idx = result['boxes'][:, 1] > self.threshold
  242. result['boxes'] = result['boxes'][keep_idx]
  243. np_boxes_num = [len(result['boxes'])]
  244. if np_boxes_num[0] <= 0:
  245. print('[WARNNING] No object detected.')
  246. result = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]}
  247. result = {k: v for k, v in result.items() if v is not None}
  248. return result
  249. def reidprocess(self, det_results, repeats=1):
  250. pred_dets = det_results['boxes'] # cls_id, score, x0, y0, x1, y1
  251. pred_xyxys = pred_dets[:, 2:6]
  252. ori_image = det_results['ori_image']
  253. ori_image_shape = ori_image.shape[:2]
  254. pred_xyxys, keep_idx = clip_box(pred_xyxys, ori_image_shape)
  255. if len(keep_idx[0]) == 0:
  256. det_results['boxes'] = np.zeros((1, 6), dtype=np.float32)
  257. det_results['embeddings'] = None
  258. return det_results
  259. pred_dets = pred_dets[keep_idx[0]]
  260. pred_xyxys = pred_dets[:, 2:6]
  261. w, h = self.tracker.input_size
  262. crops = get_crops(pred_xyxys, ori_image, w, h)
  263. # to keep fast speed, only use topk crops
  264. crops = crops[:50] # reid_batch_size
  265. det_results['crops'] = np.array(crops).astype('float32')
  266. det_results['boxes'] = pred_dets[:50]
  267. input_names = self.reid_predictor.get_input_names()
  268. for i in range(len(input_names)):
  269. input_tensor = self.reid_predictor.get_input_handle(input_names[i])
  270. input_tensor.copy_from_cpu(det_results[input_names[i]])
  271. # model prediction
  272. for i in range(repeats):
  273. self.reid_predictor.run()
  274. output_names = self.reid_predictor.get_output_names()
  275. feature_tensor = self.reid_predictor.get_output_handle(output_names[
  276. 0])
  277. pred_embs = feature_tensor.copy_to_cpu()
  278. det_results['embeddings'] = pred_embs
  279. return det_results
  280. def tracking(self, det_results, img=None):
  281. pred_dets = det_results['boxes'] # cls_id, score, x0, y0, x1, y1
  282. pred_embs = det_results.get('embeddings', None)
  283. if self.use_deepsort_tracker:
  284. # use DeepSORTTracker, only support singe class
  285. self.tracker.predict()
  286. online_targets = self.tracker.update(pred_dets, pred_embs)
  287. online_tlwhs, online_scores, online_ids = [], [], []
  288. if self.do_mtmct:
  289. online_tlbrs, online_feats = [], []
  290. for t in online_targets:
  291. if not t.is_confirmed() or t.time_since_update > 1:
  292. continue
  293. tlwh = t.to_tlwh()
  294. tscore = t.score
  295. tid = t.track_id
  296. if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
  297. 3] > self.tracker.vertical_ratio:
  298. continue
  299. online_tlwhs.append(tlwh)
  300. online_scores.append(tscore)
  301. online_ids.append(tid)
  302. if self.do_mtmct:
  303. online_tlbrs.append(t.to_tlbr())
  304. online_feats.append(t.feat)
  305. tracking_outs = {
  306. 'online_tlwhs': online_tlwhs,
  307. 'online_scores': online_scores,
  308. 'online_ids': online_ids,
  309. }
  310. if self.do_mtmct:
  311. seq_name = det_results['seq_name']
  312. frame_id = det_results['frame_id']
  313. tracking_outs['feat_data'] = {}
  314. for _tlbr, _id, _feat in zip(online_tlbrs, online_ids,
  315. online_feats):
  316. feat_data = {}
  317. feat_data['bbox'] = _tlbr
  318. feat_data['frame'] = f"{frame_id:06d}"
  319. feat_data['id'] = _id
  320. _imgname = f'{seq_name}_{_id}_{frame_id}.jpg'
  321. feat_data['imgname'] = _imgname
  322. feat_data['feat'] = _feat
  323. tracking_outs['feat_data'].update({_imgname: feat_data})
  324. return tracking_outs
  325. elif self.use_ocsort_tracker:
  326. # use OCSORTTracker, only support singe class
  327. online_targets = self.tracker.update(pred_dets, pred_embs)
  328. online_tlwhs = defaultdict(list)
  329. online_scores = defaultdict(list)
  330. online_ids = defaultdict(list)
  331. for t in online_targets:
  332. tlwh = [t[0], t[1], t[2] - t[0], t[3] - t[1]]
  333. tscore = float(t[4])
  334. tid = int(t[5])
  335. if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: continue
  336. if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
  337. 3] > self.tracker.vertical_ratio:
  338. continue
  339. if tlwh[2] * tlwh[3] > 0:
  340. online_tlwhs[0].append(tlwh)
  341. online_ids[0].append(tid)
  342. online_scores[0].append(tscore)
  343. tracking_outs = {
  344. 'online_tlwhs': online_tlwhs,
  345. 'online_scores': online_scores,
  346. 'online_ids': online_ids,
  347. }
  348. return tracking_outs
  349. elif self.use_botsort_tracker:
  350. # use BOTSORTTracker, only support singe class
  351. online_targets = self.tracker.update(pred_dets, img)
  352. online_tlwhs = defaultdict(list)
  353. online_scores = defaultdict(list)
  354. online_ids = defaultdict(list)
  355. for t in online_targets:
  356. tlwh = t.tlwh
  357. tid = t.track_id
  358. tscore = t.score
  359. if tlwh[2] * tlwh[3] <= self.tracker.min_box_area:
  360. continue
  361. online_tlwhs[0].append(tlwh)
  362. online_ids[0].append(tid)
  363. online_scores[0].append(tscore)
  364. tracking_outs = {
  365. 'online_tlwhs': online_tlwhs,
  366. 'online_scores': online_scores,
  367. 'online_ids': online_ids,
  368. }
  369. return tracking_outs
  370. else:
  371. # use ByteTracker, support multiple class
  372. online_tlwhs = defaultdict(list)
  373. online_scores = defaultdict(list)
  374. online_ids = defaultdict(list)
  375. if self.do_mtmct:
  376. online_tlbrs, online_feats = defaultdict(list), defaultdict(
  377. list)
  378. online_targets_dict = self.tracker.update(pred_dets, pred_embs)
  379. for cls_id in range(self.num_classes):
  380. online_targets = online_targets_dict[cls_id]
  381. for t in online_targets:
  382. tlwh = t.tlwh
  383. tid = t.track_id
  384. tscore = t.score
  385. if tlwh[2] * tlwh[3] <= self.tracker.min_box_area:
  386. continue
  387. if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
  388. 3] > self.tracker.vertical_ratio:
  389. continue
  390. online_tlwhs[cls_id].append(tlwh)
  391. online_ids[cls_id].append(tid)
  392. online_scores[cls_id].append(tscore)
  393. if self.do_mtmct:
  394. online_tlbrs[cls_id].append(t.tlbr)
  395. online_feats[cls_id].append(t.curr_feat)
  396. if self.do_mtmct:
  397. assert self.num_classes == 1, 'MTMCT only support single class.'
  398. tracking_outs = {
  399. 'online_tlwhs': online_tlwhs[0],
  400. 'online_scores': online_scores[0],
  401. 'online_ids': online_ids[0],
  402. }
  403. seq_name = det_results['seq_name']
  404. frame_id = det_results['frame_id']
  405. tracking_outs['feat_data'] = {}
  406. for _tlbr, _id, _feat in zip(online_tlbrs[0], online_ids[0],
  407. online_feats[0]):
  408. feat_data = {}
  409. feat_data['bbox'] = _tlbr
  410. feat_data['frame'] = f"{frame_id:06d}"
  411. feat_data['id'] = _id
  412. _imgname = f'{seq_name}_{_id}_{frame_id}.jpg'
  413. feat_data['imgname'] = _imgname
  414. feat_data['feat'] = _feat
  415. tracking_outs['feat_data'].update({_imgname: feat_data})
  416. return tracking_outs
  417. else:
  418. tracking_outs = {
  419. 'online_tlwhs': online_tlwhs,
  420. 'online_scores': online_scores,
  421. 'online_ids': online_ids,
  422. }
  423. return tracking_outs
  424. def predict_image(self,
  425. image_list,
  426. run_benchmark=False,
  427. repeats=1,
  428. visual=True,
  429. seq_name=None,
  430. reuse_det_result=False,
  431. frame_count=0):
  432. num_classes = self.num_classes
  433. image_list.sort()
  434. ids2names = self.pred_config.labels
  435. if self.do_mtmct:
  436. mot_features_dict = {} # cid_tid_fid feats
  437. else:
  438. mot_results = []
  439. for frame_id, img_file in enumerate(image_list):
  440. if self.do_mtmct:
  441. if frame_id % 10 == 0:
  442. print('Tracking frame: %d' % (frame_id))
  443. batch_image_list = [img_file] # bs=1 in MOT model
  444. frame, _ = decode_image(img_file, {})
  445. if run_benchmark:
  446. # preprocess
  447. inputs = self.preprocess(batch_image_list) # warmup
  448. self.det_times.preprocess_time_s.start()
  449. inputs = self.preprocess(batch_image_list)
  450. self.det_times.preprocess_time_s.end()
  451. # model prediction
  452. result_warmup = self.predict(repeats=repeats) # warmup
  453. self.det_times.inference_time_s.start()
  454. result = self.predict(repeats=repeats)
  455. self.det_times.inference_time_s.end(repeats=repeats)
  456. # postprocess
  457. result_warmup = self.postprocess(inputs, result) # warmup
  458. self.det_times.postprocess_time_s.start()
  459. det_result = self.postprocess(inputs, result)
  460. self.det_times.postprocess_time_s.end()
  461. # tracking
  462. if self.use_reid:
  463. det_result['frame_id'] = frame_id
  464. det_result['seq_name'] = seq_name
  465. det_result['ori_image'] = frame
  466. det_result = self.reidprocess(det_result)
  467. if self.use_botsort_tracker:
  468. result_warmup = self.tracking(det_result, batch_image_list)
  469. else:
  470. result_warmup = self.tracking(det_result)
  471. self.det_times.tracking_time_s.start()
  472. if self.use_reid:
  473. det_result = self.reidprocess(det_result)
  474. tracking_outs = self.tracking(det_result)
  475. self.det_times.tracking_time_s.end()
  476. self.det_times.img_num += 1
  477. cm, gm, gu = get_current_memory_mb()
  478. self.cpu_mem += cm
  479. self.gpu_mem += gm
  480. self.gpu_util += gu
  481. else:
  482. if frame_count > self.warmup_frame:
  483. self.det_times.preprocess_time_s.start()
  484. if not reuse_det_result:
  485. inputs = self.preprocess(batch_image_list)
  486. if frame_count > self.warmup_frame:
  487. self.det_times.preprocess_time_s.end()
  488. if frame_count > self.warmup_frame:
  489. self.det_times.inference_time_s.start()
  490. if not reuse_det_result:
  491. result = self.predict()
  492. if frame_count > self.warmup_frame:
  493. self.det_times.inference_time_s.end()
  494. if frame_count > self.warmup_frame:
  495. self.det_times.postprocess_time_s.start()
  496. if not reuse_det_result:
  497. det_result = self.postprocess(inputs, result)
  498. self.previous_det_result = det_result
  499. else:
  500. assert self.previous_det_result is not None
  501. det_result = self.previous_det_result
  502. if frame_count > self.warmup_frame:
  503. self.det_times.postprocess_time_s.end()
  504. # tracking process
  505. if frame_count > self.warmup_frame:
  506. self.det_times.tracking_time_s.start()
  507. if self.use_reid:
  508. det_result['frame_id'] = frame_id
  509. det_result['seq_name'] = seq_name
  510. det_result['ori_image'] = frame
  511. det_result = self.reidprocess(det_result)
  512. if self.use_botsort_tracker:
  513. tracking_outs = self.tracking(det_result, batch_image_list)
  514. else:
  515. tracking_outs = self.tracking(det_result)
  516. if frame_count > self.warmup_frame:
  517. self.det_times.tracking_time_s.end()
  518. self.det_times.img_num += 1
  519. online_tlwhs = tracking_outs['online_tlwhs']
  520. online_scores = tracking_outs['online_scores']
  521. online_ids = tracking_outs['online_ids']
  522. if self.do_mtmct:
  523. feat_data_dict = tracking_outs['feat_data']
  524. mot_features_dict = dict(mot_features_dict, **feat_data_dict)
  525. else:
  526. mot_results.append([online_tlwhs, online_scores, online_ids])
  527. if visual:
  528. if len(image_list) > 1 and frame_id % 10 == 0:
  529. print('Tracking frame {}'.format(frame_id))
  530. frame, _ = decode_image(img_file, {})
  531. if isinstance(online_tlwhs, defaultdict):
  532. im = plot_tracking_dict(
  533. frame,
  534. num_classes,
  535. online_tlwhs,
  536. online_ids,
  537. online_scores,
  538. frame_id=frame_id,
  539. ids2names=ids2names)
  540. else:
  541. im = plot_tracking(
  542. frame,
  543. online_tlwhs,
  544. online_ids,
  545. online_scores,
  546. frame_id=frame_id,
  547. ids2names=ids2names)
  548. save_dir = os.path.join(self.output_dir, seq_name)
  549. if not os.path.exists(save_dir):
  550. os.makedirs(save_dir)
  551. cv2.imwrite(
  552. os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
  553. if self.do_mtmct:
  554. return mot_features_dict
  555. else:
  556. return mot_results
  557. def predict_video(self, video_file, camera_id):
  558. video_out_name = 'output.mp4'
  559. if camera_id != -1:
  560. capture = cv2.VideoCapture(camera_id)
  561. else:
  562. capture = cv2.VideoCapture(video_file)
  563. video_out_name = os.path.split(video_file)[-1]
  564. # Get Video info : resolution, fps, frame count
  565. width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  566. height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  567. fps = int(capture.get(cv2.CAP_PROP_FPS))
  568. frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
  569. print("fps: %d, frame_count: %d" % (fps, frame_count))
  570. if not os.path.exists(self.output_dir):
  571. os.makedirs(self.output_dir)
  572. out_path = os.path.join(self.output_dir, video_out_name)
  573. video_format = 'mp4v'
  574. fourcc = cv2.VideoWriter_fourcc(*video_format)
  575. writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
  576. frame_id = 0
  577. timer = MOTTimer()
  578. results = defaultdict(list)
  579. num_classes = self.num_classes
  580. data_type = 'mcmot' if num_classes > 1 else 'mot'
  581. ids2names = self.pred_config.labels
  582. center_traj = None
  583. entrance = None
  584. records = None
  585. if self.draw_center_traj:
  586. center_traj = [{} for i in range(num_classes)]
  587. if num_classes == 1:
  588. id_set = set()
  589. interval_id_set = set()
  590. in_id_list = list()
  591. out_id_list = list()
  592. prev_center = dict()
  593. records = list()
  594. if self.do_entrance_counting or self.do_break_in_counting:
  595. if self.region_type == 'horizontal':
  596. entrance = [0, height / 2., width, height / 2.]
  597. elif self.region_type == 'vertical':
  598. entrance = [width / 2, 0., width / 2, height]
  599. elif self.region_type == 'custom':
  600. entrance = []
  601. assert len(
  602. self.region_polygon
  603. ) % 2 == 0, "region_polygon should be pairs of coords points when do break_in counting."
  604. for i in range(0, len(self.region_polygon), 2):
  605. entrance.append([
  606. self.region_polygon[i], self.region_polygon[i + 1]
  607. ])
  608. entrance.append([width, height])
  609. else:
  610. raise ValueError("region_type:{} is not supported.".format(
  611. self.region_type))
  612. video_fps = fps
  613. while (1):
  614. ret, frame = capture.read()
  615. if not ret:
  616. break
  617. if frame_id % 10 == 0:
  618. print('Tracking frame: %d' % (frame_id))
  619. timer.tic()
  620. mot_skip_frame_num = self.skip_frame_num
  621. reuse_det_result = False
  622. if mot_skip_frame_num > 1 and frame_id > 0 and frame_id % mot_skip_frame_num > 0:
  623. reuse_det_result = True
  624. seq_name = video_out_name.split('.')[0]
  625. mot_results = self.predict_image(
  626. [frame],
  627. visual=False,
  628. seq_name=seq_name,
  629. reuse_det_result=reuse_det_result,
  630. frame_count=frame_id)
  631. timer.toc()
  632. # bs=1 in MOT model
  633. online_tlwhs, online_scores, online_ids = mot_results[0]
  634. # flow statistic for one class, and only for bytetracker
  635. if num_classes == 1 and not self.use_deepsort_tracker and not self.use_ocsort_tracker:
  636. result = (frame_id + 1, online_tlwhs[0], online_scores[0],
  637. online_ids[0])
  638. statistic = flow_statistic(
  639. result,
  640. self.secs_interval,
  641. self.do_entrance_counting,
  642. self.do_break_in_counting,
  643. self.region_type,
  644. video_fps,
  645. entrance,
  646. id_set,
  647. interval_id_set,
  648. in_id_list,
  649. out_id_list,
  650. prev_center,
  651. records,
  652. data_type,
  653. ids2names=self.pred_config.labels)
  654. records = statistic['records']
  655. fps = 1. / timer.duration
  656. if self.use_deepsort_tracker or self.use_ocsort_tracker or self.use_botsort_tracker:
  657. # use DeepSORTTracker or OCSORTTracker, only support singe class
  658. if isinstance(online_tlwhs, defaultdict):
  659. online_tlwhs = online_tlwhs[0]
  660. online_scores = online_scores[0]
  661. online_ids = online_ids[0]
  662. results[0].append(
  663. (frame_id + 1, online_tlwhs, online_scores, online_ids))
  664. im = plot_tracking(
  665. frame,
  666. online_tlwhs,
  667. online_ids,
  668. online_scores,
  669. frame_id=frame_id,
  670. fps=fps,
  671. ids2names=ids2names,
  672. do_entrance_counting=self.do_entrance_counting,
  673. entrance=entrance)
  674. else:
  675. # use ByteTracker, support multiple class
  676. for cls_id in range(num_classes):
  677. results[cls_id].append(
  678. (frame_id + 1, online_tlwhs[cls_id],
  679. online_scores[cls_id], online_ids[cls_id]))
  680. im = plot_tracking_dict(
  681. frame,
  682. num_classes,
  683. online_tlwhs,
  684. online_ids,
  685. online_scores,
  686. frame_id=frame_id,
  687. fps=fps,
  688. ids2names=ids2names,
  689. do_entrance_counting=self.do_entrance_counting,
  690. entrance=entrance,
  691. records=records,
  692. center_traj=center_traj)
  693. writer.write(im)
  694. if camera_id != -1:
  695. cv2.imshow('Mask Detection', im)
  696. if cv2.waitKey(1) & 0xFF == ord('q'):
  697. break
  698. frame_id += 1
  699. if self.save_mot_txts:
  700. result_filename = os.path.join(
  701. self.output_dir, video_out_name.split('.')[-2] + '.txt')
  702. write_mot_results(result_filename, results)
  703. result_filename = os.path.join(
  704. self.output_dir,
  705. video_out_name.split('.')[-2] + '_flow_statistic.txt')
  706. f = open(result_filename, 'w')
  707. for line in records:
  708. f.write(line)
  709. print('Flow statistic save in {}'.format(result_filename))
  710. f.close()
  711. writer.release()
  712. def predict_mtmct(self, mtmct_dir, mtmct_cfg):
  713. cameras_bias = mtmct_cfg['cameras_bias']
  714. cid_bias = parse_bias(cameras_bias)
  715. scene_cluster = list(cid_bias.keys())
  716. # 1.zone releated parameters
  717. use_zone = mtmct_cfg.get('use_zone', False)
  718. zone_path = mtmct_cfg.get('zone_path', None)
  719. # 2.tricks parameters, can be used for other mtmct dataset
  720. use_ff = mtmct_cfg.get('use_ff', False)
  721. use_rerank = mtmct_cfg.get('use_rerank', False)
  722. # 3.camera releated parameters
  723. use_camera = mtmct_cfg.get('use_camera', False)
  724. use_st_filter = mtmct_cfg.get('use_st_filter', False)
  725. # 4.zone releated parameters
  726. use_roi = mtmct_cfg.get('use_roi', False)
  727. roi_dir = mtmct_cfg.get('roi_dir', False)
  728. mot_list_breaks = []
  729. cid_tid_dict = dict()
  730. output_dir = self.output_dir
  731. if not os.path.exists(output_dir):
  732. os.makedirs(output_dir)
  733. seqs = os.listdir(mtmct_dir)
  734. for seq in sorted(seqs):
  735. fpath = os.path.join(mtmct_dir, seq)
  736. if os.path.isfile(fpath) and _is_valid_video(fpath):
  737. seq = seq.split('.')[-2]
  738. print('ffmpeg processing of video {}'.format(fpath))
  739. frames_path = video2frames(
  740. video_path=fpath, outpath=mtmct_dir, frame_rate=25)
  741. fpath = os.path.join(mtmct_dir, seq)
  742. if os.path.isdir(fpath) == False:
  743. print('{} is not a image folder.'.format(fpath))
  744. continue
  745. if os.path.exists(os.path.join(fpath, 'img1')):
  746. fpath = os.path.join(fpath, 'img1')
  747. assert os.path.isdir(fpath), '{} should be a directory'.format(
  748. fpath)
  749. image_list = glob.glob(os.path.join(fpath, '*.jpg'))
  750. image_list.sort()
  751. assert len(image_list) > 0, '{} has no images.'.format(fpath)
  752. print('start tracking seq: {}'.format(seq))
  753. mot_features_dict = self.predict_image(
  754. image_list, visual=False, seq_name=seq)
  755. cid = int(re.sub('[a-z,A-Z]', "", seq))
  756. tid_data, mot_list_break = trajectory_fusion(
  757. mot_features_dict,
  758. cid,
  759. cid_bias,
  760. use_zone=use_zone,
  761. zone_path=zone_path)
  762. mot_list_breaks.append(mot_list_break)
  763. # single seq process
  764. for line in tid_data:
  765. tracklet = tid_data[line]
  766. tid = tracklet['tid']
  767. if (cid, tid) not in cid_tid_dict:
  768. cid_tid_dict[(cid, tid)] = tracklet
  769. map_tid = sub_cluster(
  770. cid_tid_dict,
  771. scene_cluster,
  772. use_ff=use_ff,
  773. use_rerank=use_rerank,
  774. use_camera=use_camera,
  775. use_st_filter=use_st_filter)
  776. pred_mtmct_file = os.path.join(output_dir, 'mtmct_result.txt')
  777. if use_camera:
  778. gen_res(pred_mtmct_file, scene_cluster, map_tid, mot_list_breaks)
  779. else:
  780. gen_res(
  781. pred_mtmct_file,
  782. scene_cluster,
  783. map_tid,
  784. mot_list_breaks,
  785. use_roi=use_roi,
  786. roi_dir=roi_dir)
  787. camera_results, cid_tid_fid_res = get_mtmct_matching_results(
  788. pred_mtmct_file)
  789. crops_dir = os.path.join(output_dir, 'mtmct_crops')
  790. save_mtmct_crops(
  791. cid_tid_fid_res, images_dir=mtmct_dir, crops_dir=crops_dir)
  792. save_dir = os.path.join(output_dir, 'mtmct_vis')
  793. save_mtmct_vis_results(
  794. camera_results,
  795. images_dir=mtmct_dir,
  796. save_dir=save_dir,
  797. save_videos=FLAGS.save_images)
  798. def main():
  799. deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
  800. with open(deploy_file) as f:
  801. yml_conf = yaml.safe_load(f)
  802. arch = yml_conf['arch']
  803. detector = SDE_Detector(
  804. FLAGS.model_dir,
  805. tracker_config=FLAGS.tracker_config,
  806. device=FLAGS.device,
  807. run_mode=FLAGS.run_mode,
  808. batch_size=1,
  809. trt_min_shape=FLAGS.trt_min_shape,
  810. trt_max_shape=FLAGS.trt_max_shape,
  811. trt_opt_shape=FLAGS.trt_opt_shape,
  812. trt_calib_mode=FLAGS.trt_calib_mode,
  813. cpu_threads=FLAGS.cpu_threads,
  814. enable_mkldnn=FLAGS.enable_mkldnn,
  815. output_dir=FLAGS.output_dir,
  816. threshold=FLAGS.threshold,
  817. save_images=FLAGS.save_images,
  818. save_mot_txts=FLAGS.save_mot_txts,
  819. draw_center_traj=FLAGS.draw_center_traj,
  820. secs_interval=FLAGS.secs_interval,
  821. skip_frame_num=FLAGS.skip_frame_num,
  822. warmup_frame=FLAGS.warmup_frame,
  823. do_entrance_counting=FLAGS.do_entrance_counting,
  824. do_break_in_counting=FLAGS.do_break_in_counting,
  825. region_type=FLAGS.region_type,
  826. region_polygon=FLAGS.region_polygon,
  827. reid_model_dir=FLAGS.reid_model_dir,
  828. mtmct_dir=FLAGS.mtmct_dir, )
  829. # predict from video file or camera video stream
  830. if FLAGS.video_file is not None or FLAGS.camera_id != -1:
  831. detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
  832. detector.det_times.info(average=True)
  833. elif FLAGS.mtmct_dir is not None:
  834. with open(FLAGS.mtmct_cfg) as f:
  835. mtmct_cfg = yaml.safe_load(f)
  836. detector.predict_mtmct(FLAGS.mtmct_dir, mtmct_cfg)
  837. else:
  838. # predict from image
  839. if FLAGS.image_dir is None and FLAGS.image_file is not None:
  840. assert FLAGS.batch_size == 1, "--batch_size should be 1 in MOT models."
  841. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  842. seq_name = FLAGS.image_dir.split('/')[-1]
  843. detector.predict_image(
  844. img_list, FLAGS.run_benchmark, repeats=10, seq_name=seq_name)
  845. if not FLAGS.run_benchmark:
  846. detector.det_times.info(average=True)
  847. else:
  848. mode = FLAGS.run_mode
  849. model_dir = FLAGS.model_dir
  850. model_info = {
  851. 'model_name': model_dir.strip('/').split('/')[-1],
  852. 'precision': mode.split('_')[-1]
  853. }
  854. bench_log(detector, img_list, model_info, name='MOT')
  855. if __name__ == '__main__':
  856. paddle.enable_static()
  857. parser = argsparser()
  858. FLAGS = parser.parse_args()
  859. print_arguments(FLAGS)
  860. FLAGS.device = FLAGS.device.upper()
  861. assert FLAGS.device in ['CPU', 'GPU', 'XPU'
  862. ], "device should be CPU, GPU or XPU"
  863. main()