123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import paddle
- import paddle.nn.functional as F
- from ppdet.core.workspace import register, serializable
- __all__ = ['SOLOv2Loss']
- @register
- @serializable
- class SOLOv2Loss(object):
- """
- SOLOv2Loss
- Args:
- ins_loss_weight (float): Weight of instance loss.
- focal_loss_gamma (float): Gamma parameter for focal loss.
- focal_loss_alpha (float): Alpha parameter for focal loss.
- """
- def __init__(self,
- ins_loss_weight=3.0,
- focal_loss_gamma=2.0,
- focal_loss_alpha=0.25):
- self.ins_loss_weight = ins_loss_weight
- self.focal_loss_gamma = focal_loss_gamma
- self.focal_loss_alpha = focal_loss_alpha
- def _dice_loss(self, input, target):
- input = paddle.reshape(input, shape=(paddle.shape(input)[0], -1))
- target = paddle.reshape(target, shape=(paddle.shape(target)[0], -1))
- a = paddle.sum(input * target, axis=1)
- b = paddle.sum(input * input, axis=1) + 0.001
- c = paddle.sum(target * target, axis=1) + 0.001
- d = (2 * a) / (b + c)
- return 1 - d
- def __call__(self, ins_pred_list, ins_label_list, cate_preds, cate_labels,
- num_ins):
- """
- Get loss of network of SOLOv2.
- Args:
- ins_pred_list (list): Variable list of instance branch output.
- ins_label_list (list): List of instance labels pre batch.
- cate_preds (list): Concat Variable list of categroy branch output.
- cate_labels (list): Concat list of categroy labels pre batch.
- num_ins (int): Number of positive samples in a mini-batch.
- Returns:
- loss_ins (Variable): The instance loss Variable of SOLOv2 network.
- loss_cate (Variable): The category loss Variable of SOLOv2 network.
- """
- #1. Ues dice_loss to calculate instance loss
- loss_ins = []
- total_weights = paddle.zeros(shape=[1], dtype='float32')
- for input, target in zip(ins_pred_list, ins_label_list):
- if input is None:
- continue
- target = paddle.cast(target, 'float32')
- target = paddle.reshape(
- target,
- shape=[-1, paddle.shape(input)[-2], paddle.shape(input)[-1]])
- weights = paddle.cast(
- paddle.sum(target, axis=[1, 2]) > 0, 'float32')
- input = F.sigmoid(input)
- dice_out = paddle.multiply(self._dice_loss(input, target), weights)
- total_weights += paddle.sum(weights)
- loss_ins.append(dice_out)
- loss_ins = paddle.sum(paddle.concat(loss_ins)) / total_weights
- loss_ins = loss_ins * self.ins_loss_weight
- #2. Ues sigmoid_focal_loss to calculate category loss
- # expand onehot labels
- num_classes = cate_preds.shape[-1]
- cate_labels_bin = F.one_hot(cate_labels, num_classes=num_classes + 1)
- cate_labels_bin = cate_labels_bin[:, 1:]
- loss_cate = F.sigmoid_focal_loss(
- cate_preds,
- label=cate_labels_bin,
- normalizer=num_ins + 1.,
- gamma=self.focal_loss_gamma,
- alpha=self.focal_loss_alpha)
- return loss_ins, loss_cate
|