vision_transformer.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. # Licensed under the Apache License, Version 2.0 (the "License");
  3. # you may not use this file except in compliance with the License.
  4. # You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software
  9. # distributed under the License is distributed on an "AS IS" BASIS,
  10. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  11. # See the License for the specific language governing permissions and
  12. # limitations under the License.
  13. import math
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. import numpy as np
  18. from paddle.nn.initializer import Constant
  19. from ppdet.modeling.shape_spec import ShapeSpec
  20. from ppdet.core.workspace import register, serializable
  21. from .transformer_utils import zeros_, DropPath, Identity
  22. class Mlp(nn.Layer):
  23. def __init__(self,
  24. in_features,
  25. hidden_features=None,
  26. out_features=None,
  27. act_layer=nn.GELU,
  28. drop=0.):
  29. super().__init__()
  30. out_features = out_features or in_features
  31. hidden_features = hidden_features or in_features
  32. self.fc1 = nn.Linear(in_features, hidden_features)
  33. self.act = act_layer()
  34. self.fc2 = nn.Linear(hidden_features, out_features)
  35. self.drop = nn.Dropout(drop)
  36. def forward(self, x):
  37. x = self.fc1(x)
  38. x = self.act(x)
  39. x = self.drop(x)
  40. x = self.fc2(x)
  41. x = self.drop(x)
  42. return x
  43. class Attention(nn.Layer):
  44. def __init__(self,
  45. dim,
  46. num_heads=8,
  47. qkv_bias=False,
  48. qk_scale=None,
  49. attn_drop=0.,
  50. proj_drop=0.,
  51. window_size=None):
  52. super().__init__()
  53. self.num_heads = num_heads
  54. head_dim = dim // num_heads
  55. self.scale = qk_scale or head_dim**-0.5
  56. self.qkv = nn.Linear(dim, dim * 3, bias_attr=False)
  57. if qkv_bias:
  58. self.q_bias = self.create_parameter(
  59. shape=([dim]), default_initializer=zeros_)
  60. self.v_bias = self.create_parameter(
  61. shape=([dim]), default_initializer=zeros_)
  62. else:
  63. self.q_bias = None
  64. self.v_bias = None
  65. if window_size:
  66. self.window_size = window_size
  67. self.num_relative_distance = (2 * window_size[0] - 1) * (
  68. 2 * window_size[1] - 1) + 3
  69. self.relative_position_bias_table = self.create_parameter(
  70. shape=(self.num_relative_distance, num_heads),
  71. default_initializer=zeros_) # 2*Wh-1 * 2*Ww-1, nH
  72. # cls to token & token 2 cls & cls to cls
  73. # get pair-wise relative position index for each token inside the window
  74. coords_h = paddle.arange(window_size[0])
  75. coords_w = paddle.arange(window_size[1])
  76. coords = paddle.stack(paddle.meshgrid(
  77. [coords_h, coords_w])) # 2, Wh, Ww
  78. coords_flatten = paddle.flatten(coords, 1) # 2, Wh*Ww
  79. coords_flatten_1 = paddle.unsqueeze(coords_flatten, 2)
  80. coords_flatten_2 = paddle.unsqueeze(coords_flatten, 1)
  81. relative_coords = coords_flatten_1.clone() - coords_flatten_2.clone(
  82. )
  83. #relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Wh
  84. relative_coords = relative_coords.transpose(
  85. (1, 2, 0)) #.contiguous() # Wh*Ww, Wh*Ww, 2
  86. relative_coords[:, :, 0] += window_size[
  87. 0] - 1 # shift to start from 0
  88. relative_coords[:, :, 1] += window_size[1] - 1
  89. relative_coords[:, :, 0] *= 2 * window_size[1] - 1
  90. relative_position_index = \
  91. paddle.zeros(shape=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
  92. relative_position_index[1:, 1:] = relative_coords.sum(
  93. -1) # Wh*Ww, Wh*Ww
  94. relative_position_index[0, 0:] = self.num_relative_distance - 3
  95. relative_position_index[0:, 0] = self.num_relative_distance - 2
  96. relative_position_index[0, 0] = self.num_relative_distance - 1
  97. self.register_buffer("relative_position_index",
  98. relative_position_index)
  99. # trunc_normal_(self.relative_position_bias_table, std=.0)
  100. else:
  101. self.window_size = None
  102. self.relative_position_bias_table = None
  103. self.relative_position_index = None
  104. self.attn_drop = nn.Dropout(attn_drop)
  105. self.proj = nn.Linear(dim, dim)
  106. self.proj_drop = nn.Dropout(proj_drop)
  107. def forward(self, x, rel_pos_bias=None):
  108. x_shape = paddle.shape(x)
  109. N, C = x_shape[1], x_shape[2]
  110. qkv_bias = None
  111. if self.q_bias is not None:
  112. qkv_bias = paddle.concat(
  113. (self.q_bias, paddle.zeros_like(self.v_bias), self.v_bias))
  114. qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
  115. qkv = qkv.reshape((-1, N, 3, self.num_heads,
  116. C // self.num_heads)).transpose((2, 0, 3, 1, 4))
  117. q, k, v = qkv[0], qkv[1], qkv[2]
  118. attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
  119. if self.relative_position_bias_table is not None:
  120. relative_position_bias = self.relative_position_bias_table[
  121. self.relative_position_index.reshape([-1])].reshape([
  122. self.window_size[0] * self.window_size[1] + 1,
  123. self.window_size[0] * self.window_size[1] + 1, -1
  124. ]) # Wh*Ww,Wh*Ww,nH
  125. relative_position_bias = relative_position_bias.transpose(
  126. (2, 0, 1)) #.contiguous() # nH, Wh*Ww, Wh*Ww
  127. attn = attn + relative_position_bias.unsqueeze(0)
  128. if rel_pos_bias is not None:
  129. attn = attn + rel_pos_bias
  130. attn = nn.functional.softmax(attn, axis=-1)
  131. attn = self.attn_drop(attn)
  132. x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, N, C))
  133. x = self.proj(x)
  134. x = self.proj_drop(x)
  135. return x
  136. class Block(nn.Layer):
  137. def __init__(self,
  138. dim,
  139. num_heads,
  140. mlp_ratio=4.,
  141. qkv_bias=False,
  142. qk_scale=None,
  143. drop=0.,
  144. attn_drop=0.,
  145. drop_path=0.,
  146. window_size=None,
  147. init_values=None,
  148. act_layer=nn.GELU,
  149. norm_layer='nn.LayerNorm',
  150. epsilon=1e-5):
  151. super().__init__()
  152. self.norm1 = nn.LayerNorm(dim, epsilon=1e-6)
  153. self.attn = Attention(
  154. dim,
  155. num_heads=num_heads,
  156. qkv_bias=qkv_bias,
  157. qk_scale=qk_scale,
  158. attn_drop=attn_drop,
  159. proj_drop=drop,
  160. window_size=window_size)
  161. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  162. self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
  163. self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
  164. mlp_hidden_dim = int(dim * mlp_ratio)
  165. self.mlp = Mlp(in_features=dim,
  166. hidden_features=mlp_hidden_dim,
  167. act_layer=act_layer,
  168. drop=drop)
  169. if init_values is not None:
  170. self.gamma_1 = self.create_parameter(
  171. shape=([dim]), default_initializer=Constant(value=init_values))
  172. self.gamma_2 = self.create_parameter(
  173. shape=([dim]), default_initializer=Constant(value=init_values))
  174. else:
  175. self.gamma_1, self.gamma_2 = None, None
  176. def forward(self, x, rel_pos_bias=None):
  177. if self.gamma_1 is None:
  178. x = x + self.drop_path(
  179. self.attn(
  180. self.norm1(x), rel_pos_bias=rel_pos_bias))
  181. x = x + self.drop_path(self.mlp(self.norm2(x)))
  182. else:
  183. x = x + self.drop_path(self.gamma_1 * self.attn(
  184. self.norm1(x), rel_pos_bias=rel_pos_bias))
  185. x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
  186. return x
  187. class PatchEmbed(nn.Layer):
  188. """ Image to Patch Embedding
  189. """
  190. def __init__(self,
  191. img_size=[224, 224],
  192. patch_size=16,
  193. in_chans=3,
  194. embed_dim=768):
  195. super().__init__()
  196. self.num_patches_w = img_size[0] // patch_size
  197. self.num_patches_h = img_size[1] // patch_size
  198. num_patches = self.num_patches_w * self.num_patches_h
  199. self.patch_shape = (img_size[0] // patch_size,
  200. img_size[1] // patch_size)
  201. self.img_size = img_size
  202. self.patch_size = patch_size
  203. self.num_patches = num_patches
  204. self.proj = nn.Conv2D(
  205. in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  206. @property
  207. def num_patches_in_h(self):
  208. return self.img_size[1] // self.patch_size
  209. @property
  210. def num_patches_in_w(self):
  211. return self.img_size[0] // self.patch_size
  212. def forward(self, x, mask=None):
  213. B, C, H, W = x.shape
  214. return self.proj(x)
  215. class RelativePositionBias(nn.Layer):
  216. def __init__(self, window_size, num_heads):
  217. super().__init__()
  218. self.window_size = window_size
  219. self.num_relative_distance = (2 * window_size[0] - 1) * (
  220. 2 * window_size[1] - 1) + 3
  221. self.relative_position_bias_table = self.create_parameter(
  222. shape=(self.num_relative_distance, num_heads),
  223. default_initialize=zeros_)
  224. # cls to token & token 2 cls & cls to cls
  225. # get pair-wise relative position index for each token inside the window
  226. coords_h = paddle.arange(window_size[0])
  227. coords_w = paddle.arange(window_size[1])
  228. coords = paddle.stack(paddle.meshgrid(
  229. [coords_h, coords_w])) # 2, Wh, Ww
  230. coords_flatten = coords.flatten(1) # 2, Wh*Ww
  231. relative_coords = coords_flatten[:, :,
  232. None] - coords_flatten[:,
  233. None, :] # 2, Wh*Ww, Wh*Ww
  234. relative_coords = relative_coords.transpos(
  235. (1, 2, 0)) # Wh*Ww, Wh*Ww, 2
  236. relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
  237. relative_coords[:, :, 1] += window_size[1] - 1
  238. relative_coords[:, :, 0] *= 2 * window_size[1] - 1
  239. relative_position_index = \
  240. paddle.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
  241. relative_position_index[1:, 1:] = relative_coords.sum(
  242. -1) # Wh*Ww, Wh*Ww
  243. relative_position_index[0, 0:] = self.num_relative_distance - 3
  244. relative_position_index[0:, 0] = self.num_relative_distance - 2
  245. relative_position_index[0, 0] = self.num_relative_distance - 1
  246. self.register_buffer("relative_position_index", relative_position_index)
  247. def forward(self):
  248. relative_position_bias = \
  249. self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
  250. self.window_size[0] * self.window_size[1] + 1,
  251. self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
  252. return relative_position_bias.transpose((2, 0, 1)) # nH, Wh*Ww, Wh*Ww
  253. def get_sinusoid_encoding_table(n_position, d_hid, token=False):
  254. ''' Sinusoid position encoding table '''
  255. def get_position_angle_vec(position):
  256. return [
  257. position / np.power(10000, 2 * (hid_j // 2) / d_hid)
  258. for hid_j in range(d_hid)
  259. ]
  260. sinusoid_table = np.array(
  261. [get_position_angle_vec(pos_i) for pos_i in range(n_position)])
  262. sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
  263. sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
  264. if token:
  265. sinusoid_table = np.concatenate(
  266. [sinusoid_table, np.zeros([1, d_hid])], dim=0)
  267. return paddle.to_tensor(sinusoid_table, dtype=paddle.float32).unsqueeze(0)
  268. @register
  269. @serializable
  270. class VisionTransformer(nn.Layer):
  271. """ Vision Transformer with support for patch input
  272. """
  273. def __init__(self,
  274. img_size=[672, 1092],
  275. patch_size=16,
  276. in_chans=3,
  277. embed_dim=768,
  278. depth=12,
  279. num_heads=12,
  280. mlp_ratio=4,
  281. qkv_bias=False,
  282. qk_scale=None,
  283. drop_rate=0.,
  284. attn_drop_rate=0.,
  285. drop_path_rate=0.,
  286. norm_layer='nn.LayerNorm',
  287. init_values=None,
  288. use_rel_pos_bias=False,
  289. use_shared_rel_pos_bias=False,
  290. epsilon=1e-5,
  291. final_norm=False,
  292. pretrained=None,
  293. out_indices=[3, 5, 7, 11],
  294. use_abs_pos_emb=False,
  295. use_sincos_pos_emb=True,
  296. with_fpn=True,
  297. num_fpn_levels=4,
  298. use_checkpoint=False,
  299. **args):
  300. super().__init__()
  301. self.img_size = img_size
  302. self.embed_dim = embed_dim
  303. self.with_fpn = with_fpn
  304. self.use_checkpoint = use_checkpoint
  305. self.use_sincos_pos_emb = use_sincos_pos_emb
  306. self.use_rel_pos_bias = use_rel_pos_bias
  307. self.final_norm = final_norm
  308. self.out_indices = out_indices
  309. self.num_fpn_levels = num_fpn_levels
  310. if use_checkpoint:
  311. paddle.seed(0)
  312. self.patch_embed = PatchEmbed(
  313. img_size=img_size,
  314. patch_size=patch_size,
  315. in_chans=in_chans,
  316. embed_dim=embed_dim)
  317. self.pos_w = self.patch_embed.num_patches_in_w
  318. self.pos_h = self.patch_embed.num_patches_in_h
  319. self.cls_token = self.create_parameter(
  320. shape=(1, 1, embed_dim),
  321. default_initializer=paddle.nn.initializer.Constant(value=0.))
  322. if use_abs_pos_emb:
  323. self.pos_embed = self.create_parameter(
  324. shape=(1, self.pos_w * self.pos_h + 1, embed_dim),
  325. default_initializer=paddle.nn.initializer.TruncatedNormal(
  326. std=.02))
  327. elif use_sincos_pos_emb:
  328. pos_embed = self.build_2d_sincos_position_embedding(embed_dim)
  329. self.pos_embed = pos_embed
  330. self.pos_embed = self.create_parameter(shape=pos_embed.shape)
  331. self.pos_embed.set_value(pos_embed.numpy())
  332. self.pos_embed.stop_gradient = True
  333. else:
  334. self.pos_embed = None
  335. self.pos_drop = nn.Dropout(p=drop_rate)
  336. if use_shared_rel_pos_bias:
  337. self.rel_pos_bias = RelativePositionBias(
  338. window_size=self.patch_embed.patch_shape, num_heads=num_heads)
  339. else:
  340. self.rel_pos_bias = None
  341. dpr = np.linspace(0, drop_path_rate, depth)
  342. self.blocks = nn.LayerList([
  343. Block(
  344. dim=embed_dim,
  345. num_heads=num_heads,
  346. mlp_ratio=mlp_ratio,
  347. qkv_bias=qkv_bias,
  348. qk_scale=qk_scale,
  349. drop=drop_rate,
  350. attn_drop=attn_drop_rate,
  351. drop_path=dpr[i],
  352. norm_layer=norm_layer,
  353. init_values=init_values,
  354. window_size=self.patch_embed.patch_shape
  355. if use_rel_pos_bias else None,
  356. epsilon=epsilon) for i in range(depth)
  357. ])
  358. self.pretrained = pretrained
  359. self.init_weight()
  360. assert len(out_indices) <= 4, ''
  361. self.out_indices = out_indices
  362. self.out_channels = [embed_dim for _ in range(num_fpn_levels)]
  363. self.out_strides = [4, 8, 16, 32][-num_fpn_levels:] if with_fpn else [
  364. patch_size for _ in range(len(out_indices))
  365. ]
  366. self.norm = Identity()
  367. if self.with_fpn:
  368. assert num_fpn_levels <= 4, ''
  369. self.init_fpn(
  370. embed_dim=embed_dim,
  371. patch_size=patch_size, )
  372. def init_weight(self):
  373. pretrained = self.pretrained
  374. if pretrained:
  375. if 'http' in pretrained: #URL
  376. path = paddle.utils.download.get_weights_path_from_url(
  377. pretrained)
  378. else: #model in local path
  379. path = pretrained
  380. load_state_dict = paddle.load(path)
  381. model_state_dict = self.state_dict()
  382. pos_embed_name = "pos_embed"
  383. if pos_embed_name in load_state_dict.keys():
  384. load_pos_embed = paddle.to_tensor(
  385. load_state_dict[pos_embed_name], dtype="float32")
  386. if self.pos_embed.shape != load_pos_embed.shape:
  387. pos_size = int(math.sqrt(load_pos_embed.shape[1] - 1))
  388. model_state_dict[pos_embed_name] = self.resize_pos_embed(
  389. load_pos_embed, (pos_size, pos_size),
  390. (self.pos_h, self.pos_w))
  391. # self.set_state_dict(model_state_dict)
  392. load_state_dict[pos_embed_name] = model_state_dict[
  393. pos_embed_name]
  394. print("Load pos_embed and resize it from {} to {} .".format(
  395. load_pos_embed.shape, self.pos_embed.shape))
  396. self.set_state_dict(load_state_dict)
  397. print("Load load_state_dict....")
  398. def init_fpn(self, embed_dim=768, patch_size=16, out_with_norm=False):
  399. if patch_size == 16:
  400. self.fpn1 = nn.Sequential(
  401. nn.Conv2DTranspose(
  402. embed_dim, embed_dim, kernel_size=2, stride=2),
  403. nn.BatchNorm2D(embed_dim),
  404. nn.GELU(),
  405. nn.Conv2DTranspose(
  406. embed_dim, embed_dim, kernel_size=2, stride=2), )
  407. self.fpn2 = nn.Sequential(
  408. nn.Conv2DTranspose(
  409. embed_dim, embed_dim, kernel_size=2, stride=2), )
  410. self.fpn3 = Identity()
  411. self.fpn4 = nn.MaxPool2D(kernel_size=2, stride=2)
  412. elif patch_size == 8:
  413. self.fpn1 = nn.Sequential(
  414. nn.Conv2DTranspose(
  415. embed_dim, embed_dim, kernel_size=2, stride=2), )
  416. self.fpn2 = Identity()
  417. self.fpn3 = nn.Sequential(nn.MaxPool2D(kernel_size=2, stride=2), )
  418. self.fpn4 = nn.Sequential(nn.MaxPool2D(kernel_size=4, stride=4), )
  419. if not out_with_norm:
  420. self.norm = Identity()
  421. else:
  422. self.norm = nn.LayerNorm(embed_dim, epsilon=1e-6)
  423. def interpolate_pos_encoding(self, x, w, h):
  424. npatch = x.shape[1] - 1
  425. N = self.pos_embed.shape[1] - 1
  426. w0 = w // self.patch_embed.patch_size
  427. h0 = h // self.patch_embed.patch_size
  428. if npatch == N and w0 == self.patch_embed.num_patches_w and h0 == self.patch_embed.num_patches_h:
  429. return self.pos_embed
  430. class_pos_embed = self.pos_embed[:, 0]
  431. patch_pos_embed = self.pos_embed[:, 1:]
  432. dim = x.shape[-1]
  433. # we add a small number to avoid floating point error in the interpolation
  434. # see discussion at https://github.com/facebookresearch/dino/issues/8
  435. # w0, h0 = w0 + 0.1, h0 + 0.1
  436. # patch_pos_embed = nn.functional.interpolate(
  437. # patch_pos_embed.reshape([
  438. # 1, self.patch_embed.num_patches_w,
  439. # self.patch_embed.num_patches_h, dim
  440. # ]).transpose((0, 3, 1, 2)),
  441. # scale_factor=(w0 / self.patch_embed.num_patches_w,
  442. # h0 / self.patch_embed.num_patches_h),
  443. # mode='bicubic', )
  444. patch_pos_embed = nn.functional.interpolate(
  445. patch_pos_embed.reshape([
  446. 1, self.patch_embed.num_patches_w,
  447. self.patch_embed.num_patches_h, dim
  448. ]).transpose((0, 3, 1, 2)),
  449. (w0, h0),
  450. mode='bicubic', )
  451. assert int(w0) == patch_pos_embed.shape[-2] and int(
  452. h0) == patch_pos_embed.shape[-1]
  453. patch_pos_embed = patch_pos_embed.transpose(
  454. (0, 2, 3, 1)).reshape([1, -1, dim])
  455. return paddle.concat(
  456. (class_pos_embed.unsqueeze(0), patch_pos_embed), axis=1)
  457. def resize_pos_embed(self, pos_embed, old_hw, new_hw):
  458. """
  459. Resize pos_embed weight.
  460. Args:
  461. pos_embed (Tensor): the pos_embed weight
  462. old_hw (list[int]): the height and width of old pos_embed
  463. new_hw (list[int]): the height and width of new pos_embed
  464. Returns:
  465. Tensor: the resized pos_embed weight
  466. """
  467. cls_pos_embed = pos_embed[:, :1, :]
  468. pos_embed = pos_embed[:, 1:, :]
  469. pos_embed = pos_embed.transpose([0, 2, 1])
  470. pos_embed = pos_embed.reshape([1, -1, old_hw[0], old_hw[1]])
  471. pos_embed = F.interpolate(
  472. pos_embed, new_hw, mode='bicubic', align_corners=False)
  473. pos_embed = pos_embed.flatten(2).transpose([0, 2, 1])
  474. pos_embed = paddle.concat([cls_pos_embed, pos_embed], axis=1)
  475. return pos_embed
  476. def build_2d_sincos_position_embedding(
  477. self,
  478. embed_dim=768,
  479. temperature=10000., ):
  480. h, w = self.patch_embed.patch_shape
  481. grid_w = paddle.arange(w, dtype=paddle.float32)
  482. grid_h = paddle.arange(h, dtype=paddle.float32)
  483. grid_w, grid_h = paddle.meshgrid(grid_w, grid_h)
  484. assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
  485. pos_dim = embed_dim // 4
  486. omega = paddle.arange(pos_dim, dtype=paddle.float32) / pos_dim
  487. omega = 1. / (temperature**omega)
  488. out_w = grid_w.flatten()[..., None] @omega[None]
  489. out_h = grid_h.flatten()[..., None] @omega[None]
  490. pos_emb = paddle.concat(
  491. [
  492. paddle.sin(out_w), paddle.cos(out_w), paddle.sin(out_h),
  493. paddle.cos(out_h)
  494. ],
  495. axis=1)[None, :, :]
  496. pe_token = paddle.zeros([1, 1, embed_dim], dtype=paddle.float32)
  497. pos_embed = paddle.concat([pe_token, pos_emb], axis=1)
  498. # pos_embed.stop_gradient = True
  499. return pos_embed
  500. def forward(self, x):
  501. x = x['image'] if isinstance(x, dict) else x
  502. _, _, h, w = x.shape
  503. x = self.patch_embed(x)
  504. B, D, Hp, Wp = x.shape # b * c * h * w
  505. cls_tokens = self.cls_token.expand(
  506. (B, self.cls_token.shape[-2], self.cls_token.shape[-1]))
  507. x = x.flatten(2).transpose([0, 2, 1]) # b * hw * c
  508. x = paddle.concat([cls_tokens, x], axis=1)
  509. if self.pos_embed is not None:
  510. # x = x + self.interpolate_pos_encoding(x, w, h)
  511. x = x + self.interpolate_pos_encoding(x, h, w)
  512. x = self.pos_drop(x)
  513. rel_pos_bias = self.rel_pos_bias(
  514. ) if self.rel_pos_bias is not None else None
  515. feats = []
  516. for idx, blk in enumerate(self.blocks):
  517. if self.use_checkpoint and self.training:
  518. x = paddle.distributed.fleet.utils.recompute(
  519. blk, x, rel_pos_bias, **{"preserve_rng_state": True})
  520. else:
  521. x = blk(x, rel_pos_bias)
  522. if idx in self.out_indices:
  523. xp = paddle.reshape(
  524. paddle.transpose(
  525. self.norm(x[:, 1:, :]), perm=[0, 2, 1]),
  526. shape=[B, D, Hp, Wp])
  527. feats.append(xp)
  528. if self.with_fpn:
  529. fpns = [self.fpn1, self.fpn2, self.fpn3, self.fpn4][
  530. -self.num_fpn_levels:]
  531. assert len(fpns) == len(feats) or len(feats) == 1, ''
  532. outputs = []
  533. for i, m in enumerate(fpns):
  534. outputs.append(
  535. m(feats[i] if len(feats) == len(fpns) else feats[-1]))
  536. return outputs
  537. return feats
  538. @property
  539. def num_layers(self):
  540. return len(self.blocks)
  541. @property
  542. def no_weight_decay(self):
  543. return {'pos_embed', 'cls_token'}
  544. @property
  545. def out_shape(self):
  546. return [
  547. ShapeSpec(
  548. channels=c, stride=s)
  549. for c, s in zip(self.out_channels, self.out_strides)
  550. ]