ssod_utils.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn.functional as F
  16. def align_weak_strong_shape(data_weak, data_strong):
  17. max_shape_x = max(data_strong['image'].shape[2],
  18. data_weak['image'].shape[2])
  19. max_shape_y = max(data_strong['image'].shape[3],
  20. data_weak['image'].shape[3])
  21. scale_x_s = max_shape_x / data_strong['image'].shape[2]
  22. scale_y_s = max_shape_y / data_strong['image'].shape[3]
  23. scale_x_w = max_shape_x / data_weak['image'].shape[2]
  24. scale_y_w = max_shape_y / data_weak['image'].shape[3]
  25. target_size = [max_shape_x, max_shape_y]
  26. if scale_x_s != 1 or scale_y_s != 1:
  27. data_strong['image'] = F.interpolate(
  28. data_strong['image'],
  29. size=target_size,
  30. mode='bilinear',
  31. align_corners=False)
  32. if 'gt_bbox' in data_strong:
  33. gt_bboxes = data_strong['gt_bbox']
  34. for i in range(len(gt_bboxes)):
  35. if len(gt_bboxes[i]) > 0:
  36. gt_bboxes[i][:, 0::2] = gt_bboxes[i][:, 0::2] * scale_x_s
  37. gt_bboxes[i][:, 1::2] = gt_bboxes[i][:, 1::2] * scale_y_s
  38. data_strong['gt_bbox'] = gt_bboxes
  39. if scale_x_w != 1 or scale_y_w != 1:
  40. data_weak['image'] = F.interpolate(
  41. data_weak['image'],
  42. size=target_size,
  43. mode='bilinear',
  44. align_corners=False)
  45. if 'gt_bbox' in data_weak:
  46. gt_bboxes = data_weak['gt_bbox']
  47. for i in range(len(gt_bboxes)):
  48. if len(gt_bboxes[i]) > 0:
  49. gt_bboxes[i][:, 0::2] = gt_bboxes[i][:, 0::2] * scale_x_w
  50. gt_bboxes[i][:, 1::2] = gt_bboxes[i][:, 1::2] * scale_y_w
  51. data_weak['gt_bbox'] = gt_bboxes
  52. return data_weak, data_strong
  53. def permute_to_N_HWA_K(tensor, K):
  54. """
  55. Transpose/reshape a tensor from (N, (A x K), H, W) to (N, (HxWxA), K)
  56. """
  57. assert tensor.dim() == 4, tensor.shape
  58. N, _, H, W = tensor.shape
  59. tensor = tensor.reshape([N, -1, K, H, W]).transpose([0, 3, 4, 1, 2])
  60. tensor = tensor.reshape([N, -1, K])
  61. return tensor
  62. def QFLv2(pred_sigmoid,
  63. teacher_sigmoid,
  64. weight=None,
  65. beta=2.0,
  66. reduction='mean'):
  67. pt = pred_sigmoid
  68. zerolabel = paddle.zeros_like(pt)
  69. loss = F.binary_cross_entropy(
  70. pred_sigmoid, zerolabel, reduction='none') * pt.pow(beta)
  71. pos = weight > 0
  72. pt = teacher_sigmoid[pos] - pred_sigmoid[pos]
  73. loss[pos] = F.binary_cross_entropy(
  74. pred_sigmoid[pos], teacher_sigmoid[pos],
  75. reduction='none') * pt.pow(beta)
  76. valid = weight >= 0
  77. if reduction == "mean":
  78. loss = loss[valid].mean()
  79. elif reduction == "sum":
  80. loss = loss[valid].sum()
  81. return loss