jde_matching.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/matching.py
  16. """
  17. try:
  18. import lap
  19. except:
  20. print(
  21. 'Warning: Unable to use JDE/FairMOT/ByteTrack, please install lap, for example: `pip install lap`, see https://github.com/gatagat/lap'
  22. )
  23. pass
  24. import scipy
  25. import numpy as np
  26. from scipy.spatial.distance import cdist
  27. from ..motion import kalman_filter
  28. import warnings
  29. warnings.filterwarnings("ignore")
  30. __all__ = [
  31. 'merge_matches',
  32. 'linear_assignment',
  33. 'bbox_ious',
  34. 'iou_distance',
  35. 'embedding_distance',
  36. 'fuse_motion',
  37. ]
  38. def merge_matches(m1, m2, shape):
  39. O, P, Q = shape
  40. m1 = np.asarray(m1)
  41. m2 = np.asarray(m2)
  42. M1 = scipy.sparse.coo_matrix(
  43. (np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P))
  44. M2 = scipy.sparse.coo_matrix(
  45. (np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q))
  46. mask = M1 * M2
  47. match = mask.nonzero()
  48. match = list(zip(match[0], match[1]))
  49. unmatched_O = tuple(set(range(O)) - set([i for i, j in match]))
  50. unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match]))
  51. return match, unmatched_O, unmatched_Q
  52. def linear_assignment(cost_matrix, thresh):
  53. try:
  54. import lap
  55. except Exception as e:
  56. raise RuntimeError(
  57. 'Unable to use JDE/FairMOT/ByteTrack, please install lap, for example: `pip install lap`, see https://github.com/gatagat/lap'
  58. )
  59. if cost_matrix.size == 0:
  60. return np.empty(
  61. (0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(
  62. range(cost_matrix.shape[1]))
  63. matches, unmatched_a, unmatched_b = [], [], []
  64. cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
  65. for ix, mx in enumerate(x):
  66. if mx >= 0:
  67. matches.append([ix, mx])
  68. unmatched_a = np.where(x < 0)[0]
  69. unmatched_b = np.where(y < 0)[0]
  70. matches = np.asarray(matches)
  71. return matches, unmatched_a, unmatched_b
  72. def bbox_ious(atlbrs, btlbrs):
  73. boxes = np.ascontiguousarray(atlbrs, dtype=np.float32)
  74. query_boxes = np.ascontiguousarray(btlbrs, dtype=np.float32)
  75. N = boxes.shape[0]
  76. K = query_boxes.shape[0]
  77. ious = np.zeros((N, K), dtype=boxes.dtype)
  78. if N * K == 0:
  79. return ious
  80. for k in range(K):
  81. box_area = ((query_boxes[k, 2] - query_boxes[k, 0] + 1) *
  82. (query_boxes[k, 3] - query_boxes[k, 1] + 1))
  83. for n in range(N):
  84. iw = (min(boxes[n, 2], query_boxes[k, 2]) - max(
  85. boxes[n, 0], query_boxes[k, 0]) + 1)
  86. if iw > 0:
  87. ih = (min(boxes[n, 3], query_boxes[k, 3]) - max(
  88. boxes[n, 1], query_boxes[k, 1]) + 1)
  89. if ih > 0:
  90. ua = float((boxes[n, 2] - boxes[n, 0] + 1) * (boxes[
  91. n, 3] - boxes[n, 1] + 1) + box_area - iw * ih)
  92. ious[n, k] = iw * ih / ua
  93. return ious
  94. def iou_distance(atracks, btracks):
  95. """
  96. Compute cost based on IoU between two list[STrack].
  97. """
  98. if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or (
  99. len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
  100. atlbrs = atracks
  101. btlbrs = btracks
  102. else:
  103. atlbrs = [track.tlbr for track in atracks]
  104. btlbrs = [track.tlbr for track in btracks]
  105. _ious = bbox_ious(atlbrs, btlbrs)
  106. cost_matrix = 1 - _ious
  107. return cost_matrix
  108. def embedding_distance(tracks, detections, metric='euclidean'):
  109. """
  110. Compute cost based on features between two list[STrack].
  111. """
  112. cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
  113. if cost_matrix.size == 0:
  114. return cost_matrix
  115. det_features = np.asarray(
  116. [track.curr_feat for track in detections], dtype=np.float32)
  117. track_features = np.asarray(
  118. [track.smooth_feat for track in tracks], dtype=np.float32)
  119. cost_matrix = np.maximum(0.0, cdist(track_features, det_features,
  120. metric)) # Nomalized features
  121. return cost_matrix
  122. def fuse_motion(kf,
  123. cost_matrix,
  124. tracks,
  125. detections,
  126. only_position=False,
  127. lambda_=0.98):
  128. if cost_matrix.size == 0:
  129. return cost_matrix
  130. gating_dim = 2 if only_position else 4
  131. gating_threshold = kalman_filter.chi2inv95[gating_dim]
  132. measurements = np.asarray([det.to_xyah() for det in detections])
  133. for row, track in enumerate(tracks):
  134. gating_distance = kf.gating_distance(
  135. track.mean,
  136. track.covariance,
  137. measurements,
  138. only_position,
  139. metric='maha')
  140. cost_matrix[row, gating_distance > gating_threshold] = np.inf
  141. cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_
  142. ) * gating_distance
  143. return cost_matrix