ppyoloe_contrast_head.py 8.4 KB


  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from ppdet.core.workspace import register
  18. from ..bbox_utils import batch_distance2bbox
  19. from ..losses import GIoULoss
  20. from ..initializer import bias_init_with_prob, constant_, normal_
  21. from ..assigners.utils import generate_anchors_for_grid_cell
  22. from ppdet.modeling.backbones.cspresnet import ConvBNLayer
  23. from ppdet.modeling.ops import get_static_shape, get_act_fn
  24. from ppdet.modeling.layers import MultiClassNMS
  25. from ppdet.modeling.heads.ppyoloe_head import PPYOLOEHead
  26. __all__ = ['PPYOLOEContrastHead']
  27. @register
  28. class PPYOLOEContrastHead(PPYOLOEHead):
  29. __shared__ = [
  30. 'num_classes', 'eval_size', 'trt', 'exclude_nms',
  31. 'exclude_post_process', 'use_shared_conv'
  32. ]
  33. __inject__ = ['static_assigner', 'assigner', 'nms', 'contrast_loss']
  34. def __init__(self,
  35. in_channels=[1024, 512, 256],
  36. num_classes=80,
  37. act='swish',
  38. fpn_strides=(32, 16, 8),
  39. grid_cell_scale=5.0,
  40. grid_cell_offset=0.5,
  41. reg_max=16,
  42. reg_range=None,
  43. static_assigner_epoch=4,
  44. use_varifocal_loss=True,
  45. static_assigner='ATSSAssigner',
  46. assigner='TaskAlignedAssigner',
  47. contrast_loss='SupContrast',
  48. nms='MultiClassNMS',
  49. eval_size=None,
  50. loss_weight={
  51. 'class': 1.0,
  52. 'iou': 2.5,
  53. 'dfl': 0.5,
  54. },
  55. trt=False,
  56. exclude_nms=False,
  57. exclude_post_process=False,
  58. use_shared_conv=True):
  59. super().__init__(in_channels,
  60. num_classes,
  61. act,
  62. fpn_strides,
  63. grid_cell_scale,
  64. grid_cell_offset,
  65. reg_max,
  66. reg_range,
  67. static_assigner_epoch,
  68. use_varifocal_loss,
  69. static_assigner,
  70. assigner,
  71. nms,
  72. eval_size,
  73. loss_weight,
  74. trt,
  75. exclude_nms,
  76. exclude_post_process,
  77. use_shared_conv)
  78. assert len(in_channels) > 0, "len(in_channels) should > 0"
  79. self.contrast_loss = contrast_loss
  80. self.contrast_encoder = nn.LayerList()
  81. for in_c in self.in_channels:
  82. self.contrast_encoder.append(
  83. nn.Conv2D(
  84. in_c, 128, 3, padding=1))
  85. self._init_contrast_encoder()
  86. def _init_contrast_encoder(self):
  87. bias_en = bias_init_with_prob(0.01)
  88. for en_ in self.contrast_encoder:
  89. constant_(en_.weight)
  90. constant_(en_.bias, bias_en)
  91. def forward_train(self, feats, targets):
  92. anchors, anchor_points, num_anchors_list, stride_tensor = \
  93. generate_anchors_for_grid_cell(
  94. feats, self.fpn_strides, self.grid_cell_scale,
  95. self.grid_cell_offset)
  96. cls_score_list, reg_distri_list = [], []
  97. contrast_encoder_list = []
  98. for i, feat in enumerate(feats):
  99. avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
  100. cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) +
  101. feat)
  102. reg_distri = self.pred_reg[i](self.stem_reg[i](feat, avg_feat))
  103. contrast_logit = self.contrast_encoder[i](self.stem_cls[i](feat, avg_feat) +
  104. feat)
  105. contrast_encoder_list.append(contrast_logit.flatten(2).transpose([0, 2, 1]))
  106. # cls and reg
  107. cls_score = F.sigmoid(cls_logit)
  108. cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
  109. reg_distri_list.append(reg_distri.flatten(2).transpose([0, 2, 1]))
  110. cls_score_list = paddle.concat(cls_score_list, axis=1)
  111. reg_distri_list = paddle.concat(reg_distri_list, axis=1)
  112. contrast_encoder_list = paddle.concat(contrast_encoder_list, axis=1)
  113. return self.get_loss([
  114. cls_score_list, reg_distri_list, contrast_encoder_list, anchors, anchor_points,
  115. num_anchors_list, stride_tensor
  116. ], targets)
  117. def get_loss(self, head_outs, gt_meta):
  118. pred_scores, pred_distri, pred_contrast_encoder, anchors,\
  119. anchor_points, num_anchors_list, stride_tensor = head_outs
  120. anchor_points_s = anchor_points / stride_tensor
  121. pred_bboxes = self._bbox_decode(anchor_points_s, pred_distri)
  122. gt_labels = gt_meta['gt_class']
  123. gt_bboxes = gt_meta['gt_bbox']
  124. pad_gt_mask = gt_meta['pad_gt_mask']
  125. # label assignment
  126. if gt_meta['epoch_id'] < self.static_assigner_epoch:
  127. assigned_labels, assigned_bboxes, assigned_scores = \
  128. self.static_assigner(
  129. anchors,
  130. num_anchors_list,
  131. gt_labels,
  132. gt_bboxes,
  133. pad_gt_mask,
  134. bg_index=self.num_classes,
  135. pred_bboxes=pred_bboxes.detach() * stride_tensor)
  136. alpha_l = 0.25
  137. else:
  138. if self.sm_use:
  139. assigned_labels, assigned_bboxes, assigned_scores = \
  140. self.assigner(
  141. pred_scores.detach(),
  142. pred_bboxes.detach() * stride_tensor,
  143. anchor_points,
  144. stride_tensor,
  145. gt_labels,
  146. gt_bboxes,
  147. pad_gt_mask,
  148. bg_index=self.num_classes)
  149. else:
  150. assigned_labels, assigned_bboxes, assigned_scores = \
  151. self.assigner(
  152. pred_scores.detach(),
  153. pred_bboxes.detach() * stride_tensor,
  154. anchor_points,
  155. num_anchors_list,
  156. gt_labels,
  157. gt_bboxes,
  158. pad_gt_mask,
  159. bg_index=self.num_classes)
  160. alpha_l = -1
  161. # rescale bbox
  162. assigned_bboxes /= stride_tensor
  163. # cls loss
  164. if self.use_varifocal_loss:
  165. one_hot_label = F.one_hot(assigned_labels,
  166. self.num_classes + 1)[..., :-1]
  167. loss_cls = self._varifocal_loss(pred_scores, assigned_scores,
  168. one_hot_label)
  169. else:
  170. loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha_l)
  171. assigned_scores_sum = assigned_scores.sum()
  172. if paddle.distributed.get_world_size() > 1:
  173. paddle.distributed.all_reduce(assigned_scores_sum)
  174. assigned_scores_sum /= paddle.distributed.get_world_size()
  175. assigned_scores_sum = paddle.clip(assigned_scores_sum, min=1.)
  176. loss_cls /= assigned_scores_sum
  177. loss_l1, loss_iou, loss_dfl = \
  178. self._bbox_loss(pred_distri, pred_bboxes, anchor_points_s,
  179. assigned_labels, assigned_bboxes, assigned_scores,
  180. assigned_scores_sum)
  181. # contrast loss
  182. loss_contrast = self.contrast_loss(pred_contrast_encoder.reshape([-1, pred_contrast_encoder.shape[-1]]), \
  183. assigned_labels.reshape([-1]), assigned_scores.max(-1).reshape([-1]))
  184. loss = self.loss_weight['class'] * loss_cls + \
  185. self.loss_weight['iou'] * loss_iou + \
  186. self.loss_weight['dfl'] * loss_dfl + \
  187. self.loss_weight['contrast'] * loss_contrast
  188. out_dict = {
  189. 'loss': loss,
  190. 'loss_cls': loss_cls,
  191. 'loss_iou': loss_iou,
  192. 'loss_dfl': loss_dfl,
  193. 'loss_l1': loss_l1,
  194. 'loss_contrast': loss_contrast
  195. }
  196. return out_dict