det_db_loss.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/DB_loss.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. from paddle import nn
  22. from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
  23. import paddle
  24. import paddle.nn.functional as F
  25. class CrossEntropyLoss(nn.Layer):
  26. """
  27. Implements the cross entropy loss function.
  28. Args:
  29. weight (tuple|list|ndarray|Tensor, optional): A manual rescaling weight
  30. given to each class. Its length must be equal to the number of classes.
  31. Default ``None``.
  32. ignore_index (int64, optional): Specifies a target value that is ignored
  33. and does not contribute to the input gradient. Default ``255``.
  34. top_k_percent_pixels (float, optional): the value lies in [0.0, 1.0].
  35. When its value < 1.0, only compute the loss for the top k percent pixels
  36. (e.g., the top 20% pixels). This is useful for hard pixel mining. Default ``1.0``.
  37. data_format (str, optional): The tensor format to use, 'NCHW' or 'NHWC'. Default ``'NCHW'``.
  38. """
  39. def __init__(self,
  40. weight=None,
  41. ignore_index=255,
  42. top_k_percent_pixels=1.0,
  43. data_format='NCHW'):
  44. super(CrossEntropyLoss, self).__init__()
  45. self.ignore_index = ignore_index
  46. self.top_k_percent_pixels = top_k_percent_pixels
  47. self.EPS = 1e-8
  48. self.data_format = data_format
  49. if weight is not None:
  50. self.weight = paddle.to_tensor(weight, dtype='float32')
  51. else:
  52. self.weight = None
  53. def forward(self, logit, label, semantic_weights=None):
  54. """
  55. Forward computation.
  56. Args:
  57. logit (Tensor): Logit tensor, the data type is float32, float64. Shape is
  58. (N, C), where C is number of classes, and if shape is more than 2D, this
  59. is (N, C, D1, D2,..., Dk), k >= 1.
  60. label (Tensor): Label tensor, the data type is int64. Shape is (N), where each
  61. value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
  62. (N, D1, D2,..., Dk), k >= 1.
  63. semantic_weights (Tensor, optional): Weights about loss for each pixels,
  64. shape is the same as label. Default: None.
  65. Returns:
  66. (Tensor): The average loss.
  67. """
  68. channel_axis = 1 if self.data_format == 'NCHW' else -1
  69. if self.weight is not None and logit.shape[channel_axis] != len(
  70. self.weight):
  71. raise ValueError(
  72. 'The number of weights = {} must be the same as the number of classes = {}.'
  73. .format(len(self.weight), logit.shape[channel_axis]))
  74. if channel_axis == 1:
  75. logit = paddle.transpose(logit, [0, 2, 3, 1])
  76. label = label.astype('int64')
  77. # In F.cross_entropy, the ignore_index is invalid, which needs to be fixed.
  78. # When there is 255 in the label and paddle version <= 2.1.3, the cross_entropy OP will report an error, which is fixed in paddle develop version.
  79. loss = F.cross_entropy(
  80. logit,
  81. label,
  82. ignore_index=self.ignore_index,
  83. reduction='none',
  84. weight=self.weight)
  85. return self._post_process_loss(logit, label, semantic_weights, loss)
  86. def _post_process_loss(self, logit, label, semantic_weights, loss):
  87. """
  88. Consider mask and top_k to calculate the final loss.
  89. Args:
  90. logit (Tensor): Logit tensor, the data type is float32, float64. Shape is
  91. (N, C), where C is number of classes, and if shape is more than 2D, this
  92. is (N, C, D1, D2,..., Dk), k >= 1.
  93. label (Tensor): Label tensor, the data type is int64. Shape is (N), where each
  94. value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
  95. (N, D1, D2,..., Dk), k >= 1.
  96. semantic_weights (Tensor, optional): Weights about loss for each pixels,
  97. shape is the same as label.
  98. loss (Tensor): Loss tensor which is the output of cross_entropy. If soft_label
  99. is False in cross_entropy, the shape of loss should be the same as the label.
  100. If soft_label is True in cross_entropy, the shape of loss should be
  101. (N, D1, D2,..., Dk, 1).
  102. Returns:
  103. (Tensor): The average loss.
  104. """
  105. mask = label != self.ignore_index
  106. mask = paddle.cast(mask, 'float32')
  107. label.stop_gradient = True
  108. mask.stop_gradient = True
  109. if loss.ndim > mask.ndim:
  110. loss = paddle.squeeze(loss, axis=-1)
  111. loss = loss * mask
  112. if semantic_weights is not None:
  113. loss = loss * semantic_weights
  114. if self.weight is not None:
  115. _one_hot = F.one_hot(label, logit.shape[-1])
  116. coef = paddle.sum(_one_hot * self.weight, axis=-1)
  117. else:
  118. coef = paddle.ones_like(label)
  119. if self.top_k_percent_pixels == 1.0:
  120. avg_loss = paddle.mean(loss) / (paddle.mean(mask * coef) + self.EPS)
  121. else:
  122. loss = loss.reshape((-1,))
  123. top_k_pixels = int(self.top_k_percent_pixels * loss.numel())
  124. loss, indices = paddle.topk(loss, top_k_pixels)
  125. coef = coef.reshape((-1,))
  126. coef = paddle.gather(coef, indices)
  127. coef.stop_gradient = True
  128. coef = coef.astype('float32')
  129. avg_loss = loss.mean() / (paddle.mean(coef) + self.EPS)
  130. return avg_loss
  131. class DBLoss(nn.Layer):
  132. """
  133. Differentiable Binarization (DB) Loss Function
  134. args:
  135. param (dict): the super paramter for DB Loss
  136. """
  137. def __init__(self,
  138. balance_loss=True,
  139. main_loss_type='DiceLoss',
  140. alpha=5,
  141. beta=10,
  142. ohem_ratio=3,
  143. eps=1e-6,
  144. num_classes=1,
  145. **kwargs):
  146. super(DBLoss, self).__init__()
  147. self.alpha = alpha
  148. self.beta = beta
  149. self.num_classes = num_classes
  150. self.dice_loss = DiceLoss(eps=eps)
  151. self.l1_loss = MaskL1Loss(eps=eps)
  152. self.bce_loss = BalanceLoss(
  153. balance_loss=balance_loss,
  154. main_loss_type=main_loss_type,
  155. negative_ratio=ohem_ratio)
  156. self.loss_func = CrossEntropyLoss()
  157. def forward(self, predicts, labels):
  158. predict_maps = predicts['maps']
  159. if self.num_classes > 1:
  160. predict_classes = predicts['classes']
  161. label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask, class_mask = labels[1:]
  162. else:
  163. label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[1:]
  164. shrink_maps = predict_maps[:, 0, :, :]
  165. threshold_maps = predict_maps[:, 1, :, :]
  166. binary_maps = predict_maps[:, 2, :, :]
  167. loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,
  168. label_shrink_mask)
  169. loss_threshold_maps = self.l1_loss(threshold_maps, label_threshold_map,
  170. label_threshold_mask)
  171. loss_binary_maps = self.dice_loss(binary_maps, label_shrink_map,
  172. label_shrink_mask)
  173. loss_shrink_maps = self.alpha * loss_shrink_maps
  174. loss_threshold_maps = self.beta * loss_threshold_maps
  175. # 处理
  176. if self.num_classes > 1:
  177. loss_classes = self.loss_func(predict_classes, class_mask)
  178. loss_all = loss_shrink_maps + loss_threshold_maps + loss_binary_maps + loss_classes
  179. losses = {'loss': loss_all,
  180. "loss_shrink_maps": loss_shrink_maps,
  181. "loss_threshold_maps": loss_threshold_maps,
  182. "loss_binary_maps": loss_binary_maps,
  183. "loss_classes": loss_classes}
  184. else:
  185. loss_all = loss_shrink_maps + loss_threshold_maps + loss_binary_maps
  186. losses = {'loss': loss_all,
  187. "loss_shrink_maps": loss_shrink_maps,
  188. "loss_threshold_maps": loss_threshold_maps,
  189. "loss_binary_maps": loss_binary_maps}
  190. return losses