|
@@ -24,6 +24,129 @@ from paddle import nn
|
|
|
|
|
|
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
|
|
|
|
|
|
+import paddle
|
|
|
+import paddle.nn.functional as F
|
|
|
+
|
|
|
+
|
|
|
+class CrossEntropyLoss(nn.Layer):
|
|
|
+ """
|
|
|
+ Implements the cross entropy loss function.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ weight (tuple|list|ndarray|Tensor, optional): A manual rescaling weight
|
|
|
+ given to each class. Its length must be equal to the number of classes.
|
|
|
+ Default ``None``.
|
|
|
+ ignore_index (int64, optional): Specifies a target value that is ignored
|
|
|
+ and does not contribute to the input gradient. Default ``255``.
|
|
|
+ top_k_percent_pixels (float, optional): the value lies in [0.0, 1.0].
|
|
|
+ When its value < 1.0, only compute the loss for the top k percent pixels
|
|
|
+ (e.g., the top 20% pixels). This is useful for hard pixel mining. Default ``1.0``.
|
|
|
+ data_format (str, optional): The tensor format to use, 'NCHW' or 'NHWC'. Default ``'NCHW'``.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self,
|
|
|
+ weight=None,
|
|
|
+ ignore_index=255,
|
|
|
+ top_k_percent_pixels=1.0,
|
|
|
+ data_format='NCHW'):
|
|
|
+ super(CrossEntropyLoss, self).__init__()
|
|
|
+ self.ignore_index = ignore_index
|
|
|
+ self.top_k_percent_pixels = top_k_percent_pixels
|
|
|
+ self.EPS = 1e-8
|
|
|
+ self.data_format = data_format
|
|
|
+ if weight is not None:
|
|
|
+ self.weight = paddle.to_tensor(weight, dtype='float32')
|
|
|
+ else:
|
|
|
+ self.weight = None
|
|
|
+
|
|
|
+ def forward(self, logit, label, semantic_weights=None):
|
|
|
+ """
|
|
|
+ Forward computation.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ logit (Tensor): Logit tensor, the data type is float32, float64. Shape is
|
|
|
+ (N, C), where C is number of classes, and if shape is more than 2D, this
|
|
|
+ is (N, C, D1, D2,..., Dk), k >= 1.
|
|
|
+ label (Tensor): Label tensor, the data type is int64. Shape is (N), where each
|
|
|
+ value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
|
|
|
+ (N, D1, D2,..., Dk), k >= 1.
|
|
|
+ semantic_weights (Tensor, optional): Weights about loss for each pixels,
|
|
|
+ shape is the same as label. Default: None.
|
|
|
+ Returns:
|
|
|
+ (Tensor): The average loss.
|
|
|
+ """
|
|
|
+ channel_axis = 1 if self.data_format == 'NCHW' else -1
|
|
|
+ if self.weight is not None and logit.shape[channel_axis] != len(
|
|
|
+ self.weight):
|
|
|
+ raise ValueError(
|
|
|
+ 'The number of weights = {} must be the same as the number of classes = {}.'
|
|
|
+ .format(len(self.weight), logit.shape[channel_axis]))
|
|
|
+
|
|
|
+ if channel_axis == 1:
|
|
|
+ logit = paddle.transpose(logit, [0, 2, 3, 1])
|
|
|
+ label = label.astype('int64')
|
|
|
+ # In F.cross_entropy, the ignore_index is invalid, which needs to be fixed.
|
|
|
+ # When there is 255 in the label and paddle version <= 2.1.3, the cross_entropy OP will report an error, which is fixed in paddle develop version.
|
|
|
+ loss = F.cross_entropy(
|
|
|
+ logit,
|
|
|
+ label,
|
|
|
+ ignore_index=self.ignore_index,
|
|
|
+ reduction='none',
|
|
|
+ weight=self.weight)
|
|
|
+
|
|
|
+ return self._post_process_loss(logit, label, semantic_weights, loss)
|
|
|
+
|
|
|
+ def _post_process_loss(self, logit, label, semantic_weights, loss):
|
|
|
+ """
|
|
|
+ Consider mask and top_k to calculate the final loss.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ logit (Tensor): Logit tensor, the data type is float32, float64. Shape is
|
|
|
+ (N, C), where C is number of classes, and if shape is more than 2D, this
|
|
|
+ is (N, C, D1, D2,..., Dk), k >= 1.
|
|
|
+ label (Tensor): Label tensor, the data type is int64. Shape is (N), where each
|
|
|
+ value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
|
|
|
+ (N, D1, D2,..., Dk), k >= 1.
|
|
|
+ semantic_weights (Tensor, optional): Weights about loss for each pixels,
|
|
|
+ shape is the same as label.
|
|
|
+ loss (Tensor): Loss tensor which is the output of cross_entropy. If soft_label
|
|
|
+ is False in cross_entropy, the shape of loss should be the same as the label.
|
|
|
+ If soft_label is True in cross_entropy, the shape of loss should be
|
|
|
+ (N, D1, D2,..., Dk, 1).
|
|
|
+ Returns:
|
|
|
+ (Tensor): The average loss.
|
|
|
+ """
|
|
|
+ mask = label != self.ignore_index
|
|
|
+ mask = paddle.cast(mask, 'float32')
|
|
|
+ label.stop_gradient = True
|
|
|
+ mask.stop_gradient = True
|
|
|
+
|
|
|
+ if loss.ndim > mask.ndim:
|
|
|
+ loss = paddle.squeeze(loss, axis=-1)
|
|
|
+ loss = loss * mask
|
|
|
+ if semantic_weights is not None:
|
|
|
+ loss = loss * semantic_weights
|
|
|
+
|
|
|
+ if self.weight is not None:
|
|
|
+ _one_hot = F.one_hot(label, logit.shape[-1])
|
|
|
+ coef = paddle.sum(_one_hot * self.weight, axis=-1)
|
|
|
+ else:
|
|
|
+ coef = paddle.ones_like(label)
|
|
|
+
|
|
|
+ if self.top_k_percent_pixels == 1.0:
|
|
|
+ avg_loss = paddle.mean(loss) / (paddle.mean(mask * coef) + self.EPS)
|
|
|
+ else:
|
|
|
+ loss = loss.reshape((-1,))
|
|
|
+ top_k_pixels = int(self.top_k_percent_pixels * loss.numel())
|
|
|
+ loss, indices = paddle.topk(loss, top_k_pixels)
|
|
|
+ coef = coef.reshape((-1,))
|
|
|
+ coef = paddle.gather(coef, indices)
|
|
|
+ coef.stop_gradient = True
|
|
|
+ coef = coef.astype('float32')
|
|
|
+ avg_loss = loss.mean() / (paddle.mean(coef) + self.EPS)
|
|
|
+
|
|
|
+ return avg_loss
|
|
|
+
|
|
|
|
|
|
class DBLoss(nn.Layer):
|
|
|
"""
|
|
@@ -39,21 +162,27 @@ class DBLoss(nn.Layer):
|
|
|
beta=10,
|
|
|
ohem_ratio=3,
|
|
|
eps=1e-6,
|
|
|
+ num_classes=1,
|
|
|
**kwargs):
|
|
|
super(DBLoss, self).__init__()
|
|
|
self.alpha = alpha
|
|
|
self.beta = beta
|
|
|
+ self.num_classes = num_classes
|
|
|
self.dice_loss = DiceLoss(eps=eps)
|
|
|
self.l1_loss = MaskL1Loss(eps=eps)
|
|
|
self.bce_loss = BalanceLoss(
|
|
|
balance_loss=balance_loss,
|
|
|
main_loss_type=main_loss_type,
|
|
|
negative_ratio=ohem_ratio)
|
|
|
+ self.loss_func = CrossEntropyLoss()
|
|
|
|
|
|
def forward(self, predicts, labels):
|
|
|
predict_maps = predicts['maps']
|
|
|
- label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[
|
|
|
- 1:]
|
|
|
+ if self.num_classes > 1:
|
|
|
+ predict_classes = predicts['classes']
|
|
|
+ label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask, class_mask = labels[1:]
|
|
|
+ else:
|
|
|
+ label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[1:]
|
|
|
shrink_maps = predict_maps[:, 0, :, :]
|
|
|
threshold_maps = predict_maps[:, 1, :, :]
|
|
|
binary_maps = predict_maps[:, 2, :, :]
|
|
@@ -67,10 +196,23 @@ class DBLoss(nn.Layer):
|
|
|
loss_shrink_maps = self.alpha * loss_shrink_maps
|
|
|
loss_threshold_maps = self.beta * loss_threshold_maps
|
|
|
|
|
|
- loss_all = loss_shrink_maps + loss_threshold_maps \
|
|
|
- + loss_binary_maps
|
|
|
- losses = {'loss': loss_all, \
|
|
|
- "loss_shrink_maps": loss_shrink_maps, \
|
|
|
- "loss_threshold_maps": loss_threshold_maps, \
|
|
|
- "loss_binary_maps": loss_binary_maps}
|
|
|
+
|
|
|
+ # 处理
|
|
|
+ if self.num_classes > 1:
|
|
|
+ loss_classes = self.loss_func(predict_classes, class_mask)
|
|
|
+
|
|
|
+ loss_all = loss_shrink_maps + loss_threshold_maps + loss_binary_maps + loss_classes
|
|
|
+
|
|
|
+ losses = {'loss': loss_all,
|
|
|
+ "loss_shrink_maps": loss_shrink_maps,
|
|
|
+ "loss_threshold_maps": loss_threshold_maps,
|
|
|
+ "loss_binary_maps": loss_binary_maps,
|
|
|
+ "loss_classes": loss_classes}
|
|
|
+ else:
|
|
|
+ loss_all = loss_shrink_maps + loss_threshold_maps + loss_binary_maps
|
|
|
+
|
|
|
+ losses = {'loss': loss_all,
|
|
|
+ "loss_shrink_maps": loss_shrink_maps,
|
|
|
+ "loss_threshold_maps": loss_threshold_maps,
|
|
|
+ "loss_binary_maps": loss_binary_maps}
|
|
|
return losses
|