botsort_tracker.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/WWangYuHsiang/SMILEtrack/blob/main/BoT-SORT/tracker/bot_sort.py
  16. """
  17. import cv2
  18. import matplotlib.pyplot as plt
  19. import numpy as np
  20. from collections import deque
  21. from ..matching import jde_matching as matching
  22. from ..motion import GMC
  23. from .base_jde_tracker import TrackState, STrack
  24. from .base_jde_tracker import joint_stracks, sub_stracks, remove_duplicate_stracks
  25. from ..motion import KalmanFilter
  26. class BOTSORTTracker(object):
  27. """
  28. BOTSORT tracker, support single class
  29. Args:
  30. track_high_thresh (float): threshold of detection high score
  31. track_low_thresh (float): threshold of remove detection score
  32. new_track_thresh (float): threshold of new track score
  33. match_thresh (float): iou threshold for associate
  34. track_buffer (int): tracking reserved frames,default 30
  35. min_box_area (float): reserved min box
  36. camera_motion (bool): Whether use camera motion, default False
  37. cmc_method (str): camera motion method,defalut sparseOptFlow
  38. frame_rate (int): fps buffer_size=int(frame_rate / 30.0 * track_buffer)
  39. """
  40. def __init__(self,
  41. track_high_thresh=0.3,
  42. track_low_thresh=0.2,
  43. new_track_thresh=0.4,
  44. match_thresh=0.7,
  45. track_buffer=30,
  46. min_box_area=0,
  47. camera_motion=False,
  48. cmc_method='sparseOptFlow',
  49. frame_rate=30):
  50. self.tracked_stracks = [] # type: list[STrack]
  51. self.lost_stracks = [] # type: list[STrack]
  52. self.removed_stracks = [] # type: list[STrack]
  53. self.frame_id = 0
  54. self.track_high_thresh = track_high_thresh
  55. self.track_low_thresh = track_low_thresh
  56. self.new_track_thresh = new_track_thresh
  57. self.match_thresh = match_thresh
  58. self.buffer_size = int(frame_rate / 30.0 * track_buffer)
  59. self.max_time_lost = self.buffer_size
  60. self.kalman_filter = KalmanFilter()
  61. self.min_box_area = min_box_area
  62. self.camera_motion = camera_motion
  63. self.gmc = GMC(method=cmc_method)
  64. def update(self, output_results, img=None):
  65. self.frame_id += 1
  66. activated_starcks = []
  67. refind_stracks = []
  68. lost_stracks = []
  69. removed_stracks = []
  70. if len(output_results):
  71. bboxes = output_results[:, 2:6]
  72. scores = output_results[:, 1]
  73. classes = output_results[:, 0]
  74. # Remove bad detections
  75. lowest_inds = scores > self.track_low_thresh
  76. bboxes = bboxes[lowest_inds]
  77. scores = scores[lowest_inds]
  78. classes = classes[lowest_inds]
  79. # Find high threshold detections
  80. remain_inds = scores > self.track_high_thresh
  81. dets = bboxes[remain_inds]
  82. scores_keep = scores[remain_inds]
  83. classes_keep = classes[remain_inds]
  84. else:
  85. bboxes = []
  86. scores = []
  87. classes = []
  88. dets = []
  89. scores_keep = []
  90. classes_keep = []
  91. if len(dets) > 0:
  92. '''Detections'''
  93. detections = [
  94. STrack(STrack.tlbr_to_tlwh(tlbr), s, c)
  95. for (tlbr, s, c) in zip(dets, scores_keep, classes_keep)
  96. ]
  97. else:
  98. detections = []
  99. ''' Add newly detected tracklets to tracked_stracks'''
  100. unconfirmed = []
  101. tracked_stracks = [] # type: list[STrack]
  102. for track in self.tracked_stracks:
  103. if not track.is_activated:
  104. unconfirmed.append(track)
  105. else:
  106. tracked_stracks.append(track)
  107. ''' Step 2: First association, with high score detection boxes'''
  108. strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
  109. # Predict the current location with KF
  110. STrack.multi_predict(strack_pool, self.kalman_filter)
  111. # Fix camera motion
  112. if self.camera_motion:
  113. warp = self.gmc.apply(img[0], dets)
  114. STrack.multi_gmc(strack_pool, warp)
  115. STrack.multi_gmc(unconfirmed, warp)
  116. # Associate with high score detection boxes
  117. ious_dists = matching.iou_distance(strack_pool, detections)
  118. matches, u_track, u_detection = matching.linear_assignment(
  119. ious_dists, thresh=self.match_thresh)
  120. for itracked, idet in matches:
  121. track = strack_pool[itracked]
  122. det = detections[idet]
  123. if track.state == TrackState.Tracked:
  124. track.update(detections[idet], self.frame_id)
  125. activated_starcks.append(track)
  126. else:
  127. track.re_activate(det, self.frame_id, new_id=False)
  128. refind_stracks.append(track)
  129. ''' Step 3: Second association, with low score detection boxes'''
  130. if len(scores):
  131. inds_high = scores < self.track_high_thresh
  132. inds_low = scores > self.track_low_thresh
  133. inds_second = np.logical_and(inds_low, inds_high)
  134. dets_second = bboxes[inds_second]
  135. scores_second = scores[inds_second]
  136. classes_second = classes[inds_second]
  137. else:
  138. dets_second = []
  139. scores_second = []
  140. classes_second = []
  141. # association the untrack to the low score detections
  142. if len(dets_second) > 0:
  143. '''Detections'''
  144. detections_second = [
  145. STrack(STrack.tlbr_to_tlwh(tlbr), s, c) for (tlbr, s, c) in
  146. zip(dets_second, scores_second, classes_second)
  147. ]
  148. else:
  149. detections_second = []
  150. r_tracked_stracks = [
  151. strack_pool[i] for i in u_track
  152. if strack_pool[i].state == TrackState.Tracked
  153. ]
  154. dists = matching.iou_distance(r_tracked_stracks, detections_second)
  155. matches, u_track, u_detection_second = matching.linear_assignment(
  156. dists, thresh=0.5)
  157. for itracked, idet in matches:
  158. track = r_tracked_stracks[itracked]
  159. det = detections_second[idet]
  160. if track.state == TrackState.Tracked:
  161. track.update(det, self.frame_id)
  162. activated_starcks.append(track)
  163. else:
  164. track.re_activate(det, self.frame_id, new_id=False)
  165. refind_stracks.append(track)
  166. for it in u_track:
  167. track = r_tracked_stracks[it]
  168. if not track.state == TrackState.Lost:
  169. track.mark_lost()
  170. lost_stracks.append(track)
  171. '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
  172. detections = [detections[i] for i in u_detection]
  173. dists = matching.iou_distance(unconfirmed, detections)
  174. matches, u_unconfirmed, u_detection = matching.linear_assignment(
  175. dists, thresh=0.7)
  176. for itracked, idet in matches:
  177. unconfirmed[itracked].update(detections[idet], self.frame_id)
  178. activated_starcks.append(unconfirmed[itracked])
  179. for it in u_unconfirmed:
  180. track = unconfirmed[it]
  181. track.mark_removed()
  182. removed_stracks.append(track)
  183. """ Step 4: Init new stracks"""
  184. for inew in u_detection:
  185. track = detections[inew]
  186. if track.score < self.new_track_thresh:
  187. continue
  188. track.activate(self.kalman_filter, self.frame_id)
  189. activated_starcks.append(track)
  190. """ Step 5: Update state"""
  191. for track in self.lost_stracks:
  192. if self.frame_id - track.end_frame > self.max_time_lost:
  193. track.mark_removed()
  194. removed_stracks.append(track)
  195. """ Merge """
  196. self.tracked_stracks = [
  197. t for t in self.tracked_stracks if t.state == TrackState.Tracked
  198. ]
  199. self.tracked_stracks = joint_stracks(self.tracked_stracks,
  200. activated_starcks)
  201. self.tracked_stracks = joint_stracks(self.tracked_stracks,
  202. refind_stracks)
  203. self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
  204. self.lost_stracks.extend(lost_stracks)
  205. self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
  206. self.removed_stracks.extend(removed_stracks)
  207. self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(
  208. self.tracked_stracks, self.lost_stracks)
  209. # output_stracks = [track for track in self.tracked_stracks if track.is_activated]
  210. output_stracks = [track for track in self.tracked_stracks]
  211. return output_stracks