fcos_loss.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from ppdet.core.workspace import register
  21. from ppdet.modeling import ops
  22. __all__ = ['FCOSLoss']
  23. def flatten_tensor(inputs, channel_first=False):
  24. """
  25. Flatten a Tensor
  26. Args:
  27. inputs (Tensor): 4-D Tensor with shape [N, C, H, W] or [N, H, W, C]
  28. channel_first (bool): If true the dimension order of Tensor is
  29. [N, C, H, W], otherwise is [N, H, W, C]
  30. Return:
  31. output_channel_last (Tensor): The flattened Tensor in channel_last style
  32. """
  33. if channel_first:
  34. input_channel_last = paddle.transpose(inputs, perm=[0, 2, 3, 1])
  35. else:
  36. input_channel_last = inputs
  37. output_channel_last = paddle.flatten(
  38. input_channel_last, start_axis=0, stop_axis=2)
  39. return output_channel_last
  40. @register
  41. class FCOSLoss(nn.Layer):
  42. """
  43. FCOSLoss
  44. Args:
  45. loss_alpha (float): alpha in focal loss
  46. loss_gamma (float): gamma in focal loss
  47. iou_loss_type (str): location loss type, IoU/GIoU/LINEAR_IoU
  48. reg_weights (float): weight for location loss
  49. quality (str): quality branch, centerness/iou
  50. """
  51. def __init__(self,
  52. loss_alpha=0.25,
  53. loss_gamma=2.0,
  54. iou_loss_type="giou",
  55. reg_weights=1.0,
  56. quality='centerness'):
  57. super(FCOSLoss, self).__init__()
  58. self.loss_alpha = loss_alpha
  59. self.loss_gamma = loss_gamma
  60. self.iou_loss_type = iou_loss_type
  61. self.reg_weights = reg_weights
  62. self.quality = quality
  63. def __iou_loss(self,
  64. pred,
  65. targets,
  66. positive_mask,
  67. weights=None,
  68. return_iou=False):
  69. """
  70. Calculate the loss for location prediction
  71. Args:
  72. pred (Tensor): bounding boxes prediction
  73. targets (Tensor): targets for positive samples
  74. positive_mask (Tensor): mask of positive samples
  75. weights (Tensor): weights for each positive samples
  76. Return:
  77. loss (Tensor): location loss
  78. """
  79. plw = pred[:, 0] * positive_mask
  80. pth = pred[:, 1] * positive_mask
  81. prw = pred[:, 2] * positive_mask
  82. pbh = pred[:, 3] * positive_mask
  83. tlw = targets[:, 0] * positive_mask
  84. tth = targets[:, 1] * positive_mask
  85. trw = targets[:, 2] * positive_mask
  86. tbh = targets[:, 3] * positive_mask
  87. tlw.stop_gradient = True
  88. trw.stop_gradient = True
  89. tth.stop_gradient = True
  90. tbh.stop_gradient = True
  91. ilw = paddle.minimum(plw, tlw)
  92. irw = paddle.minimum(prw, trw)
  93. ith = paddle.minimum(pth, tth)
  94. ibh = paddle.minimum(pbh, tbh)
  95. clw = paddle.maximum(plw, tlw)
  96. crw = paddle.maximum(prw, trw)
  97. cth = paddle.maximum(pth, tth)
  98. cbh = paddle.maximum(pbh, tbh)
  99. area_predict = (plw + prw) * (pth + pbh)
  100. area_target = (tlw + trw) * (tth + tbh)
  101. area_inter = (ilw + irw) * (ith + ibh)
  102. ious = (area_inter + 1.0) / (
  103. area_predict + area_target - area_inter + 1.0)
  104. ious = ious * positive_mask
  105. if return_iou:
  106. return ious
  107. if self.iou_loss_type.lower() == "linear_iou":
  108. loss = 1.0 - ious
  109. elif self.iou_loss_type.lower() == "giou":
  110. area_uniou = area_predict + area_target - area_inter
  111. area_circum = (clw + crw) * (cth + cbh) + 1e-7
  112. giou = ious - (area_circum - area_uniou) / area_circum
  113. loss = 1.0 - giou
  114. elif self.iou_loss_type.lower() == "iou":
  115. loss = 0.0 - paddle.log(ious)
  116. else:
  117. raise KeyError
  118. if weights is not None:
  119. loss = loss * weights
  120. return loss
  121. def forward(self, cls_logits, bboxes_reg, centerness, tag_labels,
  122. tag_bboxes, tag_center):
  123. """
  124. Calculate the loss for classification, location and centerness
  125. Args:
  126. cls_logits (list): list of Tensor, which is predicted
  127. score for all anchor points with shape [N, M, C]
  128. bboxes_reg (list): list of Tensor, which is predicted
  129. offsets for all anchor points with shape [N, M, 4]
  130. centerness (list): list of Tensor, which is predicted
  131. centerness for all anchor points with shape [N, M, 1]
  132. tag_labels (list): list of Tensor, which is category
  133. targets for each anchor point
  134. tag_bboxes (list): list of Tensor, which is bounding
  135. boxes targets for positive samples
  136. tag_center (list): list of Tensor, which is centerness
  137. targets for positive samples
  138. Return:
  139. loss (dict): loss composed by classification loss, bounding box
  140. """
  141. cls_logits_flatten_list = []
  142. bboxes_reg_flatten_list = []
  143. centerness_flatten_list = []
  144. tag_labels_flatten_list = []
  145. tag_bboxes_flatten_list = []
  146. tag_center_flatten_list = []
  147. num_lvl = len(cls_logits)
  148. for lvl in range(num_lvl):
  149. cls_logits_flatten_list.append(
  150. flatten_tensor(cls_logits[lvl], True))
  151. bboxes_reg_flatten_list.append(
  152. flatten_tensor(bboxes_reg[lvl], True))
  153. centerness_flatten_list.append(
  154. flatten_tensor(centerness[lvl], True))
  155. tag_labels_flatten_list.append(
  156. flatten_tensor(tag_labels[lvl], False))
  157. tag_bboxes_flatten_list.append(
  158. flatten_tensor(tag_bboxes[lvl], False))
  159. tag_center_flatten_list.append(
  160. flatten_tensor(tag_center[lvl], False))
  161. cls_logits_flatten = paddle.concat(cls_logits_flatten_list, axis=0)
  162. bboxes_reg_flatten = paddle.concat(bboxes_reg_flatten_list, axis=0)
  163. centerness_flatten = paddle.concat(centerness_flatten_list, axis=0)
  164. tag_labels_flatten = paddle.concat(tag_labels_flatten_list, axis=0)
  165. tag_bboxes_flatten = paddle.concat(tag_bboxes_flatten_list, axis=0)
  166. tag_center_flatten = paddle.concat(tag_center_flatten_list, axis=0)
  167. tag_labels_flatten.stop_gradient = True
  168. tag_bboxes_flatten.stop_gradient = True
  169. tag_center_flatten.stop_gradient = True
  170. mask_positive_bool = tag_labels_flatten > 0
  171. mask_positive_bool.stop_gradient = True
  172. mask_positive_float = paddle.cast(mask_positive_bool, dtype="float32")
  173. mask_positive_float.stop_gradient = True
  174. num_positive_fp32 = paddle.sum(mask_positive_float)
  175. num_positive_fp32.stop_gradient = True
  176. num_positive_int32 = paddle.cast(num_positive_fp32, dtype="int32")
  177. num_positive_int32 = num_positive_int32 * 0 + 1
  178. num_positive_int32.stop_gradient = True
  179. normalize_sum = paddle.sum(tag_center_flatten * mask_positive_float)
  180. normalize_sum.stop_gradient = True
  181. # 1. cls_logits: sigmoid_focal_loss
  182. # expand onehot labels
  183. num_classes = cls_logits_flatten.shape[-1]
  184. tag_labels_flatten = paddle.squeeze(tag_labels_flatten, axis=-1)
  185. tag_labels_flatten_bin = F.one_hot(
  186. tag_labels_flatten, num_classes=1 + num_classes)
  187. tag_labels_flatten_bin = tag_labels_flatten_bin[:, 1:]
  188. # sigmoid_focal_loss
  189. cls_loss = F.sigmoid_focal_loss(
  190. cls_logits_flatten, tag_labels_flatten_bin) / num_positive_fp32
  191. if self.quality == 'centerness':
  192. # 2. bboxes_reg: giou_loss
  193. mask_positive_float = paddle.squeeze(mask_positive_float, axis=-1)
  194. tag_center_flatten = paddle.squeeze(tag_center_flatten, axis=-1)
  195. reg_loss = self.__iou_loss(
  196. bboxes_reg_flatten,
  197. tag_bboxes_flatten,
  198. mask_positive_float,
  199. weights=tag_center_flatten)
  200. reg_loss = reg_loss * mask_positive_float / normalize_sum
  201. # 3. centerness: sigmoid_cross_entropy_with_logits_loss
  202. centerness_flatten = paddle.squeeze(centerness_flatten, axis=-1)
  203. quality_loss = ops.sigmoid_cross_entropy_with_logits(
  204. centerness_flatten, tag_center_flatten)
  205. quality_loss = quality_loss * mask_positive_float / num_positive_fp32
  206. elif self.quality == 'iou':
  207. # 2. bboxes_reg: giou_loss
  208. mask_positive_float = paddle.squeeze(mask_positive_float, axis=-1)
  209. tag_center_flatten = paddle.squeeze(tag_center_flatten, axis=-1)
  210. reg_loss = self.__iou_loss(
  211. bboxes_reg_flatten,
  212. tag_bboxes_flatten,
  213. mask_positive_float,
  214. weights=None)
  215. reg_loss = reg_loss * mask_positive_float / num_positive_fp32
  216. # num_positive_fp32 is num_foreground
  217. # 3. centerness: sigmoid_cross_entropy_with_logits_loss
  218. centerness_flatten = paddle.squeeze(centerness_flatten, axis=-1)
  219. gt_ious = self.__iou_loss(
  220. bboxes_reg_flatten,
  221. tag_bboxes_flatten,
  222. mask_positive_float,
  223. weights=None,
  224. return_iou=True)
  225. quality_loss = ops.sigmoid_cross_entropy_with_logits(
  226. centerness_flatten, gt_ious)
  227. quality_loss = quality_loss * mask_positive_float / num_positive_fp32
  228. else:
  229. raise Exception(f'Unknown quality type: {self.quality}')
  230. loss_all = {
  231. "loss_cls": paddle.sum(cls_loss),
  232. "loss_box": paddle.sum(reg_loss),
  233. "loss_quality": paddle.sum(quality_loss),
  234. }
  235. return loss_all