camera_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/LCFractal/AIC21-MTMC/tree/main/reid/reid-matching/tools
  16. Note: The following codes are strongly related to camera parameters of the AIC21 test-set S06,
  17. so they can only be used in S06, and can not be used for other MTMCT datasets.
  18. """
  19. import numpy as np
  20. try:
  21. from sklearn.cluster import AgglomerativeClustering
  22. except:
  23. print(
  24. 'Warning: Unable to use MTMCT in PP-Tracking, please install sklearn, for example: `pip install sklearn`'
  25. )
  26. pass
  27. from .utils import get_dire, get_match, get_cid_tid, combin_feature, combin_cluster
  28. from .utils import normalize, intracam_ignore, visual_rerank
  29. __all__ = [
  30. 'st_filter',
  31. 'get_labels_with_camera',
  32. ]
  33. CAM_DIST = [[0, 40, 55, 100, 120, 145], [40, 0, 15, 60, 80, 105],
  34. [55, 15, 0, 40, 65, 90], [100, 60, 40, 0, 20, 45],
  35. [120, 80, 65, 20, 0, 25], [145, 105, 90, 45, 25, 0]]
  36. def st_filter(st_mask, cid_tids, cid_tid_dict):
  37. count = len(cid_tids)
  38. for i in range(count):
  39. i_tracklet = cid_tid_dict[cid_tids[i]]
  40. i_cid = i_tracklet['cam']
  41. i_dire = get_dire(i_tracklet['zone_list'], i_cid)
  42. i_iot = i_tracklet['io_time']
  43. for j in range(count):
  44. j_tracklet = cid_tid_dict[cid_tids[j]]
  45. j_cid = j_tracklet['cam']
  46. j_dire = get_dire(j_tracklet['zone_list'], j_cid)
  47. j_iot = j_tracklet['io_time']
  48. match_dire = True
  49. cam_dist = CAM_DIST[i_cid - 41][j_cid - 41]
  50. # if time overlopped
  51. if i_iot[0] - cam_dist < j_iot[0] and j_iot[0] < i_iot[
  52. 1] + cam_dist:
  53. match_dire = False
  54. if i_iot[0] - cam_dist < j_iot[1] and j_iot[1] < i_iot[
  55. 1] + cam_dist:
  56. match_dire = False
  57. # not match after go out
  58. if i_dire[1] in [1, 2]: # i out
  59. if i_iot[0] < j_iot[1] + cam_dist:
  60. match_dire = False
  61. if i_dire[1] in [1, 2]:
  62. if i_dire[0] in [3] and i_cid > j_cid:
  63. match_dire = False
  64. if i_dire[0] in [4] and i_cid < j_cid:
  65. match_dire = False
  66. if i_cid in [41] and i_dire[1] in [4]:
  67. if i_iot[0] < j_iot[1] + cam_dist:
  68. match_dire = False
  69. if i_iot[1] > 199:
  70. match_dire = False
  71. if i_cid in [46] and i_dire[1] in [3]:
  72. if i_iot[0] < j_iot[1] + cam_dist:
  73. match_dire = False
  74. # match after come into
  75. if i_dire[0] in [1, 2]:
  76. if i_iot[1] > j_iot[0] - cam_dist:
  77. match_dire = False
  78. if i_dire[0] in [1, 2]:
  79. if i_dire[1] in [3] and i_cid > j_cid:
  80. match_dire = False
  81. if i_dire[1] in [4] and i_cid < j_cid:
  82. match_dire = False
  83. is_ignore = False
  84. if ((i_dire[0] == i_dire[1] and i_dire[0] in [3, 4]) or
  85. (j_dire[0] == j_dire[1] and j_dire[0] in [3, 4])):
  86. is_ignore = True
  87. if not is_ignore:
  88. # direction conflict
  89. if (i_dire[0] in [3] and j_dire[0] in [4]) or (
  90. i_dire[1] in [3] and j_dire[1] in [4]):
  91. match_dire = False
  92. # filter before going next scene
  93. if i_dire[1] in [3] and i_cid < j_cid:
  94. if i_iot[1] > j_iot[1] - cam_dist:
  95. match_dire = False
  96. if i_dire[1] in [4] and i_cid > j_cid:
  97. if i_iot[1] > j_iot[1] - cam_dist:
  98. match_dire = False
  99. if i_dire[0] in [3] and i_cid < j_cid:
  100. if i_iot[0] < j_iot[0] + cam_dist:
  101. match_dire = False
  102. if i_dire[0] in [4] and i_cid > j_cid:
  103. if i_iot[0] < j_iot[0] + cam_dist:
  104. match_dire = False
  105. ## 3-30
  106. ## 4-1
  107. if i_dire[0] in [3] and i_cid > j_cid:
  108. if i_iot[1] > j_iot[0] - cam_dist:
  109. match_dire = False
  110. if i_dire[0] in [4] and i_cid < j_cid:
  111. if i_iot[1] > j_iot[0] - cam_dist:
  112. match_dire = False
  113. # filter before going next scene
  114. ## 4-7
  115. if i_dire[1] in [3] and i_cid > j_cid:
  116. if i_iot[0] < j_iot[1] + cam_dist:
  117. match_dire = False
  118. if i_dire[1] in [4] and i_cid < j_cid:
  119. if i_iot[0] < j_iot[1] + cam_dist:
  120. match_dire = False
  121. else:
  122. if i_iot[1] > 199:
  123. if i_dire[0] in [3] and i_cid < j_cid:
  124. if i_iot[0] < j_iot[0] + cam_dist:
  125. match_dire = False
  126. if i_dire[0] in [4] and i_cid > j_cid:
  127. if i_iot[0] < j_iot[0] + cam_dist:
  128. match_dire = False
  129. if i_dire[0] in [3] and i_cid > j_cid:
  130. match_dire = False
  131. if i_dire[0] in [4] and i_cid < j_cid:
  132. match_dire = False
  133. if i_iot[0] < 1:
  134. if i_dire[1] in [3] and i_cid > j_cid:
  135. match_dire = False
  136. if i_dire[1] in [4] and i_cid < j_cid:
  137. match_dire = False
  138. if not match_dire:
  139. st_mask[i, j] = 0.0
  140. st_mask[j, i] = 0.0
  141. return st_mask
  142. def subcam_list(cid_tid_dict, cid_tids):
  143. sub_3_4 = dict()
  144. sub_4_3 = dict()
  145. for cid_tid in cid_tids:
  146. cid, tid = cid_tid
  147. tracklet = cid_tid_dict[cid_tid]
  148. zs, ze = get_dire(tracklet['zone_list'], cid)
  149. if zs in [3] and cid not in [46]: # 4 to 3
  150. if not cid + 1 in sub_4_3:
  151. sub_4_3[cid + 1] = []
  152. sub_4_3[cid + 1].append(cid_tid)
  153. if ze in [4] and cid not in [41]: # 4 to 3
  154. if not cid in sub_4_3:
  155. sub_4_3[cid] = []
  156. sub_4_3[cid].append(cid_tid)
  157. if zs in [4] and cid not in [41]: # 3 to 4
  158. if not cid - 1 in sub_3_4:
  159. sub_3_4[cid - 1] = []
  160. sub_3_4[cid - 1].append(cid_tid)
  161. if ze in [3] and cid not in [46]: # 3 to 4
  162. if not cid in sub_3_4:
  163. sub_3_4[cid] = []
  164. sub_3_4[cid].append(cid_tid)
  165. sub_cid_tids = dict()
  166. for i in sub_3_4:
  167. sub_cid_tids[(i, i + 1)] = sub_3_4[i]
  168. for i in sub_4_3:
  169. sub_cid_tids[(i, i - 1)] = sub_4_3[i]
  170. return sub_cid_tids
  171. def subcam_list2(cid_tid_dict, cid_tids):
  172. sub_dict = dict()
  173. for cid_tid in cid_tids:
  174. cid, tid = cid_tid
  175. if cid not in [41]:
  176. if not cid in sub_dict:
  177. sub_dict[cid] = []
  178. sub_dict[cid].append(cid_tid)
  179. if cid not in [46]:
  180. if not cid + 1 in sub_dict:
  181. sub_dict[cid + 1] = []
  182. sub_dict[cid + 1].append(cid_tid)
  183. return sub_dict
  184. def get_sim_matrix(cid_tid_dict,
  185. cid_tids,
  186. use_ff=True,
  187. use_rerank=True,
  188. use_st_filter=False):
  189. # Note: camera releated get_sim_matrix function,
  190. # which is different from the one in utils.py.
  191. count = len(cid_tids)
  192. q_arr = np.array(
  193. [cid_tid_dict[cid_tids[i]]['mean_feat'] for i in range(count)])
  194. g_arr = np.array(
  195. [cid_tid_dict[cid_tids[i]]['mean_feat'] for i in range(count)])
  196. q_arr = normalize(q_arr, axis=1)
  197. g_arr = normalize(g_arr, axis=1)
  198. st_mask = np.ones((count, count), dtype=np.float32)
  199. st_mask = intracam_ignore(st_mask, cid_tids)
  200. # different from utils.py
  201. if use_st_filter:
  202. st_mask = st_filter(st_mask, cid_tids, cid_tid_dict)
  203. visual_sim_matrix = visual_rerank(
  204. q_arr, g_arr, cid_tids, use_ff=use_ff, use_rerank=use_rerank)
  205. visual_sim_matrix = visual_sim_matrix.astype('float32')
  206. np.set_printoptions(precision=3)
  207. sim_matrix = visual_sim_matrix * st_mask
  208. np.fill_diagonal(sim_matrix, 0)
  209. return sim_matrix
  210. def get_labels_with_camera(cid_tid_dict,
  211. cid_tids,
  212. use_ff=True,
  213. use_rerank=True,
  214. use_st_filter=False):
  215. # 1st cluster
  216. sub_cid_tids = subcam_list(cid_tid_dict, cid_tids)
  217. sub_labels = dict()
  218. dis_thrs = [0.7, 0.5, 0.5, 0.5, 0.5, 0.7, 0.5, 0.5, 0.5, 0.5]
  219. for i, sub_c_to_c in enumerate(sub_cid_tids):
  220. sim_matrix = get_sim_matrix(
  221. cid_tid_dict,
  222. sub_cid_tids[sub_c_to_c],
  223. use_ff=use_ff,
  224. use_rerank=use_rerank,
  225. use_st_filter=use_st_filter)
  226. cluster_labels = AgglomerativeClustering(
  227. n_clusters=None,
  228. distance_threshold=1 - dis_thrs[i],
  229. affinity='precomputed',
  230. linkage='complete').fit_predict(1 - sim_matrix)
  231. labels = get_match(cluster_labels)
  232. cluster_cid_tids = get_cid_tid(labels, sub_cid_tids[sub_c_to_c])
  233. sub_labels[sub_c_to_c] = cluster_cid_tids
  234. labels, sub_cluster = combin_cluster(sub_labels, cid_tids)
  235. # 2nd cluster
  236. cid_tid_dict_new = combin_feature(cid_tid_dict, sub_cluster)
  237. sub_cid_tids = subcam_list2(cid_tid_dict_new, cid_tids)
  238. sub_labels = dict()
  239. for i, sub_c_to_c in enumerate(sub_cid_tids):
  240. sim_matrix = get_sim_matrix(
  241. cid_tid_dict_new,
  242. sub_cid_tids[sub_c_to_c],
  243. use_ff=use_ff,
  244. use_rerank=use_rerank,
  245. use_st_filter=use_st_filter)
  246. cluster_labels = AgglomerativeClustering(
  247. n_clusters=None,
  248. distance_threshold=1 - 0.1,
  249. affinity='precomputed',
  250. linkage='complete').fit_predict(1 - sim_matrix)
  251. labels = get_match(cluster_labels)
  252. cluster_cid_tids = get_cid_tid(labels, sub_cid_tids[sub_c_to_c])
  253. sub_labels[sub_c_to_c] = cluster_cid_tids
  254. labels, sub_cluster = combin_cluster(sub_labels, cid_tids)
  255. return labels