custom_pan.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import copy
  16. import numpy as np
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from ppdet.core.workspace import register, serializable
  21. from ppdet.modeling.layers import DropBlock, MultiHeadAttention
  22. from ppdet.modeling.ops import get_act_fn
  23. from ..backbones.cspresnet import ConvBNLayer, BasicBlock
  24. from ..shape_spec import ShapeSpec
  25. from ..initializer import linear_init_
  26. __all__ = ['CustomCSPPAN']
  27. def _get_clones(module, N):
  28. return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
  29. class SPP(nn.Layer):
  30. def __init__(self,
  31. ch_in,
  32. ch_out,
  33. k,
  34. pool_size,
  35. act='swish',
  36. data_format='NCHW'):
  37. super(SPP, self).__init__()
  38. self.pool = []
  39. self.data_format = data_format
  40. for i, size in enumerate(pool_size):
  41. pool = self.add_sublayer(
  42. 'pool{}'.format(i),
  43. nn.MaxPool2D(
  44. kernel_size=size,
  45. stride=1,
  46. padding=size // 2,
  47. data_format=data_format,
  48. ceil_mode=False))
  49. self.pool.append(pool)
  50. self.conv = ConvBNLayer(ch_in, ch_out, k, padding=k // 2, act=act)
  51. def forward(self, x):
  52. outs = [x]
  53. for pool in self.pool:
  54. outs.append(pool(x))
  55. if self.data_format == 'NCHW':
  56. y = paddle.concat(outs, axis=1)
  57. else:
  58. y = paddle.concat(outs, axis=-1)
  59. y = self.conv(y)
  60. return y
  61. class CSPStage(nn.Layer):
  62. def __init__(self,
  63. block_fn,
  64. ch_in,
  65. ch_out,
  66. n,
  67. act='swish',
  68. spp=False,
  69. use_alpha=False):
  70. super(CSPStage, self).__init__()
  71. ch_mid = int(ch_out // 2)
  72. self.conv1 = ConvBNLayer(ch_in, ch_mid, 1, act=act)
  73. self.conv2 = ConvBNLayer(ch_in, ch_mid, 1, act=act)
  74. self.convs = nn.Sequential()
  75. next_ch_in = ch_mid
  76. for i in range(n):
  77. self.convs.add_sublayer(
  78. str(i),
  79. eval(block_fn)(next_ch_in,
  80. ch_mid,
  81. act=act,
  82. shortcut=False,
  83. use_alpha=use_alpha))
  84. if i == (n - 1) // 2 and spp:
  85. self.convs.add_sublayer(
  86. 'spp', SPP(ch_mid * 4, ch_mid, 1, [5, 9, 13], act=act))
  87. next_ch_in = ch_mid
  88. self.conv3 = ConvBNLayer(ch_mid * 2, ch_out, 1, act=act)
  89. def forward(self, x):
  90. y1 = self.conv1(x)
  91. y2 = self.conv2(x)
  92. y2 = self.convs(y2)
  93. y = paddle.concat([y1, y2], axis=1)
  94. y = self.conv3(y)
  95. return y
  96. class TransformerEncoderLayer(nn.Layer):
  97. def __init__(self,
  98. d_model,
  99. nhead,
  100. dim_feedforward=2048,
  101. dropout=0.1,
  102. activation="relu",
  103. attn_dropout=None,
  104. act_dropout=None,
  105. normalize_before=False):
  106. super(TransformerEncoderLayer, self).__init__()
  107. attn_dropout = dropout if attn_dropout is None else attn_dropout
  108. act_dropout = dropout if act_dropout is None else act_dropout
  109. self.normalize_before = normalize_before
  110. self.self_attn = MultiHeadAttention(d_model, nhead, attn_dropout)
  111. # Implementation of Feedforward model
  112. self.linear1 = nn.Linear(d_model, dim_feedforward)
  113. self.dropout = nn.Dropout(act_dropout, mode="upscale_in_train")
  114. self.linear2 = nn.Linear(dim_feedforward, d_model)
  115. self.norm1 = nn.LayerNorm(d_model)
  116. self.norm2 = nn.LayerNorm(d_model)
  117. self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
  118. self.dropout2 = nn.Dropout(dropout, mode="upscale_in_train")
  119. self.activation = getattr(F, activation)
  120. self._reset_parameters()
  121. def _reset_parameters(self):
  122. linear_init_(self.linear1)
  123. linear_init_(self.linear2)
  124. @staticmethod
  125. def with_pos_embed(tensor, pos_embed):
  126. return tensor if pos_embed is None else tensor + pos_embed
  127. def forward(self, src, src_mask=None, pos_embed=None):
  128. residual = src
  129. if self.normalize_before:
  130. src = self.norm1(src)
  131. q = k = self.with_pos_embed(src, pos_embed)
  132. src = self.self_attn(q, k, value=src, attn_mask=src_mask)
  133. src = residual + self.dropout1(src)
  134. if not self.normalize_before:
  135. src = self.norm1(src)
  136. residual = src
  137. if self.normalize_before:
  138. src = self.norm2(src)
  139. src = self.linear2(self.dropout(self.activation(self.linear1(src))))
  140. src = residual + self.dropout2(src)
  141. if not self.normalize_before:
  142. src = self.norm2(src)
  143. return src
  144. class TransformerEncoder(nn.Layer):
  145. def __init__(self, encoder_layer, num_layers, norm=None):
  146. super(TransformerEncoder, self).__init__()
  147. self.layers = _get_clones(encoder_layer, num_layers)
  148. self.num_layers = num_layers
  149. self.norm = norm
  150. def forward(self, src, src_mask=None, pos_embed=None):
  151. output = src
  152. for layer in self.layers:
  153. output = layer(output, src_mask=src_mask, pos_embed=pos_embed)
  154. if self.norm is not None:
  155. output = self.norm(output)
  156. return output
  157. @register
  158. @serializable
  159. class CustomCSPPAN(nn.Layer):
  160. __shared__ = [
  161. 'norm_type', 'data_format', 'width_mult', 'depth_mult', 'trt',
  162. 'eval_size'
  163. ]
  164. def __init__(self,
  165. in_channels=[256, 512, 1024],
  166. out_channels=[1024, 512, 256],
  167. norm_type='bn',
  168. act='leaky',
  169. stage_fn='CSPStage',
  170. block_fn='BasicBlock',
  171. stage_num=1,
  172. block_num=3,
  173. drop_block=False,
  174. block_size=3,
  175. keep_prob=0.9,
  176. spp=False,
  177. data_format='NCHW',
  178. width_mult=1.0,
  179. depth_mult=1.0,
  180. use_alpha=False,
  181. trt=False,
  182. dim_feedforward=2048,
  183. dropout=0.1,
  184. activation='gelu',
  185. nhead=4,
  186. num_layers=4,
  187. attn_dropout=None,
  188. act_dropout=None,
  189. normalize_before=False,
  190. use_trans=False,
  191. eval_size=None):
  192. super(CustomCSPPAN, self).__init__()
  193. out_channels = [max(round(c * width_mult), 1) for c in out_channels]
  194. block_num = max(round(block_num * depth_mult), 1)
  195. act = get_act_fn(
  196. act, trt=trt) if act is None or isinstance(act,
  197. (str, dict)) else act
  198. self.num_blocks = len(in_channels)
  199. self.data_format = data_format
  200. self._out_channels = out_channels
  201. self.hidden_dim = in_channels[-1]
  202. in_channels = in_channels[::-1]
  203. self.use_trans = use_trans
  204. self.eval_size = eval_size
  205. if use_trans:
  206. if eval_size is not None:
  207. self.pos_embed = self.build_2d_sincos_position_embedding(
  208. eval_size[1] // 32,
  209. eval_size[0] // 32,
  210. embed_dim=self.hidden_dim)
  211. else:
  212. self.pos_embed = None
  213. encoder_layer = TransformerEncoderLayer(
  214. self.hidden_dim, nhead, dim_feedforward, dropout, activation,
  215. attn_dropout, act_dropout, normalize_before)
  216. encoder_norm = nn.LayerNorm(
  217. self.hidden_dim) if normalize_before else None
  218. self.encoder = TransformerEncoder(encoder_layer, num_layers,
  219. encoder_norm)
  220. fpn_stages = []
  221. fpn_routes = []
  222. for i, (ch_in, ch_out) in enumerate(zip(in_channels, out_channels)):
  223. if i > 0:
  224. ch_in += ch_pre // 2
  225. stage = nn.Sequential()
  226. for j in range(stage_num):
  227. stage.add_sublayer(
  228. str(j),
  229. eval(stage_fn)(block_fn,
  230. ch_in if j == 0 else ch_out,
  231. ch_out,
  232. block_num,
  233. act=act,
  234. spp=(spp and i == 0),
  235. use_alpha=use_alpha))
  236. if drop_block:
  237. stage.add_sublayer('drop', DropBlock(block_size, keep_prob))
  238. fpn_stages.append(stage)
  239. if i < self.num_blocks - 1:
  240. fpn_routes.append(
  241. ConvBNLayer(
  242. ch_in=ch_out,
  243. ch_out=ch_out // 2,
  244. filter_size=1,
  245. stride=1,
  246. padding=0,
  247. act=act))
  248. ch_pre = ch_out
  249. self.fpn_stages = nn.LayerList(fpn_stages)
  250. self.fpn_routes = nn.LayerList(fpn_routes)
  251. pan_stages = []
  252. pan_routes = []
  253. for i in reversed(range(self.num_blocks - 1)):
  254. pan_routes.append(
  255. ConvBNLayer(
  256. ch_in=out_channels[i + 1],
  257. ch_out=out_channels[i + 1],
  258. filter_size=3,
  259. stride=2,
  260. padding=1,
  261. act=act))
  262. ch_in = out_channels[i] + out_channels[i + 1]
  263. ch_out = out_channels[i]
  264. stage = nn.Sequential()
  265. for j in range(stage_num):
  266. stage.add_sublayer(
  267. str(j),
  268. eval(stage_fn)(block_fn,
  269. ch_in if j == 0 else ch_out,
  270. ch_out,
  271. block_num,
  272. act=act,
  273. spp=False,
  274. use_alpha=use_alpha))
  275. if drop_block:
  276. stage.add_sublayer('drop', DropBlock(block_size, keep_prob))
  277. pan_stages.append(stage)
  278. self.pan_stages = nn.LayerList(pan_stages[::-1])
  279. self.pan_routes = nn.LayerList(pan_routes[::-1])
  280. def build_2d_sincos_position_embedding(
  281. self,
  282. w,
  283. h,
  284. embed_dim=1024,
  285. temperature=10000., ):
  286. grid_w = paddle.arange(int(w), dtype=paddle.float32)
  287. grid_h = paddle.arange(int(h), dtype=paddle.float32)
  288. grid_w, grid_h = paddle.meshgrid(grid_w, grid_h)
  289. assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
  290. pos_dim = embed_dim // 4
  291. omega = paddle.arange(pos_dim, dtype=paddle.float32) / pos_dim
  292. omega = 1. / (temperature**omega)
  293. out_w = grid_w.flatten()[..., None] @omega[None]
  294. out_h = grid_h.flatten()[..., None] @omega[None]
  295. pos_emb = paddle.concat(
  296. [
  297. paddle.sin(out_w), paddle.cos(out_w), paddle.sin(out_h),
  298. paddle.cos(out_h)
  299. ],
  300. axis=1)[None, :, :]
  301. return pos_emb
  302. def forward(self, blocks, for_mot=False):
  303. if self.use_trans:
  304. last_feat = blocks[-1]
  305. n, c, h, w = last_feat.shape
  306. # flatten [B, C, H, W] to [B, HxW, C]
  307. src_flatten = last_feat.flatten(2).transpose([0, 2, 1])
  308. if self.eval_size is not None and not self.training:
  309. pos_embed = self.pos_embed
  310. else:
  311. pos_embed = self.build_2d_sincos_position_embedding(
  312. w=w, h=h, embed_dim=self.hidden_dim)
  313. memory = self.encoder(src_flatten, pos_embed=pos_embed)
  314. last_feat_encode = memory.transpose([0, 2, 1]).reshape([n, c, h, w])
  315. blocks[-1] = last_feat_encode
  316. blocks = blocks[::-1]
  317. fpn_feats = []
  318. for i, block in enumerate(blocks):
  319. if i > 0:
  320. block = paddle.concat([route, block], axis=1)
  321. route = self.fpn_stages[i](block)
  322. fpn_feats.append(route)
  323. if i < self.num_blocks - 1:
  324. route = self.fpn_routes[i](route)
  325. route = F.interpolate(
  326. route, scale_factor=2., data_format=self.data_format)
  327. pan_feats = [fpn_feats[-1], ]
  328. route = fpn_feats[-1]
  329. for i in reversed(range(self.num_blocks - 1)):
  330. block = fpn_feats[i]
  331. route = self.pan_routes[i](route)
  332. block = paddle.concat([route, block], axis=1)
  333. route = self.pan_stages[i](block)
  334. pan_feats.append(route)
  335. return pan_feats[::-1]
  336. @classmethod
  337. def from_config(cls, cfg, input_shape):
  338. return {'in_channels': [i.channels for i in input_shape], }
  339. @property
  340. def out_shape(self):
  341. return [ShapeSpec(channels=c) for c in self._out_channels]