detr_loss.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from ppdet.core.workspace import register
  21. from .iou_loss import GIoULoss
  22. from ..transformers import bbox_cxcywh_to_xyxy, sigmoid_focal_loss
  23. __all__ = ['DETRLoss', 'DINOLoss']
  24. @register
  25. class DETRLoss(nn.Layer):
  26. __shared__ = ['num_classes', 'use_focal_loss']
  27. __inject__ = ['matcher']
  28. def __init__(self,
  29. num_classes=80,
  30. matcher='HungarianMatcher',
  31. loss_coeff={
  32. 'class': 1,
  33. 'bbox': 5,
  34. 'giou': 2,
  35. 'no_object': 0.1,
  36. 'mask': 1,
  37. 'dice': 1
  38. },
  39. aux_loss=True,
  40. use_focal_loss=False):
  41. r"""
  42. Args:
  43. num_classes (int): The number of classes.
  44. matcher (HungarianMatcher): It computes an assignment between the targets
  45. and the predictions of the network.
  46. loss_coeff (dict): The coefficient of loss.
  47. aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
  48. use_focal_loss (bool): Use focal loss or not.
  49. """
  50. super(DETRLoss, self).__init__()
  51. self.num_classes = num_classes
  52. self.matcher = matcher
  53. self.loss_coeff = loss_coeff
  54. self.aux_loss = aux_loss
  55. self.use_focal_loss = use_focal_loss
  56. if not self.use_focal_loss:
  57. self.loss_coeff['class'] = paddle.full([num_classes + 1],
  58. loss_coeff['class'])
  59. self.loss_coeff['class'][-1] = loss_coeff['no_object']
  60. self.giou_loss = GIoULoss()
  61. def _get_loss_class(self,
  62. logits,
  63. gt_class,
  64. match_indices,
  65. bg_index,
  66. num_gts,
  67. postfix=""):
  68. # logits: [b, query, num_classes], gt_class: list[[n, 1]]
  69. name_class = "loss_class" + postfix
  70. if logits is None:
  71. return {name_class: paddle.zeros([1])}
  72. target_label = paddle.full(logits.shape[:2], bg_index, dtype='int64')
  73. bs, num_query_objects = target_label.shape
  74. if sum(len(a) for a in gt_class) > 0:
  75. index, updates = self._get_index_updates(num_query_objects,
  76. gt_class, match_indices)
  77. target_label = paddle.scatter(
  78. target_label.reshape([-1, 1]), index, updates.astype('int64'))
  79. target_label = target_label.reshape([bs, num_query_objects])
  80. if self.use_focal_loss:
  81. target_label = F.one_hot(target_label,
  82. self.num_classes + 1)[..., :-1]
  83. return {
  84. name_class: self.loss_coeff['class'] * sigmoid_focal_loss(
  85. logits, target_label, num_gts / num_query_objects)
  86. if self.use_focal_loss else F.cross_entropy(
  87. logits, target_label, weight=self.loss_coeff['class'])
  88. }
  89. def _get_loss_bbox(self, boxes, gt_bbox, match_indices, num_gts,
  90. postfix=""):
  91. # boxes: [b, query, 4], gt_bbox: list[[n, 4]]
  92. name_bbox = "loss_bbox" + postfix
  93. name_giou = "loss_giou" + postfix
  94. if boxes is None:
  95. return {name_bbox: paddle.zeros([1]), name_giou: paddle.zeros([1])}
  96. loss = dict()
  97. if sum(len(a) for a in gt_bbox) == 0:
  98. loss[name_bbox] = paddle.to_tensor([0.])
  99. loss[name_giou] = paddle.to_tensor([0.])
  100. return loss
  101. src_bbox, target_bbox = self._get_src_target_assign(boxes, gt_bbox,
  102. match_indices)
  103. loss[name_bbox] = self.loss_coeff['bbox'] * F.l1_loss(
  104. src_bbox, target_bbox, reduction='sum') / num_gts
  105. loss[name_giou] = self.giou_loss(
  106. bbox_cxcywh_to_xyxy(src_bbox), bbox_cxcywh_to_xyxy(target_bbox))
  107. loss[name_giou] = loss[name_giou].sum() / num_gts
  108. loss[name_giou] = self.loss_coeff['giou'] * loss[name_giou]
  109. return loss
  110. def _get_loss_mask(self, masks, gt_mask, match_indices, num_gts,
  111. postfix=""):
  112. # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
  113. name_mask = "loss_mask" + postfix
  114. name_dice = "loss_dice" + postfix
  115. if masks is None:
  116. return {name_mask: paddle.zeros([1]), name_dice: paddle.zeros([1])}
  117. loss = dict()
  118. if sum(len(a) for a in gt_mask) == 0:
  119. loss[name_mask] = paddle.to_tensor([0.])
  120. loss[name_dice] = paddle.to_tensor([0.])
  121. return loss
  122. src_masks, target_masks = self._get_src_target_assign(masks, gt_mask,
  123. match_indices)
  124. src_masks = F.interpolate(
  125. src_masks.unsqueeze(0),
  126. size=target_masks.shape[-2:],
  127. mode="bilinear")[0]
  128. loss[name_mask] = self.loss_coeff['mask'] * F.sigmoid_focal_loss(
  129. src_masks,
  130. target_masks,
  131. paddle.to_tensor(
  132. [num_gts], dtype='float32'))
  133. loss[name_dice] = self.loss_coeff['dice'] * self._dice_loss(
  134. src_masks, target_masks, num_gts)
  135. return loss
  136. def _dice_loss(self, inputs, targets, num_gts):
  137. inputs = F.sigmoid(inputs)
  138. inputs = inputs.flatten(1)
  139. targets = targets.flatten(1)
  140. numerator = 2 * (inputs * targets).sum(1)
  141. denominator = inputs.sum(-1) + targets.sum(-1)
  142. loss = 1 - (numerator + 1) / (denominator + 1)
  143. return loss.sum() / num_gts
  144. def _get_loss_aux(self,
  145. boxes,
  146. logits,
  147. gt_bbox,
  148. gt_class,
  149. bg_index,
  150. num_gts,
  151. match_indices=None,
  152. postfix=""):
  153. if boxes is None and logits is None:
  154. return {
  155. "loss_class_aux" + postfix: paddle.paddle.zeros([1]),
  156. "loss_bbox_aux" + postfix: paddle.paddle.zeros([1]),
  157. "loss_giou_aux" + postfix: paddle.paddle.zeros([1])
  158. }
  159. loss_class = []
  160. loss_bbox = []
  161. loss_giou = []
  162. for aux_boxes, aux_logits in zip(boxes, logits):
  163. if match_indices is None:
  164. match_indices = self.matcher(aux_boxes, aux_logits, gt_bbox,
  165. gt_class)
  166. loss_class.append(
  167. self._get_loss_class(aux_logits, gt_class, match_indices,
  168. bg_index, num_gts, postfix)['loss_class' +
  169. postfix])
  170. loss_ = self._get_loss_bbox(aux_boxes, gt_bbox, match_indices,
  171. num_gts, postfix)
  172. loss_bbox.append(loss_['loss_bbox' + postfix])
  173. loss_giou.append(loss_['loss_giou' + postfix])
  174. loss = {
  175. "loss_class_aux" + postfix: paddle.add_n(loss_class),
  176. "loss_bbox_aux" + postfix: paddle.add_n(loss_bbox),
  177. "loss_giou_aux" + postfix: paddle.add_n(loss_giou)
  178. }
  179. return loss
  180. def _get_index_updates(self, num_query_objects, target, match_indices):
  181. batch_idx = paddle.concat([
  182. paddle.full_like(src, i) for i, (src, _) in enumerate(match_indices)
  183. ])
  184. src_idx = paddle.concat([src for (src, _) in match_indices])
  185. src_idx += (batch_idx * num_query_objects)
  186. target_assign = paddle.concat([
  187. paddle.gather(
  188. t, dst, axis=0) for t, (_, dst) in zip(target, match_indices)
  189. ])
  190. return src_idx, target_assign
  191. def _get_src_target_assign(self, src, target, match_indices):
  192. src_assign = paddle.concat([
  193. paddle.gather(
  194. t, I, axis=0) if len(I) > 0 else paddle.zeros([0, t.shape[-1]])
  195. for t, (I, _) in zip(src, match_indices)
  196. ])
  197. target_assign = paddle.concat([
  198. paddle.gather(
  199. t, J, axis=0) if len(J) > 0 else paddle.zeros([0, t.shape[-1]])
  200. for t, (_, J) in zip(target, match_indices)
  201. ])
  202. return src_assign, target_assign
  203. def forward(self,
  204. boxes,
  205. logits,
  206. gt_bbox,
  207. gt_class,
  208. masks=None,
  209. gt_mask=None,
  210. postfix="",
  211. **kwargs):
  212. r"""
  213. Args:
  214. boxes (Tensor|None): [l, b, query, 4]
  215. logits (Tensor|None): [l, b, query, num_classes]
  216. gt_bbox (List(Tensor)): list[[n, 4]]
  217. gt_class (List(Tensor)): list[[n, 1]]
  218. masks (Tensor, optional): [b, query, h, w]
  219. gt_mask (List(Tensor), optional): list[[n, H, W]]
  220. postfix (str): postfix of loss name
  221. """
  222. if "match_indices" in kwargs:
  223. match_indices = kwargs["match_indices"]
  224. else:
  225. match_indices = self.matcher(boxes[-1].detach(),
  226. logits[-1].detach(), gt_bbox, gt_class)
  227. num_gts = sum(len(a) for a in gt_bbox)
  228. num_gts = paddle.to_tensor([num_gts], dtype="float32")
  229. if paddle.distributed.get_world_size() > 1:
  230. paddle.distributed.all_reduce(num_gts)
  231. num_gts /= paddle.distributed.get_world_size()
  232. num_gts = paddle.clip(num_gts, min=1.) * kwargs.get("dn_num_group", 1.)
  233. total_loss = dict()
  234. total_loss.update(
  235. self._get_loss_class(logits[
  236. -1] if logits is not None else None, gt_class, match_indices,
  237. self.num_classes, num_gts, postfix))
  238. total_loss.update(
  239. self._get_loss_bbox(boxes[-1] if boxes is not None else None,
  240. gt_bbox, match_indices, num_gts, postfix))
  241. if masks is not None and gt_mask is not None:
  242. total_loss.update(
  243. self._get_loss_mask(masks if masks is not None else None,
  244. gt_mask, match_indices, num_gts, postfix))
  245. if self.aux_loss:
  246. if "match_indices" not in kwargs:
  247. match_indices = None
  248. total_loss.update(
  249. self._get_loss_aux(
  250. boxes[:-1] if boxes is not None else None, logits[:-1]
  251. if logits is not None else None, gt_bbox, gt_class,
  252. self.num_classes, num_gts, match_indices, postfix))
  253. return total_loss
  254. @register
  255. class DINOLoss(DETRLoss):
  256. def forward(self,
  257. boxes,
  258. logits,
  259. gt_bbox,
  260. gt_class,
  261. masks=None,
  262. gt_mask=None,
  263. postfix="",
  264. dn_out_bboxes=None,
  265. dn_out_logits=None,
  266. dn_meta=None,
  267. **kwargs):
  268. total_loss = super(DINOLoss, self).forward(boxes, logits, gt_bbox,
  269. gt_class)
  270. # denoising training loss
  271. if dn_meta is not None:
  272. dn_positive_idx, dn_num_group = \
  273. dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
  274. assert len(gt_class) == len(dn_positive_idx)
  275. # denoising match indices
  276. dn_match_indices = []
  277. for i in range(len(gt_class)):
  278. num_gt = len(gt_class[i])
  279. if num_gt > 0:
  280. gt_idx = paddle.arange(end=num_gt, dtype="int64")
  281. gt_idx = gt_idx.unsqueeze(0).tile(
  282. [dn_num_group, 1]).flatten()
  283. assert len(gt_idx) == len(dn_positive_idx[i])
  284. dn_match_indices.append((dn_positive_idx[i], gt_idx))
  285. else:
  286. dn_match_indices.append((paddle.zeros(
  287. [0], dtype="int64"), paddle.zeros(
  288. [0], dtype="int64")))
  289. else:
  290. dn_match_indices, dn_num_group = None, 1.
  291. dn_loss = super(DINOLoss, self).forward(
  292. dn_out_bboxes,
  293. dn_out_logits,
  294. gt_bbox,
  295. gt_class,
  296. postfix="_dn",
  297. match_indices=dn_match_indices,
  298. dn_num_group=dn_num_group)
  299. total_loss.update(dn_loss)
  300. return total_loss