base_sde_tracker.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/nwojke/deep_sort/blob/master/deep_sort/track.py
  16. """
  17. import datetime
  18. __all__ = ['TrackState', 'Track']
  19. class TrackState(object):
  20. """
  21. Enumeration type for the single target track state. Newly created tracks are
  22. classified as `tentative` until enough evidence has been collected. Then,
  23. the track state is changed to `confirmed`. Tracks that are no longer alive
  24. are classified as `deleted` to mark them for removal from the set of active
  25. tracks.
  26. """
  27. Tentative = 1
  28. Confirmed = 2
  29. Deleted = 3
  30. class Track(object):
  31. """
  32. A single target track with state space `(x, y, a, h)` and associated
  33. velocities, where `(x, y)` is the center of the bounding box, `a` is the
  34. aspect ratio and `h` is the height.
  35. Args:
  36. mean (ndarray): Mean vector of the initial state distribution.
  37. covariance (ndarray): Covariance matrix of the initial state distribution.
  38. track_id (int): A unique track identifier.
  39. n_init (int): Number of consecutive detections before the track is confirmed.
  40. The track state is set to `Deleted` if a miss occurs within the first
  41. `n_init` frames.
  42. max_age (int): The maximum number of consecutive misses before the track
  43. state is set to `Deleted`.
  44. cls_id (int): The category id of the tracked box.
  45. score (float): The confidence score of the tracked box.
  46. feature (Optional[ndarray]): Feature vector of the detection this track
  47. originates from. If not None, this feature is added to the `features` cache.
  48. Attributes:
  49. hits (int): Total number of measurement updates.
  50. age (int): Total number of frames since first occurance.
  51. time_since_update (int): Total number of frames since last measurement
  52. update.
  53. state (TrackState): The current track state.
  54. features (List[ndarray]): A cache of features. On each measurement update,
  55. the associated feature vector is added to this list.
  56. """
  57. def __init__(self,
  58. mean,
  59. covariance,
  60. track_id,
  61. n_init,
  62. max_age,
  63. cls_id,
  64. score,
  65. feature=None):
  66. self.mean = mean
  67. self.covariance = covariance
  68. self.track_id = track_id
  69. self.hits = 1
  70. self.age = 1
  71. self.time_since_update = 0
  72. self.cls_id = cls_id
  73. self.score = score
  74. self.start_time = datetime.datetime.now()
  75. self.state = TrackState.Tentative
  76. self.features = []
  77. self.feat = feature
  78. if feature is not None:
  79. self.features.append(feature)
  80. self._n_init = n_init
  81. self._max_age = max_age
  82. def to_tlwh(self):
  83. """Get position in format `(top left x, top left y, width, height)`."""
  84. ret = self.mean[:4].copy()
  85. ret[2] *= ret[3]
  86. ret[:2] -= ret[2:] / 2
  87. return ret
  88. def to_tlbr(self):
  89. """Get position in bounding box format `(min x, miny, max x, max y)`."""
  90. ret = self.to_tlwh()
  91. ret[2:] = ret[:2] + ret[2:]
  92. return ret
  93. def predict(self, kalman_filter):
  94. """
  95. Propagate the state distribution to the current time step using a Kalman
  96. filter prediction step.
  97. """
  98. self.mean, self.covariance = kalman_filter.predict(self.mean,
  99. self.covariance)
  100. self.age += 1
  101. self.time_since_update += 1
  102. def update(self, kalman_filter, detection):
  103. """
  104. Perform Kalman filter measurement update step and update the associated
  105. detection feature cache.
  106. """
  107. self.mean, self.covariance = kalman_filter.update(self.mean,
  108. self.covariance,
  109. detection.to_xyah())
  110. self.features.append(detection.feature)
  111. self.feat = detection.feature
  112. self.cls_id = detection.cls_id
  113. self.score = detection.score
  114. self.hits += 1
  115. self.time_since_update = 0
  116. if self.state == TrackState.Tentative and self.hits >= self._n_init:
  117. self.state = TrackState.Confirmed
  118. def mark_missed(self):
  119. """Mark this track as missed (no association at the current time step).
  120. """
  121. if self.state == TrackState.Tentative:
  122. self.state = TrackState.Deleted
  123. elif self.time_since_update > self._max_age:
  124. self.state = TrackState.Deleted
  125. def is_tentative(self):
  126. """Returns True if this track is tentative (unconfirmed)."""
  127. return self.state == TrackState.Tentative
  128. def is_confirmed(self):
  129. """Returns True if this track is confirmed."""
  130. return self.state == TrackState.Confirmed
  131. def is_deleted(self):
  132. """Returns True if this track is dead and should be deleted."""
  133. return self.state == TrackState.Deleted