123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218 |
- # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- This code is refer from:
- https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/DB_loss.py
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- from paddle import nn
- from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
- import paddle
- import paddle.nn.functional as F
- class CrossEntropyLoss(nn.Layer):
- """
- Implements the cross entropy loss function.
- Args:
- weight (tuple|list|ndarray|Tensor, optional): A manual rescaling weight
- given to each class. Its length must be equal to the number of classes.
- Default ``None``.
- ignore_index (int64, optional): Specifies a target value that is ignored
- and does not contribute to the input gradient. Default ``255``.
- top_k_percent_pixels (float, optional): the value lies in [0.0, 1.0].
- When its value < 1.0, only compute the loss for the top k percent pixels
- (e.g., the top 20% pixels). This is useful for hard pixel mining. Default ``1.0``.
- data_format (str, optional): The tensor format to use, 'NCHW' or 'NHWC'. Default ``'NCHW'``.
- """
- def __init__(self,
- weight=None,
- ignore_index=255,
- top_k_percent_pixels=1.0,
- data_format='NCHW'):
- super(CrossEntropyLoss, self).__init__()
- self.ignore_index = ignore_index
- self.top_k_percent_pixels = top_k_percent_pixels
- self.EPS = 1e-8
- self.data_format = data_format
- if weight is not None:
- self.weight = paddle.to_tensor(weight, dtype='float32')
- else:
- self.weight = None
- def forward(self, logit, label, semantic_weights=None):
- """
- Forward computation.
- Args:
- logit (Tensor): Logit tensor, the data type is float32, float64. Shape is
- (N, C), where C is number of classes, and if shape is more than 2D, this
- is (N, C, D1, D2,..., Dk), k >= 1.
- label (Tensor): Label tensor, the data type is int64. Shape is (N), where each
- value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
- (N, D1, D2,..., Dk), k >= 1.
- semantic_weights (Tensor, optional): Weights about loss for each pixels,
- shape is the same as label. Default: None.
- Returns:
- (Tensor): The average loss.
- """
- channel_axis = 1 if self.data_format == 'NCHW' else -1
- if self.weight is not None and logit.shape[channel_axis] != len(
- self.weight):
- raise ValueError(
- 'The number of weights = {} must be the same as the number of classes = {}.'
- .format(len(self.weight), logit.shape[channel_axis]))
- if channel_axis == 1:
- logit = paddle.transpose(logit, [0, 2, 3, 1])
- label = label.astype('int64')
- # In F.cross_entropy, the ignore_index is invalid, which needs to be fixed.
- # When there is 255 in the label and paddle version <= 2.1.3, the cross_entropy OP will report an error, which is fixed in paddle develop version.
- loss = F.cross_entropy(
- logit,
- label,
- ignore_index=self.ignore_index,
- reduction='none',
- weight=self.weight)
- return self._post_process_loss(logit, label, semantic_weights, loss)
- def _post_process_loss(self, logit, label, semantic_weights, loss):
- """
- Consider mask and top_k to calculate the final loss.
- Args:
- logit (Tensor): Logit tensor, the data type is float32, float64. Shape is
- (N, C), where C is number of classes, and if shape is more than 2D, this
- is (N, C, D1, D2,..., Dk), k >= 1.
- label (Tensor): Label tensor, the data type is int64. Shape is (N), where each
- value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
- (N, D1, D2,..., Dk), k >= 1.
- semantic_weights (Tensor, optional): Weights about loss for each pixels,
- shape is the same as label.
- loss (Tensor): Loss tensor which is the output of cross_entropy. If soft_label
- is False in cross_entropy, the shape of loss should be the same as the label.
- If soft_label is True in cross_entropy, the shape of loss should be
- (N, D1, D2,..., Dk, 1).
- Returns:
- (Tensor): The average loss.
- """
- mask = label != self.ignore_index
- mask = paddle.cast(mask, 'float32')
- label.stop_gradient = True
- mask.stop_gradient = True
- if loss.ndim > mask.ndim:
- loss = paddle.squeeze(loss, axis=-1)
- loss = loss * mask
- if semantic_weights is not None:
- loss = loss * semantic_weights
- if self.weight is not None:
- _one_hot = F.one_hot(label, logit.shape[-1])
- coef = paddle.sum(_one_hot * self.weight, axis=-1)
- else:
- coef = paddle.ones_like(label)
- if self.top_k_percent_pixels == 1.0:
- avg_loss = paddle.mean(loss) / (paddle.mean(mask * coef) + self.EPS)
- else:
- loss = loss.reshape((-1,))
- top_k_pixels = int(self.top_k_percent_pixels * loss.numel())
- loss, indices = paddle.topk(loss, top_k_pixels)
- coef = coef.reshape((-1,))
- coef = paddle.gather(coef, indices)
- coef.stop_gradient = True
- coef = coef.astype('float32')
- avg_loss = loss.mean() / (paddle.mean(coef) + self.EPS)
- return avg_loss
- class DBLoss(nn.Layer):
- """
- Differentiable Binarization (DB) Loss Function
- args:
- param (dict): the super paramter for DB Loss
- """
- def __init__(self,
- balance_loss=True,
- main_loss_type='DiceLoss',
- alpha=5,
- beta=10,
- ohem_ratio=3,
- eps=1e-6,
- num_classes=1,
- **kwargs):
- super(DBLoss, self).__init__()
- self.alpha = alpha
- self.beta = beta
- self.num_classes = num_classes
- self.dice_loss = DiceLoss(eps=eps)
- self.l1_loss = MaskL1Loss(eps=eps)
- self.bce_loss = BalanceLoss(
- balance_loss=balance_loss,
- main_loss_type=main_loss_type,
- negative_ratio=ohem_ratio)
- self.loss_func = CrossEntropyLoss()
- def forward(self, predicts, labels):
- predict_maps = predicts['maps']
- if self.num_classes > 1:
- predict_classes = predicts['classes']
- label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask, class_mask = labels[1:]
- else:
- label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[1:]
- shrink_maps = predict_maps[:, 0, :, :]
- threshold_maps = predict_maps[:, 1, :, :]
- binary_maps = predict_maps[:, 2, :, :]
- loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,
- label_shrink_mask)
- loss_threshold_maps = self.l1_loss(threshold_maps, label_threshold_map,
- label_threshold_mask)
- loss_binary_maps = self.dice_loss(binary_maps, label_shrink_map,
- label_shrink_mask)
- loss_shrink_maps = self.alpha * loss_shrink_maps
- loss_threshold_maps = self.beta * loss_threshold_maps
- # 处理
- if self.num_classes > 1:
- loss_classes = self.loss_func(predict_classes, class_mask)
- loss_all = loss_shrink_maps + loss_threshold_maps + loss_binary_maps + loss_classes
- losses = {'loss': loss_all,
- "loss_shrink_maps": loss_shrink_maps,
- "loss_threshold_maps": loss_threshold_maps,
- "loss_binary_maps": loss_binary_maps,
- "loss_classes": loss_classes}
- else:
- loss_all = loss_shrink_maps + loss_threshold_maps + loss_binary_maps
- losses = {'loss': loss_all,
- "loss_shrink_maps": loss_shrink_maps,
- "loss_threshold_maps": loss_threshold_maps,
- "loss_binary_maps": loss_binary_maps}
- return losses
|