center_tracker.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/xingyizhou/CenterTrack/blob/master/src/lib/utils/tracker.py
  16. """
  17. import copy
  18. import numpy as np
  19. import sklearn
  20. from ppdet.core.workspace import register, serializable
  21. from ppdet.utils.logger import setup_logger
  22. logger = setup_logger(__name__)
  23. __all__ = ['CenterTracker']
  24. @register
  25. @serializable
  26. class CenterTracker(object):
  27. __shared__ = ['num_classes']
  28. def __init__(self,
  29. num_classes=1,
  30. min_box_area=0,
  31. vertical_ratio=-1,
  32. track_thresh=0.4,
  33. pre_thresh=0.5,
  34. new_thresh=0.4,
  35. out_thresh=0.4,
  36. hungarian=False):
  37. self.num_classes = num_classes
  38. self.min_box_area = min_box_area
  39. self.vertical_ratio = vertical_ratio
  40. self.track_thresh = track_thresh
  41. self.pre_thresh = max(track_thresh, pre_thresh)
  42. self.new_thresh = max(track_thresh, new_thresh)
  43. self.out_thresh = max(track_thresh, out_thresh)
  44. self.hungarian = hungarian
  45. self.reset()
  46. def init_track(self, results):
  47. print('Initialize tracking!')
  48. for item in results:
  49. if item['score'] > self.new_thresh:
  50. self.id_count += 1
  51. item['tracking_id'] = self.id_count
  52. if not ('ct' in item):
  53. bbox = item['bbox']
  54. item['ct'] = [(bbox[0] + bbox[2]) / 2,
  55. (bbox[1] + bbox[3]) / 2]
  56. self.tracks.append(item)
  57. def reset(self):
  58. self.id_count = 0
  59. self.tracks = []
  60. def update(self, results, public_det=None):
  61. N = len(results)
  62. M = len(self.tracks)
  63. dets = np.array([det['ct'] + det['tracking'] for det in results],
  64. np.float32) # N x 2
  65. track_size = np.array([((track['bbox'][2] - track['bbox'][0]) * \
  66. (track['bbox'][3] - track['bbox'][1])) \
  67. for track in self.tracks], np.float32) # M
  68. track_cat = np.array([track['class'] for track in self.tracks],
  69. np.int32) # M
  70. item_size = np.array([((item['bbox'][2] - item['bbox'][0]) * \
  71. (item['bbox'][3] - item['bbox'][1])) \
  72. for item in results], np.float32) # N
  73. item_cat = np.array([item['class'] for item in results], np.int32) # N
  74. tracks = np.array([pre_det['ct'] for pre_det in self.tracks],
  75. np.float32) # M x 2
  76. dist = (((tracks.reshape(1, -1, 2) - \
  77. dets.reshape(-1, 1, 2)) ** 2).sum(axis=2)) # N x M
  78. invalid = ((dist > track_size.reshape(1, M)) + \
  79. (dist > item_size.reshape(N, 1)) + \
  80. (item_cat.reshape(N, 1) != track_cat.reshape(1, M))) > 0
  81. dist = dist + invalid * 1e18
  82. if self.hungarian:
  83. item_score = np.array([item['score'] for item in results],
  84. np.float32)
  85. dist[dist > 1e18] = 1e18
  86. from sklearn.utils.linear_assignment_ import linear_assignment
  87. matched_indices = linear_assignment(dist)
  88. else:
  89. matched_indices = greedy_assignment(copy.deepcopy(dist))
  90. unmatched_dets = [d for d in range(dets.shape[0]) \
  91. if not (d in matched_indices[:, 0])]
  92. unmatched_tracks = [d for d in range(tracks.shape[0]) \
  93. if not (d in matched_indices[:, 1])]
  94. if self.hungarian:
  95. matches = []
  96. for m in matched_indices:
  97. if dist[m[0], m[1]] > 1e16:
  98. unmatched_dets.append(m[0])
  99. unmatched_tracks.append(m[1])
  100. else:
  101. matches.append(m)
  102. matches = np.array(matches).reshape(-1, 2)
  103. else:
  104. matches = matched_indices
  105. ret = []
  106. for m in matches:
  107. track = results[m[0]]
  108. track['tracking_id'] = self.tracks[m[1]]['tracking_id']
  109. ret.append(track)
  110. # Private detection: create tracks for all un-matched detections
  111. for i in unmatched_dets:
  112. track = results[i]
  113. if track['score'] > self.new_thresh:
  114. self.id_count += 1
  115. track['tracking_id'] = self.id_count
  116. ret.append(track)
  117. self.tracks = ret
  118. return ret
  119. def greedy_assignment(dist):
  120. matched_indices = []
  121. if dist.shape[1] == 0:
  122. return np.array(matched_indices, np.int32).reshape(-1, 2)
  123. for i in range(dist.shape[0]):
  124. j = dist[i].argmin()
  125. if dist[i][j] < 1e16:
  126. dist[:, j] = 1e18
  127. matched_indices.append([i, j])
  128. return np.array(matched_indices, np.int32).reshape(-1, 2)