zone.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/LCFractal/AIC21-MTMC/tree/main/reid/reid-matching/tools
  16. Note: The following codes are strongly related to zone of the AIC21 test-set S06,
  17. so they can only be used in S06, and can not be used for other MTMCT datasets.
  18. """
  19. import os
  20. import cv2
  21. import numpy as np
  22. try:
  23. from sklearn.cluster import AgglomerativeClustering
  24. except:
  25. print(
  26. 'Warning: Unable to use MTMCT in PP-Tracking, please install sklearn, for example: `pip install sklearn`'
  27. )
  28. pass
  29. BBOX_B = 10 / 15
  30. class Zone(object):
  31. def __init__(self, zone_path='datasets/zone'):
  32. # 0: b 1: g 3: r 123:w
  33. # w r not high speed
  34. # b g high speed
  35. assert zone_path != '', "Error: zone_path is not empty!"
  36. zones = {}
  37. for img_name in os.listdir(zone_path):
  38. camnum = int(img_name.split('.')[0][-3:])
  39. zone_img = cv2.imread(os.path.join(zone_path, img_name))
  40. zones[camnum] = zone_img
  41. self.zones = zones
  42. self.current_cam = 0
  43. def set_cam(self, cam):
  44. self.current_cam = cam
  45. def get_zone(self, bbox):
  46. cx = int((bbox[0] + bbox[2]) / 2)
  47. cy = int((bbox[1] + bbox[3]) / 2)
  48. pix = self.zones[self.current_cam][max(cy - 1, 0), max(cx - 1, 0), :]
  49. zone_num = 0
  50. if pix[0] > 50 and pix[1] > 50 and pix[2] > 50: # w
  51. zone_num = 1
  52. if pix[0] < 50 and pix[1] < 50 and pix[2] > 50: # r
  53. zone_num = 2
  54. if pix[0] < 50 and pix[1] > 50 and pix[2] < 50: # g
  55. zone_num = 3
  56. if pix[0] > 50 and pix[1] < 50 and pix[2] < 50: # b
  57. zone_num = 4
  58. return zone_num
  59. def is_ignore(self, zone_list, frame_list, cid):
  60. # 0 not in any corssroad, 1 white 2 red 3 green 4 bule
  61. zs, ze = zone_list[0], zone_list[-1]
  62. fs, fe = frame_list[0], frame_list[-1]
  63. if zs == ze:
  64. # if always on one section, excluding
  65. if ze in [1, 2]:
  66. return 2
  67. if zs != 0 and 0 in zone_list:
  68. return 0
  69. if fe - fs > 1500:
  70. return 2
  71. if fs < 2:
  72. if cid in [45]:
  73. if ze in [3, 4]:
  74. return 1
  75. else:
  76. return 2
  77. if fe > 1999:
  78. if cid in [41]:
  79. if ze not in [3]:
  80. return 2
  81. else:
  82. return 0
  83. if fs < 2 or fe > 1999:
  84. if ze in [3, 4]:
  85. return 0
  86. if ze in [3, 4]:
  87. return 1
  88. return 2
  89. else:
  90. # if camera section change
  91. if cid in [41, 42, 43, 44, 45, 46]:
  92. # come from road extension, exclusing
  93. if zs == 1 and ze == 2:
  94. return 2
  95. if zs == 2 and ze == 1:
  96. return 2
  97. if cid in [41]:
  98. # On 41 camera, no vehicle come into 42 camera
  99. if (zs in [1, 2]) and ze == 4:
  100. return 2
  101. if zs == 4 and (ze in [1, 2]):
  102. return 2
  103. if cid in [46]:
  104. # On 46 camera,no vehicle come into 45
  105. if (zs in [1, 2]) and ze == 3:
  106. return 2
  107. if zs == 3 and (ze in [1, 2]):
  108. return 2
  109. return 0
  110. def filter_mot(self, mot_list, cid):
  111. new_mot_list = dict()
  112. sub_mot_list = dict()
  113. for tracklet in mot_list:
  114. tracklet_dict = mot_list[tracklet]
  115. frame_list = list(tracklet_dict.keys())
  116. frame_list.sort()
  117. zone_list = []
  118. for f in frame_list:
  119. zone_list.append(tracklet_dict[f]['zone'])
  120. if self.is_ignore(zone_list, frame_list, cid) == 0:
  121. new_mot_list[tracklet] = tracklet_dict
  122. if self.is_ignore(zone_list, frame_list, cid) == 1:
  123. sub_mot_list[tracklet] = tracklet_dict
  124. return new_mot_list
  125. def filter_bbox(self, mot_list, cid):
  126. new_mot_list = dict()
  127. yh = self.zones[cid].shape[0]
  128. for tracklet in mot_list:
  129. tracklet_dict = mot_list[tracklet]
  130. frame_list = list(tracklet_dict.keys())
  131. frame_list.sort()
  132. bbox_list = []
  133. for f in frame_list:
  134. bbox_list.append(tracklet_dict[f]['bbox'])
  135. bbox_x = [b[0] for b in bbox_list]
  136. bbox_y = [b[1] for b in bbox_list]
  137. bbox_w = [b[2] - b[0] for b in bbox_list]
  138. bbox_h = [b[3] - b[1] for b in bbox_list]
  139. new_frame_list = list()
  140. if 0 in bbox_x or 0 in bbox_y:
  141. b0 = [
  142. i for i, f in enumerate(frame_list)
  143. if bbox_x[i] < 5 or bbox_y[i] + bbox_h[i] > yh - 5
  144. ]
  145. if len(b0) == len(frame_list):
  146. if cid in [41, 42, 44, 45, 46]:
  147. continue
  148. max_w = max(bbox_w)
  149. max_h = max(bbox_h)
  150. for i, f in enumerate(frame_list):
  151. if bbox_w[i] > max_w * BBOX_B and bbox_h[
  152. i] > max_h * BBOX_B:
  153. new_frame_list.append(f)
  154. else:
  155. l_i, r_i = 0, len(frame_list) - 1
  156. if len(b0) == 0:
  157. continue
  158. if b0[0] == 0:
  159. for i in range(len(b0) - 1):
  160. if b0[i] + 1 == b0[i + 1]:
  161. l_i = b0[i + 1]
  162. else:
  163. break
  164. if b0[-1] == len(frame_list) - 1:
  165. for i in range(len(b0) - 1):
  166. i = len(b0) - 1 - i
  167. if b0[i] - 1 == b0[i - 1]:
  168. r_i = b0[i - 1]
  169. else:
  170. break
  171. max_lw, max_lh = bbox_w[l_i], bbox_h[l_i]
  172. max_rw, max_rh = bbox_w[r_i], bbox_h[r_i]
  173. for i, f in enumerate(frame_list):
  174. if i < l_i:
  175. if bbox_w[i] > max_lw * BBOX_B and bbox_h[
  176. i] > max_lh * BBOX_B:
  177. new_frame_list.append(f)
  178. elif i > r_i:
  179. if bbox_w[i] > max_rw * BBOX_B and bbox_h[
  180. i] > max_rh * BBOX_B:
  181. new_frame_list.append(f)
  182. else:
  183. new_frame_list.append(f)
  184. new_tracklet_dict = dict()
  185. for f in new_frame_list:
  186. new_tracklet_dict[f] = tracklet_dict[f]
  187. new_mot_list[tracklet] = new_tracklet_dict
  188. else:
  189. new_mot_list[tracklet] = tracklet_dict
  190. return new_mot_list
  191. def break_mot(self, mot_list, cid):
  192. new_mot_list = dict()
  193. new_num_tracklets = max(mot_list) + 1
  194. for tracklet in mot_list:
  195. tracklet_dict = mot_list[tracklet]
  196. frame_list = list(tracklet_dict.keys())
  197. frame_list.sort()
  198. zone_list = []
  199. back_tracklet = False
  200. new_zone_f = 0
  201. pre_frame = frame_list[0]
  202. time_break = False
  203. for f in frame_list:
  204. if f - pre_frame > 100:
  205. if cid in [44, 45]:
  206. time_break = True
  207. break
  208. if not cid in [41, 44, 45, 46]:
  209. break
  210. pre_frame = f
  211. new_zone = tracklet_dict[f]['zone']
  212. if len(zone_list) > 0 and zone_list[-1] == new_zone:
  213. continue
  214. if new_zone_f > 1:
  215. if len(zone_list) > 1 and new_zone in zone_list:
  216. back_tracklet = True
  217. zone_list.append(new_zone)
  218. new_zone_f = 0
  219. else:
  220. new_zone_f += 1
  221. if back_tracklet:
  222. new_tracklet_dict = dict()
  223. pre_bbox = -1
  224. pre_arrow = 0
  225. have_break = False
  226. for f in frame_list:
  227. now_bbox = tracklet_dict[f]['bbox']
  228. if type(pre_bbox) == int:
  229. if pre_bbox == -1:
  230. pre_bbox = now_bbox
  231. now_arrow = now_bbox[0] - pre_bbox[0]
  232. if pre_arrow * now_arrow < 0 and len(
  233. new_tracklet_dict) > 15 and not have_break:
  234. new_mot_list[tracklet] = new_tracklet_dict
  235. new_tracklet_dict = dict()
  236. have_break = True
  237. if have_break:
  238. tracklet_dict[f]['id'] = new_num_tracklets
  239. new_tracklet_dict[f] = tracklet_dict[f]
  240. pre_bbox, pre_arrow = now_bbox, now_arrow
  241. if have_break:
  242. new_mot_list[new_num_tracklets] = new_tracklet_dict
  243. new_num_tracklets += 1
  244. else:
  245. new_mot_list[tracklet] = new_tracklet_dict
  246. elif time_break:
  247. new_tracklet_dict = dict()
  248. have_break = False
  249. pre_frame = frame_list[0]
  250. for f in frame_list:
  251. if f - pre_frame > 100:
  252. new_mot_list[tracklet] = new_tracklet_dict
  253. new_tracklet_dict = dict()
  254. have_break = True
  255. new_tracklet_dict[f] = tracklet_dict[f]
  256. pre_frame = f
  257. if have_break:
  258. new_mot_list[new_num_tracklets] = new_tracklet_dict
  259. new_num_tracklets += 1
  260. else:
  261. new_mot_list[tracklet] = new_tracklet_dict
  262. else:
  263. new_mot_list[tracklet] = tracklet_dict
  264. return new_mot_list
  265. def intra_matching(self, mot_list, sub_mot_list):
  266. sub_zone_dict = dict()
  267. new_mot_list = dict()
  268. new_mot_list, new_sub_mot_list = self.do_intra_matching2(mot_list,
  269. sub_mot_list)
  270. return new_mot_list
  271. def do_intra_matching2(self, mot_list, sub_list):
  272. new_zone_dict = dict()
  273. def get_trac_info(tracklet1):
  274. t1_f = list(tracklet1)
  275. t1_f.sort()
  276. t1_fs = t1_f[0]
  277. t1_fe = t1_f[-1]
  278. t1_zs = tracklet1[t1_fs]['zone']
  279. t1_ze = tracklet1[t1_fe]['zone']
  280. t1_boxs = tracklet1[t1_fs]['bbox']
  281. t1_boxe = tracklet1[t1_fe]['bbox']
  282. t1_boxs = [(t1_boxs[2] + t1_boxs[0]) / 2,
  283. (t1_boxs[3] + t1_boxs[1]) / 2]
  284. t1_boxe = [(t1_boxe[2] + t1_boxe[0]) / 2,
  285. (t1_boxe[3] + t1_boxe[1]) / 2]
  286. return t1_fs, t1_fe, t1_zs, t1_ze, t1_boxs, t1_boxe
  287. for t1id in sub_list:
  288. tracklet1 = sub_list[t1id]
  289. if tracklet1 == -1:
  290. continue
  291. t1_fs, t1_fe, t1_zs, t1_ze, t1_boxs, t1_boxe = get_trac_info(
  292. tracklet1)
  293. sim_dict = dict()
  294. for t2id in mot_list:
  295. tracklet2 = mot_list[t2id]
  296. t2_fs, t2_fe, t2_zs, t2_ze, t2_boxs, t2_boxe = get_trac_info(
  297. tracklet2)
  298. if t1_ze == t2_zs:
  299. if abs(t2_fs - t1_fe) < 5 and abs(t2_boxe[0] - t1_boxs[
  300. 0]) < 50 and abs(t2_boxe[1] - t1_boxs[1]) < 50:
  301. t1_feat = tracklet1[t1_fe]['feat']
  302. t2_feat = tracklet2[t2_fs]['feat']
  303. sim_dict[t2id] = np.matmul(t1_feat, t2_feat)
  304. if t1_zs == t2_ze:
  305. if abs(t2_fe - t1_fs) < 5 and abs(t2_boxs[0] - t1_boxe[
  306. 0]) < 50 and abs(t2_boxs[1] - t1_boxe[1]) < 50:
  307. t1_feat = tracklet1[t1_fs]['feat']
  308. t2_feat = tracklet2[t2_fe]['feat']
  309. sim_dict[t2id] = np.matmul(t1_feat, t2_feat)
  310. if len(sim_dict) > 0:
  311. max_sim = 0
  312. max_id = 0
  313. for t2id in sim_dict:
  314. if sim_dict[t2id] > max_sim:
  315. sim_dict[t2id] = max_sim
  316. max_id = t2id
  317. if max_sim > 0.5:
  318. t2 = mot_list[max_id]
  319. for t1f in tracklet1:
  320. if t1f not in t2:
  321. tracklet1[t1f]['id'] = max_id
  322. t2[t1f] = tracklet1[t1f]
  323. mot_list[max_id] = t2
  324. sub_list[t1id] = -1
  325. return mot_list, sub_list
  326. def do_intra_matching(self, sub_zone_dict, sub_zone):
  327. new_zone_dict = dict()
  328. id_list = list(sub_zone_dict)
  329. id2index = dict()
  330. for index, id in enumerate(id_list):
  331. id2index[id] = index
  332. def get_trac_info(tracklet1):
  333. t1_f = list(tracklet1)
  334. t1_f.sort()
  335. t1_fs = t1_f[0]
  336. t1_fe = t1_f[-1]
  337. t1_zs = tracklet1[t1_fs]['zone']
  338. t1_ze = tracklet1[t1_fe]['zone']
  339. t1_boxs = tracklet1[t1_fs]['bbox']
  340. t1_boxe = tracklet1[t1_fe]['bbox']
  341. t1_boxs = [(t1_boxs[2] + t1_boxs[0]) / 2,
  342. (t1_boxs[3] + t1_boxs[1]) / 2]
  343. t1_boxe = [(t1_boxe[2] + t1_boxe[0]) / 2,
  344. (t1_boxe[3] + t1_boxe[1]) / 2]
  345. return t1_fs, t1_fe, t1_zs, t1_ze, t1_boxs, t1_boxe
  346. sim_matrix = np.zeros([len(id_list), len(id_list)])
  347. for t1id in sub_zone_dict:
  348. tracklet1 = sub_zone_dict[t1id]
  349. t1_fs, t1_fe, t1_zs, t1_ze, t1_boxs, t1_boxe = get_trac_info(
  350. tracklet1)
  351. t1_feat = tracklet1[t1_fe]['feat']
  352. for t2id in sub_zone_dict:
  353. if t1id == t2id:
  354. continue
  355. tracklet2 = sub_zone_dict[t2id]
  356. t2_fs, t2_fe, t2_zs, t2_ze, t2_boxs, t2_boxe = get_trac_info(
  357. tracklet2)
  358. if t1_zs != t1_ze and t2_ze != t2_zs or t1_fe > t2_fs:
  359. continue
  360. if abs(t1_boxe[0] - t2_boxs[0]) > 50 or abs(t1_boxe[1] -
  361. t2_boxs[1]) > 50:
  362. continue
  363. if t2_fs - t1_fe > 5:
  364. continue
  365. t2_feat = tracklet2[t2_fs]['feat']
  366. sim_matrix[id2index[t1id], id2index[t2id]] = np.matmul(t1_feat,
  367. t2_feat)
  368. sim_matrix[id2index[t2id], id2index[t1id]] = np.matmul(t1_feat,
  369. t2_feat)
  370. sim_matrix = 1 - sim_matrix
  371. cluster_labels = AgglomerativeClustering(
  372. n_clusters=None,
  373. distance_threshold=0.7,
  374. affinity='precomputed',
  375. linkage='complete').fit_predict(sim_matrix)
  376. new_zone_dict = dict()
  377. label2id = dict()
  378. for index, label in enumerate(cluster_labels):
  379. tracklet = sub_zone_dict[id_list[index]]
  380. if label not in label2id:
  381. new_id = tracklet[list(tracklet)[0]]
  382. new_tracklet = dict()
  383. else:
  384. new_id = label2id[label]
  385. new_tracklet = new_zone_dict[label2id[label]]
  386. for tf in tracklet:
  387. tracklet[tf]['id'] = new_id
  388. new_tracklet[tf] = tracklet[tf]
  389. new_zone_dict[label] = new_tracklet
  390. return new_zone_dict