postprocess.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/LCFractal/AIC21-MTMC/tree/main/reid/reid-matching/tools
  16. """
  17. import os
  18. import re
  19. import cv2
  20. from tqdm import tqdm
  21. import numpy as np
  22. try:
  23. import motmetrics as mm
  24. except:
  25. print(
  26. 'Warning: Unable to use motmetrics in MTMCT in PP-Tracking, please install motmetrics, for example: `pip install motmetrics`, see https://github.com/longcw/py-motmetrics'
  27. )
  28. pass
  29. from functools import reduce
  30. from .utils import parse_pt_gt, parse_pt, compare_dataframes_mtmc
  31. from .utils import get_labels, getData, gen_new_mot
  32. from .camera_utils import get_labels_with_camera
  33. from .zone import Zone
  34. from ..visualize import plot_tracking
  35. __all__ = [
  36. 'trajectory_fusion',
  37. 'sub_cluster',
  38. 'gen_res',
  39. 'print_mtmct_result',
  40. 'get_mtmct_matching_results',
  41. 'save_mtmct_crops',
  42. 'save_mtmct_vis_results',
  43. ]
  44. def trajectory_fusion(mot_feature, cid, cid_bias, use_zone=False, zone_path=''):
  45. cur_bias = cid_bias[cid]
  46. mot_list_break = {}
  47. if use_zone:
  48. zones = Zone(zone_path=zone_path)
  49. zones.set_cam(cid)
  50. mot_list = parse_pt(mot_feature, zones)
  51. else:
  52. mot_list = parse_pt(mot_feature)
  53. if use_zone:
  54. mot_list = zones.break_mot(mot_list, cid)
  55. mot_list = zones.filter_mot(mot_list, cid) # filter by zone
  56. mot_list = zones.filter_bbox(mot_list, cid) # filter bbox
  57. mot_list_break = gen_new_mot(mot_list) # save break feature for gen result
  58. tid_data = dict()
  59. for tid in mot_list:
  60. tracklet = mot_list[tid]
  61. if len(tracklet) <= 1:
  62. continue
  63. frame_list = list(tracklet.keys())
  64. frame_list.sort()
  65. # filter area too large
  66. zone_list = [tracklet[f]['zone'] for f in frame_list]
  67. feature_list = [
  68. tracklet[f]['feat'] for f in frame_list
  69. if (tracklet[f]['bbox'][3] - tracklet[f]['bbox'][1]) *
  70. (tracklet[f]['bbox'][2] - tracklet[f]['bbox'][0]) > 2000
  71. ]
  72. if len(feature_list) < 2:
  73. feature_list = [tracklet[f]['feat'] for f in frame_list]
  74. io_time = [
  75. cur_bias + frame_list[0] / 10., cur_bias + frame_list[-1] / 10.
  76. ]
  77. all_feat = np.array([feat for feat in feature_list])
  78. mean_feat = np.mean(all_feat, axis=0)
  79. tid_data[tid] = {
  80. 'cam': cid,
  81. 'tid': tid,
  82. 'mean_feat': mean_feat,
  83. 'zone_list': zone_list,
  84. 'frame_list': frame_list,
  85. 'tracklet': tracklet,
  86. 'io_time': io_time
  87. }
  88. return tid_data, mot_list_break
  89. def sub_cluster(cid_tid_dict,
  90. scene_cluster,
  91. use_ff=True,
  92. use_rerank=True,
  93. use_camera=False,
  94. use_st_filter=False):
  95. '''
  96. cid_tid_dict: all camera_id and track_id
  97. scene_cluster: like [41, 42, 43, 44, 45, 46] in AIC21 MTMCT S06 test videos
  98. '''
  99. assert (len(scene_cluster) != 0), "Error: scene_cluster length equals 0"
  100. cid_tids = sorted(
  101. [key for key in cid_tid_dict.keys() if key[0] in scene_cluster])
  102. if use_camera:
  103. clu = get_labels_with_camera(
  104. cid_tid_dict,
  105. cid_tids,
  106. use_ff=use_ff,
  107. use_rerank=use_rerank,
  108. use_st_filter=use_st_filter)
  109. else:
  110. clu = get_labels(
  111. cid_tid_dict,
  112. cid_tids,
  113. use_ff=use_ff,
  114. use_rerank=use_rerank,
  115. use_st_filter=use_st_filter)
  116. new_clu = list()
  117. for c_list in clu:
  118. if len(c_list) <= 1: continue
  119. cam_list = [cid_tids[c][0] for c in c_list]
  120. if len(cam_list) != len(set(cam_list)): continue
  121. new_clu.append([cid_tids[c] for c in c_list])
  122. all_clu = new_clu
  123. cid_tid_label = dict()
  124. for i, c_list in enumerate(all_clu):
  125. for c in c_list:
  126. cid_tid_label[c] = i + 1
  127. return cid_tid_label
  128. def gen_res(output_dir_filename,
  129. scene_cluster,
  130. map_tid,
  131. mot_list_breaks,
  132. use_roi=False,
  133. roi_dir=''):
  134. f_w = open(output_dir_filename, 'w')
  135. for idx, mot_feature in enumerate(mot_list_breaks):
  136. cid = scene_cluster[idx]
  137. img_rects = parse_pt_gt(mot_feature)
  138. if use_roi:
  139. assert (roi_dir != ''), "Error: roi_dir is not empty!"
  140. roi = cv2.imread(os.path.join(roi_dir, f'c{cid:03d}/roi.jpg'), 0)
  141. height, width = roi.shape
  142. for fid in img_rects:
  143. tid_rects = img_rects[fid]
  144. fid = int(fid) + 1
  145. for tid_rect in tid_rects:
  146. tid = tid_rect[0]
  147. rect = tid_rect[1:]
  148. cx = 0.5 * rect[0] + 0.5 * rect[2]
  149. cy = 0.5 * rect[1] + 0.5 * rect[3]
  150. w = rect[2] - rect[0]
  151. w = min(w * 1.2, w + 40)
  152. h = rect[3] - rect[1]
  153. h = min(h * 1.2, h + 40)
  154. rect[2] -= rect[0]
  155. rect[3] -= rect[1]
  156. rect[0] = max(0, rect[0])
  157. rect[1] = max(0, rect[1])
  158. x1, y1 = max(0, cx - 0.5 * w), max(0, cy - 0.5 * h)
  159. if use_roi:
  160. x2, y2 = min(width, cx + 0.5 * w), min(height, cy + 0.5 * h)
  161. else:
  162. x2, y2 = cx + 0.5 * w, cy + 0.5 * h
  163. w, h = x2 - x1, y2 - y1
  164. new_rect = list(map(int, [x1, y1, w, h]))
  165. rect = list(map(int, rect))
  166. if (cid, tid) in map_tid:
  167. new_tid = map_tid[(cid, tid)]
  168. f_w.write(
  169. str(cid) + ' ' + str(new_tid) + ' ' + str(fid) + ' ' +
  170. ' '.join(map(str, new_rect)) + ' -1 -1'
  171. '\n')
  172. print('gen_res: write file in {}'.format(output_dir_filename))
  173. f_w.close()
  174. def print_mtmct_result(gt_file, pred_file):
  175. names = [
  176. 'CameraId', 'Id', 'FrameId', 'X', 'Y', 'Width', 'Height', 'Xworld',
  177. 'Yworld'
  178. ]
  179. gt = getData(gt_file, names=names)
  180. pred = getData(pred_file, names=names)
  181. summary = compare_dataframes_mtmc(gt, pred)
  182. print('MTMCT summary: ', summary.columns.tolist())
  183. formatters = {
  184. 'idf1': '{:2.2f}'.format,
  185. 'idp': '{:2.2f}'.format,
  186. 'idr': '{:2.2f}'.format,
  187. 'mota': '{:2.2f}'.format
  188. }
  189. summary = summary[['idf1', 'idp', 'idr', 'mota']]
  190. summary.loc[:, 'idp'] *= 100
  191. summary.loc[:, 'idr'] *= 100
  192. summary.loc[:, 'idf1'] *= 100
  193. summary.loc[:, 'mota'] *= 100
  194. try:
  195. import motmetrics as mm
  196. except Exception as e:
  197. raise RuntimeError(
  198. 'Unable to use motmetrics in MTMCT in PP-Tracking, please install motmetrics, for example: `pip install motmetrics`, see https://github.com/longcw/py-motmetrics'
  199. )
  200. print(
  201. mm.io.render_summary(
  202. summary,
  203. formatters=formatters,
  204. namemap=mm.io.motchallenge_metric_names))
  205. def get_mtmct_matching_results(pred_mtmct_file, secs_interval=0.5,
  206. video_fps=20):
  207. res = np.loadtxt(pred_mtmct_file) # 'cid, tid, fid, x1, y1, w, h, -1, -1'
  208. camera_ids = list(map(int, np.unique(res[:, 0])))
  209. res = res[:, :7]
  210. # each line in res: 'cid, tid, fid, x1, y1, w, h'
  211. camera_tids = []
  212. camera_results = dict()
  213. for c_id in camera_ids:
  214. camera_results[c_id] = res[res[:, 0] == c_id]
  215. tids = np.unique(camera_results[c_id][:, 1])
  216. tids = list(map(int, tids))
  217. camera_tids.append(tids)
  218. # select common tids throughout each video
  219. common_tids = reduce(np.intersect1d, camera_tids)
  220. if len(common_tids) == 0:
  221. print(
  222. 'No common tracked ids in these videos, please check your MOT result or select new videos.'
  223. )
  224. return None, None
  225. # get mtmct matching results by cid_tid_fid_results[c_id][t_id][f_id]
  226. cid_tid_fid_results = dict()
  227. cid_tid_to_fids = dict()
  228. interval = int(secs_interval * video_fps) # preferably less than 10
  229. for c_id in camera_ids:
  230. cid_tid_fid_results[c_id] = dict()
  231. cid_tid_to_fids[c_id] = dict()
  232. for t_id in common_tids:
  233. tid_mask = camera_results[c_id][:, 1] == t_id
  234. cid_tid_fid_results[c_id][t_id] = dict()
  235. camera_trackid_results = camera_results[c_id][tid_mask]
  236. fids = np.unique(camera_trackid_results[:, 2])
  237. fids = fids[fids % interval == 0]
  238. fids = list(map(int, fids))
  239. cid_tid_to_fids[c_id][t_id] = fids
  240. for f_id in fids:
  241. st_frame = f_id
  242. ed_frame = f_id + interval
  243. st_mask = camera_trackid_results[:, 2] >= st_frame
  244. ed_mask = camera_trackid_results[:, 2] < ed_frame
  245. frame_mask = np.logical_and(st_mask, ed_mask)
  246. cid_tid_fid_results[c_id][t_id][f_id] = camera_trackid_results[
  247. frame_mask]
  248. return camera_results, cid_tid_fid_results
  249. def save_mtmct_crops(cid_tid_fid_res,
  250. images_dir,
  251. crops_dir,
  252. width=300,
  253. height=200):
  254. camera_ids = cid_tid_fid_res.keys()
  255. seqs_folder = os.listdir(images_dir)
  256. seqs = []
  257. for x in seqs_folder:
  258. if os.path.isdir(os.path.join(images_dir, x)):
  259. seqs.append(x)
  260. assert len(seqs) == len(camera_ids)
  261. seqs.sort()
  262. if not os.path.exists(crops_dir):
  263. os.makedirs(crops_dir)
  264. common_tids = list(cid_tid_fid_res[list(camera_ids)[0]].keys())
  265. # get crops by name 'tid_cid_fid.jpg
  266. for t_id in common_tids:
  267. for i, c_id in enumerate(camera_ids):
  268. infer_dir = os.path.join(images_dir, seqs[i])
  269. if os.path.exists(os.path.join(infer_dir, 'img1')):
  270. infer_dir = os.path.join(infer_dir, 'img1')
  271. all_images = os.listdir(infer_dir)
  272. all_images.sort()
  273. for f_id in cid_tid_fid_res[c_id][t_id].keys():
  274. frame_idx = f_id - 1 if f_id > 0 else 0
  275. im_path = os.path.join(infer_dir, all_images[frame_idx])
  276. im = cv2.imread(im_path) # (H, W, 3)
  277. # only select one track
  278. track = cid_tid_fid_res[c_id][t_id][f_id][0]
  279. cid, tid, fid, x1, y1, w, h = [int(v) for v in track]
  280. clip = im[y1:(y1 + h), x1:(x1 + w)]
  281. clip = cv2.resize(clip, (width, height))
  282. cv2.imwrite(
  283. os.path.join(crops_dir,
  284. 'tid{:06d}_cid{:06d}_fid{:06d}.jpg'.format(
  285. tid, cid, fid)), clip)
  286. print("Finish cropping image of tracked_id {} in camera: {}".format(
  287. t_id, c_id))
  288. def save_mtmct_vis_results(camera_results,
  289. images_dir,
  290. save_dir,
  291. save_videos=False):
  292. # camera_results: 'cid, tid, fid, x1, y1, w, h'
  293. camera_ids = camera_results.keys()
  294. seqs_folder = os.listdir(images_dir)
  295. seqs = []
  296. for x in seqs_folder:
  297. if os.path.isdir(os.path.join(images_dir, x)):
  298. seqs.append(x)
  299. assert len(seqs) == len(camera_ids)
  300. seqs.sort()
  301. if not os.path.exists(save_dir):
  302. os.makedirs(save_dir)
  303. for i, c_id in enumerate(camera_ids):
  304. print("Start visualization for camera {} of sequence {}.".format(
  305. c_id, seqs[i]))
  306. cid_save_dir = os.path.join(save_dir, '{}'.format(seqs[i]))
  307. if not os.path.exists(cid_save_dir):
  308. os.makedirs(cid_save_dir)
  309. infer_dir = os.path.join(images_dir, seqs[i])
  310. if os.path.exists(os.path.join(infer_dir, 'img1')):
  311. infer_dir = os.path.join(infer_dir, 'img1')
  312. all_images = os.listdir(infer_dir)
  313. all_images.sort()
  314. for f_id, im_path in enumerate(all_images):
  315. img = cv2.imread(os.path.join(infer_dir, im_path))
  316. tracks = camera_results[c_id][camera_results[c_id][:, 2] == f_id]
  317. if tracks.shape[0] > 0:
  318. tracked_ids = tracks[:, 1]
  319. xywhs = tracks[:, 3:]
  320. online_im = plot_tracking(
  321. img, xywhs, tracked_ids, scores=None, frame_id=f_id)
  322. else:
  323. online_im = img
  324. print('Frame {} of seq {} has no tracking results'.format(
  325. f_id, seqs[i]))
  326. cv2.imwrite(
  327. os.path.join(cid_save_dir, '{:05d}.jpg'.format(f_id)),
  328. online_im)
  329. if f_id % 40 == 0:
  330. print('Processing frame {}'.format(f_id))
  331. if save_videos:
  332. output_video_path = os.path.join(cid_save_dir, '..',
  333. '{}_mtmct_vis.mp4'.format(seqs[i]))
  334. cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(
  335. cid_save_dir, output_video_path)
  336. os.system(cmd_str)
  337. print('Save camera {} video in {}.'.format(seqs[i],
  338. output_video_path))