botsort_tracker.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/WWangYuHsiang/SMILEtrack/blob/main/BoT-SORT/tracker/bot_sort.py
  16. """
  17. import cv2
  18. import matplotlib.pyplot as plt
  19. import numpy as np
  20. from collections import deque
  21. from ..matching import jde_matching as matching
  22. from ..motion import GMC
  23. from .base_jde_tracker import TrackState, STrack
  24. from .base_jde_tracker import joint_stracks, sub_stracks, remove_duplicate_stracks
  25. from ..motion import KalmanFilter
  26. from ppdet.core.workspace import register, serializable
  27. @register
  28. @serializable
  29. class BOTSORTTracker(object):
  30. """
  31. BOTSORT tracker, support single class
  32. Args:
  33. track_high_thresh (float): threshold of detection high score
  34. track_low_thresh (float): threshold of remove detection score
  35. new_track_thresh (float): threshold of new track score
  36. match_thresh (float): iou threshold for associate
  37. track_buffer (int): tracking reserved frames,default 30
  38. min_box_area (float): reserved min box
  39. camera_motion (bool): Whether use camera motion, default False
  40. cmc_method (str): camera motion method,defalut sparseOptFlow
  41. frame_rate (int): fps buffer_size=int(frame_rate / 30.0 * track_buffer)
  42. """
  43. def __init__(self,
  44. track_high_thresh=0.3,
  45. track_low_thresh=0.2,
  46. new_track_thresh=0.4,
  47. match_thresh=0.7,
  48. track_buffer=30,
  49. min_box_area=0,
  50. camera_motion=False,
  51. cmc_method='sparseOptFlow',
  52. frame_rate=30):
  53. self.tracked_stracks = [] # type: list[STrack]
  54. self.lost_stracks = [] # type: list[STrack]
  55. self.removed_stracks = [] # type: list[STrack]
  56. self.frame_id = 0
  57. self.track_high_thresh = track_high_thresh
  58. self.track_low_thresh = track_low_thresh
  59. self.new_track_thresh = new_track_thresh
  60. self.match_thresh = match_thresh
  61. self.buffer_size = int(frame_rate / 30.0 * track_buffer)
  62. self.max_time_lost = self.buffer_size
  63. self.kalman_filter = KalmanFilter()
  64. self.min_box_area = min_box_area
  65. self.camera_motion = camera_motion
  66. self.gmc = GMC(method=cmc_method)
  67. def update(self, output_results, img=None):
  68. self.frame_id += 1
  69. activated_starcks = []
  70. refind_stracks = []
  71. lost_stracks = []
  72. removed_stracks = []
  73. if len(output_results):
  74. bboxes = output_results[:, 2:6]
  75. scores = output_results[:, 1]
  76. classes = output_results[:, 0]
  77. # Remove bad detections
  78. lowest_inds = scores > self.track_low_thresh
  79. bboxes = bboxes[lowest_inds]
  80. scores = scores[lowest_inds]
  81. classes = classes[lowest_inds]
  82. # Find high threshold detections
  83. remain_inds = scores > self.track_high_thresh
  84. dets = bboxes[remain_inds]
  85. scores_keep = scores[remain_inds]
  86. classes_keep = classes[remain_inds]
  87. else:
  88. bboxes = []
  89. scores = []
  90. classes = []
  91. dets = []
  92. scores_keep = []
  93. classes_keep = []
  94. if len(dets) > 0:
  95. '''Detections'''
  96. detections = [
  97. STrack(STrack.tlbr_to_tlwh(tlbr), s, c)
  98. for (tlbr, s, c) in zip(dets, scores_keep, classes_keep)
  99. ]
  100. else:
  101. detections = []
  102. ''' Add newly detected tracklets to tracked_stracks'''
  103. unconfirmed = []
  104. tracked_stracks = [] # type: list[STrack]
  105. for track in self.tracked_stracks:
  106. if not track.is_activated:
  107. unconfirmed.append(track)
  108. else:
  109. tracked_stracks.append(track)
  110. ''' Step 2: First association, with high score detection boxes'''
  111. strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
  112. # Predict the current location with KF
  113. STrack.multi_predict(strack_pool, self.kalman_filter)
  114. # Fix camera motion
  115. if self.camera_motion:
  116. warp = self.gmc.apply(img[0], dets)
  117. STrack.multi_gmc(strack_pool, warp)
  118. STrack.multi_gmc(unconfirmed, warp)
  119. # Associate with high score detection boxes
  120. ious_dists = matching.iou_distance(strack_pool, detections)
  121. matches, u_track, u_detection = matching.linear_assignment(
  122. ious_dists, thresh=self.match_thresh)
  123. for itracked, idet in matches:
  124. track = strack_pool[itracked]
  125. det = detections[idet]
  126. if track.state == TrackState.Tracked:
  127. track.update(detections[idet], self.frame_id)
  128. activated_starcks.append(track)
  129. else:
  130. track.re_activate(det, self.frame_id, new_id=False)
  131. refind_stracks.append(track)
  132. ''' Step 3: Second association, with low score detection boxes'''
  133. if len(scores):
  134. inds_high = scores < self.track_high_thresh
  135. inds_low = scores > self.track_low_thresh
  136. inds_second = np.logical_and(inds_low, inds_high)
  137. dets_second = bboxes[inds_second]
  138. scores_second = scores[inds_second]
  139. classes_second = classes[inds_second]
  140. else:
  141. dets_second = []
  142. scores_second = []
  143. classes_second = []
  144. # association the untrack to the low score detections
  145. if len(dets_second) > 0:
  146. '''Detections'''
  147. detections_second = [
  148. STrack(STrack.tlbr_to_tlwh(tlbr), s, c) for (tlbr, s, c) in
  149. zip(dets_second, scores_second, classes_second)
  150. ]
  151. else:
  152. detections_second = []
  153. r_tracked_stracks = [
  154. strack_pool[i] for i in u_track
  155. if strack_pool[i].state == TrackState.Tracked
  156. ]
  157. dists = matching.iou_distance(r_tracked_stracks, detections_second)
  158. matches, u_track, u_detection_second = matching.linear_assignment(
  159. dists, thresh=0.5)
  160. for itracked, idet in matches:
  161. track = r_tracked_stracks[itracked]
  162. det = detections_second[idet]
  163. if track.state == TrackState.Tracked:
  164. track.update(det, self.frame_id)
  165. activated_starcks.append(track)
  166. else:
  167. track.re_activate(det, self.frame_id, new_id=False)
  168. refind_stracks.append(track)
  169. for it in u_track:
  170. track = r_tracked_stracks[it]
  171. if not track.state == TrackState.Lost:
  172. track.mark_lost()
  173. lost_stracks.append(track)
  174. '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
  175. detections = [detections[i] for i in u_detection]
  176. dists = matching.iou_distance(unconfirmed, detections)
  177. matches, u_unconfirmed, u_detection = matching.linear_assignment(
  178. dists, thresh=0.7)
  179. for itracked, idet in matches:
  180. unconfirmed[itracked].update(detections[idet], self.frame_id)
  181. activated_starcks.append(unconfirmed[itracked])
  182. for it in u_unconfirmed:
  183. track = unconfirmed[it]
  184. track.mark_removed()
  185. removed_stracks.append(track)
  186. """ Step 4: Init new stracks"""
  187. for inew in u_detection:
  188. track = detections[inew]
  189. if track.score < self.new_track_thresh:
  190. continue
  191. track.activate(self.kalman_filter, self.frame_id)
  192. activated_starcks.append(track)
  193. """ Step 5: Update state"""
  194. for track in self.lost_stracks:
  195. if self.frame_id - track.end_frame > self.max_time_lost:
  196. track.mark_removed()
  197. removed_stracks.append(track)
  198. """ Merge """
  199. self.tracked_stracks = [
  200. t for t in self.tracked_stracks if t.state == TrackState.Tracked
  201. ]
  202. self.tracked_stracks = joint_stracks(self.tracked_stracks,
  203. activated_starcks)
  204. self.tracked_stracks = joint_stracks(self.tracked_stracks,
  205. refind_stracks)
  206. self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
  207. self.lost_stracks.extend(lost_stracks)
  208. self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
  209. self.removed_stracks.extend(removed_stracks)
  210. self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(
  211. self.tracked_stracks, self.lost_stracks)
  212. # output_stracks = [track for track in self.tracked_stracks if track.is_activated]
  213. output_stracks = [track for track in self.tracked_stracks]
  214. return output_stracks