mtmct.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from pptracking.python.mot.visualize import plot_tracking
  15. from python.visualize import visualize_attr
  16. import os
  17. import re
  18. import cv2
  19. import gc
  20. import numpy as np
  21. try:
  22. from sklearn import preprocessing
  23. from sklearn.cluster import AgglomerativeClustering
  24. except:
  25. print(
  26. 'Warning: Unable to use MTMCT in PP-Human, please install sklearn, for example: `pip install sklearn`'
  27. )
  28. pass
  29. import pandas as pd
  30. from tqdm import tqdm
  31. from functools import reduce
  32. import warnings
  33. warnings.filterwarnings("ignore")
  34. def gen_restxt(output_dir_filename, map_tid, cid_tid_dict):
  35. pattern = re.compile(r'c(\d)_t(\d)')
  36. f_w = open(output_dir_filename, 'w')
  37. for key, res in cid_tid_dict.items():
  38. cid, tid = pattern.search(key).groups()
  39. cid = int(cid) + 1
  40. rects = res["rects"]
  41. frames = res["frames"]
  42. for idx, bbox in enumerate(rects):
  43. bbox[0][3:] -= bbox[0][1:3]
  44. fid = frames[idx] + 1
  45. rect = [max(int(x), 0) for x in bbox[0][1:]]
  46. if key in map_tid:
  47. new_tid = map_tid[key]
  48. f_w.write(
  49. str(cid) + ' ' + str(new_tid) + ' ' + str(fid) + ' ' +
  50. ' '.join(map(str, rect)) + '\n')
  51. print('gen_res: write file in {}'.format(output_dir_filename))
  52. f_w.close()
  53. def get_mtmct_matching_results(pred_mtmct_file, secs_interval=0.5,
  54. video_fps=20):
  55. res = np.loadtxt(pred_mtmct_file) # 'cid, tid, fid, x1, y1, w, h, -1, -1'
  56. camera_ids = list(map(int, np.unique(res[:, 0])))
  57. res = res[:, :7]
  58. # each line in res: 'cid, tid, fid, x1, y1, w, h'
  59. camera_tids = []
  60. camera_results = dict()
  61. for c_id in camera_ids:
  62. camera_results[c_id] = res[res[:, 0] == c_id]
  63. tids = np.unique(camera_results[c_id][:, 1])
  64. tids = list(map(int, tids))
  65. camera_tids.append(tids)
  66. # select common tids throughout each video
  67. common_tids = reduce(np.intersect1d, camera_tids)
  68. # get mtmct matching results by cid_tid_fid_results[c_id][t_id][f_id]
  69. cid_tid_fid_results = dict()
  70. cid_tid_to_fids = dict()
  71. interval = int(secs_interval * video_fps) # preferably less than 10
  72. for c_id in camera_ids:
  73. cid_tid_fid_results[c_id] = dict()
  74. cid_tid_to_fids[c_id] = dict()
  75. for t_id in common_tids:
  76. tid_mask = camera_results[c_id][:, 1] == t_id
  77. cid_tid_fid_results[c_id][t_id] = dict()
  78. camera_trackid_results = camera_results[c_id][tid_mask]
  79. fids = np.unique(camera_trackid_results[:, 2])
  80. fids = fids[fids % interval == 0]
  81. fids = list(map(int, fids))
  82. cid_tid_to_fids[c_id][t_id] = fids
  83. for f_id in fids:
  84. st_frame = f_id
  85. ed_frame = f_id + interval
  86. st_mask = camera_trackid_results[:, 2] >= st_frame
  87. ed_mask = camera_trackid_results[:, 2] < ed_frame
  88. frame_mask = np.logical_and(st_mask, ed_mask)
  89. cid_tid_fid_results[c_id][t_id][f_id] = camera_trackid_results[
  90. frame_mask]
  91. return camera_results, cid_tid_fid_results
  92. def save_mtmct_vis_results(camera_results, captures, output_dir,
  93. multi_res=None):
  94. # camera_results: 'cid, tid, fid, x1, y1, w, h'
  95. camera_ids = list(camera_results.keys())
  96. import shutil
  97. save_dir = os.path.join(output_dir, 'mtmct_vis')
  98. if os.path.exists(save_dir):
  99. shutil.rmtree(save_dir)
  100. os.makedirs(save_dir)
  101. for idx, video_file in enumerate(captures):
  102. capture = cv2.VideoCapture(video_file)
  103. cid = camera_ids[idx]
  104. basename = os.path.basename(video_file)
  105. video_out_name = "vis_" + basename
  106. out_path = os.path.join(save_dir, video_out_name)
  107. print("Start visualizing output video: {}".format(out_path))
  108. # Get Video info : resolution, fps, frame count
  109. width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  110. height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  111. fps = int(capture.get(cv2.CAP_PROP_FPS))
  112. frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
  113. fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
  114. writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
  115. frame_id = 0
  116. while (1):
  117. if frame_id % 50 == 0:
  118. print('frame id: ', frame_id)
  119. ret, frame = capture.read()
  120. frame_id += 1
  121. if not ret:
  122. if frame_id == 1:
  123. print("video read failed!")
  124. break
  125. frame_results = camera_results[cid][camera_results[cid][:, 2] ==
  126. frame_id]
  127. boxes = frame_results[:, -4:]
  128. ids = frame_results[:, 1]
  129. image = plot_tracking(frame, boxes, ids, frame_id=frame_id, fps=fps)
  130. # add attr vis
  131. if multi_res:
  132. tid_list = multi_res.keys() # c0_t1, c0_t2...
  133. all_attr_result = [multi_res[i]["attrs"]
  134. for i in tid_list] # all cid_tid result
  135. if any(
  136. all_attr_result
  137. ): # at least one cid_tid[attrs] is not None will goes to attrs_vis
  138. attr_res = []
  139. cid_str = 'c' + str(cid - 1) + "_"
  140. for k in tid_list:
  141. if not k.startswith(cid_str):
  142. continue
  143. if (frame_id - 1) >= len(multi_res[k]['attrs']):
  144. t_attr = None
  145. else:
  146. t_attr = multi_res[k]['attrs'][frame_id - 1]
  147. attr_res.append(t_attr)
  148. assert len(attr_res) == len(boxes)
  149. image = visualize_attr(
  150. image, attr_res, boxes, is_mtmct=True)
  151. writer.write(image)
  152. writer.release()
  153. def get_euclidean(x, y, **kwargs):
  154. m = x.shape[0]
  155. n = y.shape[0]
  156. distmat = (np.power(x, 2).sum(axis=1, keepdims=True).repeat(
  157. n, axis=1) + np.power(y, 2).sum(axis=1, keepdims=True).repeat(
  158. m, axis=1).T)
  159. distmat -= np.dot(2 * x, y.T)
  160. return distmat
  161. def cosine_similarity(x, y, eps=1e-12):
  162. """
  163. Computes cosine similarity between two tensors.
  164. Value == 1 means the same vector
  165. Value == 0 means perpendicular vectors
  166. """
  167. x_n, y_n = np.linalg.norm(
  168. x, axis=1, keepdims=True), np.linalg.norm(
  169. y, axis=1, keepdims=True)
  170. x_norm = x / np.maximum(x_n, eps * np.ones_like(x_n))
  171. y_norm = y / np.maximum(y_n, eps * np.ones_like(y_n))
  172. sim_mt = np.dot(x_norm, y_norm.T)
  173. return sim_mt
  174. def get_cosine(x, y, eps=1e-12):
  175. """
  176. Computes cosine distance between two tensors.
  177. The cosine distance is the inverse cosine similarity
  178. -> cosine_distance = abs(-cosine_distance) to make it
  179. similar in behaviour to euclidean distance
  180. """
  181. sim_mt = cosine_similarity(x, y, eps)
  182. return sim_mt
  183. def get_dist_mat(x, y, func_name="euclidean"):
  184. if func_name == "cosine":
  185. dist_mat = get_cosine(x, y)
  186. elif func_name == "euclidean":
  187. dist_mat = get_euclidean(x, y)
  188. print("Using {} as distance function during evaluation".format(func_name))
  189. return dist_mat
  190. def intracam_ignore(st_mask, cid_tids):
  191. count = len(cid_tids)
  192. for i in range(count):
  193. for j in range(count):
  194. if cid_tids[i][1] == cid_tids[j][1]:
  195. st_mask[i, j] = 0.
  196. return st_mask
  197. def get_sim_matrix_new(cid_tid_dict, cid_tids):
  198. # Note: camera independent get_sim_matrix function,
  199. # which is different from the one in camera_utils.py.
  200. count = len(cid_tids)
  201. q_arr = np.array(
  202. [cid_tid_dict[cid_tids[i]]['mean_feat'] for i in range(count)])
  203. g_arr = np.array(
  204. [cid_tid_dict[cid_tids[i]]['mean_feat'] for i in range(count)])
  205. #compute distmat
  206. distmat = get_dist_mat(q_arr, g_arr, func_name="cosine")
  207. #mask the element which belongs to same video
  208. st_mask = np.ones((count, count), dtype=np.float32)
  209. st_mask = intracam_ignore(st_mask, cid_tids)
  210. sim_matrix = distmat * st_mask
  211. np.fill_diagonal(sim_matrix, 0.)
  212. return 1. - sim_matrix
  213. def get_match(cluster_labels):
  214. cluster_dict = dict()
  215. cluster = list()
  216. for i, l in enumerate(cluster_labels):
  217. if l in list(cluster_dict.keys()):
  218. cluster_dict[l].append(i)
  219. else:
  220. cluster_dict[l] = [i]
  221. for idx in cluster_dict:
  222. cluster.append(cluster_dict[idx])
  223. return cluster
  224. def get_cid_tid(cluster_labels, cid_tids):
  225. cluster = list()
  226. for labels in cluster_labels:
  227. cid_tid_list = list()
  228. for label in labels:
  229. cid_tid_list.append(cid_tids[label])
  230. cluster.append(cid_tid_list)
  231. return cluster
  232. def get_labels(cid_tid_dict, cid_tids):
  233. #compute cost matrix between features
  234. cost_matrix = get_sim_matrix_new(cid_tid_dict, cid_tids)
  235. #cluster all the features
  236. cluster1 = AgglomerativeClustering(
  237. n_clusters=None,
  238. distance_threshold=0.5,
  239. affinity='precomputed',
  240. linkage='complete')
  241. cluster_labels1 = cluster1.fit_predict(cost_matrix)
  242. labels = get_match(cluster_labels1)
  243. sub_cluster = get_cid_tid(labels, cid_tids)
  244. return labels
  245. def sub_cluster(cid_tid_dict):
  246. '''
  247. cid_tid_dict: all camera_id and track_id
  248. '''
  249. #get all keys
  250. cid_tids = sorted([key for key in cid_tid_dict.keys()])
  251. #cluster all trackid
  252. clu = get_labels(cid_tid_dict, cid_tids)
  253. #relabel every cluster groups
  254. new_clu = list()
  255. for c_list in clu:
  256. new_clu.append([cid_tids[c] for c in c_list])
  257. cid_tid_label = dict()
  258. for i, c_list in enumerate(new_clu):
  259. for c in c_list:
  260. cid_tid_label[c] = i + 1
  261. return cid_tid_label
  262. def distill_idfeat(mot_res):
  263. qualities_list = mot_res["qualities"]
  264. feature_list = mot_res["features"]
  265. rects = mot_res["rects"]
  266. qualities_new = []
  267. feature_new = []
  268. #filter rect less than 100*20
  269. for idx, rect in enumerate(rects):
  270. conf, xmin, ymin, xmax, ymax = rect[0]
  271. if (xmax - xmin) * (ymax - ymin) and (xmax > xmin) > 2000:
  272. qualities_new.append(qualities_list[idx])
  273. feature_new.append(feature_list[idx])
  274. #take all features if available rect is less than 2
  275. if len(qualities_new) < 2:
  276. qualities_new = qualities_list
  277. feature_new = feature_list
  278. #if available frames number is more than 200, take one frame data per 20 frames
  279. skipf = 1
  280. if len(qualities_new) > 20:
  281. skipf = 2
  282. quality_skip = np.array(qualities_new[::skipf])
  283. feature_skip = np.array(feature_new[::skipf])
  284. #sort features with image qualities, take the most trustworth features
  285. topk_argq = np.argsort(quality_skip)[::-1]
  286. if (quality_skip > 0.6).sum() > 1:
  287. topk_feat = feature_skip[topk_argq[quality_skip > 0.6]]
  288. else:
  289. topk_feat = feature_skip[topk_argq]
  290. #get final features by mean or cluster, at most take five
  291. mean_feat = np.mean(topk_feat[:5], axis=0)
  292. return mean_feat
  293. def res2dict(multi_res):
  294. cid_tid_dict = {}
  295. for cid, c_res in enumerate(multi_res):
  296. for tid, res in c_res.items():
  297. key = "c" + str(cid) + "_t" + str(tid)
  298. if key not in cid_tid_dict:
  299. if len(res["features"]) == 0:
  300. continue
  301. cid_tid_dict[key] = res
  302. cid_tid_dict[key]['mean_feat'] = distill_idfeat(res)
  303. return cid_tid_dict
  304. def mtmct_process(multi_res, captures, mtmct_vis=True, output_dir="output"):
  305. cid_tid_dict = res2dict(multi_res)
  306. if len(cid_tid_dict) == 0:
  307. print("no tracking result found, mtmct will be skiped.")
  308. return
  309. map_tid = sub_cluster(cid_tid_dict)
  310. if not os.path.exists(output_dir):
  311. os.mkdir(output_dir)
  312. pred_mtmct_file = os.path.join(output_dir, 'mtmct_result.txt')
  313. gen_restxt(pred_mtmct_file, map_tid, cid_tid_dict)
  314. if mtmct_vis:
  315. camera_results, cid_tid_fid_res = get_mtmct_matching_results(
  316. pred_mtmct_file)
  317. save_mtmct_vis_results(
  318. camera_results,
  319. captures,
  320. output_dir=output_dir,
  321. multi_res=cid_tid_dict)