centertrack.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import copy
  18. import math
  19. import numpy as np
  20. import paddle
  21. from ppdet.core.workspace import register, create
  22. from .meta_arch import BaseArch
  23. from ..keypoint_utils import affine_transform
  24. from ppdet.data.transform.op_helper import gaussian_radius, gaussian2D, draw_umich_gaussian
  25. __all__ = ['CenterTrack']
  26. @register
  27. class CenterTrack(BaseArch):
  28. """
  29. CenterTrack network, see http://arxiv.org/abs/2004.01177
  30. Args:
  31. detector (object): 'CenterNet' instance
  32. plugin_head (object): 'CenterTrackHead' instance
  33. tracker (object): 'CenterTracker' instance
  34. """
  35. __category__ = 'architecture'
  36. __shared__ = ['mot_metric']
  37. def __init__(self,
  38. detector='CenterNet',
  39. plugin_head='CenterTrackHead',
  40. tracker='CenterTracker',
  41. mot_metric=False):
  42. super(CenterTrack, self).__init__()
  43. self.detector = detector
  44. self.plugin_head = plugin_head
  45. self.tracker = tracker
  46. self.mot_metric = mot_metric
  47. self.pre_image = None
  48. self.deploy = False
  49. @classmethod
  50. def from_config(cls, cfg, *args, **kwargs):
  51. detector = create(cfg['detector'])
  52. detector_out_shape = detector.neck and detector.neck.out_shape or detector.backbone.out_shape
  53. kwargs = {'input_shape': detector_out_shape}
  54. plugin_head = create(cfg['plugin_head'], **kwargs)
  55. tracker = create(cfg['tracker'])
  56. return {
  57. 'detector': detector,
  58. 'plugin_head': plugin_head,
  59. 'tracker': tracker,
  60. }
  61. def _forward(self):
  62. if self.training:
  63. det_outs = self.detector(self.inputs)
  64. neck_feat = det_outs['neck_feat']
  65. losses = {}
  66. for k, v in det_outs.items():
  67. if 'loss' not in k: continue
  68. losses.update({k: v})
  69. plugin_outs = self.plugin_head(neck_feat, self.inputs)
  70. for k, v in plugin_outs.items():
  71. if 'loss' not in k: continue
  72. losses.update({k: v})
  73. losses['loss'] = det_outs['det_loss'] + plugin_outs['plugin_loss']
  74. return losses
  75. else:
  76. if not self.mot_metric:
  77. # detection, support bs>=1
  78. det_outs = self.detector(self.inputs)
  79. return {
  80. 'bbox': det_outs['bbox'],
  81. 'bbox_num': det_outs['bbox_num']
  82. }
  83. else:
  84. # MOT, only support bs=1
  85. if not self.deploy:
  86. if self.pre_image is None:
  87. self.pre_image = self.inputs['image']
  88. # initializing tracker for the first frame
  89. self.tracker.init_track([])
  90. self.inputs['pre_image'] = self.pre_image
  91. self.pre_image = self.inputs[
  92. 'image'] # Note: update for next image
  93. # render input heatmap from tracker status
  94. pre_hm = self.get_additional_inputs(
  95. self.tracker.tracks, self.inputs, with_hm=True)
  96. self.inputs['pre_hm'] = paddle.to_tensor(pre_hm)
  97. # model inference
  98. det_outs = self.detector(self.inputs)
  99. neck_feat = det_outs['neck_feat']
  100. result = self.plugin_head(
  101. neck_feat, self.inputs, det_outs['bbox'],
  102. det_outs['bbox_inds'], det_outs['topk_clses'],
  103. det_outs['topk_ys'], det_outs['topk_xs'])
  104. if not self.deploy:
  105. # convert the cropped and 4x downsampled output coordinate system
  106. # back to the input image coordinate system
  107. result = self.plugin_head.centertrack_post_process(
  108. result, self.inputs, self.tracker.out_thresh)
  109. return result
  110. def get_pred(self):
  111. return self._forward()
  112. def get_loss(self):
  113. return self._forward()
  114. def reset_tracking(self):
  115. self.tracker.reset()
  116. self.pre_image = None
  117. def get_additional_inputs(self, dets, meta, with_hm=True):
  118. # Render input heatmap from previous trackings.
  119. trans_input = meta['trans_input'][0].numpy()
  120. inp_width, inp_height = int(meta['inp_width'][0]), int(meta[
  121. 'inp_height'][0])
  122. input_hm = np.zeros((1, inp_height, inp_width), dtype=np.float32)
  123. for det in dets:
  124. if det['score'] < self.tracker.pre_thresh:
  125. continue
  126. bbox = affine_transform_bbox(det['bbox'], trans_input, inp_width,
  127. inp_height)
  128. h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
  129. if (h > 0 and w > 0):
  130. radius = gaussian_radius(
  131. (math.ceil(h), math.ceil(w)), min_overlap=0.7)
  132. radius = max(0, int(radius))
  133. ct = np.array(
  134. [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2],
  135. dtype=np.float32)
  136. ct_int = ct.astype(np.int32)
  137. if with_hm:
  138. input_hm[0] = draw_umich_gaussian(input_hm[0], ct_int,
  139. radius)
  140. if with_hm:
  141. input_hm = input_hm[np.newaxis]
  142. return input_hm
  143. def affine_transform_bbox(bbox, trans, width, height):
  144. bbox = np.array(copy.deepcopy(bbox), dtype=np.float32)
  145. bbox[:2] = affine_transform(bbox[:2], trans)
  146. bbox[2:] = affine_transform(bbox[2:], trans)
  147. bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, width - 1)
  148. bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, height - 1)
  149. return bbox