mobileone.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is the paddle implementation of MobileOne block, see: https://arxiv.org/pdf/2206.04040.pdf.
  16. Some codes are based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
  17. Ths copyright of microsoft/Swin-Transformer is as follows:
  18. MIT License [see LICENSE for details]
  19. """
  20. import paddle
  21. import paddle.nn as nn
  22. from paddle import ParamAttr
  23. from paddle.regularizer import L2Decay
  24. from paddle.nn.initializer import Normal, Constant
  25. from ppdet.modeling.ops import get_act_fn
  26. from ppdet.modeling.layers import ConvNormLayer
  27. class MobileOneBlock(nn.Layer):
  28. def __init__(
  29. self,
  30. ch_in,
  31. ch_out,
  32. stride,
  33. kernel_size,
  34. conv_num=1,
  35. norm_type='bn',
  36. norm_decay=0.,
  37. norm_groups=32,
  38. bias_on=False,
  39. lr_scale=1.,
  40. freeze_norm=False,
  41. initializer=Normal(
  42. mean=0., std=0.01),
  43. skip_quant=False,
  44. act='relu', ):
  45. super(MobileOneBlock, self).__init__()
  46. self.ch_in = ch_in
  47. self.ch_out = ch_out
  48. self.kernel_size = kernel_size
  49. self.stride = stride
  50. self.padding = (kernel_size - 1) // 2
  51. self.k = conv_num
  52. self.depth_conv = nn.LayerList()
  53. self.point_conv = nn.LayerList()
  54. for _ in range(self.k):
  55. self.depth_conv.append(
  56. ConvNormLayer(
  57. ch_in,
  58. ch_in,
  59. kernel_size,
  60. stride=stride,
  61. groups=ch_in,
  62. norm_type=norm_type,
  63. norm_decay=norm_decay,
  64. norm_groups=norm_groups,
  65. bias_on=bias_on,
  66. lr_scale=lr_scale,
  67. freeze_norm=freeze_norm,
  68. initializer=initializer,
  69. skip_quant=skip_quant))
  70. self.point_conv.append(
  71. ConvNormLayer(
  72. ch_in,
  73. ch_out,
  74. 1,
  75. stride=1,
  76. groups=1,
  77. norm_type=norm_type,
  78. norm_decay=norm_decay,
  79. norm_groups=norm_groups,
  80. bias_on=bias_on,
  81. lr_scale=lr_scale,
  82. freeze_norm=freeze_norm,
  83. initializer=initializer,
  84. skip_quant=skip_quant))
  85. self.rbr_1x1 = ConvNormLayer(
  86. ch_in,
  87. ch_in,
  88. 1,
  89. stride=self.stride,
  90. groups=ch_in,
  91. norm_type=norm_type,
  92. norm_decay=norm_decay,
  93. norm_groups=norm_groups,
  94. bias_on=bias_on,
  95. lr_scale=lr_scale,
  96. freeze_norm=freeze_norm,
  97. initializer=initializer,
  98. skip_quant=skip_quant)
  99. self.rbr_identity_st1 = nn.BatchNorm2D(
  100. num_features=ch_in,
  101. weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
  102. bias_attr=ParamAttr(regularizer=L2Decay(
  103. 0.0))) if ch_in == ch_out and self.stride == 1 else None
  104. self.rbr_identity_st2 = nn.BatchNorm2D(
  105. num_features=ch_out,
  106. weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
  107. bias_attr=ParamAttr(regularizer=L2Decay(
  108. 0.0))) if ch_in == ch_out and self.stride == 1 else None
  109. self.act = get_act_fn(act) if act is None or isinstance(act, (
  110. str, dict)) else act
  111. def forward(self, x):
  112. if hasattr(self, "conv1") and hasattr(self, "conv2"):
  113. y = self.act(self.conv2(self.act(self.conv1(x))))
  114. else:
  115. if self.rbr_identity_st1 is None:
  116. id_out_st1 = 0
  117. else:
  118. id_out_st1 = self.rbr_identity_st1(x)
  119. x1_1 = 0
  120. for i in range(self.k):
  121. x1_1 += self.depth_conv[i](x)
  122. x1_2 = self.rbr_1x1(x)
  123. x1 = self.act(x1_1 + x1_2 + id_out_st1)
  124. if self.rbr_identity_st2 is None:
  125. id_out_st2 = 0
  126. else:
  127. id_out_st2 = self.rbr_identity_st2(x1)
  128. x2_1 = 0
  129. for i in range(self.k):
  130. x2_1 += self.point_conv[i](x1)
  131. y = self.act(x2_1 + id_out_st2)
  132. return y
  133. def convert_to_deploy(self):
  134. if not hasattr(self, 'conv1'):
  135. self.conv1 = nn.Conv2D(
  136. in_channels=self.ch_in,
  137. out_channels=self.ch_in,
  138. kernel_size=self.kernel_size,
  139. stride=self.stride,
  140. padding=self.padding,
  141. groups=self.ch_in,
  142. bias_attr=ParamAttr(
  143. initializer=Constant(value=0.), learning_rate=1.))
  144. if not hasattr(self, 'conv2'):
  145. self.conv2 = nn.Conv2D(
  146. in_channels=self.ch_in,
  147. out_channels=self.ch_out,
  148. kernel_size=1,
  149. stride=1,
  150. padding='SAME',
  151. groups=1,
  152. bias_attr=ParamAttr(
  153. initializer=Constant(value=0.), learning_rate=1.))
  154. conv1_kernel, conv1_bias, conv2_kernel, conv2_bias = self.get_equivalent_kernel_bias(
  155. )
  156. self.conv1.weight.set_value(conv1_kernel)
  157. self.conv1.bias.set_value(conv1_bias)
  158. self.conv2.weight.set_value(conv2_kernel)
  159. self.conv2.bias.set_value(conv2_bias)
  160. self.__delattr__('depth_conv')
  161. self.__delattr__('point_conv')
  162. self.__delattr__('rbr_1x1')
  163. if hasattr(self, 'rbr_identity_st1'):
  164. self.__delattr__('rbr_identity_st1')
  165. if hasattr(self, 'rbr_identity_st2'):
  166. self.__delattr__('rbr_identity_st2')
  167. def get_equivalent_kernel_bias(self):
  168. st1_kernel3x3, st1_bias3x3 = self._fuse_bn_tensor(self.depth_conv)
  169. st1_kernel1x1, st1_bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
  170. st1_kernelid, st1_biasid = self._fuse_bn_tensor(
  171. self.rbr_identity_st1, kernel_size=self.kernel_size)
  172. st2_kernel1x1, st2_bias1x1 = self._fuse_bn_tensor(self.point_conv)
  173. st2_kernelid, st2_biasid = self._fuse_bn_tensor(
  174. self.rbr_identity_st2, kernel_size=1)
  175. conv1_kernel = st1_kernel3x3 + self._pad_1x1_to_3x3_tensor(
  176. st1_kernel1x1) + st1_kernelid
  177. conv1_bias = st1_bias3x3 + st1_bias1x1 + st1_biasid
  178. conv2_kernel = st2_kernel1x1 + st2_kernelid
  179. conv2_bias = st2_bias1x1 + st2_biasid
  180. return conv1_kernel, conv1_bias, conv2_kernel, conv2_bias
  181. def _pad_1x1_to_3x3_tensor(self, kernel1x1):
  182. if kernel1x1 is None:
  183. return 0
  184. else:
  185. padding_size = (self.kernel_size - 1) // 2
  186. return nn.functional.pad(
  187. kernel1x1,
  188. [padding_size, padding_size, padding_size, padding_size])
  189. def _fuse_bn_tensor(self, branch, kernel_size=3):
  190. if branch is None:
  191. return 0, 0
  192. if isinstance(branch, nn.LayerList):
  193. fused_kernels = []
  194. fused_bias = []
  195. for block in branch:
  196. kernel = block.conv.weight
  197. running_mean = block.norm._mean
  198. running_var = block.norm._variance
  199. gamma = block.norm.weight
  200. beta = block.norm.bias
  201. eps = block.norm._epsilon
  202. std = (running_var + eps).sqrt()
  203. t = (gamma / std).reshape((-1, 1, 1, 1))
  204. fused_kernels.append(kernel * t)
  205. fused_bias.append(beta - running_mean * gamma / std)
  206. return sum(fused_kernels), sum(fused_bias)
  207. elif isinstance(branch, ConvNormLayer):
  208. kernel = branch.conv.weight
  209. running_mean = branch.norm._mean
  210. running_var = branch.norm._variance
  211. gamma = branch.norm.weight
  212. beta = branch.norm.bias
  213. eps = branch.norm._epsilon
  214. else:
  215. assert isinstance(branch, nn.BatchNorm2D)
  216. input_dim = self.ch_in if kernel_size == 1 else 1
  217. kernel_value = paddle.zeros(
  218. shape=[self.ch_in, input_dim, kernel_size, kernel_size],
  219. dtype='float32')
  220. if kernel_size > 1:
  221. for i in range(self.ch_in):
  222. kernel_value[i, i % input_dim, (kernel_size - 1) // 2, (
  223. kernel_size - 1) // 2] = 1
  224. elif kernel_size == 1:
  225. for i in range(self.ch_in):
  226. kernel_value[i, i % input_dim, 0, 0] = 1
  227. else:
  228. raise ValueError("Invalid kernel size recieved!")
  229. kernel = paddle.to_tensor(kernel_value, place=branch.weight.place)
  230. running_mean = branch._mean
  231. running_var = branch._variance
  232. gamma = branch.weight
  233. beta = branch.bias
  234. eps = branch._epsilon
  235. std = (running_var + eps).sqrt()
  236. t = (gamma / std).reshape((-1, 1, 1, 1))
  237. return kernel * t, beta - running_mean * gamma / std