deepsort_tracker.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/nwojke/deep_sort/blob/master/deep_sort/tracker.py
  16. """
  17. import numpy as np
  18. from ..motion import KalmanFilter
  19. from ..matching.deepsort_matching import NearestNeighborDistanceMetric
  20. from ..matching.deepsort_matching import iou_cost, min_cost_matching, matching_cascade, gate_cost_matrix
  21. from .base_sde_tracker import Track
  22. from ..utils import Detection
  23. __all__ = ['DeepSORTTracker']
  24. class DeepSORTTracker(object):
  25. """
  26. DeepSORT tracker
  27. Args:
  28. input_size (list): input feature map size to reid model, [h, w] format,
  29. [64, 192] as default.
  30. min_box_area (int): min box area to filter out low quality boxes
  31. vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
  32. bad results, set 1.6 default for pedestrian tracking. If set <=0
  33. means no need to filter bboxes.
  34. budget (int): If not None, fix samples per class to at most this number.
  35. Removes the oldest samples when the budget is reached.
  36. max_age (int): maximum number of missed misses before a track is deleted
  37. n_init (float): Number of frames that a track remains in initialization
  38. phase. Number of consecutive detections before the track is confirmed.
  39. The track state is set to `Deleted` if a miss occurs within the first
  40. `n_init` frames.
  41. metric_type (str): either "euclidean" or "cosine", the distance metric
  42. used for measurement to track association.
  43. matching_threshold (float): samples with larger distance are
  44. considered an invalid match.
  45. max_iou_distance (float): max iou distance threshold
  46. motion (object): KalmanFilter instance
  47. """
  48. def __init__(self,
  49. input_size=[64, 192],
  50. min_box_area=0,
  51. vertical_ratio=-1,
  52. budget=100,
  53. max_age=70,
  54. n_init=3,
  55. metric_type='cosine',
  56. matching_threshold=0.2,
  57. max_iou_distance=0.9,
  58. motion='KalmanFilter'):
  59. self.input_size = input_size
  60. self.min_box_area = min_box_area
  61. self.vertical_ratio = vertical_ratio
  62. self.max_age = max_age
  63. self.n_init = n_init
  64. self.metric = NearestNeighborDistanceMetric(metric_type,
  65. matching_threshold, budget)
  66. self.max_iou_distance = max_iou_distance
  67. if motion == 'KalmanFilter':
  68. self.motion = KalmanFilter()
  69. self.tracks = []
  70. self._next_id = 1
  71. def predict(self):
  72. """
  73. Propagate track state distributions one time step forward.
  74. This function should be called once every time step, before `update`.
  75. """
  76. for track in self.tracks:
  77. track.predict(self.motion)
  78. def update(self, pred_dets, pred_embs):
  79. """
  80. Perform measurement update and track management.
  81. Args:
  82. pred_dets (np.array): Detection results of the image, the shape is
  83. [N, 6], means 'cls_id, score, x0, y0, x1, y1'.
  84. pred_embs (np.array): Embedding results of the image, the shape is
  85. [N, 128], usually pred_embs.shape[1] is a multiple of 128.
  86. """
  87. pred_cls_ids = pred_dets[:, 0:1]
  88. pred_scores = pred_dets[:, 1:2]
  89. pred_xyxys = pred_dets[:, 2:6]
  90. pred_tlwhs = np.concatenate((pred_xyxys[:, 0:2], pred_xyxys[:, 2:4] - pred_xyxys[:, 0:2] + 1), axis=1)
  91. detections = [
  92. Detection(tlwh, score, feat, cls_id)
  93. for tlwh, score, feat, cls_id in zip(pred_tlwhs, pred_scores,
  94. pred_embs, pred_cls_ids)
  95. ]
  96. # Run matching cascade.
  97. matches, unmatched_tracks, unmatched_detections = \
  98. self._match(detections)
  99. # Update track set.
  100. for track_idx, detection_idx in matches:
  101. self.tracks[track_idx].update(self.motion,
  102. detections[detection_idx])
  103. for track_idx in unmatched_tracks:
  104. self.tracks[track_idx].mark_missed()
  105. for detection_idx in unmatched_detections:
  106. self._initiate_track(detections[detection_idx])
  107. self.tracks = [t for t in self.tracks if not t.is_deleted()]
  108. # Update distance metric.
  109. active_targets = [t.track_id for t in self.tracks if t.is_confirmed()]
  110. features, targets = [], []
  111. for track in self.tracks:
  112. if not track.is_confirmed():
  113. continue
  114. features += track.features
  115. targets += [track.track_id for _ in track.features]
  116. track.features = []
  117. self.metric.partial_fit(
  118. np.asarray(features), np.asarray(targets), active_targets)
  119. output_stracks = self.tracks
  120. return output_stracks
  121. def _match(self, detections):
  122. def gated_metric(tracks, dets, track_indices, detection_indices):
  123. features = np.array([dets[i].feature for i in detection_indices])
  124. targets = np.array([tracks[i].track_id for i in track_indices])
  125. cost_matrix = self.metric.distance(features, targets)
  126. cost_matrix = gate_cost_matrix(self.motion, cost_matrix, tracks,
  127. dets, track_indices,
  128. detection_indices)
  129. return cost_matrix
  130. # Split track set into confirmed and unconfirmed tracks.
  131. confirmed_tracks = [
  132. i for i, t in enumerate(self.tracks) if t.is_confirmed()
  133. ]
  134. unconfirmed_tracks = [
  135. i for i, t in enumerate(self.tracks) if not t.is_confirmed()
  136. ]
  137. # Associate confirmed tracks using appearance features.
  138. matches_a, unmatched_tracks_a, unmatched_detections = \
  139. matching_cascade(
  140. gated_metric, self.metric.matching_threshold, self.max_age,
  141. self.tracks, detections, confirmed_tracks)
  142. # Associate remaining tracks together with unconfirmed tracks using IOU.
  143. iou_track_candidates = unconfirmed_tracks + [
  144. k for k in unmatched_tracks_a
  145. if self.tracks[k].time_since_update == 1
  146. ]
  147. unmatched_tracks_a = [
  148. k for k in unmatched_tracks_a
  149. if self.tracks[k].time_since_update != 1
  150. ]
  151. matches_b, unmatched_tracks_b, unmatched_detections = \
  152. min_cost_matching(
  153. iou_cost, self.max_iou_distance, self.tracks,
  154. detections, iou_track_candidates, unmatched_detections)
  155. matches = matches_a + matches_b
  156. unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b))
  157. return matches, unmatched_tracks, unmatched_detections
  158. def _initiate_track(self, detection):
  159. mean, covariance = self.motion.initiate(detection.to_xyah())
  160. self.tracks.append(
  161. Track(mean, covariance, self._next_id, self.n_init, self.max_age,
  162. detection.cls_id, detection.score, detection.feature))
  163. self._next_id += 1