base_sde_tracker.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/nwojke/deep_sort/blob/master/deep_sort/track.py
  16. """
  17. import datetime
  18. from ppdet.core.workspace import register, serializable
  19. __all__ = ['TrackState', 'Track']
  20. class TrackState(object):
  21. """
  22. Enumeration type for the single target track state. Newly created tracks are
  23. classified as `tentative` until enough evidence has been collected. Then,
  24. the track state is changed to `confirmed`. Tracks that are no longer alive
  25. are classified as `deleted` to mark them for removal from the set of active
  26. tracks.
  27. """
  28. Tentative = 1
  29. Confirmed = 2
  30. Deleted = 3
  31. @register
  32. @serializable
  33. class Track(object):
  34. """
  35. A single target track with state space `(x, y, a, h)` and associated
  36. velocities, where `(x, y)` is the center of the bounding box, `a` is the
  37. aspect ratio and `h` is the height.
  38. Args:
  39. mean (ndarray): Mean vector of the initial state distribution.
  40. covariance (ndarray): Covariance matrix of the initial state distribution.
  41. track_id (int): A unique track identifier.
  42. n_init (int): Number of consecutive detections before the track is confirmed.
  43. The track state is set to `Deleted` if a miss occurs within the first
  44. `n_init` frames.
  45. max_age (int): The maximum number of consecutive misses before the track
  46. state is set to `Deleted`.
  47. cls_id (int): The category id of the tracked box.
  48. score (float): The confidence score of the tracked box.
  49. feature (Optional[ndarray]): Feature vector of the detection this track
  50. originates from. If not None, this feature is added to the `features` cache.
  51. Attributes:
  52. hits (int): Total number of measurement updates.
  53. age (int): Total number of frames since first occurance.
  54. time_since_update (int): Total number of frames since last measurement
  55. update.
  56. state (TrackState): The current track state.
  57. features (List[ndarray]): A cache of features. On each measurement update,
  58. the associated feature vector is added to this list.
  59. """
  60. def __init__(self,
  61. mean,
  62. covariance,
  63. track_id,
  64. n_init,
  65. max_age,
  66. cls_id,
  67. score,
  68. feature=None):
  69. self.mean = mean
  70. self.covariance = covariance
  71. self.track_id = track_id
  72. self.hits = 1
  73. self.age = 1
  74. self.time_since_update = 0
  75. self.cls_id = cls_id
  76. self.score = score
  77. self.start_time = datetime.datetime.now()
  78. self.state = TrackState.Tentative
  79. self.features = []
  80. self.feat = feature
  81. if feature is not None:
  82. self.features.append(feature)
  83. self._n_init = n_init
  84. self._max_age = max_age
  85. def to_tlwh(self):
  86. """Get position in format `(top left x, top left y, width, height)`."""
  87. ret = self.mean[:4].copy()
  88. ret[2] *= ret[3]
  89. ret[:2] -= ret[2:] / 2
  90. return ret
  91. def to_tlbr(self):
  92. """Get position in bounding box format `(min x, miny, max x, max y)`."""
  93. ret = self.to_tlwh()
  94. ret[2:] = ret[:2] + ret[2:]
  95. return ret
  96. def predict(self, kalman_filter):
  97. """
  98. Propagate the state distribution to the current time step using a Kalman
  99. filter prediction step.
  100. """
  101. self.mean, self.covariance = kalman_filter.predict(self.mean,
  102. self.covariance)
  103. self.age += 1
  104. self.time_since_update += 1
  105. def update(self, kalman_filter, detection):
  106. """
  107. Perform Kalman filter measurement update step and update the associated
  108. detection feature cache.
  109. """
  110. self.mean, self.covariance = kalman_filter.update(self.mean,
  111. self.covariance,
  112. detection.to_xyah())
  113. self.features.append(detection.feature)
  114. self.feat = detection.feature
  115. self.cls_id = detection.cls_id
  116. self.score = detection.score
  117. self.hits += 1
  118. self.time_since_update = 0
  119. if self.state == TrackState.Tentative and self.hits >= self._n_init:
  120. self.state = TrackState.Confirmed
  121. def mark_missed(self):
  122. """Mark this track as missed (no association at the current time step).
  123. """
  124. if self.state == TrackState.Tentative:
  125. self.state = TrackState.Deleted
  126. elif self.time_since_update > self._max_age:
  127. self.state = TrackState.Deleted
  128. def is_tentative(self):
  129. """Returns True if this track is tentative (unconfirmed)."""
  130. return self.state == TrackState.Tentative
  131. def is_confirmed(self):
  132. """Returns True if this track is confirmed."""
  133. return self.state == TrackState.Confirmed
  134. def is_deleted(self):
  135. """Returns True if this track is dead and should be deleted."""
  136. return self.state == TrackState.Deleted