123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- from ppdet.core.workspace import register
- from .iou_loss import GIoULoss
- from ..transformers import bbox_cxcywh_to_xyxy, sigmoid_focal_loss
- __all__ = ['DETRLoss', 'DINOLoss']
- @register
- class DETRLoss(nn.Layer):
- __shared__ = ['num_classes', 'use_focal_loss']
- __inject__ = ['matcher']
- def __init__(self,
- num_classes=80,
- matcher='HungarianMatcher',
- loss_coeff={
- 'class': 1,
- 'bbox': 5,
- 'giou': 2,
- 'no_object': 0.1,
- 'mask': 1,
- 'dice': 1
- },
- aux_loss=True,
- use_focal_loss=False):
- r"""
- Args:
- num_classes (int): The number of classes.
- matcher (HungarianMatcher): It computes an assignment between the targets
- and the predictions of the network.
- loss_coeff (dict): The coefficient of loss.
- aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
- use_focal_loss (bool): Use focal loss or not.
- """
- super(DETRLoss, self).__init__()
- self.num_classes = num_classes
- self.matcher = matcher
- self.loss_coeff = loss_coeff
- self.aux_loss = aux_loss
- self.use_focal_loss = use_focal_loss
- if not self.use_focal_loss:
- self.loss_coeff['class'] = paddle.full([num_classes + 1],
- loss_coeff['class'])
- self.loss_coeff['class'][-1] = loss_coeff['no_object']
- self.giou_loss = GIoULoss()
- def _get_loss_class(self,
- logits,
- gt_class,
- match_indices,
- bg_index,
- num_gts,
- postfix=""):
- # logits: [b, query, num_classes], gt_class: list[[n, 1]]
- name_class = "loss_class" + postfix
- if logits is None:
- return {name_class: paddle.zeros([1])}
- target_label = paddle.full(logits.shape[:2], bg_index, dtype='int64')
- bs, num_query_objects = target_label.shape
- if sum(len(a) for a in gt_class) > 0:
- index, updates = self._get_index_updates(num_query_objects,
- gt_class, match_indices)
- target_label = paddle.scatter(
- target_label.reshape([-1, 1]), index, updates.astype('int64'))
- target_label = target_label.reshape([bs, num_query_objects])
- if self.use_focal_loss:
- target_label = F.one_hot(target_label,
- self.num_classes + 1)[..., :-1]
- return {
- name_class: self.loss_coeff['class'] * sigmoid_focal_loss(
- logits, target_label, num_gts / num_query_objects)
- if self.use_focal_loss else F.cross_entropy(
- logits, target_label, weight=self.loss_coeff['class'])
- }
- def _get_loss_bbox(self, boxes, gt_bbox, match_indices, num_gts,
- postfix=""):
- # boxes: [b, query, 4], gt_bbox: list[[n, 4]]
- name_bbox = "loss_bbox" + postfix
- name_giou = "loss_giou" + postfix
- if boxes is None:
- return {name_bbox: paddle.zeros([1]), name_giou: paddle.zeros([1])}
- loss = dict()
- if sum(len(a) for a in gt_bbox) == 0:
- loss[name_bbox] = paddle.to_tensor([0.])
- loss[name_giou] = paddle.to_tensor([0.])
- return loss
- src_bbox, target_bbox = self._get_src_target_assign(boxes, gt_bbox,
- match_indices)
- loss[name_bbox] = self.loss_coeff['bbox'] * F.l1_loss(
- src_bbox, target_bbox, reduction='sum') / num_gts
- loss[name_giou] = self.giou_loss(
- bbox_cxcywh_to_xyxy(src_bbox), bbox_cxcywh_to_xyxy(target_bbox))
- loss[name_giou] = loss[name_giou].sum() / num_gts
- loss[name_giou] = self.loss_coeff['giou'] * loss[name_giou]
- return loss
- def _get_loss_mask(self, masks, gt_mask, match_indices, num_gts,
- postfix=""):
- # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
- name_mask = "loss_mask" + postfix
- name_dice = "loss_dice" + postfix
- if masks is None:
- return {name_mask: paddle.zeros([1]), name_dice: paddle.zeros([1])}
- loss = dict()
- if sum(len(a) for a in gt_mask) == 0:
- loss[name_mask] = paddle.to_tensor([0.])
- loss[name_dice] = paddle.to_tensor([0.])
- return loss
- src_masks, target_masks = self._get_src_target_assign(masks, gt_mask,
- match_indices)
- src_masks = F.interpolate(
- src_masks.unsqueeze(0),
- size=target_masks.shape[-2:],
- mode="bilinear")[0]
- loss[name_mask] = self.loss_coeff['mask'] * F.sigmoid_focal_loss(
- src_masks,
- target_masks,
- paddle.to_tensor(
- [num_gts], dtype='float32'))
- loss[name_dice] = self.loss_coeff['dice'] * self._dice_loss(
- src_masks, target_masks, num_gts)
- return loss
- def _dice_loss(self, inputs, targets, num_gts):
- inputs = F.sigmoid(inputs)
- inputs = inputs.flatten(1)
- targets = targets.flatten(1)
- numerator = 2 * (inputs * targets).sum(1)
- denominator = inputs.sum(-1) + targets.sum(-1)
- loss = 1 - (numerator + 1) / (denominator + 1)
- return loss.sum() / num_gts
- def _get_loss_aux(self,
- boxes,
- logits,
- gt_bbox,
- gt_class,
- bg_index,
- num_gts,
- match_indices=None,
- postfix=""):
- if boxes is None and logits is None:
- return {
- "loss_class_aux" + postfix: paddle.paddle.zeros([1]),
- "loss_bbox_aux" + postfix: paddle.paddle.zeros([1]),
- "loss_giou_aux" + postfix: paddle.paddle.zeros([1])
- }
- loss_class = []
- loss_bbox = []
- loss_giou = []
- for aux_boxes, aux_logits in zip(boxes, logits):
- if match_indices is None:
- match_indices = self.matcher(aux_boxes, aux_logits, gt_bbox,
- gt_class)
- loss_class.append(
- self._get_loss_class(aux_logits, gt_class, match_indices,
- bg_index, num_gts, postfix)['loss_class' +
- postfix])
- loss_ = self._get_loss_bbox(aux_boxes, gt_bbox, match_indices,
- num_gts, postfix)
- loss_bbox.append(loss_['loss_bbox' + postfix])
- loss_giou.append(loss_['loss_giou' + postfix])
- loss = {
- "loss_class_aux" + postfix: paddle.add_n(loss_class),
- "loss_bbox_aux" + postfix: paddle.add_n(loss_bbox),
- "loss_giou_aux" + postfix: paddle.add_n(loss_giou)
- }
- return loss
- def _get_index_updates(self, num_query_objects, target, match_indices):
- batch_idx = paddle.concat([
- paddle.full_like(src, i) for i, (src, _) in enumerate(match_indices)
- ])
- src_idx = paddle.concat([src for (src, _) in match_indices])
- src_idx += (batch_idx * num_query_objects)
- target_assign = paddle.concat([
- paddle.gather(
- t, dst, axis=0) for t, (_, dst) in zip(target, match_indices)
- ])
- return src_idx, target_assign
- def _get_src_target_assign(self, src, target, match_indices):
- src_assign = paddle.concat([
- paddle.gather(
- t, I, axis=0) if len(I) > 0 else paddle.zeros([0, t.shape[-1]])
- for t, (I, _) in zip(src, match_indices)
- ])
- target_assign = paddle.concat([
- paddle.gather(
- t, J, axis=0) if len(J) > 0 else paddle.zeros([0, t.shape[-1]])
- for t, (_, J) in zip(target, match_indices)
- ])
- return src_assign, target_assign
- def forward(self,
- boxes,
- logits,
- gt_bbox,
- gt_class,
- masks=None,
- gt_mask=None,
- postfix="",
- **kwargs):
- r"""
- Args:
- boxes (Tensor|None): [l, b, query, 4]
- logits (Tensor|None): [l, b, query, num_classes]
- gt_bbox (List(Tensor)): list[[n, 4]]
- gt_class (List(Tensor)): list[[n, 1]]
- masks (Tensor, optional): [b, query, h, w]
- gt_mask (List(Tensor), optional): list[[n, H, W]]
- postfix (str): postfix of loss name
- """
- if "match_indices" in kwargs:
- match_indices = kwargs["match_indices"]
- else:
- match_indices = self.matcher(boxes[-1].detach(),
- logits[-1].detach(), gt_bbox, gt_class)
- num_gts = sum(len(a) for a in gt_bbox)
- num_gts = paddle.to_tensor([num_gts], dtype="float32")
- if paddle.distributed.get_world_size() > 1:
- paddle.distributed.all_reduce(num_gts)
- num_gts /= paddle.distributed.get_world_size()
- num_gts = paddle.clip(num_gts, min=1.) * kwargs.get("dn_num_group", 1.)
- total_loss = dict()
- total_loss.update(
- self._get_loss_class(logits[
- -1] if logits is not None else None, gt_class, match_indices,
- self.num_classes, num_gts, postfix))
- total_loss.update(
- self._get_loss_bbox(boxes[-1] if boxes is not None else None,
- gt_bbox, match_indices, num_gts, postfix))
- if masks is not None and gt_mask is not None:
- total_loss.update(
- self._get_loss_mask(masks if masks is not None else None,
- gt_mask, match_indices, num_gts, postfix))
- if self.aux_loss:
- if "match_indices" not in kwargs:
- match_indices = None
- total_loss.update(
- self._get_loss_aux(
- boxes[:-1] if boxes is not None else None, logits[:-1]
- if logits is not None else None, gt_bbox, gt_class,
- self.num_classes, num_gts, match_indices, postfix))
- return total_loss
- @register
- class DINOLoss(DETRLoss):
- def forward(self,
- boxes,
- logits,
- gt_bbox,
- gt_class,
- masks=None,
- gt_mask=None,
- postfix="",
- dn_out_bboxes=None,
- dn_out_logits=None,
- dn_meta=None,
- **kwargs):
- total_loss = super(DINOLoss, self).forward(boxes, logits, gt_bbox,
- gt_class)
- # denoising training loss
- if dn_meta is not None:
- dn_positive_idx, dn_num_group = \
- dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
- assert len(gt_class) == len(dn_positive_idx)
- # denoising match indices
- dn_match_indices = []
- for i in range(len(gt_class)):
- num_gt = len(gt_class[i])
- if num_gt > 0:
- gt_idx = paddle.arange(end=num_gt, dtype="int64")
- gt_idx = gt_idx.unsqueeze(0).tile(
- [dn_num_group, 1]).flatten()
- assert len(gt_idx) == len(dn_positive_idx[i])
- dn_match_indices.append((dn_positive_idx[i], gt_idx))
- else:
- dn_match_indices.append((paddle.zeros(
- [0], dtype="int64"), paddle.zeros(
- [0], dtype="int64")))
- else:
- dn_match_indices, dn_num_group = None, 1.
- dn_loss = super(DINOLoss, self).forward(
- dn_out_bboxes,
- dn_out_logits,
- gt_bbox,
- gt_class,
- postfix="_dn",
- match_indices=dn_match_indices,
- dn_num_group=dn_num_group)
- total_loss.update(dn_loss)
- return total_loss
|