jde_tracker.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
  16. """
  17. import numpy as np
  18. from collections import defaultdict
  19. from ..matching import jde_matching as matching
  20. from ..motion import KalmanFilter
  21. from .base_jde_tracker import TrackState, STrack
  22. from .base_jde_tracker import joint_stracks, sub_stracks, remove_duplicate_stracks
  23. from ppdet.core.workspace import register, serializable
  24. from ppdet.utils.logger import setup_logger
  25. logger = setup_logger(__name__)
  26. __all__ = ['JDETracker']
  27. @register
  28. @serializable
  29. class JDETracker(object):
  30. __shared__ = ['num_classes']
  31. """
  32. JDE tracker, support single class and multi classes
  33. Args:
  34. use_byte (bool): Whether use ByteTracker, default False
  35. num_classes (int): the number of classes
  36. det_thresh (float): threshold of detection score
  37. track_buffer (int): buffer for tracker
  38. min_box_area (int): min box area to filter out low quality boxes
  39. vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
  40. bad results. If set <= 0 means no need to filter bboxes,usually set
  41. 1.6 for pedestrian tracking.
  42. tracked_thresh (float): linear assignment threshold of tracked
  43. stracks and detections
  44. r_tracked_thresh (float): linear assignment threshold of
  45. tracked stracks and unmatched detections
  46. unconfirmed_thresh (float): linear assignment threshold of
  47. unconfirmed stracks and unmatched detections
  48. conf_thres (float): confidence threshold for tracking, also used in
  49. ByteTracker as higher confidence threshold
  50. match_thres (float): linear assignment threshold of tracked
  51. stracks and detections in ByteTracker
  52. low_conf_thres (float): lower confidence threshold for tracking in
  53. ByteTracker
  54. input_size (list): input feature map size to reid model, [h, w] format,
  55. [64, 192] as default.
  56. motion (str): motion model, KalmanFilter as default
  57. metric_type (str): either "euclidean" or "cosine", the distance metric
  58. used for measurement to track association.
  59. """
  60. def __init__(self,
  61. use_byte=False,
  62. num_classes=1,
  63. det_thresh=0.3,
  64. track_buffer=30,
  65. min_box_area=0,
  66. vertical_ratio=0,
  67. tracked_thresh=0.7,
  68. r_tracked_thresh=0.5,
  69. unconfirmed_thresh=0.7,
  70. conf_thres=0,
  71. match_thres=0.8,
  72. low_conf_thres=0.2,
  73. input_size=[64, 192],
  74. motion='KalmanFilter',
  75. metric_type='euclidean'):
  76. self.use_byte = use_byte
  77. self.num_classes = num_classes
  78. self.det_thresh = det_thresh if not use_byte else conf_thres + 0.1
  79. self.track_buffer = track_buffer
  80. self.min_box_area = min_box_area
  81. self.vertical_ratio = vertical_ratio
  82. self.tracked_thresh = tracked_thresh
  83. self.r_tracked_thresh = r_tracked_thresh
  84. self.unconfirmed_thresh = unconfirmed_thresh
  85. self.conf_thres = conf_thres
  86. self.match_thres = match_thres
  87. self.low_conf_thres = low_conf_thres
  88. self.input_size = input_size
  89. if motion == 'KalmanFilter':
  90. self.motion = KalmanFilter()
  91. self.metric_type = metric_type
  92. self.frame_id = 0
  93. self.tracked_tracks_dict = defaultdict(list) # dict(list[STrack])
  94. self.lost_tracks_dict = defaultdict(list) # dict(list[STrack])
  95. self.removed_tracks_dict = defaultdict(list) # dict(list[STrack])
  96. self.max_time_lost = 0
  97. # max_time_lost will be calculated: int(frame_rate / 30.0 * track_buffer)
  98. def update(self, pred_dets, pred_embs=None):
  99. """
  100. Processes the image frame and finds bounding box(detections).
  101. Associates the detection with corresponding tracklets and also handles
  102. lost, removed, refound and active tracklets.
  103. Args:
  104. pred_dets (np.array): Detection results of the image, the shape is
  105. [N, 6], means 'cls_id, score, x0, y0, x1, y1'.
  106. pred_embs (np.array): Embedding results of the image, the shape is
  107. [N, 128] or [N, 512].
  108. Return:
  109. output_stracks_dict (dict(list)): The list contains information
  110. regarding the online_tracklets for the received image tensor.
  111. """
  112. self.frame_id += 1
  113. if self.frame_id == 1:
  114. STrack.init_count(self.num_classes)
  115. activated_tracks_dict = defaultdict(list)
  116. refined_tracks_dict = defaultdict(list)
  117. lost_tracks_dict = defaultdict(list)
  118. removed_tracks_dict = defaultdict(list)
  119. output_tracks_dict = defaultdict(list)
  120. pred_dets_dict = defaultdict(list)
  121. pred_embs_dict = defaultdict(list)
  122. # unify single and multi classes detection and embedding results
  123. for cls_id in range(self.num_classes):
  124. cls_idx = (pred_dets[:, 0:1] == cls_id).squeeze(-1)
  125. pred_dets_dict[cls_id] = pred_dets[cls_idx]
  126. if pred_embs is not None:
  127. pred_embs_dict[cls_id] = pred_embs[cls_idx]
  128. else:
  129. pred_embs_dict[cls_id] = None
  130. for cls_id in range(self.num_classes):
  131. """ Step 1: Get detections by class"""
  132. pred_dets_cls = pred_dets_dict[cls_id]
  133. pred_embs_cls = pred_embs_dict[cls_id]
  134. remain_inds = (pred_dets_cls[:, 1:2] > self.conf_thres).squeeze(-1)
  135. if remain_inds.sum() > 0:
  136. pred_dets_cls = pred_dets_cls[remain_inds]
  137. if pred_embs_cls is None:
  138. # in original ByteTrack
  139. detections = [
  140. STrack(
  141. STrack.tlbr_to_tlwh(tlbrs[2:6]),
  142. tlbrs[1],
  143. cls_id,
  144. 30,
  145. temp_feat=None) for tlbrs in pred_dets_cls
  146. ]
  147. else:
  148. pred_embs_cls = pred_embs_cls[remain_inds]
  149. detections = [
  150. STrack(
  151. STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1], cls_id,
  152. 30, temp_feat) for (tlbrs, temp_feat) in
  153. zip(pred_dets_cls, pred_embs_cls)
  154. ]
  155. else:
  156. detections = []
  157. ''' Add newly detected tracklets to tracked_stracks'''
  158. unconfirmed_dict = defaultdict(list)
  159. tracked_tracks_dict = defaultdict(list)
  160. for track in self.tracked_tracks_dict[cls_id]:
  161. if not track.is_activated:
  162. # previous tracks which are not active in the current frame are added in unconfirmed list
  163. unconfirmed_dict[cls_id].append(track)
  164. else:
  165. # Active tracks are added to the local list 'tracked_stracks'
  166. tracked_tracks_dict[cls_id].append(track)
  167. """ Step 2: First association, with embedding"""
  168. # building tracking pool for the current frame
  169. track_pool_dict = defaultdict(list)
  170. track_pool_dict[cls_id] = joint_stracks(
  171. tracked_tracks_dict[cls_id], self.lost_tracks_dict[cls_id])
  172. # Predict the current location with KalmanFilter
  173. STrack.multi_predict(track_pool_dict[cls_id], self.motion)
  174. if pred_embs_cls is None:
  175. # in original ByteTrack
  176. dists = matching.iou_distance(track_pool_dict[cls_id],
  177. detections)
  178. matches, u_track, u_detection = matching.linear_assignment(
  179. dists, thresh=self.match_thres) # not self.tracked_thresh
  180. else:
  181. dists = matching.embedding_distance(
  182. track_pool_dict[cls_id],
  183. detections,
  184. metric=self.metric_type)
  185. dists = matching.fuse_motion(
  186. self.motion, dists, track_pool_dict[cls_id], detections)
  187. matches, u_track, u_detection = matching.linear_assignment(
  188. dists, thresh=self.tracked_thresh)
  189. for i_tracked, idet in matches:
  190. # i_tracked is the id of the track and idet is the detection
  191. track = track_pool_dict[cls_id][i_tracked]
  192. det = detections[idet]
  193. if track.state == TrackState.Tracked:
  194. # If the track is active, add the detection to the track
  195. track.update(detections[idet], self.frame_id)
  196. activated_tracks_dict[cls_id].append(track)
  197. else:
  198. # We have obtained a detection from a track which is not active,
  199. # hence put the track in refind_stracks list
  200. track.re_activate(det, self.frame_id, new_id=False)
  201. refined_tracks_dict[cls_id].append(track)
  202. # None of the steps below happen if there are no undetected tracks.
  203. """ Step 3: Second association, with IOU"""
  204. if self.use_byte:
  205. inds_low = pred_dets_dict[cls_id][:, 1:2] > self.low_conf_thres
  206. inds_high = pred_dets_dict[cls_id][:, 1:2] < self.conf_thres
  207. inds_second = np.logical_and(inds_low, inds_high).squeeze(-1)
  208. pred_dets_cls_second = pred_dets_dict[cls_id][inds_second]
  209. # association the untrack to the low score detections
  210. if len(pred_dets_cls_second) > 0:
  211. if pred_embs_dict[cls_id] is None:
  212. # in original ByteTrack
  213. detections_second = [
  214. STrack(
  215. STrack.tlbr_to_tlwh(tlbrs[2:6]),
  216. tlbrs[1],
  217. cls_id,
  218. 30,
  219. temp_feat=None)
  220. for tlbrs in pred_dets_cls_second
  221. ]
  222. else:
  223. pred_embs_cls_second = pred_embs_dict[cls_id][
  224. inds_second]
  225. detections_second = [
  226. STrack(
  227. STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1],
  228. cls_id, 30, temp_feat) for (tlbrs, temp_feat) in
  229. zip(pred_dets_cls_second, pred_embs_cls_second)
  230. ]
  231. else:
  232. detections_second = []
  233. r_tracked_stracks = [
  234. track_pool_dict[cls_id][i] for i in u_track
  235. if track_pool_dict[cls_id][i].state == TrackState.Tracked
  236. ]
  237. dists = matching.iou_distance(r_tracked_stracks,
  238. detections_second)
  239. matches, u_track, u_detection_second = matching.linear_assignment(
  240. dists, thresh=0.4) # not r_tracked_thresh
  241. else:
  242. detections = [detections[i] for i in u_detection]
  243. r_tracked_stracks = []
  244. for i in u_track:
  245. if track_pool_dict[cls_id][i].state == TrackState.Tracked:
  246. r_tracked_stracks.append(track_pool_dict[cls_id][i])
  247. dists = matching.iou_distance(r_tracked_stracks, detections)
  248. matches, u_track, u_detection = matching.linear_assignment(
  249. dists, thresh=self.r_tracked_thresh)
  250. for i_tracked, idet in matches:
  251. track = r_tracked_stracks[i_tracked]
  252. det = detections[
  253. idet] if not self.use_byte else detections_second[idet]
  254. if track.state == TrackState.Tracked:
  255. track.update(det, self.frame_id)
  256. activated_tracks_dict[cls_id].append(track)
  257. else:
  258. track.re_activate(det, self.frame_id, new_id=False)
  259. refined_tracks_dict[cls_id].append(track)
  260. for it in u_track:
  261. track = r_tracked_stracks[it]
  262. if not track.state == TrackState.Lost:
  263. track.mark_lost()
  264. lost_tracks_dict[cls_id].append(track)
  265. '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
  266. detections = [detections[i] for i in u_detection]
  267. dists = matching.iou_distance(unconfirmed_dict[cls_id], detections)
  268. matches, u_unconfirmed, u_detection = matching.linear_assignment(
  269. dists, thresh=self.unconfirmed_thresh)
  270. for i_tracked, idet in matches:
  271. unconfirmed_dict[cls_id][i_tracked].update(detections[idet],
  272. self.frame_id)
  273. activated_tracks_dict[cls_id].append(unconfirmed_dict[cls_id][
  274. i_tracked])
  275. for it in u_unconfirmed:
  276. track = unconfirmed_dict[cls_id][it]
  277. track.mark_removed()
  278. removed_tracks_dict[cls_id].append(track)
  279. """ Step 4: Init new stracks"""
  280. for inew in u_detection:
  281. track = detections[inew]
  282. if track.score < self.det_thresh:
  283. continue
  284. track.activate(self.motion, self.frame_id)
  285. activated_tracks_dict[cls_id].append(track)
  286. """ Step 5: Update state"""
  287. for track in self.lost_tracks_dict[cls_id]:
  288. if self.frame_id - track.end_frame > self.max_time_lost:
  289. track.mark_removed()
  290. removed_tracks_dict[cls_id].append(track)
  291. self.tracked_tracks_dict[cls_id] = [
  292. t for t in self.tracked_tracks_dict[cls_id]
  293. if t.state == TrackState.Tracked
  294. ]
  295. self.tracked_tracks_dict[cls_id] = joint_stracks(
  296. self.tracked_tracks_dict[cls_id], activated_tracks_dict[cls_id])
  297. self.tracked_tracks_dict[cls_id] = joint_stracks(
  298. self.tracked_tracks_dict[cls_id], refined_tracks_dict[cls_id])
  299. self.lost_tracks_dict[cls_id] = sub_stracks(
  300. self.lost_tracks_dict[cls_id], self.tracked_tracks_dict[cls_id])
  301. self.lost_tracks_dict[cls_id].extend(lost_tracks_dict[cls_id])
  302. self.lost_tracks_dict[cls_id] = sub_stracks(
  303. self.lost_tracks_dict[cls_id], self.removed_tracks_dict[cls_id])
  304. self.removed_tracks_dict[cls_id].extend(removed_tracks_dict[cls_id])
  305. self.tracked_tracks_dict[cls_id], self.lost_tracks_dict[
  306. cls_id] = remove_duplicate_stracks(
  307. self.tracked_tracks_dict[cls_id],
  308. self.lost_tracks_dict[cls_id])
  309. # get scores of lost tracks
  310. output_tracks_dict[cls_id] = [
  311. track for track in self.tracked_tracks_dict[cls_id]
  312. if track.is_activated
  313. ]
  314. logger.debug('===========Frame {}=========='.format(self.frame_id))
  315. logger.debug('Activated: {}'.format(
  316. [track.track_id for track in activated_tracks_dict[cls_id]]))
  317. logger.debug('Refind: {}'.format(
  318. [track.track_id for track in refined_tracks_dict[cls_id]]))
  319. logger.debug('Lost: {}'.format(
  320. [track.track_id for track in lost_tracks_dict[cls_id]]))
  321. logger.debug('Removed: {}'.format(
  322. [track.track_id for track in removed_tracks_dict[cls_id]]))
  323. return output_tracks_dict