cot_loss.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. import numpy as np
  21. from ppdet.core.workspace import register
  22. __all__ = ['COTLoss']
  23. @register
  24. class COTLoss(nn.Layer):
  25. __shared__ = ['num_classes']
  26. def __init__(self,
  27. num_classes=80,
  28. cot_scale=1,
  29. cot_lambda=1):
  30. super(COTLoss, self).__init__()
  31. self.cot_scale = cot_scale
  32. self.cot_lambda = cot_lambda
  33. self.num_classes = num_classes
  34. def forward(self, scores, targets, cot_relation):
  35. cls_name = 'loss_bbox_cls_cot'
  36. loss_bbox = {}
  37. tgt_labels, tgt_bboxes, tgt_gt_inds = targets
  38. tgt_labels = paddle.concat(tgt_labels) if len(
  39. tgt_labels) > 1 else tgt_labels[0]
  40. mask = (tgt_labels < self.num_classes)
  41. valid_inds = paddle.nonzero(tgt_labels >= 0).flatten()
  42. if valid_inds.shape[0] == 0:
  43. loss_bbox[cls_name] = paddle.zeros([1], dtype='float32')
  44. else:
  45. tgt_labels = tgt_labels.cast('int64')
  46. valid_cot_targets = []
  47. for i in range(tgt_labels.shape[0]):
  48. train_label = tgt_labels[i]
  49. if train_label < self.num_classes:
  50. valid_cot_targets.append(cot_relation[train_label])
  51. coco_targets = paddle.to_tensor(valid_cot_targets)
  52. coco_targets.stop_gradient = True
  53. coco_loss = - coco_targets * F.log_softmax(scores[mask][:, :-1] * self.cot_scale)
  54. loss_bbox[cls_name] = self.cot_lambda * paddle.mean(paddle.sum(coco_loss, axis=-1))
  55. return loss_bbox