9.2 KB

  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. #
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on
  16. """
  17. import numpy as np
  18. from collections import defaultdict
  19. from collections import deque, OrderedDict
  20. from ..matching import jde_matching as matching
  21. __all__ = [
  22. 'TrackState',
  23. 'BaseTrack',
  24. 'STrack',
  25. 'joint_stracks',
  26. 'sub_stracks',
  27. 'remove_duplicate_stracks',
  28. ]
  29. class TrackState(object):
  30. New = 0
  31. Tracked = 1
  32. Lost = 2
  33. Removed = 3
  34. class BaseTrack(object):
  35. _count_dict = defaultdict(int) # support single class and multi classes
  36. track_id = 0
  37. is_activated = False
  38. state = TrackState.New
  39. history = OrderedDict()
  40. features = []
  41. curr_feat = None
  42. score = 0
  43. start_frame = 0
  44. frame_id = 0
  45. time_since_update = 0
  46. # multi-camera
  47. location = (np.inf, np.inf)
  48. @property
  49. def end_frame(self):
  50. return self.frame_id
  51. @staticmethod
  52. def next_id(cls_id):
  53. BaseTrack._count_dict[cls_id] += 1
  54. return BaseTrack._count_dict[cls_id]
  55. # @even: reset track id
  56. @staticmethod
  57. def init_count(num_classes):
  58. """
  59. Initiate _count for all object classes
  60. :param num_classes:
  61. """
  62. for cls_id in range(num_classes):
  63. BaseTrack._count_dict[cls_id] = 0
  64. @staticmethod
  65. def reset_track_count(cls_id):
  66. BaseTrack._count_dict[cls_id] = 0
  67. def activate(self, *args):
  68. raise NotImplementedError
  69. def predict(self):
  70. raise NotImplementedError
  71. def update(self, *args, **kwargs):
  72. raise NotImplementedError
  73. def mark_lost(self):
  74. self.state = TrackState.Lost
  75. def mark_removed(self):
  76. self.state = TrackState.Removed
  77. class STrack(BaseTrack):
  78. def __init__(self, tlwh, score, cls_id, buff_size=30, temp_feat=None):
  79. # wait activate
  80. self._tlwh = np.asarray(tlwh, dtype=np.float32)
  81. self.score = score
  82. self.cls_id = cls_id
  83. self.track_len = 0
  84. self.kalman_filter = None
  85. self.mean, self.covariance = None, None
  86. self.is_activated = False
  87. self.use_reid = True if temp_feat is not None else False
  88. if self.use_reid:
  89. self.smooth_feat = None
  90. self.update_features(temp_feat)
  91. self.features = deque([], maxlen=buff_size)
  92. self.alpha = 0.9
  93. def update_features(self, feat):
  94. # L2 normalizing, this function has no use for BYTETracker
  95. feat /= np.linalg.norm(feat)
  96. self.curr_feat = feat
  97. if self.smooth_feat is None:
  98. self.smooth_feat = feat
  99. else:
  100. self.smooth_feat = self.alpha * self.smooth_feat + (1.0 - self.alpha
  101. ) * feat
  102. self.features.append(feat)
  103. self.smooth_feat /= np.linalg.norm(self.smooth_feat)
  104. def predict(self):
  105. mean_state = self.mean.copy()
  106. if self.state != TrackState.Tracked:
  107. mean_state[7] = 0
  108. self.mean, self.covariance = self.kalman_filter.predict(mean_state,
  109. self.covariance)
  110. @staticmethod
  111. def multi_predict(tracks, kalman_filter):
  112. if len(tracks) > 0:
  113. multi_mean = np.asarray([track.mean.copy() for track in tracks])
  114. multi_covariance = np.asarray(
  115. [track.covariance for track in tracks])
  116. for i, st in enumerate(tracks):
  117. if st.state != TrackState.Tracked:
  118. multi_mean[i][7] = 0
  119. multi_mean, multi_covariance = kalman_filter.multi_predict(
  120. multi_mean, multi_covariance)
  121. for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
  122. tracks[i].mean = mean
  123. tracks[i].covariance = cov
  124. @staticmethod
  125. def multi_gmc(stracks, H=np.eye(2, 3)):
  126. if len(stracks) > 0:
  127. multi_mean = np.asarray([st.mean.copy() for st in stracks])
  128. multi_covariance = np.asarray([st.covariance for st in stracks])
  129. R = H[:2, :2]
  130. R8x8 = np.kron(np.eye(4, dtype=float), R)
  131. t = H[:2, 2]
  132. for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
  133. mean =
  134. mean[:2] += t
  135. cov =
  136. stracks[i].mean = mean
  137. stracks[i].covariance = cov
  138. def reset_track_id(self):
  139. self.reset_track_count(self.cls_id)
  140. def activate(self, kalman_filter, frame_id):
  141. """Start a new track"""
  142. self.kalman_filter = kalman_filter
  143. # update track id for the object class
  144. self.track_id = self.next_id(self.cls_id)
  145. self.mean, self.covariance = self.kalman_filter.initiate(
  146. self.tlwh_to_xyah(self._tlwh))
  147. self.track_len = 0
  148. self.state = TrackState.Tracked # set flag 'tracked'
  149. if frame_id == 1: # to record the first frame's detection result
  150. self.is_activated = True
  151. self.frame_id = frame_id
  152. self.start_frame = frame_id
  153. def re_activate(self, new_track, frame_id, new_id=False):
  154. self.mean, self.covariance = self.kalman_filter.update(
  155. self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh))
  156. if self.use_reid:
  157. self.update_features(new_track.curr_feat)
  158. self.track_len = 0
  159. self.state = TrackState.Tracked
  160. self.is_activated = True
  161. self.frame_id = frame_id
  162. if new_id: # update track id for the object class
  163. self.track_id = self.next_id(self.cls_id)
  164. def update(self, new_track, frame_id, update_feature=True):
  165. self.frame_id = frame_id
  166. self.track_len += 1
  167. new_tlwh = new_track.tlwh
  168. self.mean, self.covariance = self.kalman_filter.update(
  169. self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
  170. self.state = TrackState.Tracked # set flag 'tracked'
  171. self.is_activated = True # set flag 'activated'
  172. self.score = new_track.score
  173. if update_feature and self.use_reid:
  174. self.update_features(new_track.curr_feat)
  175. @property
  176. def tlwh(self):
  177. """Get current position in bounding box format `(top left x, top left y,
  178. width, height)`.
  179. """
  180. if self.mean is None:
  181. return self._tlwh.copy()
  182. ret = self.mean[:4].copy()
  183. ret[2] *= ret[3]
  184. ret[:2] -= ret[2:] / 2
  185. return ret
  186. @property
  187. def tlbr(self):
  188. """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
  189. `(top left, bottom right)`.
  190. """
  191. ret = self.tlwh.copy()
  192. ret[2:] += ret[:2]
  193. return ret
  194. @staticmethod
  195. def tlwh_to_xyah(tlwh):
  196. """Convert bounding box to format `(center x, center y, aspect ratio,
  197. height)`, where the aspect ratio is `width / height`.
  198. """
  199. ret = np.asarray(tlwh).copy()
  200. ret[:2] += ret[2:] / 2
  201. ret[2] /= ret[3]
  202. return ret
  203. def to_xyah(self):
  204. return self.tlwh_to_xyah(self.tlwh)
  205. @staticmethod
  206. def tlbr_to_tlwh(tlbr):
  207. ret = np.asarray(tlbr).copy()
  208. ret[2:] -= ret[:2]
  209. return ret
  210. @staticmethod
  211. def tlwh_to_tlbr(tlwh):
  212. ret = np.asarray(tlwh).copy()
  213. ret[2:] += ret[:2]
  214. return ret
  215. def __repr__(self):
  216. return 'OT_({}-{})_({}-{})'.format(self.cls_id, self.track_id,
  217. self.start_frame, self.end_frame)
  218. def joint_stracks(tlista, tlistb):
  219. exists = {}
  220. res = []
  221. for t in tlista:
  222. exists[t.track_id] = 1
  223. res.append(t)
  224. for t in tlistb:
  225. tid = t.track_id
  226. if not exists.get(tid, 0):
  227. exists[tid] = 1
  228. res.append(t)
  229. return res
  230. def sub_stracks(tlista, tlistb):
  231. stracks = {}
  232. for t in tlista:
  233. stracks[t.track_id] = t
  234. for t in tlistb:
  235. tid = t.track_id
  236. if stracks.get(tid, 0):
  237. del stracks[tid]
  238. return list(stracks.values())
  239. def remove_duplicate_stracks(stracksa, stracksb):
  240. pdist = matching.iou_distance(stracksa, stracksb)
  241. pairs = np.where(pdist < 0.15)
  242. dupa, dupb = list(), list()
  243. for p, q in zip(*pairs):
  244. timep = stracksa[p].frame_id - stracksa[p].start_frame
  245. timeq = stracksb[q].frame_id - stracksb[q].start_frame
  246. if timep > timeq:
  247. dupb.append(q)
  248. else:
  249. dupa.append(p)
  250. resa = [t for i, t in enumerate(stracksa) if not i in dupa]
  251. resb = [t for i, t in enumerate(stracksb) if not i in dupb]
  252. return resa, resb