swin_transformer.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
  16. Ths copyright of microsoft/Swin-Transformer is as follows:
  17. MIT License [see LICENSE for details]
  18. """
  19. import paddle
  20. import paddle.nn as nn
  21. import paddle.nn.functional as F
  22. from ppdet.modeling.shape_spec import ShapeSpec
  23. from ppdet.core.workspace import register, serializable
  24. import numpy as np
  25. from .transformer_utils import DropPath, Identity
  26. from .transformer_utils import add_parameter, to_2tuple
  27. from .transformer_utils import ones_, zeros_, trunc_normal_
  28. class Mlp(nn.Layer):
  29. def __init__(self,
  30. in_features,
  31. hidden_features=None,
  32. out_features=None,
  33. act_layer=nn.GELU,
  34. drop=0.):
  35. super().__init__()
  36. out_features = out_features or in_features
  37. hidden_features = hidden_features or in_features
  38. self.fc1 = nn.Linear(in_features, hidden_features)
  39. self.act = act_layer()
  40. self.fc2 = nn.Linear(hidden_features, out_features)
  41. self.drop = nn.Dropout(drop)
  42. def forward(self, x):
  43. x = self.fc1(x)
  44. x = self.act(x)
  45. x = self.drop(x)
  46. x = self.fc2(x)
  47. x = self.drop(x)
  48. return x
  49. def window_partition(x, window_size):
  50. """
  51. Args:
  52. x: (B, H, W, C)
  53. window_size (int): window size
  54. Returns:
  55. windows: (num_windows*B, window_size, window_size, C)
  56. """
  57. B, H, W, C = x.shape
  58. x = x.reshape(
  59. [-1, H // window_size, window_size, W // window_size, window_size, C])
  60. windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape(
  61. [-1, window_size, window_size, C])
  62. return windows
  63. def window_reverse(windows, window_size, H, W):
  64. """
  65. Args:
  66. windows: (num_windows*B, window_size, window_size, C)
  67. window_size (int): Window size
  68. H (int): Height of image
  69. W (int): Width of image
  70. Returns:
  71. x: (B, H, W, C)
  72. """
  73. _, _, _, C = windows.shape
  74. B = int(windows.shape[0] / (H * W / window_size / window_size))
  75. x = windows.reshape(
  76. [-1, H // window_size, W // window_size, window_size, window_size, C])
  77. x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, H, W, C])
  78. return x
  79. class WindowAttention(nn.Layer):
  80. """ Window based multi-head self attention (W-MSA) module with relative position bias.
  81. It supports both of shifted and non-shifted window.
  82. Args:
  83. dim (int): Number of input channels.
  84. window_size (tuple[int]): The height and width of the window.
  85. num_heads (int): Number of attention heads.
  86. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  87. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
  88. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
  89. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
  90. """
  91. def __init__(self,
  92. dim,
  93. window_size,
  94. num_heads,
  95. qkv_bias=True,
  96. qk_scale=None,
  97. attn_drop=0.,
  98. proj_drop=0.):
  99. super().__init__()
  100. self.dim = dim
  101. self.window_size = window_size # Wh, Ww
  102. self.num_heads = num_heads
  103. head_dim = dim // num_heads
  104. self.scale = qk_scale or head_dim**-0.5
  105. # define a parameter table of relative position bias
  106. self.relative_position_bias_table = add_parameter(
  107. self,
  108. paddle.zeros(((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
  109. num_heads))) # 2*Wh-1 * 2*Ww-1, nH
  110. # get pair-wise relative position index for each token inside the window
  111. coords_h = paddle.arange(self.window_size[0])
  112. coords_w = paddle.arange(self.window_size[1])
  113. coords = paddle.stack(paddle.meshgrid(
  114. [coords_h, coords_w])) # 2, Wh, Ww
  115. coords_flatten = paddle.flatten(coords, 1) # 2, Wh*Ww
  116. coords_flatten_1 = coords_flatten.unsqueeze(axis=2)
  117. coords_flatten_2 = coords_flatten.unsqueeze(axis=1)
  118. relative_coords = coords_flatten_1 - coords_flatten_2
  119. relative_coords = relative_coords.transpose(
  120. [1, 2, 0]) # Wh*Ww, Wh*Ww, 2
  121. relative_coords[:, :, 0] += self.window_size[
  122. 0] - 1 # shift to start from 0
  123. relative_coords[:, :, 1] += self.window_size[1] - 1
  124. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  125. self.relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  126. self.register_buffer("relative_position_index",
  127. self.relative_position_index)
  128. self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
  129. self.attn_drop = nn.Dropout(attn_drop)
  130. self.proj = nn.Linear(dim, dim)
  131. self.proj_drop = nn.Dropout(proj_drop)
  132. trunc_normal_(self.relative_position_bias_table)
  133. self.softmax = nn.Softmax(axis=-1)
  134. def forward(self, x, mask=None):
  135. """ Forward function.
  136. Args:
  137. x: input features with shape of (num_windows*B, N, C)
  138. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
  139. """
  140. B_, N, C = x.shape
  141. qkv = self.qkv(x).reshape(
  142. [-1, N, 3, self.num_heads, C // self.num_heads]).transpose(
  143. [2, 0, 3, 1, 4])
  144. q, k, v = qkv[0], qkv[1], qkv[2]
  145. q = q * self.scale
  146. attn = paddle.mm(q, k.transpose([0, 1, 3, 2]))
  147. index = self.relative_position_index.flatten()
  148. relative_position_bias = paddle.index_select(
  149. self.relative_position_bias_table, index)
  150. relative_position_bias = relative_position_bias.reshape([
  151. self.window_size[0] * self.window_size[1],
  152. self.window_size[0] * self.window_size[1], -1
  153. ]) # Wh*Ww,Wh*Ww,nH
  154. relative_position_bias = relative_position_bias.transpose(
  155. [2, 0, 1]) # nH, Wh*Ww, Wh*Ww
  156. attn = attn + relative_position_bias.unsqueeze(0)
  157. if mask is not None:
  158. nW = mask.shape[0]
  159. attn = attn.reshape([-1, nW, self.num_heads, N, N
  160. ]) + mask.unsqueeze(1).unsqueeze(0)
  161. attn = attn.reshape([-1, self.num_heads, N, N])
  162. attn = self.softmax(attn)
  163. else:
  164. attn = self.softmax(attn)
  165. attn = self.attn_drop(attn)
  166. # x = (attn @ v).transpose(1, 2).reshape([B_, N, C])
  167. x = paddle.mm(attn, v).transpose([0, 2, 1, 3]).reshape([-1, N, C])
  168. x = self.proj(x)
  169. x = self.proj_drop(x)
  170. return x
  171. class SwinTransformerBlock(nn.Layer):
  172. """ Swin Transformer Block.
  173. Args:
  174. dim (int): Number of input channels.
  175. num_heads (int): Number of attention heads.
  176. window_size (int): Window size.
  177. shift_size (int): Shift size for SW-MSA.
  178. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  179. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  180. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  181. drop (float, optional): Dropout rate. Default: 0.0
  182. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  183. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  184. act_layer (nn.Layer, optional): Activation layer. Default: nn.GELU
  185. norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
  186. """
  187. def __init__(self,
  188. dim,
  189. num_heads,
  190. window_size=7,
  191. shift_size=0,
  192. mlp_ratio=4.,
  193. qkv_bias=True,
  194. qk_scale=None,
  195. drop=0.,
  196. attn_drop=0.,
  197. drop_path=0.,
  198. act_layer=nn.GELU,
  199. norm_layer=nn.LayerNorm):
  200. super().__init__()
  201. self.dim = dim
  202. self.num_heads = num_heads
  203. self.window_size = window_size
  204. self.shift_size = shift_size
  205. self.mlp_ratio = mlp_ratio
  206. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
  207. self.norm1 = norm_layer(dim)
  208. self.attn = WindowAttention(
  209. dim,
  210. window_size=to_2tuple(self.window_size),
  211. num_heads=num_heads,
  212. qkv_bias=qkv_bias,
  213. qk_scale=qk_scale,
  214. attn_drop=attn_drop,
  215. proj_drop=drop)
  216. self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
  217. self.norm2 = norm_layer(dim)
  218. mlp_hidden_dim = int(dim * mlp_ratio)
  219. self.mlp = Mlp(in_features=dim,
  220. hidden_features=mlp_hidden_dim,
  221. act_layer=act_layer,
  222. drop=drop)
  223. self.H = None
  224. self.W = None
  225. def forward(self, x, mask_matrix):
  226. """ Forward function.
  227. Args:
  228. x: Input feature, tensor size (B, H*W, C).
  229. H, W: Spatial resolution of the input feature.
  230. mask_matrix: Attention mask for cyclic shift.
  231. """
  232. B, L, C = x.shape
  233. H, W = self.H, self.W
  234. assert L == H * W, "input feature has wrong size"
  235. shortcut = x
  236. x = self.norm1(x)
  237. x = x.reshape([-1, H, W, C])
  238. # pad feature maps to multiples of window size
  239. pad_l = pad_t = 0
  240. pad_r = (self.window_size - W % self.window_size) % self.window_size
  241. pad_b = (self.window_size - H % self.window_size) % self.window_size
  242. x = F.pad(x, [0, pad_l, 0, pad_b, 0, pad_r, 0, pad_t])
  243. _, Hp, Wp, _ = x.shape
  244. # cyclic shift
  245. if self.shift_size > 0:
  246. shifted_x = paddle.roll(
  247. x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2))
  248. attn_mask = mask_matrix
  249. else:
  250. shifted_x = x
  251. attn_mask = None
  252. # partition windows
  253. x_windows = window_partition(
  254. shifted_x, self.window_size) # nW*B, window_size, window_size, C
  255. x_windows = x_windows.reshape(
  256. [x_windows.shape[0], self.window_size * self.window_size,
  257. C]) # nW*B, window_size*window_size, C
  258. # W-MSA/SW-MSA
  259. attn_windows = self.attn(
  260. x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
  261. # merge windows
  262. attn_windows = attn_windows.reshape(
  263. [x_windows.shape[0], self.window_size, self.window_size, C])
  264. shifted_x = window_reverse(attn_windows, self.window_size, Hp,
  265. Wp) # B H' W' C
  266. # reverse cyclic shift
  267. if self.shift_size > 0:
  268. x = paddle.roll(
  269. shifted_x,
  270. shifts=(self.shift_size, self.shift_size),
  271. axis=(1, 2))
  272. else:
  273. x = shifted_x
  274. if pad_r > 0 or pad_b > 0:
  275. x = x[:, :H, :W, :]
  276. x = x.reshape([-1, H * W, C])
  277. # FFN
  278. x = shortcut + self.drop_path(x)
  279. x = x + self.drop_path(self.mlp(self.norm2(x)))
  280. return x
  281. class PatchMerging(nn.Layer):
  282. r""" Patch Merging Layer.
  283. Args:
  284. dim (int): Number of input channels.
  285. norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
  286. """
  287. def __init__(self, dim, norm_layer=nn.LayerNorm):
  288. super().__init__()
  289. self.dim = dim
  290. self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
  291. self.norm = norm_layer(4 * dim)
  292. def forward(self, x, H, W):
  293. """ Forward function.
  294. Args:
  295. x: Input feature, tensor size (B, H*W, C).
  296. H, W: Spatial resolution of the input feature.
  297. """
  298. B, L, C = x.shape
  299. assert L == H * W, "input feature has wrong size"
  300. x = x.reshape([-1, H, W, C])
  301. # padding
  302. pad_input = (H % 2 == 1) or (W % 2 == 1)
  303. if pad_input:
  304. x = F.pad(x, [0, 0, 0, W % 2, 0, H % 2])
  305. x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
  306. x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
  307. x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
  308. x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
  309. x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
  310. x = x.reshape([-1, H * W // 4, 4 * C]) # B H/2*W/2 4*C
  311. x = self.norm(x)
  312. x = self.reduction(x)
  313. return x
  314. class BasicLayer(nn.Layer):
  315. """ A basic Swin Transformer layer for one stage.
  316. Args:
  317. dim (int): Number of input channels.
  318. input_resolution (tuple[int]): Input resolution.
  319. depth (int): Number of blocks.
  320. num_heads (int): Number of attention heads.
  321. window_size (int): Local window size.
  322. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  323. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  324. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  325. drop (float, optional): Dropout rate. Default: 0.0
  326. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  327. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  328. norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
  329. downsample (nn.Layer | None, optional): Downsample layer at the end of the layer. Default: None
  330. """
  331. def __init__(self,
  332. dim,
  333. depth,
  334. num_heads,
  335. window_size=7,
  336. mlp_ratio=4.,
  337. qkv_bias=True,
  338. qk_scale=None,
  339. drop=0.,
  340. attn_drop=0.,
  341. drop_path=0.,
  342. norm_layer=nn.LayerNorm,
  343. downsample=None):
  344. super().__init__()
  345. self.window_size = window_size
  346. self.shift_size = window_size // 2
  347. self.depth = depth
  348. # build blocks
  349. self.blocks = nn.LayerList([
  350. SwinTransformerBlock(
  351. dim=dim,
  352. num_heads=num_heads,
  353. window_size=window_size,
  354. shift_size=0 if (i % 2 == 0) else window_size // 2,
  355. mlp_ratio=mlp_ratio,
  356. qkv_bias=qkv_bias,
  357. qk_scale=qk_scale,
  358. drop=drop,
  359. attn_drop=attn_drop,
  360. drop_path=drop_path[i]
  361. if isinstance(drop_path, np.ndarray) else drop_path,
  362. norm_layer=norm_layer) for i in range(depth)
  363. ])
  364. # patch merging layer
  365. if downsample is not None:
  366. self.downsample = downsample(dim=dim, norm_layer=norm_layer)
  367. else:
  368. self.downsample = None
  369. def forward(self, x, H, W):
  370. """ Forward function.
  371. Args:
  372. x: Input feature, tensor size (B, H*W, C).
  373. H, W: Spatial resolution of the input feature.
  374. """
  375. # calculate attention mask for SW-MSA
  376. Hp = int(np.ceil(H / self.window_size)) * self.window_size
  377. Wp = int(np.ceil(W / self.window_size)) * self.window_size
  378. img_mask = paddle.zeros([1, Hp, Wp, 1], dtype='float32') # 1 Hp Wp 1
  379. h_slices = (slice(0, -self.window_size),
  380. slice(-self.window_size, -self.shift_size),
  381. slice(-self.shift_size, None))
  382. w_slices = (slice(0, -self.window_size),
  383. slice(-self.window_size, -self.shift_size),
  384. slice(-self.shift_size, None))
  385. cnt = 0
  386. for h in h_slices:
  387. for w in w_slices:
  388. try:
  389. img_mask[:, h, w, :] = cnt
  390. except:
  391. pass
  392. cnt += 1
  393. mask_windows = window_partition(
  394. img_mask, self.window_size) # nW, window_size, window_size, 1
  395. mask_windows = mask_windows.reshape(
  396. [-1, self.window_size * self.window_size])
  397. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  398. huns = -100.0 * paddle.ones_like(attn_mask)
  399. attn_mask = huns * (attn_mask != 0).astype("float32")
  400. for blk in self.blocks:
  401. blk.H, blk.W = H, W
  402. x = blk(x, attn_mask)
  403. if self.downsample is not None:
  404. x_down = self.downsample(x, H, W)
  405. Wh, Ww = (H + 1) // 2, (W + 1) // 2
  406. return x, H, W, x_down, Wh, Ww
  407. else:
  408. return x, H, W, x, H, W
  409. class PatchEmbed(nn.Layer):
  410. """ Image to Patch Embedding
  411. Args:
  412. patch_size (int): Patch token size. Default: 4.
  413. in_chans (int): Number of input image channels. Default: 3.
  414. embed_dim (int): Number of linear projection output channels. Default: 96.
  415. norm_layer (nn.Layer, optional): Normalization layer. Default: None
  416. """
  417. def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
  418. super().__init__()
  419. patch_size = to_2tuple(patch_size)
  420. self.patch_size = patch_size
  421. self.in_chans = in_chans
  422. self.embed_dim = embed_dim
  423. self.proj = nn.Conv2D(
  424. in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  425. if norm_layer is not None:
  426. self.norm = norm_layer(embed_dim)
  427. else:
  428. self.norm = None
  429. def forward(self, x):
  430. B, C, H, W = x.shape
  431. # assert [H, W] == self.img_size[:2], "Input image size ({H}*{W}) doesn't match model ({}*{}).".format(H, W, self.img_size[0], self.img_size[1])
  432. if W % self.patch_size[1] != 0:
  433. x = F.pad(x, [0, self.patch_size[1] - W % self.patch_size[1], 0, 0])
  434. if H % self.patch_size[0] != 0:
  435. x = F.pad(x, [0, 0, 0, self.patch_size[0] - H % self.patch_size[0]])
  436. x = self.proj(x)
  437. if self.norm is not None:
  438. _, _, Wh, Ww = x.shape
  439. x = x.flatten(2).transpose([0, 2, 1])
  440. x = self.norm(x)
  441. x = x.transpose([0, 2, 1]).reshape([-1, self.embed_dim, Wh, Ww])
  442. return x
  443. @register
  444. @serializable
  445. class SwinTransformer(nn.Layer):
  446. """ Swin Transformer
  447. A PaddlePaddle impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
  448. https://arxiv.org/pdf/2103.14030
  449. Args:
  450. img_size (int | tuple(int)): Input image size. Default 224
  451. patch_size (int | tuple(int)): Patch size. Default: 4
  452. in_chans (int): Number of input image channels. Default: 3
  453. num_classes (int): Number of classes for classification head. Default: 1000
  454. embed_dim (int): Patch embedding dimension. Default: 96
  455. depths (tuple(int)): Depth of each Swin Transformer layer.
  456. num_heads (tuple(int)): Number of attention heads in different layers.
  457. window_size (int): Window size. Default: 7
  458. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  459. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
  460. qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
  461. drop_rate (float): Dropout rate. Default: 0
  462. attn_drop_rate (float): Attention dropout rate. Default: 0
  463. drop_path_rate (float): Stochastic depth rate. Default: 0.1
  464. norm_layer (nn.Layer): Normalization layer. Default: nn.LayerNorm.
  465. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
  466. patch_norm (bool): If True, add normalization after patch embedding. Default: True
  467. """
  468. def __init__(self,
  469. pretrain_img_size=224,
  470. patch_size=4,
  471. in_chans=3,
  472. embed_dim=96,
  473. depths=[2, 2, 6, 2],
  474. num_heads=[3, 6, 12, 24],
  475. window_size=7,
  476. mlp_ratio=4.,
  477. qkv_bias=True,
  478. qk_scale=None,
  479. drop_rate=0.,
  480. attn_drop_rate=0.,
  481. drop_path_rate=0.2,
  482. norm_layer=nn.LayerNorm,
  483. ape=False,
  484. patch_norm=True,
  485. out_indices=(0, 1, 2, 3),
  486. frozen_stages=-1,
  487. pretrained=None):
  488. super(SwinTransformer, self).__init__()
  489. self.pretrain_img_size = pretrain_img_size
  490. self.num_layers = len(depths)
  491. self.embed_dim = embed_dim
  492. self.ape = ape
  493. self.patch_norm = patch_norm
  494. self.out_indices = out_indices
  495. self.frozen_stages = frozen_stages
  496. # split image into non-overlapping patches
  497. self.patch_embed = PatchEmbed(
  498. patch_size=patch_size,
  499. in_chans=in_chans,
  500. embed_dim=embed_dim,
  501. norm_layer=norm_layer if self.patch_norm else None)
  502. # absolute position embedding
  503. if self.ape:
  504. pretrain_img_size = to_2tuple(pretrain_img_size)
  505. patch_size = to_2tuple(patch_size)
  506. patches_resolution = [
  507. pretrain_img_size[0] // patch_size[0],
  508. pretrain_img_size[1] // patch_size[1]
  509. ]
  510. self.absolute_pos_embed = add_parameter(
  511. self,
  512. paddle.zeros((1, embed_dim, patches_resolution[0],
  513. patches_resolution[1])))
  514. trunc_normal_(self.absolute_pos_embed)
  515. self.pos_drop = nn.Dropout(p=drop_rate)
  516. # stochastic depth
  517. dpr = np.linspace(0, drop_path_rate,
  518. sum(depths)) # stochastic depth decay rule
  519. # build layers
  520. self.layers = nn.LayerList()
  521. for i_layer in range(self.num_layers):
  522. layer = BasicLayer(
  523. dim=int(embed_dim * 2**i_layer),
  524. depth=depths[i_layer],
  525. num_heads=num_heads[i_layer],
  526. window_size=window_size,
  527. mlp_ratio=mlp_ratio,
  528. qkv_bias=qkv_bias,
  529. qk_scale=qk_scale,
  530. drop=drop_rate,
  531. attn_drop=attn_drop_rate,
  532. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  533. norm_layer=norm_layer,
  534. downsample=PatchMerging
  535. if (i_layer < self.num_layers - 1) else None)
  536. self.layers.append(layer)
  537. num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
  538. self.num_features = num_features
  539. # add a norm layer for each output
  540. for i_layer in out_indices:
  541. layer = norm_layer(num_features[i_layer])
  542. layer_name = f'norm{i_layer}'
  543. self.add_sublayer(layer_name, layer)
  544. self.apply(self._init_weights)
  545. self._freeze_stages()
  546. if pretrained:
  547. if 'http' in pretrained: #URL
  548. path = paddle.utils.download.get_weights_path_from_url(
  549. pretrained)
  550. else: #model in local path
  551. path = pretrained
  552. self.set_state_dict(paddle.load(path))
  553. def _freeze_stages(self):
  554. if self.frozen_stages >= 0:
  555. self.patch_embed.eval()
  556. for param in self.patch_embed.parameters():
  557. param.stop_gradient = True
  558. if self.frozen_stages >= 1 and self.ape:
  559. self.absolute_pos_embed.stop_gradient = True
  560. if self.frozen_stages >= 2:
  561. self.pos_drop.eval()
  562. for i in range(0, self.frozen_stages - 1):
  563. m = self.layers[i]
  564. m.eval()
  565. for param in m.parameters():
  566. param.stop_gradient = True
  567. def _init_weights(self, m):
  568. if isinstance(m, nn.Linear):
  569. trunc_normal_(m.weight)
  570. if isinstance(m, nn.Linear) and m.bias is not None:
  571. zeros_(m.bias)
  572. elif isinstance(m, nn.LayerNorm):
  573. zeros_(m.bias)
  574. ones_(m.weight)
  575. def forward(self, x):
  576. """Forward function."""
  577. x = self.patch_embed(x['image'])
  578. B, _, Wh, Ww = x.shape
  579. if self.ape:
  580. # interpolate the position embedding to the corresponding size
  581. absolute_pos_embed = F.interpolate(
  582. self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
  583. x = (x + absolute_pos_embed).flatten(2).transpose([0, 2, 1])
  584. else:
  585. x = x.flatten(2).transpose([0, 2, 1])
  586. x = self.pos_drop(x)
  587. outs = []
  588. for i in range(self.num_layers):
  589. layer = self.layers[i]
  590. x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
  591. if i in self.out_indices:
  592. norm_layer = getattr(self, f'norm{i}')
  593. x_out = norm_layer(x_out)
  594. out = x_out.reshape((-1, H, W, self.num_features[i])).transpose(
  595. (0, 3, 1, 2))
  596. outs.append(out)
  597. return tuple(outs)
  598. @property
  599. def out_shape(self):
  600. out_strides = [4, 8, 16, 32]
  601. return [
  602. ShapeSpec(
  603. channels=self.num_features[i], stride=out_strides[i])
  604. for i in self.out_indices
  605. ]