base_jde_tracker.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
  16. """
  17. import numpy as np
  18. from collections import defaultdict
  19. from collections import deque, OrderedDict
  20. from ..matching import jde_matching as matching
  21. from ppdet.core.workspace import register, serializable
  22. import warnings
  23. warnings.filterwarnings("ignore")
  24. __all__ = [
  25. 'TrackState',
  26. 'BaseTrack',
  27. 'STrack',
  28. 'joint_stracks',
  29. 'sub_stracks',
  30. 'remove_duplicate_stracks',
  31. ]
  32. class TrackState(object):
  33. New = 0
  34. Tracked = 1
  35. Lost = 2
  36. Removed = 3
  37. @register
  38. @serializable
  39. class BaseTrack(object):
  40. _count_dict = defaultdict(int) # support single class and multi classes
  41. track_id = 0
  42. is_activated = False
  43. state = TrackState.New
  44. history = OrderedDict()
  45. features = []
  46. curr_feat = None
  47. score = 0
  48. start_frame = 0
  49. frame_id = 0
  50. time_since_update = 0
  51. # multi-camera
  52. location = (np.inf, np.inf)
  53. @property
  54. def end_frame(self):
  55. return self.frame_id
  56. @staticmethod
  57. def next_id(cls_id):
  58. BaseTrack._count_dict[cls_id] += 1
  59. return BaseTrack._count_dict[cls_id]
  60. # @even: reset track id
  61. @staticmethod
  62. def init_count(num_classes):
  63. """
  64. Initiate _count for all object classes
  65. :param num_classes:
  66. """
  67. for cls_id in range(num_classes):
  68. BaseTrack._count_dict[cls_id] = 0
  69. @staticmethod
  70. def reset_track_count(cls_id):
  71. BaseTrack._count_dict[cls_id] = 0
  72. def activate(self, *args):
  73. raise NotImplementedError
  74. def predict(self):
  75. raise NotImplementedError
  76. def update(self, *args, **kwargs):
  77. raise NotImplementedError
  78. def mark_lost(self):
  79. self.state = TrackState.Lost
  80. def mark_removed(self):
  81. self.state = TrackState.Removed
  82. @register
  83. @serializable
  84. class STrack(BaseTrack):
  85. def __init__(self, tlwh, score, cls_id, buff_size=30, temp_feat=None):
  86. # wait activate
  87. self._tlwh = np.asarray(tlwh, dtype=np.float32)
  88. self.score = score
  89. self.cls_id = cls_id
  90. self.track_len = 0
  91. self.kalman_filter = None
  92. self.mean, self.covariance = None, None
  93. self.is_activated = False
  94. self.use_reid = True if temp_feat is not None else False
  95. if self.use_reid:
  96. self.smooth_feat = None
  97. self.update_features(temp_feat)
  98. self.features = deque([], maxlen=buff_size)
  99. self.alpha = 0.9
  100. def update_features(self, feat):
  101. # L2 normalizing, this function has no use for BYTETracker
  102. feat /= np.linalg.norm(feat)
  103. self.curr_feat = feat
  104. if self.smooth_feat is None:
  105. self.smooth_feat = feat
  106. else:
  107. self.smooth_feat = self.alpha * self.smooth_feat + (1.0 - self.alpha
  108. ) * feat
  109. self.features.append(feat)
  110. self.smooth_feat /= np.linalg.norm(self.smooth_feat)
  111. def predict(self):
  112. mean_state = self.mean.copy()
  113. if self.state != TrackState.Tracked:
  114. mean_state[7] = 0
  115. self.mean, self.covariance = self.kalman_filter.predict(mean_state,
  116. self.covariance)
  117. @staticmethod
  118. def multi_predict(tracks, kalman_filter):
  119. if len(tracks) > 0:
  120. multi_mean = np.asarray([track.mean.copy() for track in tracks])
  121. multi_covariance = np.asarray(
  122. [track.covariance for track in tracks])
  123. for i, st in enumerate(tracks):
  124. if st.state != TrackState.Tracked:
  125. multi_mean[i][7] = 0
  126. multi_mean, multi_covariance = kalman_filter.multi_predict(
  127. multi_mean, multi_covariance)
  128. for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
  129. tracks[i].mean = mean
  130. tracks[i].covariance = cov
  131. @staticmethod
  132. def multi_gmc(stracks, H=np.eye(2, 3)):
  133. if len(stracks) > 0:
  134. multi_mean = np.asarray([st.mean.copy() for st in stracks])
  135. multi_covariance = np.asarray([st.covariance for st in stracks])
  136. R = H[:2, :2]
  137. R8x8 = np.kron(np.eye(4, dtype=float), R)
  138. t = H[:2, 2]
  139. for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
  140. mean = R8x8.dot(mean)
  141. mean[:2] += t
  142. cov = R8x8.dot(cov).dot(R8x8.transpose())
  143. stracks[i].mean = mean
  144. stracks[i].covariance = cov
  145. def reset_track_id(self):
  146. self.reset_track_count(self.cls_id)
  147. def activate(self, kalman_filter, frame_id):
  148. """Start a new track"""
  149. self.kalman_filter = kalman_filter
  150. # update track id for the object class
  151. self.track_id = self.next_id(self.cls_id)
  152. self.mean, self.covariance = self.kalman_filter.initiate(
  153. self.tlwh_to_xyah(self._tlwh))
  154. self.track_len = 0
  155. self.state = TrackState.Tracked # set flag 'tracked'
  156. if frame_id == 1: # to record the first frame's detection result
  157. self.is_activated = True
  158. self.frame_id = frame_id
  159. self.start_frame = frame_id
  160. def re_activate(self, new_track, frame_id, new_id=False):
  161. self.mean, self.covariance = self.kalman_filter.update(
  162. self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh))
  163. if self.use_reid:
  164. self.update_features(new_track.curr_feat)
  165. self.track_len = 0
  166. self.state = TrackState.Tracked
  167. self.is_activated = True
  168. self.frame_id = frame_id
  169. if new_id: # update track id for the object class
  170. self.track_id = self.next_id(self.cls_id)
  171. def update(self, new_track, frame_id, update_feature=True):
  172. self.frame_id = frame_id
  173. self.track_len += 1
  174. new_tlwh = new_track.tlwh
  175. self.mean, self.covariance = self.kalman_filter.update(
  176. self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
  177. self.state = TrackState.Tracked # set flag 'tracked'
  178. self.is_activated = True # set flag 'activated'
  179. self.score = new_track.score
  180. if update_feature and self.use_reid:
  181. self.update_features(new_track.curr_feat)
  182. @property
  183. def tlwh(self):
  184. """Get current position in bounding box format `(top left x, top left y,
  185. width, height)`.
  186. """
  187. if self.mean is None:
  188. return self._tlwh.copy()
  189. ret = self.mean[:4].copy()
  190. ret[2] *= ret[3]
  191. ret[:2] -= ret[2:] / 2
  192. return ret
  193. @property
  194. def tlbr(self):
  195. """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
  196. `(top left, bottom right)`.
  197. """
  198. ret = self.tlwh.copy()
  199. ret[2:] += ret[:2]
  200. return ret
  201. @staticmethod
  202. def tlwh_to_xyah(tlwh):
  203. """Convert bounding box to format `(center x, center y, aspect ratio,
  204. height)`, where the aspect ratio is `width / height`.
  205. """
  206. ret = np.asarray(tlwh).copy()
  207. ret[:2] += ret[2:] / 2
  208. ret[2] /= ret[3]
  209. return ret
  210. def to_xyah(self):
  211. return self.tlwh_to_xyah(self.tlwh)
  212. @staticmethod
  213. def tlbr_to_tlwh(tlbr):
  214. ret = np.asarray(tlbr).copy()
  215. ret[2:] -= ret[:2]
  216. return ret
  217. @staticmethod
  218. def tlwh_to_tlbr(tlwh):
  219. ret = np.asarray(tlwh).copy()
  220. ret[2:] += ret[:2]
  221. return ret
  222. def __repr__(self):
  223. return 'OT_({}-{})_({}-{})'.format(self.cls_id, self.track_id,
  224. self.start_frame, self.end_frame)
  225. def joint_stracks(tlista, tlistb):
  226. exists = {}
  227. res = []
  228. for t in tlista:
  229. exists[t.track_id] = 1
  230. res.append(t)
  231. for t in tlistb:
  232. tid = t.track_id
  233. if not exists.get(tid, 0):
  234. exists[tid] = 1
  235. res.append(t)
  236. return res
  237. def sub_stracks(tlista, tlistb):
  238. stracks = {}
  239. for t in tlista:
  240. stracks[t.track_id] = t
  241. for t in tlistb:
  242. tid = t.track_id
  243. if stracks.get(tid, 0):
  244. del stracks[tid]
  245. return list(stracks.values())
  246. def remove_duplicate_stracks(stracksa, stracksb):
  247. pdist = matching.iou_distance(stracksa, stracksb)
  248. pairs = np.where(pdist < 0.15)
  249. dupa, dupb = list(), list()
  250. for p, q in zip(*pairs):
  251. timep = stracksa[p].frame_id - stracksa[p].start_frame
  252. timeq = stracksb[q].frame_id - stracksb[q].start_frame
  253. if timep > timeq:
  254. dupb.append(q)
  255. else:
  256. dupa.append(p)
  257. resa = [t for i, t in enumerate(stracksa) if not i in dupa]
  258. resb = [t for i, t in enumerate(stracksb) if not i in dupb]
  259. return resa, resb