deepsort_tracker.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/nwojke/deep_sort/blob/master/deep_sort/tracker.py
  16. """
  17. import numpy as np
  18. from ..motion import KalmanFilter
  19. from ..matching.deepsort_matching import NearestNeighborDistanceMetric
  20. from ..matching.deepsort_matching import iou_cost, min_cost_matching, matching_cascade, gate_cost_matrix
  21. from .base_sde_tracker import Track
  22. from ..utils import Detection
  23. from ppdet.core.workspace import register, serializable
  24. from ppdet.utils.logger import setup_logger
  25. logger = setup_logger(__name__)
  26. __all__ = ['DeepSORTTracker']
  27. @register
  28. @serializable
  29. class DeepSORTTracker(object):
  30. """
  31. DeepSORT tracker
  32. Args:
  33. input_size (list): input feature map size to reid model, [h, w] format,
  34. [64, 192] as default.
  35. min_box_area (int): min box area to filter out low quality boxes
  36. vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
  37. bad results, set 1.6 default for pedestrian tracking. If set <=0
  38. means no need to filter bboxes.
  39. budget (int): If not None, fix samples per class to at most this number.
  40. Removes the oldest samples when the budget is reached.
  41. max_age (int): maximum number of missed misses before a track is deleted
  42. n_init (float): Number of frames that a track remains in initialization
  43. phase. Number of consecutive detections before the track is confirmed.
  44. The track state is set to `Deleted` if a miss occurs within the first
  45. `n_init` frames.
  46. metric_type (str): either "euclidean" or "cosine", the distance metric
  47. used for measurement to track association.
  48. matching_threshold (float): samples with larger distance are
  49. considered an invalid match.
  50. max_iou_distance (float): max iou distance threshold
  51. motion (object): KalmanFilter instance
  52. """
  53. def __init__(self,
  54. input_size=[64, 192],
  55. min_box_area=0,
  56. vertical_ratio=-1,
  57. budget=100,
  58. max_age=70,
  59. n_init=3,
  60. metric_type='cosine',
  61. matching_threshold=0.2,
  62. max_iou_distance=0.9,
  63. motion='KalmanFilter'):
  64. self.input_size = input_size
  65. self.min_box_area = min_box_area
  66. self.vertical_ratio = vertical_ratio
  67. self.max_age = max_age
  68. self.n_init = n_init
  69. self.metric = NearestNeighborDistanceMetric(metric_type,
  70. matching_threshold, budget)
  71. self.max_iou_distance = max_iou_distance
  72. if motion == 'KalmanFilter':
  73. self.motion = KalmanFilter()
  74. self.tracks = []
  75. self._next_id = 1
  76. def predict(self):
  77. """
  78. Propagate track state distributions one time step forward.
  79. This function should be called once every time step, before `update`.
  80. """
  81. for track in self.tracks:
  82. track.predict(self.motion)
  83. def update(self, pred_dets, pred_embs):
  84. """
  85. Perform measurement update and track management.
  86. Args:
  87. pred_dets (np.array): Detection results of the image, the shape is
  88. [N, 6], means 'cls_id, score, x0, y0, x1, y1'.
  89. pred_embs (np.array): Embedding results of the image, the shape is
  90. [N, 128], usually pred_embs.shape[1] is a multiple of 128.
  91. """
  92. pred_cls_ids = pred_dets[:, 0:1]
  93. pred_scores = pred_dets[:, 1:2]
  94. pred_xyxys = pred_dets[:, 2:6]
  95. pred_tlwhs = np.concatenate((pred_xyxys[:, 0:2], pred_xyxys[:, 2:4] - pred_xyxys[:, 0:2] + 1), axis=1)
  96. detections = [
  97. Detection(tlwh, score, feat, cls_id)
  98. for tlwh, score, feat, cls_id in zip(pred_tlwhs, pred_scores,
  99. pred_embs, pred_cls_ids)
  100. ]
  101. # Run matching cascade.
  102. matches, unmatched_tracks, unmatched_detections = \
  103. self._match(detections)
  104. # Update track set.
  105. for track_idx, detection_idx in matches:
  106. self.tracks[track_idx].update(self.motion,
  107. detections[detection_idx])
  108. for track_idx in unmatched_tracks:
  109. self.tracks[track_idx].mark_missed()
  110. for detection_idx in unmatched_detections:
  111. self._initiate_track(detections[detection_idx])
  112. self.tracks = [t for t in self.tracks if not t.is_deleted()]
  113. # Update distance metric.
  114. active_targets = [t.track_id for t in self.tracks if t.is_confirmed()]
  115. features, targets = [], []
  116. for track in self.tracks:
  117. if not track.is_confirmed():
  118. continue
  119. features += track.features
  120. targets += [track.track_id for _ in track.features]
  121. track.features = []
  122. self.metric.partial_fit(
  123. np.asarray(features), np.asarray(targets), active_targets)
  124. output_stracks = self.tracks
  125. return output_stracks
  126. def _match(self, detections):
  127. def gated_metric(tracks, dets, track_indices, detection_indices):
  128. features = np.array([dets[i].feature for i in detection_indices])
  129. targets = np.array([tracks[i].track_id for i in track_indices])
  130. cost_matrix = self.metric.distance(features, targets)
  131. cost_matrix = gate_cost_matrix(self.motion, cost_matrix, tracks,
  132. dets, track_indices,
  133. detection_indices)
  134. return cost_matrix
  135. # Split track set into confirmed and unconfirmed tracks.
  136. confirmed_tracks = [
  137. i for i, t in enumerate(self.tracks) if t.is_confirmed()
  138. ]
  139. unconfirmed_tracks = [
  140. i for i, t in enumerate(self.tracks) if not t.is_confirmed()
  141. ]
  142. # Associate confirmed tracks using appearance features.
  143. matches_a, unmatched_tracks_a, unmatched_detections = \
  144. matching_cascade(
  145. gated_metric, self.metric.matching_threshold, self.max_age,
  146. self.tracks, detections, confirmed_tracks)
  147. # Associate remaining tracks together with unconfirmed tracks using IOU.
  148. iou_track_candidates = unconfirmed_tracks + [
  149. k for k in unmatched_tracks_a
  150. if self.tracks[k].time_since_update == 1
  151. ]
  152. unmatched_tracks_a = [
  153. k for k in unmatched_tracks_a
  154. if self.tracks[k].time_since_update != 1
  155. ]
  156. matches_b, unmatched_tracks_b, unmatched_detections = \
  157. min_cost_matching(
  158. iou_cost, self.max_iou_distance, self.tracks,
  159. detections, iou_track_candidates, unmatched_detections)
  160. matches = matches_a + matches_b
  161. unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b))
  162. return matches, unmatched_tracks, unmatched_detections
  163. def _initiate_track(self, detection):
  164. mean, covariance = self.motion.initiate(detection.to_xyah())
  165. self.tracks.append(
  166. Track(mean, covariance, self._next_id, self.n_init, self.max_age,
  167. detection.cls_id, detection.score, detection.feature))
  168. self._next_id += 1