cspresnet.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. from paddle.regularizer import L2Decay
  22. from paddle.nn.initializer import Constant
  23. from ppdet.modeling.ops import get_act_fn
  24. from ppdet.core.workspace import register, serializable
  25. from ..shape_spec import ShapeSpec
  26. __all__ = ['CSPResNet', 'BasicBlock', 'EffectiveSELayer', 'ConvBNLayer']
  27. class ConvBNLayer(nn.Layer):
  28. def __init__(self,
  29. ch_in,
  30. ch_out,
  31. filter_size=3,
  32. stride=1,
  33. groups=1,
  34. padding=0,
  35. act=None):
  36. super(ConvBNLayer, self).__init__()
  37. self.conv = nn.Conv2D(
  38. in_channels=ch_in,
  39. out_channels=ch_out,
  40. kernel_size=filter_size,
  41. stride=stride,
  42. padding=padding,
  43. groups=groups,
  44. bias_attr=False)
  45. self.bn = nn.BatchNorm2D(
  46. ch_out,
  47. weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
  48. bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
  49. self.act = get_act_fn(act) if act is None or isinstance(act, (
  50. str, dict)) else act
  51. def forward(self, x):
  52. x = self.conv(x)
  53. x = self.bn(x)
  54. x = self.act(x)
  55. return x
  56. class RepVggBlock(nn.Layer):
  57. def __init__(self, ch_in, ch_out, act='relu', alpha=False):
  58. super(RepVggBlock, self).__init__()
  59. self.ch_in = ch_in
  60. self.ch_out = ch_out
  61. self.conv1 = ConvBNLayer(
  62. ch_in, ch_out, 3, stride=1, padding=1, act=None)
  63. self.conv2 = ConvBNLayer(
  64. ch_in, ch_out, 1, stride=1, padding=0, act=None)
  65. self.act = get_act_fn(act) if act is None or isinstance(act, (
  66. str, dict)) else act
  67. if alpha:
  68. self.alpha = self.create_parameter(
  69. shape=[1],
  70. attr=ParamAttr(initializer=Constant(value=1.)),
  71. dtype="float32")
  72. else:
  73. self.alpha = None
  74. def forward(self, x):
  75. if hasattr(self, 'conv'):
  76. y = self.conv(x)
  77. else:
  78. if self.alpha:
  79. y = self.conv1(x) + self.alpha * self.conv2(x)
  80. else:
  81. y = self.conv1(x) + self.conv2(x)
  82. y = self.act(y)
  83. return y
  84. def convert_to_deploy(self):
  85. if not hasattr(self, 'conv'):
  86. self.conv = nn.Conv2D(
  87. in_channels=self.ch_in,
  88. out_channels=self.ch_out,
  89. kernel_size=3,
  90. stride=1,
  91. padding=1,
  92. groups=1)
  93. kernel, bias = self.get_equivalent_kernel_bias()
  94. self.conv.weight.set_value(kernel)
  95. self.conv.bias.set_value(bias)
  96. self.__delattr__('conv1')
  97. self.__delattr__('conv2')
  98. def get_equivalent_kernel_bias(self):
  99. kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
  100. kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
  101. if self.alpha:
  102. return kernel3x3 + self.alpha * self._pad_1x1_to_3x3_tensor(
  103. kernel1x1), bias3x3 + self.alpha * bias1x1
  104. else:
  105. return kernel3x3 + self._pad_1x1_to_3x3_tensor(
  106. kernel1x1), bias3x3 + bias1x1
  107. def _pad_1x1_to_3x3_tensor(self, kernel1x1):
  108. if kernel1x1 is None:
  109. return 0
  110. else:
  111. return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
  112. def _fuse_bn_tensor(self, branch):
  113. if branch is None:
  114. return 0, 0
  115. kernel = branch.conv.weight
  116. running_mean = branch.bn._mean
  117. running_var = branch.bn._variance
  118. gamma = branch.bn.weight
  119. beta = branch.bn.bias
  120. eps = branch.bn._epsilon
  121. std = (running_var + eps).sqrt()
  122. t = (gamma / std).reshape((-1, 1, 1, 1))
  123. return kernel * t, beta - running_mean * gamma / std
  124. class BasicBlock(nn.Layer):
  125. def __init__(self,
  126. ch_in,
  127. ch_out,
  128. act='relu',
  129. shortcut=True,
  130. use_alpha=False):
  131. super(BasicBlock, self).__init__()
  132. assert ch_in == ch_out
  133. self.conv1 = ConvBNLayer(ch_in, ch_out, 3, stride=1, padding=1, act=act)
  134. self.conv2 = RepVggBlock(ch_out, ch_out, act=act, alpha=use_alpha)
  135. self.shortcut = shortcut
  136. def forward(self, x):
  137. y = self.conv1(x)
  138. y = self.conv2(y)
  139. if self.shortcut:
  140. return paddle.add(x, y)
  141. else:
  142. return y
  143. class EffectiveSELayer(nn.Layer):
  144. """ Effective Squeeze-Excitation
  145. From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
  146. """
  147. def __init__(self, channels, act='hardsigmoid'):
  148. super(EffectiveSELayer, self).__init__()
  149. self.fc = nn.Conv2D(channels, channels, kernel_size=1, padding=0)
  150. self.act = get_act_fn(act) if act is None or isinstance(act, (
  151. str, dict)) else act
  152. def forward(self, x):
  153. x_se = x.mean((2, 3), keepdim=True)
  154. x_se = self.fc(x_se)
  155. return x * self.act(x_se)
  156. class CSPResStage(nn.Layer):
  157. def __init__(self,
  158. block_fn,
  159. ch_in,
  160. ch_out,
  161. n,
  162. stride,
  163. act='relu',
  164. attn='eca',
  165. use_alpha=False):
  166. super(CSPResStage, self).__init__()
  167. ch_mid = (ch_in + ch_out) // 2
  168. if stride == 2:
  169. self.conv_down = ConvBNLayer(
  170. ch_in, ch_mid, 3, stride=2, padding=1, act=act)
  171. else:
  172. self.conv_down = None
  173. self.conv1 = ConvBNLayer(ch_mid, ch_mid // 2, 1, act=act)
  174. self.conv2 = ConvBNLayer(ch_mid, ch_mid // 2, 1, act=act)
  175. self.blocks = nn.Sequential(*[
  176. block_fn(
  177. ch_mid // 2,
  178. ch_mid // 2,
  179. act=act,
  180. shortcut=True,
  181. use_alpha=use_alpha) for i in range(n)
  182. ])
  183. if attn:
  184. self.attn = EffectiveSELayer(ch_mid, act='hardsigmoid')
  185. else:
  186. self.attn = None
  187. self.conv3 = ConvBNLayer(ch_mid, ch_out, 1, act=act)
  188. def forward(self, x):
  189. if self.conv_down is not None:
  190. x = self.conv_down(x)
  191. y1 = self.conv1(x)
  192. y2 = self.blocks(self.conv2(x))
  193. y = paddle.concat([y1, y2], axis=1)
  194. if self.attn is not None:
  195. y = self.attn(y)
  196. y = self.conv3(y)
  197. return y
  198. @register
  199. @serializable
  200. class CSPResNet(nn.Layer):
  201. __shared__ = ['width_mult', 'depth_mult', 'trt']
  202. def __init__(self,
  203. layers=[3, 6, 6, 3],
  204. channels=[64, 128, 256, 512, 1024],
  205. act='swish',
  206. return_idx=[1, 2, 3],
  207. depth_wise=False,
  208. use_large_stem=False,
  209. width_mult=1.0,
  210. depth_mult=1.0,
  211. trt=False,
  212. use_checkpoint=False,
  213. use_alpha=False,
  214. **args):
  215. super(CSPResNet, self).__init__()
  216. self.use_checkpoint = use_checkpoint
  217. channels = [max(round(c * width_mult), 1) for c in channels]
  218. layers = [max(round(l * depth_mult), 1) for l in layers]
  219. act = get_act_fn(
  220. act, trt=trt) if act is None or isinstance(act,
  221. (str, dict)) else act
  222. if use_large_stem:
  223. self.stem = nn.Sequential(
  224. ('conv1', ConvBNLayer(
  225. 3, channels[0] // 2, 3, stride=2, padding=1, act=act)),
  226. ('conv2', ConvBNLayer(
  227. channels[0] // 2,
  228. channels[0] // 2,
  229. 3,
  230. stride=1,
  231. padding=1,
  232. act=act)), ('conv3', ConvBNLayer(
  233. channels[0] // 2,
  234. channels[0],
  235. 3,
  236. stride=1,
  237. padding=1,
  238. act=act)))
  239. else:
  240. self.stem = nn.Sequential(
  241. ('conv1', ConvBNLayer(
  242. 3, channels[0] // 2, 3, stride=2, padding=1, act=act)),
  243. ('conv2', ConvBNLayer(
  244. channels[0] // 2,
  245. channels[0],
  246. 3,
  247. stride=1,
  248. padding=1,
  249. act=act)))
  250. n = len(channels) - 1
  251. self.stages = nn.Sequential(*[(str(i), CSPResStage(
  252. BasicBlock,
  253. channels[i],
  254. channels[i + 1],
  255. layers[i],
  256. 2,
  257. act=act,
  258. use_alpha=use_alpha)) for i in range(n)])
  259. self._out_channels = channels[1:]
  260. self._out_strides = [4 * 2**i for i in range(n)]
  261. self.return_idx = return_idx
  262. if use_checkpoint:
  263. paddle.seed(0)
  264. def forward(self, inputs):
  265. x = inputs['image']
  266. x = self.stem(x)
  267. outs = []
  268. for idx, stage in enumerate(self.stages):
  269. if self.use_checkpoint and self.training:
  270. x = paddle.distributed.fleet.utils.recompute(
  271. stage, x, **{"preserve_rng_state": True})
  272. else:
  273. x = stage(x)
  274. if idx in self.return_idx:
  275. outs.append(x)
  276. return outs
  277. @property
  278. def out_shape(self):
  279. return [
  280. ShapeSpec(
  281. channels=self._out_channels[i], stride=self._out_strides[i])
  282. for i in self.return_idx
  283. ]