trans_encoder.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle.nn import ReLU, Swish, GELU
  18. import math
  19. from ppdet.core.workspace import register
  20. from ..shape_spec import ShapeSpec
  21. __all__ = ['TransEncoder']
  22. class BertEmbeddings(nn.Layer):
  23. def __init__(self, word_size, position_embeddings_size, word_type_size,
  24. hidden_size, dropout_prob):
  25. super(BertEmbeddings, self).__init__()
  26. self.word_embeddings = nn.Embedding(
  27. word_size, hidden_size, padding_idx=0)
  28. self.position_embeddings = nn.Embedding(position_embeddings_size,
  29. hidden_size)
  30. self.token_type_embeddings = nn.Embedding(word_type_size, hidden_size)
  31. self.layernorm = nn.LayerNorm(hidden_size, epsilon=1e-8)
  32. self.dropout = nn.Dropout(dropout_prob)
  33. def forward(self, x, token_type_ids=None, position_ids=None):
  34. seq_len = paddle.shape(x)[1]
  35. if position_ids is None:
  36. position_ids = paddle.arange(seq_len).unsqueeze(0).expand_as(x)
  37. if token_type_ids is None:
  38. token_type_ids = paddle.zeros(paddle.shape(x))
  39. word_embs = self.word_embeddings(x)
  40. position_embs = self.position_embeddings(position_ids)
  41. token_type_embs = self.token_type_embeddings(token_type_ids)
  42. embs_cmb = word_embs + position_embs + token_type_embs
  43. embs_out = self.layernorm(embs_cmb)
  44. embs_out = self.dropout(embs_out)
  45. return embs_out
  46. class BertSelfAttention(nn.Layer):
  47. def __init__(self,
  48. hidden_size,
  49. num_attention_heads,
  50. attention_probs_dropout_prob,
  51. output_attentions=False):
  52. super(BertSelfAttention, self).__init__()
  53. if hidden_size % num_attention_heads != 0:
  54. raise ValueError(
  55. "The hidden_size must be a multiple of the number of attention "
  56. "heads, but got {} % {} != 0" %
  57. (hidden_size, num_attention_heads))
  58. self.num_attention_heads = num_attention_heads
  59. self.attention_head_size = int(hidden_size / num_attention_heads)
  60. self.all_head_size = self.num_attention_heads * self.attention_head_size
  61. self.query = nn.Linear(hidden_size, self.all_head_size)
  62. self.key = nn.Linear(hidden_size, self.all_head_size)
  63. self.value = nn.Linear(hidden_size, self.all_head_size)
  64. self.dropout = nn.Dropout(attention_probs_dropout_prob)
  65. self.output_attentions = output_attentions
  66. def forward(self, x, attention_mask, head_mask=None):
  67. query = self.query(x)
  68. key = self.key(x)
  69. value = self.value(x)
  70. query_dim1, query_dim2 = paddle.shape(query)[:-1]
  71. new_shape = [
  72. query_dim1, query_dim2, self.num_attention_heads,
  73. self.attention_head_size
  74. ]
  75. query = query.reshape(new_shape).transpose(perm=(0, 2, 1, 3))
  76. key = key.reshape(new_shape).transpose(perm=(0, 2, 3, 1))
  77. value = value.reshape(new_shape).transpose(perm=(0, 2, 1, 3))
  78. attention = paddle.matmul(query,
  79. key) / math.sqrt(self.attention_head_size)
  80. attention = attention + attention_mask
  81. attention_value = F.softmax(attention, axis=-1)
  82. attention_value = self.dropout(attention_value)
  83. if head_mask is not None:
  84. attention_value = attention_value * head_mask
  85. context = paddle.matmul(attention_value, value).transpose(perm=(0, 2, 1,
  86. 3))
  87. ctx_dim1, ctx_dim2 = paddle.shape(context)[:-2]
  88. new_context_shape = [
  89. ctx_dim1,
  90. ctx_dim2,
  91. self.all_head_size,
  92. ]
  93. context = context.reshape(new_context_shape)
  94. if self.output_attentions:
  95. return (context, attention_value)
  96. else:
  97. return (context, )
  98. class BertAttention(nn.Layer):
  99. def __init__(self,
  100. hidden_size,
  101. num_attention_heads,
  102. attention_probs_dropout_prob,
  103. fc_dropout_prob,
  104. output_attentions=False):
  105. super(BertAttention, self).__init__()
  106. self.bert_selfattention = BertSelfAttention(
  107. hidden_size, num_attention_heads, attention_probs_dropout_prob,
  108. output_attentions)
  109. self.fc = nn.Linear(hidden_size, hidden_size)
  110. self.layernorm = nn.LayerNorm(hidden_size, epsilon=1e-8)
  111. self.dropout = nn.Dropout(fc_dropout_prob)
  112. def forward(self, x, attention_mask, head_mask=None):
  113. attention_feats = self.bert_selfattention(x, attention_mask, head_mask)
  114. features = self.fc(attention_feats[0])
  115. features = self.dropout(features)
  116. features = self.layernorm(features + x)
  117. if len(attention_feats) == 2:
  118. return (features, attention_feats[1])
  119. else:
  120. return (features, )
  121. class BertFeedForward(nn.Layer):
  122. def __init__(self,
  123. hidden_size,
  124. intermediate_size,
  125. num_attention_heads,
  126. attention_probs_dropout_prob,
  127. fc_dropout_prob,
  128. act_fn='ReLU',
  129. output_attentions=False):
  130. super(BertFeedForward, self).__init__()
  131. self.fc1 = nn.Linear(hidden_size, intermediate_size)
  132. self.act_fn = eval(act_fn)
  133. self.fc2 = nn.Linear(intermediate_size, hidden_size)
  134. self.layernorm = nn.LayerNorm(hidden_size, epsilon=1e-8)
  135. self.dropout = nn.Dropout(fc_dropout_prob)
  136. def forward(self, x):
  137. features = self.fc1(x)
  138. features = self.act_fn(features)
  139. features = self.fc2(features)
  140. features = self.dropout(features)
  141. features = self.layernorm(features + x)
  142. return features
  143. class BertLayer(nn.Layer):
  144. def __init__(self,
  145. hidden_size,
  146. intermediate_size,
  147. num_attention_heads,
  148. attention_probs_dropout_prob,
  149. fc_dropout_prob,
  150. act_fn='ReLU',
  151. output_attentions=False):
  152. super(BertLayer, self).__init__()
  153. self.attention = BertAttention(hidden_size, num_attention_heads,
  154. attention_probs_dropout_prob,
  155. output_attentions)
  156. self.feed_forward = BertFeedForward(
  157. hidden_size, intermediate_size, num_attention_heads,
  158. attention_probs_dropout_prob, fc_dropout_prob, act_fn,
  159. output_attentions)
  160. def forward(self, x, attention_mask, head_mask=None):
  161. attention_feats = self.attention(x, attention_mask, head_mask)
  162. features = self.feed_forward(attention_feats[0])
  163. if len(attention_feats) == 2:
  164. return (features, attention_feats[1])
  165. else:
  166. return (features, )
  167. class BertEncoder(nn.Layer):
  168. def __init__(self,
  169. num_hidden_layers,
  170. hidden_size,
  171. intermediate_size,
  172. num_attention_heads,
  173. attention_probs_dropout_prob,
  174. fc_dropout_prob,
  175. act_fn='ReLU',
  176. output_attentions=False,
  177. output_hidden_feats=False):
  178. super(BertEncoder, self).__init__()
  179. self.output_attentions = output_attentions
  180. self.output_hidden_feats = output_hidden_feats
  181. self.layers = nn.LayerList([
  182. BertLayer(hidden_size, intermediate_size, num_attention_heads,
  183. attention_probs_dropout_prob, fc_dropout_prob, act_fn,
  184. output_attentions) for _ in range(num_hidden_layers)
  185. ])
  186. def forward(self, x, attention_mask, head_mask=None):
  187. all_features = (x, )
  188. all_attentions = ()
  189. for i, layer in enumerate(self.layers):
  190. mask = head_mask[i] if head_mask is not None else None
  191. layer_out = layer(x, attention_mask, mask)
  192. if self.output_hidden_feats:
  193. all_features = all_features + (x, )
  194. x = layer_out[0]
  195. if self.output_attentions:
  196. all_attentions = all_attentions + (layer_out[1], )
  197. outputs = (x, )
  198. if self.output_hidden_feats:
  199. outputs += (all_features, )
  200. if self.output_attentions:
  201. outputs += (all_attentions, )
  202. return outputs
  203. class BertPooler(nn.Layer):
  204. def __init__(self, hidden_size):
  205. super(BertPooler, self).__init__()
  206. self.fc = nn.Linear(hidden_size, hidden_size)
  207. self.act = nn.Tanh()
  208. def forward(self, x):
  209. first_token = x[:, 0]
  210. pooled_output = self.fc(first_token)
  211. pooled_output = self.act(pooled_output)
  212. return pooled_output
  213. class METROEncoder(nn.Layer):
  214. def __init__(self,
  215. vocab_size,
  216. num_hidden_layers,
  217. features_dims,
  218. position_embeddings_size,
  219. hidden_size,
  220. intermediate_size,
  221. output_feature_dim,
  222. num_attention_heads,
  223. attention_probs_dropout_prob,
  224. fc_dropout_prob,
  225. act_fn='ReLU',
  226. output_attentions=False,
  227. output_hidden_feats=False,
  228. use_img_layernorm=False):
  229. super(METROEncoder, self).__init__()
  230. self.img_dims = features_dims
  231. self.num_hidden_layers = num_hidden_layers
  232. self.use_img_layernorm = use_img_layernorm
  233. self.output_attentions = output_attentions
  234. self.embedding = BertEmbeddings(vocab_size, position_embeddings_size, 2,
  235. hidden_size, fc_dropout_prob)
  236. self.encoder = BertEncoder(
  237. num_hidden_layers, hidden_size, intermediate_size,
  238. num_attention_heads, attention_probs_dropout_prob, fc_dropout_prob,
  239. act_fn, output_attentions, output_hidden_feats)
  240. self.pooler = BertPooler(hidden_size)
  241. self.position_embeddings = nn.Embedding(position_embeddings_size,
  242. hidden_size)
  243. self.img_embedding = nn.Linear(
  244. features_dims, hidden_size, bias_attr=True)
  245. self.dropout = nn.Dropout(fc_dropout_prob)
  246. self.cls_head = nn.Linear(hidden_size, output_feature_dim)
  247. self.residual = nn.Linear(features_dims, output_feature_dim)
  248. self.apply(self.init_weights)
  249. def init_weights(self, module):
  250. """ Initialize the weights.
  251. """
  252. if isinstance(module, (nn.Linear, nn.Embedding)):
  253. module.weight.set_value(
  254. paddle.normal(
  255. mean=0.0, std=0.02, shape=module.weight.shape))
  256. elif isinstance(module, nn.LayerNorm):
  257. module.bias.set_value(paddle.zeros(shape=module.bias.shape))
  258. module.weight.set_value(
  259. paddle.full(
  260. shape=module.weight.shape, fill_value=1.0))
  261. if isinstance(module, nn.Linear) and module.bias is not None:
  262. module.bias.set_value(paddle.zeros(shape=module.bias.shape))
  263. def forward(self, x):
  264. batchsize, seq_len = paddle.shape(x)[:2]
  265. input_ids = paddle.zeros((batchsize, seq_len), dtype="int64")
  266. position_ids = paddle.arange(
  267. seq_len, dtype="int64").unsqueeze(0).expand_as(input_ids)
  268. attention_mask = paddle.ones_like(input_ids).unsqueeze(1).unsqueeze(2)
  269. head_mask = [None] * self.num_hidden_layers
  270. position_embs = self.position_embeddings(position_ids)
  271. attention_mask = (1.0 - attention_mask) * -10000.0
  272. img_features = self.img_embedding(x)
  273. # We empirically observe that adding an additional learnable position embedding leads to more stable training
  274. embeddings = position_embs + img_features
  275. if self.use_img_layernorm:
  276. embeddings = self.layernorm(embeddings)
  277. embeddings = self.dropout(embeddings)
  278. encoder_outputs = self.encoder(
  279. embeddings, attention_mask, head_mask=head_mask)
  280. pred_score = self.cls_head(encoder_outputs[0])
  281. res_img_feats = self.residual(x)
  282. pred_score = pred_score + res_img_feats
  283. if self.output_attentions and self.output_hidden_feats:
  284. return pred_score, encoder_outputs[1], encoder_outputs[-1]
  285. else:
  286. return pred_score
  287. def gelu(x):
  288. """Implementation of the gelu activation function.
  289. https://arxiv.org/abs/1606.08415
  290. """
  291. return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0)))
  292. @register
  293. class TransEncoder(nn.Layer):
  294. def __init__(self,
  295. vocab_size=30522,
  296. num_hidden_layers=4,
  297. num_attention_heads=4,
  298. position_embeddings_size=512,
  299. intermediate_size=3072,
  300. input_feat_dim=[2048, 512, 128],
  301. hidden_feat_dim=[1024, 256, 128],
  302. attention_probs_dropout_prob=0.1,
  303. fc_dropout_prob=0.1,
  304. act_fn='gelu',
  305. output_attentions=False,
  306. output_hidden_feats=False):
  307. super(TransEncoder, self).__init__()
  308. output_feat_dim = input_feat_dim[1:] + [3]
  309. trans_encoder = []
  310. for i in range(len(output_feat_dim)):
  311. features_dims = input_feat_dim[i]
  312. output_feature_dim = output_feat_dim[i]
  313. hidden_size = hidden_feat_dim[i]
  314. # init a transformer encoder and append it to a list
  315. assert hidden_size % num_attention_heads == 0
  316. model = METROEncoder(vocab_size, num_hidden_layers, features_dims,
  317. position_embeddings_size, hidden_size,
  318. intermediate_size, output_feature_dim,
  319. num_attention_heads,
  320. attention_probs_dropout_prob, fc_dropout_prob,
  321. act_fn, output_attentions, output_hidden_feats)
  322. trans_encoder.append(model)
  323. self.trans_encoder = paddle.nn.Sequential(*trans_encoder)
  324. def forward(self, x):
  325. out = self.trans_encoder(x)
  326. return out