9.2 KB

  1. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. #
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on
  16. """
  17. import cv2
  18. import matplotlib.pyplot as plt
  19. import numpy as np
  20. from collections import deque
  21. from ..matching import jde_matching as matching
  22. from ..motion import GMC
  23. from .base_jde_tracker import TrackState, STrack
  24. from .base_jde_tracker import joint_stracks, sub_stracks, remove_duplicate_stracks
  25. from ..motion import KalmanFilter
  26. class BOTSORTTracker(object):
  27. """
  28. BOTSORT tracker, support single class
  29. Args:
  30. track_high_thresh (float): threshold of detection high score
  31. track_low_thresh (float): threshold of remove detection score
  32. new_track_thresh (float): threshold of new track score
  33. match_thresh (float): iou threshold for associate
  34. track_buffer (int): tracking reserved frames,default 30
  35. min_box_area (float): reserved min box
  36. camera_motion (bool): Whether use camera motion, default False
  37. cmc_method (str): camera motion method,defalut sparseOptFlow
  38. frame_rate (int): fps buffer_size=int(frame_rate / 30.0 * track_buffer)
  39. """
  40. def __init__(self,
  41. track_high_thresh=0.3,
  42. track_low_thresh=0.2,
  43. new_track_thresh=0.4,
  44. match_thresh=0.7,
  45. track_buffer=30,
  46. min_box_area=0,
  47. camera_motion=False,
  48. cmc_method='sparseOptFlow',
  49. frame_rate=30):
  50. self.tracked_stracks = [] # type: list[STrack]
  51. self.lost_stracks = [] # type: list[STrack]
  52. self.removed_stracks = [] # type: list[STrack]
  53. self.frame_id = 0
  54. self.track_high_thresh = track_high_thresh
  55. self.track_low_thresh = track_low_thresh
  56. self.new_track_thresh = new_track_thresh
  57. self.match_thresh = match_thresh
  58. self.buffer_size = int(frame_rate / 30.0 * track_buffer)
  59. self.max_time_lost = self.buffer_size
  60. self.kalman_filter = KalmanFilter()
  61. self.min_box_area = min_box_area
  62. self.camera_motion = camera_motion
  63. self.gmc = GMC(method=cmc_method)
  64. def update(self, output_results, img=None):
  65. self.frame_id += 1
  66. activated_starcks = []
  67. refind_stracks = []
  68. lost_stracks = []
  69. removed_stracks = []
  70. if len(output_results):
  71. bboxes = output_results[:, 2:6]
  72. scores = output_results[:, 1]
  73. classes = output_results[:, 0]
  74. # Remove bad detections
  75. lowest_inds = scores > self.track_low_thresh
  76. bboxes = bboxes[lowest_inds]
  77. scores = scores[lowest_inds]
  78. classes = classes[lowest_inds]
  79. # Find high threshold detections
  80. remain_inds = scores > self.track_high_thresh
  81. dets = bboxes[remain_inds]
  82. scores_keep = scores[remain_inds]
  83. classes_keep = classes[remain_inds]
  84. else:
  85. bboxes = []
  86. scores = []
  87. classes = []
  88. dets = []
  89. scores_keep = []
  90. classes_keep = []
  91. if len(dets) > 0:
  92. '''Detections'''
  93. detections = [
  94. STrack(STrack.tlbr_to_tlwh(tlbr), s, c)
  95. for (tlbr, s, c) in zip(dets, scores_keep, classes_keep)
  96. ]
  97. else:
  98. detections = []
  99. ''' Add newly detected tracklets to tracked_stracks'''
  100. unconfirmed = []
  101. tracked_stracks = [] # type: list[STrack]
  102. for track in self.tracked_stracks:
  103. if not track.is_activated:
  104. unconfirmed.append(track)
  105. else:
  106. tracked_stracks.append(track)
  107. ''' Step 2: First association, with high score detection boxes'''
  108. strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
  109. # Predict the current location with KF
  110. STrack.multi_predict(strack_pool, self.kalman_filter)
  111. # Fix camera motion
  112. if self.camera_motion:
  113. warp = self.gmc.apply(img[0], dets)
  114. STrack.multi_gmc(strack_pool, warp)
  115. STrack.multi_gmc(unconfirmed, warp)
  116. # Associate with high score detection boxes
  117. ious_dists = matching.iou_distance(strack_pool, detections)
  118. matches, u_track, u_detection = matching.linear_assignment(
  119. ious_dists, thresh=self.match_thresh)
  120. for itracked, idet in matches:
  121. track = strack_pool[itracked]
  122. det = detections[idet]
  123. if track.state == TrackState.Tracked:
  124. track.update(detections[idet], self.frame_id)
  125. activated_starcks.append(track)
  126. else:
  127. track.re_activate(det, self.frame_id, new_id=False)
  128. refind_stracks.append(track)
  129. ''' Step 3: Second association, with low score detection boxes'''
  130. if len(scores):
  131. inds_high = scores < self.track_high_thresh
  132. inds_low = scores > self.track_low_thresh
  133. inds_second = np.logical_and(inds_low, inds_high)
  134. dets_second = bboxes[inds_second]
  135. scores_second = scores[inds_second]
  136. classes_second = classes[inds_second]
  137. else:
  138. dets_second = []
  139. scores_second = []
  140. classes_second = []
  141. # association the untrack to the low score detections
  142. if len(dets_second) > 0:
  143. '''Detections'''
  144. detections_second = [
  145. STrack(STrack.tlbr_to_tlwh(tlbr), s, c) for (tlbr, s, c) in
  146. zip(dets_second, scores_second, classes_second)
  147. ]
  148. else:
  149. detections_second = []
  150. r_tracked_stracks = [
  151. strack_pool[i] for i in u_track
  152. if strack_pool[i].state == TrackState.Tracked
  153. ]
  154. dists = matching.iou_distance(r_tracked_stracks, detections_second)
  155. matches, u_track, u_detection_second = matching.linear_assignment(
  156. dists, thresh=0.5)
  157. for itracked, idet in matches:
  158. track = r_tracked_stracks[itracked]
  159. det = detections_second[idet]
  160. if track.state == TrackState.Tracked:
  161. track.update(det, self.frame_id)
  162. activated_starcks.append(track)
  163. else:
  164. track.re_activate(det, self.frame_id, new_id=False)
  165. refind_stracks.append(track)
  166. for it in u_track:
  167. track = r_tracked_stracks[it]
  168. if not track.state == TrackState.Lost:
  169. track.mark_lost()
  170. lost_stracks.append(track)
  171. '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
  172. detections = [detections[i] for i in u_detection]
  173. dists = matching.iou_distance(unconfirmed, detections)
  174. matches, u_unconfirmed, u_detection = matching.linear_assignment(
  175. dists, thresh=0.7)
  176. for itracked, idet in matches:
  177. unconfirmed[itracked].update(detections[idet], self.frame_id)
  178. activated_starcks.append(unconfirmed[itracked])
  179. for it in u_unconfirmed:
  180. track = unconfirmed[it]
  181. track.mark_removed()
  182. removed_stracks.append(track)
  183. """ Step 4: Init new stracks"""
  184. for inew in u_detection:
  185. track = detections[inew]
  186. if track.score < self.new_track_thresh:
  187. continue
  188. track.activate(self.kalman_filter, self.frame_id)
  189. activated_starcks.append(track)
  190. """ Step 5: Update state"""
  191. for track in self.lost_stracks:
  192. if self.frame_id - track.end_frame > self.max_time_lost:
  193. track.mark_removed()
  194. removed_stracks.append(track)
  195. """ Merge """
  196. self.tracked_stracks = [
  197. t for t in self.tracked_stracks if t.state == TrackState.Tracked
  198. ]
  199. self.tracked_stracks = joint_stracks(self.tracked_stracks,
  200. activated_starcks)
  201. self.tracked_stracks = joint_stracks(self.tracked_stracks,
  202. refind_stracks)
  203. self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
  204. self.lost_stracks.extend(lost_stracks)
  205. self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
  206. self.removed_stracks.extend(removed_stracks)
  207. self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(
  208. self.tracked_stracks, self.lost_stracks)
  209. # output_stracks = [track for track in self.tracked_stracks if track.is_activated]
  210. output_stracks = [track for track in self.tracked_stracks]
  211. return output_stracks