dino_transformer.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # Modified from Deformable-DETR (https://github.com/fundamentalvision/Deformable-DETR)
  16. # Copyright (c) 2020 SenseTime. All Rights Reserved.
  17. # Modified from detrex (https://github.com/IDEA-Research/detrex)
  18. # Copyright 2022 The IDEA Authors. All rights reserved.
  19. from __future__ import absolute_import
  20. from __future__ import division
  21. from __future__ import print_function
  22. import math
  23. import paddle
  24. import paddle.nn as nn
  25. import paddle.nn.functional as F
  26. from paddle import ParamAttr
  27. from paddle.regularizer import L2Decay
  28. from ppdet.core.workspace import register
  29. from ..layers import MultiHeadAttention
  30. from .position_encoding import PositionEmbedding
  31. from ..heads.detr_head import MLP
  32. from .deformable_transformer import MSDeformableAttention
  33. from ..initializer import (linear_init_, constant_, xavier_uniform_, normal_,
  34. bias_init_with_prob)
  35. from .utils import (_get_clones, get_valid_ratio,
  36. get_contrastive_denoising_training_group,
  37. get_sine_pos_embed, inverse_sigmoid)
  38. __all__ = ['DINOTransformer']
  39. class DINOTransformerEncoderLayer(nn.Layer):
  40. def __init__(self,
  41. d_model=256,
  42. n_head=8,
  43. dim_feedforward=1024,
  44. dropout=0.,
  45. activation="relu",
  46. n_levels=4,
  47. n_points=4,
  48. weight_attr=None,
  49. bias_attr=None):
  50. super(DINOTransformerEncoderLayer, self).__init__()
  51. # self attention
  52. self.self_attn = MSDeformableAttention(d_model, n_head, n_levels,
  53. n_points, 1.0)
  54. self.dropout1 = nn.Dropout(dropout)
  55. self.norm1 = nn.LayerNorm(
  56. d_model,
  57. weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
  58. bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
  59. # ffn
  60. self.linear1 = nn.Linear(d_model, dim_feedforward, weight_attr,
  61. bias_attr)
  62. self.activation = getattr(F, activation)
  63. self.dropout2 = nn.Dropout(dropout)
  64. self.linear2 = nn.Linear(dim_feedforward, d_model, weight_attr,
  65. bias_attr)
  66. self.dropout3 = nn.Dropout(dropout)
  67. self.norm2 = nn.LayerNorm(
  68. d_model,
  69. weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
  70. bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
  71. self._reset_parameters()
  72. def _reset_parameters(self):
  73. linear_init_(self.linear1)
  74. linear_init_(self.linear2)
  75. xavier_uniform_(self.linear1.weight)
  76. xavier_uniform_(self.linear2.weight)
  77. def with_pos_embed(self, tensor, pos):
  78. return tensor if pos is None else tensor + pos
  79. def forward_ffn(self, src):
  80. src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
  81. src = src + self.dropout3(src2)
  82. src = self.norm2(src)
  83. return src
  84. def forward(self,
  85. src,
  86. reference_points,
  87. spatial_shapes,
  88. level_start_index,
  89. src_mask=None,
  90. query_pos_embed=None):
  91. # self attention
  92. src2 = self.self_attn(
  93. self.with_pos_embed(src, query_pos_embed), reference_points, src,
  94. spatial_shapes, level_start_index, src_mask)
  95. src = src + self.dropout1(src2)
  96. src = self.norm1(src)
  97. # ffn
  98. src = self.forward_ffn(src)
  99. return src
  100. class DINOTransformerEncoder(nn.Layer):
  101. def __init__(self, encoder_layer, num_layers):
  102. super(DINOTransformerEncoder, self).__init__()
  103. self.layers = _get_clones(encoder_layer, num_layers)
  104. self.num_layers = num_layers
  105. @staticmethod
  106. def get_reference_points(spatial_shapes, valid_ratios, offset=0.5):
  107. valid_ratios = valid_ratios.unsqueeze(1)
  108. reference_points = []
  109. for i, (H, W) in enumerate(spatial_shapes):
  110. ref_y, ref_x = paddle.meshgrid(
  111. paddle.arange(end=H) + offset, paddle.arange(end=W) + offset)
  112. ref_y = ref_y.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 1] *
  113. H)
  114. ref_x = ref_x.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 0] *
  115. W)
  116. reference_points.append(paddle.stack((ref_x, ref_y), axis=-1))
  117. reference_points = paddle.concat(reference_points, 1).unsqueeze(2)
  118. reference_points = reference_points * valid_ratios
  119. return reference_points
  120. def forward(self,
  121. feat,
  122. spatial_shapes,
  123. level_start_index,
  124. feat_mask=None,
  125. query_pos_embed=None,
  126. valid_ratios=None):
  127. if valid_ratios is None:
  128. valid_ratios = paddle.ones(
  129. [feat.shape[0], spatial_shapes.shape[0], 2])
  130. reference_points = self.get_reference_points(spatial_shapes,
  131. valid_ratios)
  132. for layer in self.layers:
  133. feat = layer(feat, reference_points, spatial_shapes,
  134. level_start_index, feat_mask, query_pos_embed)
  135. return feat
  136. class DINOTransformerDecoderLayer(nn.Layer):
  137. def __init__(self,
  138. d_model=256,
  139. n_head=8,
  140. dim_feedforward=1024,
  141. dropout=0.,
  142. activation="relu",
  143. n_levels=4,
  144. n_points=4,
  145. weight_attr=None,
  146. bias_attr=None):
  147. super(DINOTransformerDecoderLayer, self).__init__()
  148. # self attention
  149. self.self_attn = MultiHeadAttention(d_model, n_head, dropout=dropout)
  150. self.dropout1 = nn.Dropout(dropout)
  151. self.norm1 = nn.LayerNorm(
  152. d_model,
  153. weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
  154. bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
  155. # cross attention
  156. self.cross_attn = MSDeformableAttention(d_model, n_head, n_levels,
  157. n_points, 1.0)
  158. self.dropout2 = nn.Dropout(dropout)
  159. self.norm2 = nn.LayerNorm(
  160. d_model,
  161. weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
  162. bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
  163. # ffn
  164. self.linear1 = nn.Linear(d_model, dim_feedforward, weight_attr,
  165. bias_attr)
  166. self.activation = getattr(F, activation)
  167. self.dropout3 = nn.Dropout(dropout)
  168. self.linear2 = nn.Linear(dim_feedforward, d_model, weight_attr,
  169. bias_attr)
  170. self.dropout4 = nn.Dropout(dropout)
  171. self.norm3 = nn.LayerNorm(
  172. d_model,
  173. weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
  174. bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
  175. self._reset_parameters()
  176. def _reset_parameters(self):
  177. linear_init_(self.linear1)
  178. linear_init_(self.linear2)
  179. xavier_uniform_(self.linear1.weight)
  180. xavier_uniform_(self.linear2.weight)
  181. def with_pos_embed(self, tensor, pos):
  182. return tensor if pos is None else tensor + pos
  183. def forward_ffn(self, tgt):
  184. return self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
  185. def forward(self,
  186. tgt,
  187. reference_points,
  188. memory,
  189. memory_spatial_shapes,
  190. memory_level_start_index,
  191. attn_mask=None,
  192. memory_mask=None,
  193. query_pos_embed=None):
  194. # self attention
  195. q = k = self.with_pos_embed(tgt, query_pos_embed)
  196. if attn_mask is not None:
  197. attn_mask = attn_mask.astype('bool')
  198. tgt2 = self.self_attn(q, k, value=tgt, attn_mask=attn_mask)
  199. tgt = tgt + self.dropout1(tgt2)
  200. tgt = self.norm1(tgt)
  201. # cross attention
  202. tgt2 = self.cross_attn(
  203. self.with_pos_embed(tgt, query_pos_embed), reference_points, memory,
  204. memory_spatial_shapes, memory_level_start_index, memory_mask)
  205. tgt = tgt + self.dropout2(tgt2)
  206. tgt = self.norm2(tgt)
  207. # ffn
  208. tgt2 = self.forward_ffn(tgt)
  209. tgt = tgt + self.dropout4(tgt2)
  210. tgt = self.norm3(tgt)
  211. return tgt
  212. class DINOTransformerDecoder(nn.Layer):
  213. def __init__(self,
  214. hidden_dim,
  215. decoder_layer,
  216. num_layers,
  217. return_intermediate=True):
  218. super(DINOTransformerDecoder, self).__init__()
  219. self.layers = _get_clones(decoder_layer, num_layers)
  220. self.hidden_dim = hidden_dim
  221. self.num_layers = num_layers
  222. self.return_intermediate = return_intermediate
  223. self.norm = nn.LayerNorm(
  224. hidden_dim,
  225. weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
  226. bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
  227. def forward(self,
  228. tgt,
  229. reference_points,
  230. memory,
  231. memory_spatial_shapes,
  232. memory_level_start_index,
  233. bbox_head,
  234. query_pos_head,
  235. valid_ratios=None,
  236. attn_mask=None,
  237. memory_mask=None):
  238. if valid_ratios is None:
  239. valid_ratios = paddle.ones(
  240. [memory.shape[0], memory_spatial_shapes.shape[0], 2])
  241. output = tgt
  242. intermediate = []
  243. inter_ref_bboxes = []
  244. for i, layer in enumerate(self.layers):
  245. reference_points_input = reference_points.unsqueeze(
  246. 2) * valid_ratios.tile([1, 1, 2]).unsqueeze(1)
  247. query_pos_embed = get_sine_pos_embed(
  248. reference_points_input[..., 0, :], self.hidden_dim // 2)
  249. query_pos_embed = query_pos_head(query_pos_embed)
  250. output = layer(output, reference_points_input, memory,
  251. memory_spatial_shapes, memory_level_start_index,
  252. attn_mask, memory_mask, query_pos_embed)
  253. inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
  254. reference_points))
  255. if self.return_intermediate:
  256. intermediate.append(self.norm(output))
  257. inter_ref_bboxes.append(inter_ref_bbox)
  258. reference_points = inter_ref_bbox.detach()
  259. if self.return_intermediate:
  260. return paddle.stack(intermediate), paddle.stack(inter_ref_bboxes)
  261. return output, reference_points
  262. @register
  263. class DINOTransformer(nn.Layer):
  264. __shared__ = ['num_classes', 'hidden_dim']
  265. def __init__(self,
  266. num_classes=80,
  267. hidden_dim=256,
  268. num_queries=900,
  269. position_embed_type='sine',
  270. return_intermediate_dec=True,
  271. backbone_feat_channels=[512, 1024, 2048],
  272. num_levels=4,
  273. num_encoder_points=4,
  274. num_decoder_points=4,
  275. nhead=8,
  276. num_encoder_layers=6,
  277. num_decoder_layers=6,
  278. dim_feedforward=1024,
  279. dropout=0.,
  280. activation="relu",
  281. pe_temperature=10000,
  282. pe_offset=-0.5,
  283. num_denoising=100,
  284. label_noise_ratio=0.5,
  285. box_noise_scale=1.0,
  286. learnt_init_query=True,
  287. eps=1e-2):
  288. super(DINOTransformer, self).__init__()
  289. assert position_embed_type in ['sine', 'learned'], \
  290. f'ValueError: position_embed_type not supported {position_embed_type}!'
  291. assert len(backbone_feat_channels) <= num_levels
  292. self.hidden_dim = hidden_dim
  293. self.nhead = nhead
  294. self.num_levels = num_levels
  295. self.num_classes = num_classes
  296. self.num_queries = num_queries
  297. self.eps = eps
  298. self.num_decoder_layers = num_decoder_layers
  299. # backbone feature projection
  300. self._build_input_proj_layer(backbone_feat_channels)
  301. # Transformer module
  302. encoder_layer = DINOTransformerEncoderLayer(
  303. hidden_dim, nhead, dim_feedforward, dropout, activation, num_levels,
  304. num_encoder_points)
  305. self.encoder = DINOTransformerEncoder(encoder_layer, num_encoder_layers)
  306. decoder_layer = DINOTransformerDecoderLayer(
  307. hidden_dim, nhead, dim_feedforward, dropout, activation, num_levels,
  308. num_decoder_points)
  309. self.decoder = DINOTransformerDecoder(hidden_dim, decoder_layer,
  310. num_decoder_layers,
  311. return_intermediate_dec)
  312. # denoising part
  313. self.denoising_class_embed = nn.Embedding(
  314. num_classes,
  315. hidden_dim,
  316. weight_attr=ParamAttr(initializer=nn.initializer.Normal()))
  317. self.num_denoising = num_denoising
  318. self.label_noise_ratio = label_noise_ratio
  319. self.box_noise_scale = box_noise_scale
  320. # position embedding
  321. self.position_embedding = PositionEmbedding(
  322. hidden_dim // 2,
  323. temperature=pe_temperature,
  324. normalize=True if position_embed_type == 'sine' else False,
  325. embed_type=position_embed_type,
  326. offset=pe_offset)
  327. self.level_embed = nn.Embedding(num_levels, hidden_dim)
  328. # decoder embedding
  329. self.learnt_init_query = learnt_init_query
  330. if learnt_init_query:
  331. self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
  332. self.query_pos_head = MLP(2 * hidden_dim,
  333. hidden_dim,
  334. hidden_dim,
  335. num_layers=2)
  336. # encoder head
  337. self.enc_output = nn.Sequential(
  338. nn.Linear(hidden_dim, hidden_dim),
  339. nn.LayerNorm(
  340. hidden_dim,
  341. weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
  342. bias_attr=ParamAttr(regularizer=L2Decay(0.0))))
  343. self.enc_score_head = nn.Linear(hidden_dim, num_classes)
  344. self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
  345. # decoder head
  346. self.dec_score_head = nn.LayerList([
  347. nn.Linear(hidden_dim, num_classes)
  348. for _ in range(num_decoder_layers)
  349. ])
  350. self.dec_bbox_head = nn.LayerList([
  351. MLP(hidden_dim, hidden_dim, 4, num_layers=3)
  352. for _ in range(num_decoder_layers)
  353. ])
  354. self._reset_parameters()
  355. def _reset_parameters(self):
  356. # class and bbox head init
  357. bias_cls = bias_init_with_prob(0.01)
  358. linear_init_(self.enc_score_head)
  359. constant_(self.enc_score_head.bias, bias_cls)
  360. constant_(self.enc_bbox_head.layers[-1].weight)
  361. constant_(self.enc_bbox_head.layers[-1].bias)
  362. for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
  363. linear_init_(cls_)
  364. constant_(cls_.bias, bias_cls)
  365. constant_(reg_.layers[-1].weight)
  366. constant_(reg_.layers[-1].bias)
  367. linear_init_(self.enc_output[0])
  368. xavier_uniform_(self.enc_output[0].weight)
  369. normal_(self.level_embed.weight)
  370. xavier_uniform_(self.tgt_embed.weight)
  371. xavier_uniform_(self.query_pos_head.layers[0].weight)
  372. xavier_uniform_(self.query_pos_head.layers[1].weight)
  373. for l in self.input_proj:
  374. xavier_uniform_(l[0].weight)
  375. constant_(l[0].bias)
  376. @classmethod
  377. def from_config(cls, cfg, input_shape):
  378. return {'backbone_feat_channels': [i.channels for i in input_shape], }
  379. def _build_input_proj_layer(self, backbone_feat_channels):
  380. self.input_proj = nn.LayerList()
  381. for in_channels in backbone_feat_channels:
  382. self.input_proj.append(
  383. nn.Sequential(
  384. ('conv', nn.Conv2D(
  385. in_channels, self.hidden_dim, kernel_size=1)),
  386. ('norm', nn.GroupNorm(
  387. 32,
  388. self.hidden_dim,
  389. weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
  390. bias_attr=ParamAttr(regularizer=L2Decay(0.0))))))
  391. in_channels = backbone_feat_channels[-1]
  392. for _ in range(self.num_levels - len(backbone_feat_channels)):
  393. self.input_proj.append(
  394. nn.Sequential(
  395. ('conv', nn.Conv2D(
  396. in_channels,
  397. self.hidden_dim,
  398. kernel_size=3,
  399. stride=2,
  400. padding=1)), ('norm', nn.GroupNorm(
  401. 32,
  402. self.hidden_dim,
  403. weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
  404. bias_attr=ParamAttr(regularizer=L2Decay(0.0))))))
  405. in_channels = self.hidden_dim
  406. def _get_encoder_input(self, feats, pad_mask=None):
  407. # get projection features
  408. proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
  409. if self.num_levels > len(proj_feats):
  410. len_srcs = len(proj_feats)
  411. for i in range(len_srcs, self.num_levels):
  412. if i == len_srcs:
  413. proj_feats.append(self.input_proj[i](feats[-1]))
  414. else:
  415. proj_feats.append(self.input_proj[i](proj_feats[-1]))
  416. # get encoder inputs
  417. feat_flatten = []
  418. mask_flatten = []
  419. lvl_pos_embed_flatten = []
  420. spatial_shapes = []
  421. valid_ratios = []
  422. for i, feat in enumerate(proj_feats):
  423. bs, _, h, w = paddle.shape(feat)
  424. spatial_shapes.append(paddle.concat([h, w]))
  425. # [b,c,h,w] -> [b,h*w,c]
  426. feat_flatten.append(feat.flatten(2).transpose([0, 2, 1]))
  427. if pad_mask is not None:
  428. mask = F.interpolate(pad_mask.unsqueeze(0), size=(h, w))[0]
  429. else:
  430. mask = paddle.ones([bs, h, w])
  431. valid_ratios.append(get_valid_ratio(mask))
  432. # [b, h*w, c]
  433. pos_embed = self.position_embedding(mask).flatten(1, 2)
  434. lvl_pos_embed = pos_embed + self.level_embed.weight[i]
  435. lvl_pos_embed_flatten.append(lvl_pos_embed)
  436. if pad_mask is not None:
  437. # [b, h*w]
  438. mask_flatten.append(mask.flatten(1))
  439. # [b, l, c]
  440. feat_flatten = paddle.concat(feat_flatten, 1)
  441. # [b, l]
  442. mask_flatten = None if pad_mask is None else paddle.concat(mask_flatten,
  443. 1)
  444. # [b, l, c]
  445. lvl_pos_embed_flatten = paddle.concat(lvl_pos_embed_flatten, 1)
  446. # [num_levels, 2]
  447. spatial_shapes = paddle.to_tensor(
  448. paddle.stack(spatial_shapes).astype('int64'))
  449. # [l], 每一个level的起始index
  450. level_start_index = paddle.concat([
  451. paddle.zeros(
  452. [1], dtype='int64'), spatial_shapes.prod(1).cumsum(0)[:-1]
  453. ])
  454. # [b, num_levels, 2]
  455. valid_ratios = paddle.stack(valid_ratios, 1)
  456. return (feat_flatten, spatial_shapes, level_start_index, mask_flatten,
  457. lvl_pos_embed_flatten, valid_ratios)
  458. def forward(self, feats, pad_mask=None, gt_meta=None):
  459. # input projection and embedding
  460. (feat_flatten, spatial_shapes, level_start_index, mask_flatten,
  461. lvl_pos_embed_flatten,
  462. valid_ratios) = self._get_encoder_input(feats, pad_mask)
  463. # encoder
  464. memory = self.encoder(feat_flatten, spatial_shapes, level_start_index,
  465. mask_flatten, lvl_pos_embed_flatten, valid_ratios)
  466. # prepare denoising training
  467. if self.training:
  468. denoising_class, denoising_bbox, attn_mask, dn_meta = \
  469. get_contrastive_denoising_training_group(gt_meta,
  470. self.num_classes,
  471. self.num_queries,
  472. self.denoising_class_embed.weight,
  473. self.num_denoising,
  474. self.label_noise_ratio,
  475. self.box_noise_scale)
  476. else:
  477. denoising_class, denoising_bbox, attn_mask, dn_meta = None, None, None, None
  478. target, init_ref_points, enc_topk_bboxes, enc_topk_logits = \
  479. self._get_decoder_input(
  480. memory, spatial_shapes, mask_flatten, denoising_class,
  481. denoising_bbox)
  482. # decoder
  483. inter_feats, inter_ref_bboxes = self.decoder(
  484. target, init_ref_points, memory, spatial_shapes, level_start_index,
  485. self.dec_bbox_head, self.query_pos_head, valid_ratios, attn_mask,
  486. mask_flatten)
  487. out_bboxes = []
  488. out_logits = []
  489. for i in range(self.num_decoder_layers):
  490. out_logits.append(self.dec_score_head[i](inter_feats[i]))
  491. if i == 0:
  492. out_bboxes.append(
  493. F.sigmoid(self.dec_bbox_head[i](inter_feats[i]) +
  494. inverse_sigmoid(init_ref_points)))
  495. else:
  496. out_bboxes.append(
  497. F.sigmoid(self.dec_bbox_head[i](inter_feats[i]) +
  498. inverse_sigmoid(inter_ref_bboxes[i - 1])))
  499. out_bboxes = paddle.stack(out_bboxes)
  500. out_logits = paddle.stack(out_logits)
  501. return (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits,
  502. dn_meta)
  503. def _get_encoder_output_anchors(self,
  504. memory,
  505. spatial_shapes,
  506. memory_mask=None,
  507. grid_size=0.05):
  508. output_anchors = []
  509. idx = 0
  510. for lvl, (h, w) in enumerate(spatial_shapes):
  511. if memory_mask is not None:
  512. mask_ = memory_mask[:, idx:idx + h * w].reshape([-1, h, w])
  513. valid_H = paddle.sum(mask_[:, :, 0], 1)
  514. valid_W = paddle.sum(mask_[:, 0, :], 1)
  515. else:
  516. valid_H, valid_W = h, w
  517. grid_y, grid_x = paddle.meshgrid(
  518. paddle.arange(
  519. end=h, dtype=memory.dtype),
  520. paddle.arange(
  521. end=w, dtype=memory.dtype))
  522. grid_xy = paddle.stack([grid_x, grid_y], -1)
  523. valid_WH = paddle.stack([valid_W, valid_H], -1).reshape(
  524. [-1, 1, 1, 2]).astype(grid_xy.dtype)
  525. grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
  526. wh = paddle.ones_like(grid_xy) * grid_size * (2.0**lvl)
  527. output_anchors.append(
  528. paddle.concat([grid_xy, wh], -1).reshape([-1, h * w, 4]))
  529. idx += h * w
  530. output_anchors = paddle.concat(output_anchors, 1)
  531. valid_mask = ((output_anchors > self.eps) *
  532. (output_anchors < 1 - self.eps)).all(-1, keepdim=True)
  533. output_anchors = paddle.log(output_anchors / (1 - output_anchors))
  534. if memory_mask is not None:
  535. valid_mask = (valid_mask * (memory_mask.unsqueeze(-1) > 0)) > 0
  536. output_anchors = paddle.where(valid_mask, output_anchors,
  537. paddle.to_tensor(float("inf")))
  538. memory = paddle.where(valid_mask, memory, paddle.to_tensor(0.))
  539. output_memory = self.enc_output(memory)
  540. return output_memory, output_anchors
  541. def _get_decoder_input(self,
  542. memory,
  543. spatial_shapes,
  544. memory_mask=None,
  545. denoising_class=None,
  546. denoising_bbox=None):
  547. bs, _, _ = memory.shape
  548. # prepare input for decoder
  549. output_memory, output_anchors = self._get_encoder_output_anchors(
  550. memory, spatial_shapes, memory_mask)
  551. enc_outputs_class = self.enc_score_head(output_memory)
  552. enc_outputs_coord_unact = self.enc_bbox_head(
  553. output_memory) + output_anchors
  554. _, topk_ind = paddle.topk(
  555. enc_outputs_class.max(-1), self.num_queries, axis=1)
  556. # extract region proposal boxes
  557. batch_ind = paddle.arange(end=bs, dtype=topk_ind.dtype)
  558. batch_ind = batch_ind.unsqueeze(-1).tile([1, self.num_queries])
  559. topk_ind = paddle.stack([batch_ind, topk_ind], axis=-1)
  560. topk_coords_unact = paddle.gather_nd(enc_outputs_coord_unact,
  561. topk_ind) # unsigmoided.
  562. reference_points = enc_topk_bboxes = F.sigmoid(topk_coords_unact)
  563. if denoising_bbox is not None:
  564. reference_points = paddle.concat([denoising_bbox, enc_topk_bboxes],
  565. 1)
  566. enc_topk_logits = paddle.gather_nd(enc_outputs_class, topk_ind)
  567. # extract region features
  568. if self.learnt_init_query:
  569. target = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1])
  570. else:
  571. target = paddle.gather_nd(output_memory, topk_ind).detach()
  572. if denoising_class is not None:
  573. target = paddle.concat([denoising_class, target], 1)
  574. return target, reference_points.detach(
  575. ), enc_topk_bboxes, enc_topk_logits