ocsort_tracker.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/noahcao/OC_SORT/blob/master/trackers/ocsort_tracker/ocsort.py
  16. """
  17. import numpy as np
  18. from ..matching.ocsort_matching import associate, linear_assignment, iou_batch, associate_only_iou
  19. from ..motion.ocsort_kalman_filter import OCSORTKalmanFilter
  20. from ppdet.core.workspace import register, serializable
  21. def k_previous_obs(observations, cur_age, k):
  22. if len(observations) == 0:
  23. return [-1, -1, -1, -1, -1]
  24. for i in range(k):
  25. dt = k - i
  26. if cur_age - dt in observations:
  27. return observations[cur_age - dt]
  28. max_age = max(observations.keys())
  29. return observations[max_age]
  30. def convert_bbox_to_z(bbox):
  31. """
  32. Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
  33. [x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
  34. the aspect ratio
  35. """
  36. w = bbox[2] - bbox[0]
  37. h = bbox[3] - bbox[1]
  38. x = bbox[0] + w / 2.
  39. y = bbox[1] + h / 2.
  40. s = w * h # scale is just area
  41. r = w / float(h + 1e-6)
  42. return np.array([x, y, s, r]).reshape((4, 1))
  43. def convert_x_to_bbox(x, score=None):
  44. """
  45. Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
  46. [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
  47. """
  48. w = np.sqrt(x[2] * x[3])
  49. h = x[2] / w
  50. if (score == None):
  51. return np.array(
  52. [x[0] - w / 2., x[1] - h / 2., x[0] + w / 2.,
  53. x[1] + h / 2.]).reshape((1, 4))
  54. else:
  55. score = np.array([score])
  56. return np.array([
  57. x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score
  58. ]).reshape((1, 5))
  59. def speed_direction(bbox1, bbox2):
  60. cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0
  61. cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0
  62. speed = np.array([cy2 - cy1, cx2 - cx1])
  63. norm = np.sqrt((cy2 - cy1)**2 + (cx2 - cx1)**2) + 1e-6
  64. return speed / norm
  65. class KalmanBoxTracker(object):
  66. """
  67. This class represents the internal state of individual tracked objects observed as bbox.
  68. Args:
  69. bbox (np.array): bbox in [x1,y1,x2,y2,score] format.
  70. delta_t (int): delta_t of previous observation
  71. """
  72. count = 0
  73. def __init__(self, bbox, delta_t=3):
  74. self.kf = OCSORTKalmanFilter(dim_x=7, dim_z=4)
  75. self.kf.F = np.array([[1., 0, 0, 0, 1., 0, 0], [0, 1., 0, 0, 0, 1., 0],
  76. [0, 0, 1., 0, 0, 0, 1], [0, 0, 0, 1., 0, 0, 0],
  77. [0, 0, 0, 0, 1., 0, 0], [0, 0, 0, 0, 0, 1., 0],
  78. [0, 0, 0, 0, 0, 0, 1.]])
  79. self.kf.H = np.array([[1., 0, 0, 0, 0, 0, 0], [0, 1., 0, 0, 0, 0, 0],
  80. [0, 0, 1., 0, 0, 0, 0], [0, 0, 0, 1., 0, 0, 0]])
  81. self.kf.R[2:, 2:] *= 10.
  82. self.kf.P[4:, 4:] *= 1000.
  83. # give high uncertainty to the unobservable initial velocities
  84. self.kf.P *= 10.
  85. self.kf.Q[-1, -1] *= 0.01
  86. self.kf.Q[4:, 4:] *= 0.01
  87. self.score = bbox[4]
  88. self.kf.x[:4] = convert_bbox_to_z(bbox)
  89. self.time_since_update = 0
  90. self.id = KalmanBoxTracker.count
  91. KalmanBoxTracker.count += 1
  92. self.history = []
  93. self.hits = 0
  94. self.hit_streak = 0
  95. self.age = 0
  96. """
  97. NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of
  98. function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a
  99. fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), let's bear it for now.
  100. """
  101. self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder
  102. self.observations = dict()
  103. self.history_observations = []
  104. self.velocity = None
  105. self.delta_t = delta_t
  106. def update(self, bbox, angle_cost=False):
  107. """
  108. Updates the state vector with observed bbox.
  109. """
  110. if bbox is not None:
  111. if angle_cost and self.last_observation.sum(
  112. ) >= 0: # no previous observation
  113. previous_box = None
  114. for i in range(self.delta_t):
  115. dt = self.delta_t - i
  116. if self.age - dt in self.observations:
  117. previous_box = self.observations[self.age - dt]
  118. break
  119. if previous_box is None:
  120. previous_box = self.last_observation
  121. """
  122. Estimate the track speed direction with observations \Delta t steps away
  123. """
  124. self.velocity = speed_direction(previous_box, bbox)
  125. """
  126. Insert new observations. This is a ugly way to maintain both self.observations
  127. and self.history_observations. Bear it for the moment.
  128. """
  129. self.last_observation = bbox
  130. self.observations[self.age] = bbox
  131. self.history_observations.append(bbox)
  132. self.time_since_update = 0
  133. self.history = []
  134. self.hits += 1
  135. self.hit_streak += 1
  136. self.kf.update(convert_bbox_to_z(bbox))
  137. else:
  138. self.kf.update(bbox)
  139. def predict(self):
  140. """
  141. Advances the state vector and returns the predicted bounding box estimate.
  142. """
  143. if ((self.kf.x[6] + self.kf.x[2]) <= 0):
  144. self.kf.x[6] *= 0.0
  145. self.kf.predict()
  146. self.age += 1
  147. if (self.time_since_update > 0):
  148. self.hit_streak = 0
  149. self.time_since_update += 1
  150. self.history.append(convert_x_to_bbox(self.kf.x, score=self.score))
  151. return self.history[-1]
  152. def get_state(self):
  153. return convert_x_to_bbox(self.kf.x, score=self.score)
  154. @register
  155. @serializable
  156. class OCSORTTracker(object):
  157. """
  158. OCSORT tracker, support single class
  159. Args:
  160. det_thresh (float): threshold of detection score
  161. max_age (int): maximum number of missed misses before a track is deleted
  162. min_hits (int): minimum hits for associate
  163. iou_threshold (float): iou threshold for associate
  164. delta_t (int): delta_t of previous observation
  165. inertia (float): vdc_weight of angle_diff_cost for associate
  166. vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
  167. bad results. If set <= 0 means no need to filter bboxes,usually set
  168. 1.6 for pedestrian tracking.
  169. min_box_area (int): min box area to filter out low quality boxes
  170. use_byte (bool): Whether use ByteTracker, default False
  171. """
  172. def __init__(self,
  173. det_thresh=0.6,
  174. max_age=30,
  175. min_hits=3,
  176. iou_threshold=0.3,
  177. delta_t=3,
  178. inertia=0.2,
  179. vertical_ratio=-1,
  180. min_box_area=0,
  181. use_byte=False,
  182. use_angle_cost=False):
  183. self.det_thresh = det_thresh
  184. self.max_age = max_age
  185. self.min_hits = min_hits
  186. self.iou_threshold = iou_threshold
  187. self.delta_t = delta_t
  188. self.inertia = inertia
  189. self.vertical_ratio = vertical_ratio
  190. self.min_box_area = min_box_area
  191. self.use_byte = use_byte
  192. self.use_angle_cost = use_angle_cost
  193. self.trackers = []
  194. self.frame_count = 0
  195. KalmanBoxTracker.count = 0
  196. def update(self, pred_dets, pred_embs=None):
  197. """
  198. Args:
  199. pred_dets (np.array): Detection results of the image, the shape is
  200. [N, 6], means 'cls_id, score, x0, y0, x1, y1'.
  201. pred_embs (np.array): Embedding results of the image, the shape is
  202. [N, 128] or [N, 512], default as None.
  203. Return:
  204. tracking boxes (np.array): [M, 6], means 'x0, y0, x1, y1, score, id'.
  205. """
  206. if pred_dets is None:
  207. return np.empty((0, 6))
  208. self.frame_count += 1
  209. bboxes = pred_dets[:, 2:]
  210. scores = pred_dets[:, 1:2]
  211. dets = np.concatenate((bboxes, scores), axis=1)
  212. scores = scores.squeeze(-1)
  213. inds_low = scores > 0.1
  214. inds_high = scores < self.det_thresh
  215. inds_second = np.logical_and(inds_low, inds_high)
  216. # self.det_thresh > score > 0.1, for second matching
  217. dets_second = dets[inds_second] # detections for second matching
  218. remain_inds = scores > self.det_thresh
  219. dets = dets[remain_inds]
  220. # get predicted locations from existing trackers.
  221. trks = np.zeros((len(self.trackers), 5))
  222. to_del = []
  223. ret = []
  224. for t, trk in enumerate(trks):
  225. pos = self.trackers[t].predict()[0]
  226. trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
  227. if np.any(np.isnan(pos)):
  228. to_del.append(t)
  229. trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
  230. for t in reversed(to_del):
  231. self.trackers.pop(t)
  232. if self.use_angle_cost:
  233. velocities = np.array([
  234. trk.velocity if trk.velocity is not None else np.array((0, 0))
  235. for trk in self.trackers
  236. ])
  237. k_observations = np.array([
  238. k_previous_obs(trk.observations, trk.age, self.delta_t)
  239. for trk in self.trackers
  240. ])
  241. last_boxes = np.array([trk.last_observation for trk in self.trackers])
  242. """
  243. First round of association
  244. """
  245. if self.use_angle_cost:
  246. matched, unmatched_dets, unmatched_trks = associate(
  247. dets, trks, self.iou_threshold, velocities, k_observations,
  248. self.inertia)
  249. else:
  250. matched, unmatched_dets, unmatched_trks = associate_only_iou(
  251. dets, trks, self.iou_threshold)
  252. for m in matched:
  253. self.trackers[m[1]].update(
  254. dets[m[0], :], angle_cost=self.use_angle_cost)
  255. """
  256. Second round of associaton by OCR
  257. """
  258. # BYTE association
  259. if self.use_byte and len(dets_second) > 0 and unmatched_trks.shape[
  260. 0] > 0:
  261. u_trks = trks[unmatched_trks]
  262. iou_left = iou_batch(
  263. dets_second,
  264. u_trks) # iou between low score detections and unmatched tracks
  265. iou_left = np.array(iou_left)
  266. if iou_left.max() > self.iou_threshold:
  267. """
  268. NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
  269. get a higher performance especially on MOT17/MOT20 datasets. But we keep it
  270. uniform here for simplicity
  271. """
  272. matched_indices = linear_assignment(-iou_left)
  273. to_remove_trk_indices = []
  274. for m in matched_indices:
  275. det_ind, trk_ind = m[0], unmatched_trks[m[1]]
  276. if iou_left[m[0], m[1]] < self.iou_threshold:
  277. continue
  278. self.trackers[trk_ind].update(
  279. dets_second[det_ind, :], angle_cost=self.use_angle_cost)
  280. to_remove_trk_indices.append(trk_ind)
  281. unmatched_trks = np.setdiff1d(unmatched_trks,
  282. np.array(to_remove_trk_indices))
  283. if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
  284. left_dets = dets[unmatched_dets]
  285. left_trks = last_boxes[unmatched_trks]
  286. iou_left = iou_batch(left_dets, left_trks)
  287. iou_left = np.array(iou_left)
  288. if iou_left.max() > self.iou_threshold:
  289. """
  290. NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
  291. get a higher performance especially on MOT17/MOT20 datasets. But we keep it
  292. uniform here for simplicity
  293. """
  294. rematched_indices = linear_assignment(-iou_left)
  295. to_remove_det_indices = []
  296. to_remove_trk_indices = []
  297. for m in rematched_indices:
  298. det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[
  299. 1]]
  300. if iou_left[m[0], m[1]] < self.iou_threshold:
  301. continue
  302. self.trackers[trk_ind].update(
  303. dets[det_ind, :], angle_cost=self.use_angle_cost)
  304. to_remove_det_indices.append(det_ind)
  305. to_remove_trk_indices.append(trk_ind)
  306. unmatched_dets = np.setdiff1d(unmatched_dets,
  307. np.array(to_remove_det_indices))
  308. unmatched_trks = np.setdiff1d(unmatched_trks,
  309. np.array(to_remove_trk_indices))
  310. for m in unmatched_trks:
  311. self.trackers[m].update(None)
  312. # create and initialise new trackers for unmatched detections
  313. for i in unmatched_dets:
  314. trk = KalmanBoxTracker(dets[i, :], delta_t=self.delta_t)
  315. self.trackers.append(trk)
  316. i = len(self.trackers)
  317. for trk in reversed(self.trackers):
  318. if trk.last_observation.sum() < 0:
  319. d = trk.get_state()[0]
  320. else:
  321. d = trk.last_observation # tlbr + score
  322. if (trk.time_since_update < 1) and (
  323. trk.hit_streak >= self.min_hits or
  324. self.frame_count <= self.min_hits):
  325. # +1 as MOT benchmark requires positive
  326. ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1))
  327. i -= 1
  328. # remove dead tracklet
  329. if (trk.time_since_update > self.max_age):
  330. self.trackers.pop(i)
  331. if (len(ret) > 0):
  332. return np.concatenate(ret)
  333. return np.empty((0, 6))