detr_transformer.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # Modified from DETR (https://github.com/facebookresearch/detr)
  16. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import paddle
  21. import paddle.nn as nn
  22. import paddle.nn.functional as F
  23. from ppdet.core.workspace import register
  24. from ..layers import MultiHeadAttention, _convert_attention_mask
  25. from .position_encoding import PositionEmbedding
  26. from .utils import _get_clones
  27. from ..initializer import linear_init_, conv_init_, xavier_uniform_, normal_
  28. __all__ = ['DETRTransformer']
  29. class TransformerEncoderLayer(nn.Layer):
  30. def __init__(self,
  31. d_model,
  32. nhead,
  33. dim_feedforward=2048,
  34. dropout=0.1,
  35. activation="relu",
  36. attn_dropout=None,
  37. act_dropout=None,
  38. normalize_before=False):
  39. super(TransformerEncoderLayer, self).__init__()
  40. attn_dropout = dropout if attn_dropout is None else attn_dropout
  41. act_dropout = dropout if act_dropout is None else act_dropout
  42. self.normalize_before = normalize_before
  43. self.self_attn = MultiHeadAttention(d_model, nhead, attn_dropout)
  44. # Implementation of Feedforward model
  45. self.linear1 = nn.Linear(d_model, dim_feedforward)
  46. self.dropout = nn.Dropout(act_dropout, mode="upscale_in_train")
  47. self.linear2 = nn.Linear(dim_feedforward, d_model)
  48. self.norm1 = nn.LayerNorm(d_model)
  49. self.norm2 = nn.LayerNorm(d_model)
  50. self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
  51. self.dropout2 = nn.Dropout(dropout, mode="upscale_in_train")
  52. self.activation = getattr(F, activation)
  53. self._reset_parameters()
  54. def _reset_parameters(self):
  55. linear_init_(self.linear1)
  56. linear_init_(self.linear2)
  57. @staticmethod
  58. def with_pos_embed(tensor, pos_embed):
  59. return tensor if pos_embed is None else tensor + pos_embed
  60. def forward(self, src, src_mask=None, pos_embed=None):
  61. residual = src
  62. if self.normalize_before:
  63. src = self.norm1(src)
  64. q = k = self.with_pos_embed(src, pos_embed)
  65. src = self.self_attn(q, k, value=src, attn_mask=src_mask)
  66. src = residual + self.dropout1(src)
  67. if not self.normalize_before:
  68. src = self.norm1(src)
  69. residual = src
  70. if self.normalize_before:
  71. src = self.norm2(src)
  72. src = self.linear2(self.dropout(self.activation(self.linear1(src))))
  73. src = residual + self.dropout2(src)
  74. if not self.normalize_before:
  75. src = self.norm2(src)
  76. return src
  77. class TransformerEncoder(nn.Layer):
  78. def __init__(self, encoder_layer, num_layers, norm=None):
  79. super(TransformerEncoder, self).__init__()
  80. self.layers = _get_clones(encoder_layer, num_layers)
  81. self.num_layers = num_layers
  82. self.norm = norm
  83. def forward(self, src, src_mask=None, pos_embed=None):
  84. output = src
  85. for layer in self.layers:
  86. output = layer(output, src_mask=src_mask, pos_embed=pos_embed)
  87. if self.norm is not None:
  88. output = self.norm(output)
  89. return output
  90. class TransformerDecoderLayer(nn.Layer):
  91. def __init__(self,
  92. d_model,
  93. nhead,
  94. dim_feedforward=2048,
  95. dropout=0.1,
  96. activation="relu",
  97. attn_dropout=None,
  98. act_dropout=None,
  99. normalize_before=False):
  100. super(TransformerDecoderLayer, self).__init__()
  101. attn_dropout = dropout if attn_dropout is None else attn_dropout
  102. act_dropout = dropout if act_dropout is None else act_dropout
  103. self.normalize_before = normalize_before
  104. self.self_attn = MultiHeadAttention(d_model, nhead, attn_dropout)
  105. self.cross_attn = MultiHeadAttention(d_model, nhead, attn_dropout)
  106. # Implementation of Feedforward model
  107. self.linear1 = nn.Linear(d_model, dim_feedforward)
  108. self.dropout = nn.Dropout(act_dropout, mode="upscale_in_train")
  109. self.linear2 = nn.Linear(dim_feedforward, d_model)
  110. self.norm1 = nn.LayerNorm(d_model)
  111. self.norm2 = nn.LayerNorm(d_model)
  112. self.norm3 = nn.LayerNorm(d_model)
  113. self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
  114. self.dropout2 = nn.Dropout(dropout, mode="upscale_in_train")
  115. self.dropout3 = nn.Dropout(dropout, mode="upscale_in_train")
  116. self.activation = getattr(F, activation)
  117. self._reset_parameters()
  118. def _reset_parameters(self):
  119. linear_init_(self.linear1)
  120. linear_init_(self.linear2)
  121. @staticmethod
  122. def with_pos_embed(tensor, pos_embed):
  123. return tensor if pos_embed is None else tensor + pos_embed
  124. def forward(self,
  125. tgt,
  126. memory,
  127. tgt_mask=None,
  128. memory_mask=None,
  129. pos_embed=None,
  130. query_pos_embed=None):
  131. tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
  132. residual = tgt
  133. if self.normalize_before:
  134. tgt = self.norm1(tgt)
  135. q = k = self.with_pos_embed(tgt, query_pos_embed)
  136. tgt = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask)
  137. tgt = residual + self.dropout1(tgt)
  138. if not self.normalize_before:
  139. tgt = self.norm1(tgt)
  140. residual = tgt
  141. if self.normalize_before:
  142. tgt = self.norm2(tgt)
  143. q = self.with_pos_embed(tgt, query_pos_embed)
  144. k = self.with_pos_embed(memory, pos_embed)
  145. tgt = self.cross_attn(q, k, value=memory, attn_mask=memory_mask)
  146. tgt = residual + self.dropout2(tgt)
  147. if not self.normalize_before:
  148. tgt = self.norm2(tgt)
  149. residual = tgt
  150. if self.normalize_before:
  151. tgt = self.norm3(tgt)
  152. tgt = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
  153. tgt = residual + self.dropout3(tgt)
  154. if not self.normalize_before:
  155. tgt = self.norm3(tgt)
  156. return tgt
  157. class TransformerDecoder(nn.Layer):
  158. def __init__(self,
  159. decoder_layer,
  160. num_layers,
  161. norm=None,
  162. return_intermediate=False):
  163. super(TransformerDecoder, self).__init__()
  164. self.layers = _get_clones(decoder_layer, num_layers)
  165. self.num_layers = num_layers
  166. self.norm = norm
  167. self.return_intermediate = return_intermediate
  168. def forward(self,
  169. tgt,
  170. memory,
  171. tgt_mask=None,
  172. memory_mask=None,
  173. pos_embed=None,
  174. query_pos_embed=None):
  175. tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
  176. output = tgt
  177. intermediate = []
  178. for layer in self.layers:
  179. output = layer(
  180. output,
  181. memory,
  182. tgt_mask=tgt_mask,
  183. memory_mask=memory_mask,
  184. pos_embed=pos_embed,
  185. query_pos_embed=query_pos_embed)
  186. if self.return_intermediate:
  187. intermediate.append(self.norm(output))
  188. if self.norm is not None:
  189. output = self.norm(output)
  190. if self.return_intermediate:
  191. return paddle.stack(intermediate)
  192. return output.unsqueeze(0)
  193. @register
  194. class DETRTransformer(nn.Layer):
  195. __shared__ = ['hidden_dim']
  196. def __init__(self,
  197. num_queries=100,
  198. position_embed_type='sine',
  199. return_intermediate_dec=True,
  200. backbone_num_channels=2048,
  201. hidden_dim=256,
  202. nhead=8,
  203. num_encoder_layers=6,
  204. num_decoder_layers=6,
  205. dim_feedforward=2048,
  206. dropout=0.1,
  207. activation="relu",
  208. attn_dropout=None,
  209. act_dropout=None,
  210. normalize_before=False):
  211. super(DETRTransformer, self).__init__()
  212. assert position_embed_type in ['sine', 'learned'],\
  213. f'ValueError: position_embed_type not supported {position_embed_type}!'
  214. self.hidden_dim = hidden_dim
  215. self.nhead = nhead
  216. encoder_layer = TransformerEncoderLayer(
  217. hidden_dim, nhead, dim_feedforward, dropout, activation,
  218. attn_dropout, act_dropout, normalize_before)
  219. encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None
  220. self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers,
  221. encoder_norm)
  222. decoder_layer = TransformerDecoderLayer(
  223. hidden_dim, nhead, dim_feedforward, dropout, activation,
  224. attn_dropout, act_dropout, normalize_before)
  225. decoder_norm = nn.LayerNorm(hidden_dim)
  226. self.decoder = TransformerDecoder(
  227. decoder_layer,
  228. num_decoder_layers,
  229. decoder_norm,
  230. return_intermediate=return_intermediate_dec)
  231. self.input_proj = nn.Conv2D(
  232. backbone_num_channels, hidden_dim, kernel_size=1)
  233. self.query_pos_embed = nn.Embedding(num_queries, hidden_dim)
  234. self.position_embedding = PositionEmbedding(
  235. hidden_dim // 2,
  236. normalize=True if position_embed_type == 'sine' else False,
  237. embed_type=position_embed_type)
  238. self._reset_parameters()
  239. def _reset_parameters(self):
  240. for p in self.parameters():
  241. if p.dim() > 1:
  242. xavier_uniform_(p)
  243. conv_init_(self.input_proj)
  244. normal_(self.query_pos_embed.weight)
  245. @classmethod
  246. def from_config(cls, cfg, input_shape):
  247. return {
  248. 'backbone_num_channels': [i.channels for i in input_shape][-1],
  249. }
  250. def _convert_attention_mask(self, mask):
  251. return (mask - 1.0) * 1e9
  252. def forward(self, src, src_mask=None, *args, **kwargs):
  253. r"""
  254. Applies a Transformer model on the inputs.
  255. Parameters:
  256. src (List(Tensor)): Backbone feature maps with shape [[bs, c, h, w]].
  257. src_mask (Tensor, optional): A tensor used in multi-head attention
  258. to prevents attention to some unwanted positions, usually the
  259. paddings or the subsequent positions. It is a tensor with shape
  260. [bs, H, W]`. When the data type is bool, the unwanted positions
  261. have `False` values and the others have `True` values. When the
  262. data type is int, the unwanted positions have 0 values and the
  263. others have 1 values. When the data type is float, the unwanted
  264. positions have `-INF` values and the others have 0 values. It
  265. can be None when nothing wanted or needed to be prevented
  266. attention to. Default None.
  267. Returns:
  268. output (Tensor): [num_levels, batch_size, num_queries, hidden_dim]
  269. memory (Tensor): [batch_size, hidden_dim, h, w]
  270. """
  271. # use last level feature map
  272. src_proj = self.input_proj(src[-1])
  273. bs, c, h, w = paddle.shape(src_proj)
  274. # flatten [B, C, H, W] to [B, HxW, C]
  275. src_flatten = src_proj.flatten(2).transpose([0, 2, 1])
  276. if src_mask is not None:
  277. src_mask = F.interpolate(src_mask.unsqueeze(0), size=(h, w))[0]
  278. else:
  279. src_mask = paddle.ones([bs, h, w])
  280. pos_embed = self.position_embedding(src_mask).flatten(1, 2)
  281. if self.training:
  282. src_mask = self._convert_attention_mask(src_mask)
  283. src_mask = src_mask.reshape([bs, 1, 1, h * w])
  284. else:
  285. src_mask = None
  286. memory = self.encoder(
  287. src_flatten, src_mask=src_mask, pos_embed=pos_embed)
  288. query_pos_embed = self.query_pos_embed.weight.unsqueeze(0).tile(
  289. [bs, 1, 1])
  290. tgt = paddle.zeros_like(query_pos_embed)
  291. output = self.decoder(
  292. tgt,
  293. memory,
  294. memory_mask=src_mask,
  295. pos_embed=pos_embed,
  296. query_pos_embed=query_pos_embed)
  297. if self.training:
  298. src_mask = src_mask.reshape([bs, 1, 1, h, w])
  299. else:
  300. src_mask = None
  301. return (output, memory.transpose([0, 2, 1]).reshape([bs, c, h, w]),
  302. src_proj, src_mask)