base_jde_tracker.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
  16. """
  17. import numpy as np
  18. from collections import defaultdict
  19. from collections import deque, OrderedDict
  20. from ..matching import jde_matching as matching
  21. __all__ = [
  22. 'TrackState',
  23. 'BaseTrack',
  24. 'STrack',
  25. 'joint_stracks',
  26. 'sub_stracks',
  27. 'remove_duplicate_stracks',
  28. ]
  29. class TrackState(object):
  30. New = 0
  31. Tracked = 1
  32. Lost = 2
  33. Removed = 3
  34. class BaseTrack(object):
  35. _count_dict = defaultdict(int) # support single class and multi classes
  36. track_id = 0
  37. is_activated = False
  38. state = TrackState.New
  39. history = OrderedDict()
  40. features = []
  41. curr_feat = None
  42. score = 0
  43. start_frame = 0
  44. frame_id = 0
  45. time_since_update = 0
  46. # multi-camera
  47. location = (np.inf, np.inf)
  48. @property
  49. def end_frame(self):
  50. return self.frame_id
  51. @staticmethod
  52. def next_id(cls_id):
  53. BaseTrack._count_dict[cls_id] += 1
  54. return BaseTrack._count_dict[cls_id]
  55. # @even: reset track id
  56. @staticmethod
  57. def init_count(num_classes):
  58. """
  59. Initiate _count for all object classes
  60. :param num_classes:
  61. """
  62. for cls_id in range(num_classes):
  63. BaseTrack._count_dict[cls_id] = 0
  64. @staticmethod
  65. def reset_track_count(cls_id):
  66. BaseTrack._count_dict[cls_id] = 0
  67. def activate(self, *args):
  68. raise NotImplementedError
  69. def predict(self):
  70. raise NotImplementedError
  71. def update(self, *args, **kwargs):
  72. raise NotImplementedError
  73. def mark_lost(self):
  74. self.state = TrackState.Lost
  75. def mark_removed(self):
  76. self.state = TrackState.Removed
  77. class STrack(BaseTrack):
  78. def __init__(self, tlwh, score, cls_id, buff_size=30, temp_feat=None):
  79. # wait activate
  80. self._tlwh = np.asarray(tlwh, dtype=np.float32)
  81. self.score = score
  82. self.cls_id = cls_id
  83. self.track_len = 0
  84. self.kalman_filter = None
  85. self.mean, self.covariance = None, None
  86. self.is_activated = False
  87. self.use_reid = True if temp_feat is not None else False
  88. if self.use_reid:
  89. self.smooth_feat = None
  90. self.update_features(temp_feat)
  91. self.features = deque([], maxlen=buff_size)
  92. self.alpha = 0.9
  93. def update_features(self, feat):
  94. # L2 normalizing, this function has no use for BYTETracker
  95. feat /= np.linalg.norm(feat)
  96. self.curr_feat = feat
  97. if self.smooth_feat is None:
  98. self.smooth_feat = feat
  99. else:
  100. self.smooth_feat = self.alpha * self.smooth_feat + (1.0 - self.alpha
  101. ) * feat
  102. self.features.append(feat)
  103. self.smooth_feat /= np.linalg.norm(self.smooth_feat)
  104. def predict(self):
  105. mean_state = self.mean.copy()
  106. if self.state != TrackState.Tracked:
  107. mean_state[7] = 0
  108. self.mean, self.covariance = self.kalman_filter.predict(mean_state,
  109. self.covariance)
  110. @staticmethod
  111. def multi_predict(tracks, kalman_filter):
  112. if len(tracks) > 0:
  113. multi_mean = np.asarray([track.mean.copy() for track in tracks])
  114. multi_covariance = np.asarray(
  115. [track.covariance for track in tracks])
  116. for i, st in enumerate(tracks):
  117. if st.state != TrackState.Tracked:
  118. multi_mean[i][7] = 0
  119. multi_mean, multi_covariance = kalman_filter.multi_predict(
  120. multi_mean, multi_covariance)
  121. for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
  122. tracks[i].mean = mean
  123. tracks[i].covariance = cov
  124. @staticmethod
  125. def multi_gmc(stracks, H=np.eye(2, 3)):
  126. if len(stracks) > 0:
  127. multi_mean = np.asarray([st.mean.copy() for st in stracks])
  128. multi_covariance = np.asarray([st.covariance for st in stracks])
  129. R = H[:2, :2]
  130. R8x8 = np.kron(np.eye(4, dtype=float), R)
  131. t = H[:2, 2]
  132. for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
  133. mean = R8x8.dot(mean)
  134. mean[:2] += t
  135. cov = R8x8.dot(cov).dot(R8x8.transpose())
  136. stracks[i].mean = mean
  137. stracks[i].covariance = cov
  138. def reset_track_id(self):
  139. self.reset_track_count(self.cls_id)
  140. def activate(self, kalman_filter, frame_id):
  141. """Start a new track"""
  142. self.kalman_filter = kalman_filter
  143. # update track id for the object class
  144. self.track_id = self.next_id(self.cls_id)
  145. self.mean, self.covariance = self.kalman_filter.initiate(
  146. self.tlwh_to_xyah(self._tlwh))
  147. self.track_len = 0
  148. self.state = TrackState.Tracked # set flag 'tracked'
  149. if frame_id == 1: # to record the first frame's detection result
  150. self.is_activated = True
  151. self.frame_id = frame_id
  152. self.start_frame = frame_id
  153. def re_activate(self, new_track, frame_id, new_id=False):
  154. self.mean, self.covariance = self.kalman_filter.update(
  155. self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh))
  156. if self.use_reid:
  157. self.update_features(new_track.curr_feat)
  158. self.track_len = 0
  159. self.state = TrackState.Tracked
  160. self.is_activated = True
  161. self.frame_id = frame_id
  162. if new_id: # update track id for the object class
  163. self.track_id = self.next_id(self.cls_id)
  164. def update(self, new_track, frame_id, update_feature=True):
  165. self.frame_id = frame_id
  166. self.track_len += 1
  167. new_tlwh = new_track.tlwh
  168. self.mean, self.covariance = self.kalman_filter.update(
  169. self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
  170. self.state = TrackState.Tracked # set flag 'tracked'
  171. self.is_activated = True # set flag 'activated'
  172. self.score = new_track.score
  173. if update_feature and self.use_reid:
  174. self.update_features(new_track.curr_feat)
  175. @property
  176. def tlwh(self):
  177. """Get current position in bounding box format `(top left x, top left y,
  178. width, height)`.
  179. """
  180. if self.mean is None:
  181. return self._tlwh.copy()
  182. ret = self.mean[:4].copy()
  183. ret[2] *= ret[3]
  184. ret[:2] -= ret[2:] / 2
  185. return ret
  186. @property
  187. def tlbr(self):
  188. """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
  189. `(top left, bottom right)`.
  190. """
  191. ret = self.tlwh.copy()
  192. ret[2:] += ret[:2]
  193. return ret
  194. @staticmethod
  195. def tlwh_to_xyah(tlwh):
  196. """Convert bounding box to format `(center x, center y, aspect ratio,
  197. height)`, where the aspect ratio is `width / height`.
  198. """
  199. ret = np.asarray(tlwh).copy()
  200. ret[:2] += ret[2:] / 2
  201. ret[2] /= ret[3]
  202. return ret
  203. def to_xyah(self):
  204. return self.tlwh_to_xyah(self.tlwh)
  205. @staticmethod
  206. def tlbr_to_tlwh(tlbr):
  207. ret = np.asarray(tlbr).copy()
  208. ret[2:] -= ret[:2]
  209. return ret
  210. @staticmethod
  211. def tlwh_to_tlbr(tlwh):
  212. ret = np.asarray(tlwh).copy()
  213. ret[2:] += ret[:2]
  214. return ret
  215. def __repr__(self):
  216. return 'OT_({}-{})_({}-{})'.format(self.cls_id, self.track_id,
  217. self.start_frame, self.end_frame)
  218. def joint_stracks(tlista, tlistb):
  219. exists = {}
  220. res = []
  221. for t in tlista:
  222. exists[t.track_id] = 1
  223. res.append(t)
  224. for t in tlistb:
  225. tid = t.track_id
  226. if not exists.get(tid, 0):
  227. exists[tid] = 1
  228. res.append(t)
  229. return res
  230. def sub_stracks(tlista, tlistb):
  231. stracks = {}
  232. for t in tlista:
  233. stracks[t.track_id] = t
  234. for t in tlistb:
  235. tid = t.track_id
  236. if stracks.get(tid, 0):
  237. del stracks[tid]
  238. return list(stracks.values())
  239. def remove_duplicate_stracks(stracksa, stracksb):
  240. pdist = matching.iou_distance(stracksa, stracksb)
  241. pairs = np.where(pdist < 0.15)
  242. dupa, dupb = list(), list()
  243. for p, q in zip(*pairs):
  244. timep = stracksa[p].frame_id - stracksa[p].start_frame
  245. timeq = stracksb[q].frame_id - stracksb[q].start_frame
  246. if timep > timeq:
  247. dupb.append(q)
  248. else:
  249. dupa.append(p)
  250. resa = [t for i, t in enumerate(stracksa) if not i in dupa]
  251. resb = [t for i, t in enumerate(stracksb) if not i in dupb]
  252. return resa, resb