utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # Modified from DETR (https://github.com/facebookresearch/detr)
  16. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  17. # Modified from detrex (https://github.com/IDEA-Research/detrex)
  18. # Copyright 2022 The IDEA Authors. All rights reserved.
  19. from __future__ import absolute_import
  20. from __future__ import division
  21. from __future__ import print_function
  22. import copy
  23. import math
  24. import paddle
  25. import paddle.nn as nn
  26. import paddle.nn.functional as F
  27. from ..bbox_utils import bbox_overlaps
  28. __all__ = [
  29. '_get_clones', 'bbox_overlaps', 'bbox_cxcywh_to_xyxy',
  30. 'bbox_xyxy_to_cxcywh', 'sigmoid_focal_loss', 'inverse_sigmoid',
  31. 'deformable_attention_core_func'
  32. ]
  33. def _get_clones(module, N):
  34. return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
  35. def bbox_cxcywh_to_xyxy(x):
  36. cxcy, wh = paddle.split(x, 2, axis=-1)
  37. return paddle.concat([cxcy - 0.5 * wh, cxcy + 0.5 * wh], axis=-1)
  38. def bbox_xyxy_to_cxcywh(x):
  39. x1, y1, x2, y2 = x.split(4, axis=-1)
  40. return paddle.concat(
  41. [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)], axis=-1)
  42. def sigmoid_focal_loss(logit, label, normalizer=1.0, alpha=0.25, gamma=2.0):
  43. prob = F.sigmoid(logit)
  44. ce_loss = F.binary_cross_entropy_with_logits(logit, label, reduction="none")
  45. p_t = prob * label + (1 - prob) * (1 - label)
  46. loss = ce_loss * ((1 - p_t)**gamma)
  47. if alpha >= 0:
  48. alpha_t = alpha * label + (1 - alpha) * (1 - label)
  49. loss = alpha_t * loss
  50. return loss.mean(1).sum() / normalizer
  51. def inverse_sigmoid(x, eps=1e-6):
  52. x = x.clip(min=0., max=1.)
  53. return paddle.log(x / (1 - x + eps) + eps)
  54. def deformable_attention_core_func(value, value_spatial_shapes,
  55. value_level_start_index, sampling_locations,
  56. attention_weights):
  57. """
  58. Args:
  59. value (Tensor): [bs, value_length, n_head, c]
  60. value_spatial_shapes (Tensor): [n_levels, 2]
  61. value_level_start_index (Tensor): [n_levels]
  62. sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2]
  63. attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points]
  64. Returns:
  65. output (Tensor): [bs, Length_{query}, C]
  66. """
  67. bs, _, n_head, c = value.shape
  68. _, Len_q, _, n_levels, n_points, _ = sampling_locations.shape
  69. value_list = value.split(
  70. value_spatial_shapes.prod(1).split(n_levels), axis=1)
  71. sampling_grids = 2 * sampling_locations - 1
  72. sampling_value_list = []
  73. for level, (h, w) in enumerate(value_spatial_shapes):
  74. # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
  75. value_l_ = value_list[level].flatten(2).transpose(
  76. [0, 2, 1]).reshape([bs * n_head, c, h, w])
  77. # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
  78. sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(
  79. [0, 2, 1, 3, 4]).flatten(0, 1)
  80. # N_*M_, D_, Lq_, P_
  81. sampling_value_l_ = F.grid_sample(
  82. value_l_,
  83. sampling_grid_l_,
  84. mode='bilinear',
  85. padding_mode='zeros',
  86. align_corners=False)
  87. sampling_value_list.append(sampling_value_l_)
  88. # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_*M_, 1, Lq_, L_*P_)
  89. attention_weights = attention_weights.transpose([0, 2, 1, 3, 4]).reshape(
  90. [bs * n_head, 1, Len_q, n_levels * n_points])
  91. output = (paddle.stack(
  92. sampling_value_list, axis=-2).flatten(-2) *
  93. attention_weights).sum(-1).reshape([bs, n_head * c, Len_q])
  94. return output.transpose([0, 2, 1])
  95. def get_valid_ratio(mask):
  96. _, H, W = paddle.shape(mask)
  97. valid_ratio_h = paddle.sum(mask[:, :, 0], 1) / H
  98. valid_ratio_w = paddle.sum(mask[:, 0, :], 1) / W
  99. # [b, 2]
  100. return paddle.stack([valid_ratio_w, valid_ratio_h], -1)
  101. def get_contrastive_denoising_training_group(targets,
  102. num_classes,
  103. num_queries,
  104. class_embed,
  105. num_denoising=100,
  106. label_noise_ratio=0.5,
  107. box_noise_scale=1.0):
  108. if num_denoising <= 0:
  109. return None, None, None, None
  110. num_gts = [len(t) for t in targets["gt_class"]]
  111. max_gt_num = max(num_gts)
  112. if max_gt_num == 0:
  113. return None, None, None, None
  114. num_group = num_denoising // max_gt_num
  115. num_group = 1 if num_group == 0 else num_group
  116. # pad gt to max_num of a batch
  117. bs = len(targets["gt_class"])
  118. input_query_class = paddle.full(
  119. [bs, max_gt_num], num_classes, dtype='int32')
  120. input_query_bbox = paddle.zeros([bs, max_gt_num, 4])
  121. pad_gt_mask = paddle.zeros([bs, max_gt_num])
  122. for i in range(bs):
  123. num_gt = num_gts[i]
  124. if num_gt > 0:
  125. input_query_class[i, :num_gt] = targets["gt_class"][i].squeeze(-1)
  126. input_query_bbox[i, :num_gt] = targets["gt_bbox"][i]
  127. pad_gt_mask[i, :num_gt] = 1
  128. # each group has positive and negative queries.
  129. input_query_class = input_query_class.tile([1, 2 * num_group])
  130. input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1])
  131. pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group])
  132. # positive and negative mask
  133. negative_gt_mask = paddle.zeros([bs, max_gt_num * 2, 1])
  134. negative_gt_mask[:, max_gt_num:] = 1
  135. negative_gt_mask = negative_gt_mask.tile([1, num_group, 1])
  136. positive_gt_mask = 1 - negative_gt_mask
  137. # contrastive denoising training positive index
  138. positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
  139. dn_positive_idx = paddle.nonzero(positive_gt_mask)[:, 1]
  140. dn_positive_idx = paddle.split(dn_positive_idx,
  141. [n * num_group for n in num_gts])
  142. # total denoising queries
  143. num_denoising = int(max_gt_num * 2 * num_group)
  144. if label_noise_ratio > 0:
  145. input_query_class = input_query_class.flatten()
  146. pad_gt_mask = pad_gt_mask.flatten()
  147. # half of bbox prob
  148. mask = paddle.rand(input_query_class.shape) < (label_noise_ratio * 0.5)
  149. chosen_idx = paddle.nonzero(mask * pad_gt_mask).squeeze(-1)
  150. # randomly put a new one here
  151. new_label = paddle.randint_like(
  152. chosen_idx, 0, num_classes, dtype=input_query_class.dtype)
  153. input_query_class.scatter_(chosen_idx, new_label)
  154. input_query_class.reshape_([bs, num_denoising])
  155. pad_gt_mask.reshape_([bs, num_denoising])
  156. if box_noise_scale > 0:
  157. known_bbox = bbox_cxcywh_to_xyxy(input_query_bbox)
  158. diff = paddle.tile(input_query_bbox[..., 2:] * 0.5,
  159. [1, 1, 2]) * box_noise_scale
  160. rand_sign = paddle.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
  161. rand_part = paddle.rand(input_query_bbox.shape)
  162. rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (
  163. 1 - negative_gt_mask)
  164. rand_part *= rand_sign
  165. known_bbox += rand_part * diff
  166. known_bbox.clip_(min=0.0, max=1.0)
  167. input_query_bbox = bbox_xyxy_to_cxcywh(known_bbox)
  168. input_query_bbox.clip_(min=0.0, max=1.0)
  169. class_embed = paddle.concat(
  170. [class_embed, paddle.zeros([1, class_embed.shape[-1]])])
  171. input_query_class = paddle.gather(
  172. class_embed, input_query_class.flatten(),
  173. axis=0).reshape([bs, num_denoising, -1])
  174. tgt_size = num_denoising + num_queries
  175. attn_mask = paddle.ones([tgt_size, tgt_size]) < 0
  176. # match query cannot see the reconstruct
  177. attn_mask[num_denoising:, :num_denoising] = True
  178. # reconstruct cannot see each other
  179. for i in range(num_group):
  180. if i == 0:
  181. attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), max_gt_num *
  182. 2 * (i + 1):num_denoising] = True
  183. if i == num_group - 1:
  184. attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), :max_gt_num *
  185. i * 2] = True
  186. else:
  187. attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), max_gt_num *
  188. 2 * (i + 1):num_denoising] = True
  189. attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), :max_gt_num *
  190. 2 * i] = True
  191. attn_mask = ~attn_mask
  192. dn_meta = {
  193. "dn_positive_idx": dn_positive_idx,
  194. "dn_num_group": num_group,
  195. "dn_num_split": [num_denoising, num_queries]
  196. }
  197. return input_query_class, input_query_bbox, attn_mask, dn_meta
  198. def get_sine_pos_embed(pos_tensor,
  199. num_pos_feats=128,
  200. temperature=10000,
  201. exchange_xy=True):
  202. """generate sine position embedding from a position tensor
  203. Args:
  204. pos_tensor (torch.Tensor): Shape as `(None, n)`.
  205. num_pos_feats (int): projected shape for each float in the tensor. Default: 128
  206. temperature (int): The temperature used for scaling
  207. the position embedding. Default: 10000.
  208. exchange_xy (bool, optional): exchange pos x and pos y. \
  209. For example, input tensor is `[x, y]`, the results will # noqa
  210. be `[pos(y), pos(x)]`. Defaults: True.
  211. Returns:
  212. torch.Tensor: Returned position embedding # noqa
  213. with shape `(None, n * num_pos_feats)`.
  214. """
  215. scale = 2. * math.pi
  216. dim_t = 2. * paddle.floor_divide(
  217. paddle.arange(num_pos_feats), paddle.to_tensor(2))
  218. dim_t = scale / temperature**(dim_t / num_pos_feats)
  219. def sine_func(x):
  220. x *= dim_t
  221. return paddle.stack(
  222. (x[:, :, 0::2].sin(), x[:, :, 1::2].cos()), axis=3).flatten(2)
  223. pos_res = [sine_func(x) for x in pos_tensor.split(pos_tensor.shape[-1], -1)]
  224. if exchange_xy:
  225. pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
  226. pos_res = paddle.concat(pos_res, axis=2)
  227. return pos_res