deepsort_matching.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/nwojke/deep_sort/tree/master/deep_sort
  16. """
  17. import numpy as np
  18. from scipy.optimize import linear_sum_assignment
  19. from ..motion import kalman_filter
  20. INFTY_COST = 1e+5
  21. __all__ = [
  22. 'iou_1toN',
  23. 'iou_cost',
  24. '_nn_euclidean_distance',
  25. '_nn_cosine_distance',
  26. 'NearestNeighborDistanceMetric',
  27. 'min_cost_matching',
  28. 'matching_cascade',
  29. 'gate_cost_matrix',
  30. ]
  31. def iou_1toN(bbox, candidates):
  32. """
  33. Computer intersection over union (IoU) by one box to N candidates.
  34. Args:
  35. bbox (ndarray): A bounding box in format `(top left x, top left y, width, height)`.
  36. candidates (ndarray): A matrix of candidate bounding boxes (one per row) in the
  37. same format as `bbox`.
  38. Returns:
  39. ious (ndarray): The intersection over union in [0, 1] between the `bbox`
  40. and each candidate. A higher score means a larger fraction of the
  41. `bbox` is occluded by the candidate.
  42. """
  43. bbox_tl = bbox[:2]
  44. bbox_br = bbox[:2] + bbox[2:]
  45. candidates_tl = candidates[:, :2]
  46. candidates_br = candidates[:, :2] + candidates[:, 2:]
  47. tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis],
  48. np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]]
  49. br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis],
  50. np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]]
  51. wh = np.maximum(0., br - tl)
  52. area_intersection = wh.prod(axis=1)
  53. area_bbox = bbox[2:].prod()
  54. area_candidates = candidates[:, 2:].prod(axis=1)
  55. ious = area_intersection / (area_bbox + area_candidates - area_intersection)
  56. return ious
  57. def iou_cost(tracks, detections, track_indices=None, detection_indices=None):
  58. """
  59. IoU distance metric.
  60. Args:
  61. tracks (list[Track]): A list of tracks.
  62. detections (list[Detection]): A list of detections.
  63. track_indices (Optional[list[int]]): A list of indices to tracks that
  64. should be matched. Defaults to all `tracks`.
  65. detection_indices (Optional[list[int]]): A list of indices to detections
  66. that should be matched. Defaults to all `detections`.
  67. Returns:
  68. cost_matrix (ndarray): A cost matrix of shape len(track_indices),
  69. len(detection_indices) where entry (i, j) is
  70. `1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
  71. """
  72. if track_indices is None:
  73. track_indices = np.arange(len(tracks))
  74. if detection_indices is None:
  75. detection_indices = np.arange(len(detections))
  76. cost_matrix = np.zeros((len(track_indices), len(detection_indices)))
  77. for row, track_idx in enumerate(track_indices):
  78. if tracks[track_idx].time_since_update > 1:
  79. cost_matrix[row, :] = 1e+5
  80. continue
  81. bbox = tracks[track_idx].to_tlwh()
  82. candidates = np.asarray([detections[i].tlwh for i in detection_indices])
  83. cost_matrix[row, :] = 1. - iou_1toN(bbox, candidates)
  84. return cost_matrix
  85. def _nn_euclidean_distance(s, q):
  86. """
  87. Compute pair-wise squared (Euclidean) distance between points in `s` and `q`.
  88. Args:
  89. s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M.
  90. q (ndarray): Query points: an LxM matrix of L samples of dimensionality M.
  91. Returns:
  92. distances (ndarray): A vector of length M that contains for each entry in `q` the
  93. smallest Euclidean distance to a sample in `s`.
  94. """
  95. s, q = np.asarray(s), np.asarray(q)
  96. if len(s) == 0 or len(q) == 0:
  97. return np.zeros((len(s), len(q)))
  98. s2, q2 = np.square(s).sum(axis=1), np.square(q).sum(axis=1)
  99. distances = -2. * np.dot(s, q.T) + s2[:, None] + q2[None, :]
  100. distances = np.clip(distances, 0., float(np.inf))
  101. return np.maximum(0.0, distances.min(axis=0))
  102. def _nn_cosine_distance(s, q):
  103. """
  104. Compute pair-wise cosine distance between points in `s` and `q`.
  105. Args:
  106. s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M.
  107. q (ndarray): Query points: an LxM matrix of L samples of dimensionality M.
  108. Returns:
  109. distances (ndarray): A vector of length M that contains for each entry in `q` the
  110. smallest Euclidean distance to a sample in `s`.
  111. """
  112. s = np.asarray(s) / np.linalg.norm(s, axis=1, keepdims=True)
  113. q = np.asarray(q) / np.linalg.norm(q, axis=1, keepdims=True)
  114. distances = 1. - np.dot(s, q.T)
  115. return distances.min(axis=0)
  116. class NearestNeighborDistanceMetric(object):
  117. """
  118. A nearest neighbor distance metric that, for each target, returns
  119. the closest distance to any sample that has been observed so far.
  120. Args:
  121. metric (str): Either "euclidean" or "cosine".
  122. matching_threshold (float): The matching threshold. Samples with larger
  123. distance are considered an invalid match.
  124. budget (Optional[int]): If not None, fix samples per class to at most
  125. this number. Removes the oldest samples when the budget is reached.
  126. Attributes:
  127. samples (Dict[int -> List[ndarray]]): A dictionary that maps from target
  128. identities to the list of samples that have been observed so far.
  129. """
  130. def __init__(self, metric, matching_threshold, budget=None):
  131. if metric == "euclidean":
  132. self._metric = _nn_euclidean_distance
  133. elif metric == "cosine":
  134. self._metric = _nn_cosine_distance
  135. else:
  136. raise ValueError(
  137. "Invalid metric; must be either 'euclidean' or 'cosine'")
  138. self.matching_threshold = matching_threshold
  139. self.budget = budget
  140. self.samples = {}
  141. def partial_fit(self, features, targets, active_targets):
  142. """
  143. Update the distance metric with new data.
  144. Args:
  145. features (ndarray): An NxM matrix of N features of dimensionality M.
  146. targets (ndarray): An integer array of associated target identities.
  147. active_targets (List[int]): A list of targets that are currently
  148. present in the scene.
  149. """
  150. for feature, target in zip(features, targets):
  151. self.samples.setdefault(target, []).append(feature)
  152. if self.budget is not None:
  153. self.samples[target] = self.samples[target][-self.budget:]
  154. self.samples = {k: self.samples[k] for k in active_targets}
  155. def distance(self, features, targets):
  156. """
  157. Compute distance between features and targets.
  158. Args:
  159. features (ndarray): An NxM matrix of N features of dimensionality M.
  160. targets (list[int]): A list of targets to match the given `features` against.
  161. Returns:
  162. cost_matrix (ndarray): a cost matrix of shape len(targets), len(features),
  163. where element (i, j) contains the closest squared distance between
  164. `targets[i]` and `features[j]`.
  165. """
  166. cost_matrix = np.zeros((len(targets), len(features)))
  167. for i, target in enumerate(targets):
  168. cost_matrix[i, :] = self._metric(self.samples[target], features)
  169. return cost_matrix
  170. def min_cost_matching(distance_metric,
  171. max_distance,
  172. tracks,
  173. detections,
  174. track_indices=None,
  175. detection_indices=None):
  176. """
  177. Solve linear assignment problem.
  178. Args:
  179. distance_metric :
  180. Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
  181. The distance metric is given a list of tracks and detections as
  182. well as a list of N track indices and M detection indices. The
  183. metric should return the NxM dimensional cost matrix, where element
  184. (i, j) is the association cost between the i-th track in the given
  185. track indices and the j-th detection in the given detection_indices.
  186. max_distance (float): Gating threshold. Associations with cost larger
  187. than this value are disregarded.
  188. tracks (list[Track]): A list of predicted tracks at the current time
  189. step.
  190. detections (list[Detection]): A list of detections at the current time
  191. step.
  192. track_indices (list[int]): List of track indices that maps rows in
  193. `cost_matrix` to tracks in `tracks`.
  194. detection_indices (List[int]): List of detection indices that maps
  195. columns in `cost_matrix` to detections in `detections`.
  196. Returns:
  197. A tuple (List[(int, int)], List[int], List[int]) with the following
  198. three entries:
  199. * A list of matched track and detection indices.
  200. * A list of unmatched track indices.
  201. * A list of unmatched detection indices.
  202. """
  203. if track_indices is None:
  204. track_indices = np.arange(len(tracks))
  205. if detection_indices is None:
  206. detection_indices = np.arange(len(detections))
  207. if len(detection_indices) == 0 or len(track_indices) == 0:
  208. return [], track_indices, detection_indices # Nothing to match.
  209. cost_matrix = distance_metric(tracks, detections, track_indices,
  210. detection_indices)
  211. cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
  212. indices = linear_sum_assignment(cost_matrix)
  213. matches, unmatched_tracks, unmatched_detections = [], [], []
  214. for col, detection_idx in enumerate(detection_indices):
  215. if col not in indices[1]:
  216. unmatched_detections.append(detection_idx)
  217. for row, track_idx in enumerate(track_indices):
  218. if row not in indices[0]:
  219. unmatched_tracks.append(track_idx)
  220. for row, col in zip(indices[0], indices[1]):
  221. track_idx = track_indices[row]
  222. detection_idx = detection_indices[col]
  223. if cost_matrix[row, col] > max_distance:
  224. unmatched_tracks.append(track_idx)
  225. unmatched_detections.append(detection_idx)
  226. else:
  227. matches.append((track_idx, detection_idx))
  228. return matches, unmatched_tracks, unmatched_detections
  229. def matching_cascade(distance_metric,
  230. max_distance,
  231. cascade_depth,
  232. tracks,
  233. detections,
  234. track_indices=None,
  235. detection_indices=None):
  236. """
  237. Run matching cascade.
  238. Args:
  239. distance_metric :
  240. Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
  241. The distance metric is given a list of tracks and detections as
  242. well as a list of N track indices and M detection indices. The
  243. metric should return the NxM dimensional cost matrix, where element
  244. (i, j) is the association cost between the i-th track in the given
  245. track indices and the j-th detection in the given detection_indices.
  246. max_distance (float): Gating threshold. Associations with cost larger
  247. than this value are disregarded.
  248. cascade_depth (int): The cascade depth, should be se to the maximum
  249. track age.
  250. tracks (list[Track]): A list of predicted tracks at the current time
  251. step.
  252. detections (list[Detection]): A list of detections at the current time
  253. step.
  254. track_indices (list[int]): List of track indices that maps rows in
  255. `cost_matrix` to tracks in `tracks`.
  256. detection_indices (List[int]): List of detection indices that maps
  257. columns in `cost_matrix` to detections in `detections`.
  258. Returns:
  259. A tuple (List[(int, int)], List[int], List[int]) with the following
  260. three entries:
  261. * A list of matched track and detection indices.
  262. * A list of unmatched track indices.
  263. * A list of unmatched detection indices.
  264. """
  265. if track_indices is None:
  266. track_indices = list(range(len(tracks)))
  267. if detection_indices is None:
  268. detection_indices = list(range(len(detections)))
  269. unmatched_detections = detection_indices
  270. matches = []
  271. for level in range(cascade_depth):
  272. if len(unmatched_detections) == 0: # No detections left
  273. break
  274. track_indices_l = [
  275. k for k in track_indices if tracks[k].time_since_update == 1 + level
  276. ]
  277. if len(track_indices_l) == 0: # Nothing to match at this level
  278. continue
  279. matches_l, _, unmatched_detections = \
  280. min_cost_matching(
  281. distance_metric, max_distance, tracks, detections,
  282. track_indices_l, unmatched_detections)
  283. matches += matches_l
  284. unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches))
  285. return matches, unmatched_tracks, unmatched_detections
  286. def gate_cost_matrix(kf,
  287. cost_matrix,
  288. tracks,
  289. detections,
  290. track_indices,
  291. detection_indices,
  292. gated_cost=INFTY_COST,
  293. only_position=False):
  294. """
  295. Invalidate infeasible entries in cost matrix based on the state
  296. distributions obtained by Kalman filtering.
  297. Args:
  298. kf (object): The Kalman filter.
  299. cost_matrix (ndarray): The NxM dimensional cost matrix, where N is the
  300. number of track indices and M is the number of detection indices,
  301. such that entry (i, j) is the association cost between
  302. `tracks[track_indices[i]]` and `detections[detection_indices[j]]`.
  303. tracks (list[Track]): A list of predicted tracks at the current time
  304. step.
  305. detections (list[Detection]): A list of detections at the current time
  306. step.
  307. track_indices (List[int]): List of track indices that maps rows in
  308. `cost_matrix` to tracks in `tracks`.
  309. detection_indices (List[int]): List of detection indices that maps
  310. columns in `cost_matrix` to detections in `detections`.
  311. gated_cost (Optional[float]): Entries in the cost matrix corresponding
  312. to infeasible associations are set this value. Defaults to a very
  313. large value.
  314. only_position (Optional[bool]): If True, only the x, y position of the
  315. state distribution is considered during gating. Default False.
  316. """
  317. gating_dim = 2 if only_position else 4
  318. gating_threshold = kalman_filter.chi2inv95[gating_dim]
  319. measurements = np.asarray(
  320. [detections[i].to_xyah() for i in detection_indices])
  321. for row, track_idx in enumerate(track_indices):
  322. track = tracks[track_idx]
  323. gating_distance = kf.gating_distance(track.mean, track.covariance,
  324. measurements, only_position)
  325. cost_matrix[row, gating_distance > gating_threshold] = gated_cost
  326. return cost_matrix