fcos.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn.functional as F
  19. from ppdet.core.workspace import register, create
  20. from .meta_arch import BaseArch
  21. from ..ssod_utils import permute_to_N_HWA_K, QFLv2
  22. from ..losses import GIoULoss
  23. __all__ = ['FCOS']
  24. @register
  25. class FCOS(BaseArch):
  26. """
  27. FCOS network, see https://arxiv.org/abs/1904.01355
  28. Args:
  29. backbone (object): backbone instance
  30. neck (object): 'FPN' instance
  31. fcos_head (object): 'FCOSHead' instance
  32. """
  33. __category__ = 'architecture'
  34. def __init__(self, backbone, neck='FPN', fcos_head='FCOSHead'):
  35. super(FCOS, self).__init__()
  36. self.backbone = backbone
  37. self.neck = neck
  38. self.fcos_head = fcos_head
  39. self.is_teacher = False
  40. @classmethod
  41. def from_config(cls, cfg, *args, **kwargs):
  42. backbone = create(cfg['backbone'])
  43. kwargs = {'input_shape': backbone.out_shape}
  44. neck = create(cfg['neck'], **kwargs)
  45. kwargs = {'input_shape': neck.out_shape}
  46. fcos_head = create(cfg['fcos_head'], **kwargs)
  47. return {
  48. 'backbone': backbone,
  49. 'neck': neck,
  50. "fcos_head": fcos_head,
  51. }
  52. def _forward(self):
  53. body_feats = self.backbone(self.inputs)
  54. fpn_feats = self.neck(body_feats)
  55. self.is_teacher = self.inputs.get('is_teacher', False)
  56. if self.training or self.is_teacher:
  57. losses = self.fcos_head(fpn_feats, self.inputs)
  58. return losses
  59. else:
  60. fcos_head_outs = self.fcos_head(fpn_feats)
  61. bbox_pred, bbox_num = self.fcos_head.post_process(
  62. fcos_head_outs, self.inputs['scale_factor'])
  63. return {'bbox': bbox_pred, 'bbox_num': bbox_num}
  64. def get_loss(self):
  65. return self._forward()
  66. def get_pred(self):
  67. return self._forward()
  68. def get_loss_keys(self):
  69. return ['loss_cls', 'loss_box', 'loss_quality']
  70. def get_distill_loss(self,
  71. fcos_head_outs,
  72. teacher_fcos_head_outs,
  73. ratio=0.01):
  74. student_logits, student_deltas, student_quality = fcos_head_outs
  75. teacher_logits, teacher_deltas, teacher_quality = teacher_fcos_head_outs
  76. nc = student_logits[0].shape[1]
  77. student_logits = paddle.concat(
  78. [
  79. _.transpose([0, 2, 3, 1]).reshape([-1, nc])
  80. for _ in student_logits
  81. ],
  82. axis=0)
  83. teacher_logits = paddle.concat(
  84. [
  85. _.transpose([0, 2, 3, 1]).reshape([-1, nc])
  86. for _ in teacher_logits
  87. ],
  88. axis=0)
  89. student_deltas = paddle.concat(
  90. [
  91. _.transpose([0, 2, 3, 1]).reshape([-1, 4])
  92. for _ in student_deltas
  93. ],
  94. axis=0)
  95. teacher_deltas = paddle.concat(
  96. [
  97. _.transpose([0, 2, 3, 1]).reshape([-1, 4])
  98. for _ in teacher_deltas
  99. ],
  100. axis=0)
  101. student_quality = paddle.concat(
  102. [
  103. _.transpose([0, 2, 3, 1]).reshape([-1, 1])
  104. for _ in student_quality
  105. ],
  106. axis=0)
  107. teacher_quality = paddle.concat(
  108. [
  109. _.transpose([0, 2, 3, 1]).reshape([-1, 1])
  110. for _ in teacher_quality
  111. ],
  112. axis=0)
  113. with paddle.no_grad():
  114. # Region Selection
  115. count_num = int(teacher_logits.shape[0] * ratio)
  116. teacher_probs = F.sigmoid(teacher_logits)
  117. max_vals = paddle.max(teacher_probs, 1)
  118. sorted_vals, sorted_inds = paddle.topk(max_vals,
  119. teacher_logits.shape[0])
  120. mask = paddle.zeros_like(max_vals)
  121. mask[sorted_inds[:count_num]] = 1.
  122. fg_num = sorted_vals[:count_num].sum()
  123. b_mask = mask > 0
  124. # distill_loss_cls
  125. loss_logits = QFLv2(
  126. F.sigmoid(student_logits),
  127. teacher_probs,
  128. weight=mask,
  129. reduction="sum") / fg_num
  130. # distill_loss_box
  131. inputs = paddle.concat(
  132. (-student_deltas[b_mask][..., :2], student_deltas[b_mask][..., 2:]),
  133. axis=-1)
  134. targets = paddle.concat(
  135. (-teacher_deltas[b_mask][..., :2], teacher_deltas[b_mask][..., 2:]),
  136. axis=-1)
  137. iou_loss = GIoULoss(reduction='mean')
  138. loss_deltas = iou_loss(inputs, targets)
  139. # distill_loss_quality
  140. loss_quality = F.binary_cross_entropy(
  141. F.sigmoid(student_quality[b_mask]),
  142. F.sigmoid(teacher_quality[b_mask]),
  143. reduction='mean')
  144. return {
  145. "distill_loss_cls": loss_logits,
  146. "distill_loss_box": loss_deltas,
  147. "distill_loss_quality": loss_quality,
  148. "fg_sum": fg_num,
  149. }