fuse_utils.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import paddle
  16. import paddle.nn as nn
  17. __all__ = ['fuse_conv_bn']
  18. def fuse_conv_bn(model):
  19. is_train = False
  20. if model.training:
  21. model.eval()
  22. is_train = True
  23. fuse_list = []
  24. tmp_pair = [None, None]
  25. for name, layer in model.named_sublayers():
  26. if isinstance(layer, nn.Conv2D):
  27. tmp_pair[0] = name
  28. if isinstance(layer, nn.BatchNorm2D):
  29. tmp_pair[1] = name
  30. if tmp_pair[0] and tmp_pair[1] and len(tmp_pair) == 2:
  31. fuse_list.append(tmp_pair)
  32. tmp_pair = [None, None]
  33. model = fuse_layers(model, fuse_list)
  34. if is_train:
  35. model.train()
  36. return model
  37. def find_parent_layer_and_sub_name(model, name):
  38. """
  39. Given the model and the name of a layer, find the parent layer and
  40. the sub_name of the layer.
  41. For example, if name is 'block_1/convbn_1/conv_1', the parent layer is
  42. 'block_1/convbn_1' and the sub_name is `conv_1`.
  43. Args:
  44. model(paddle.nn.Layer): the model to be quantized.
  45. name(string): the name of a layer
  46. Returns:
  47. parent_layer, subname
  48. """
  49. assert isinstance(model, nn.Layer), \
  50. "The model must be the instance of paddle.nn.Layer."
  51. assert len(name) > 0, "The input (name) should not be empty."
  52. last_idx = 0
  53. idx = 0
  54. parent_layer = model
  55. while idx < len(name):
  56. if name[idx] == '.':
  57. sub_name = name[last_idx:idx]
  58. if hasattr(parent_layer, sub_name):
  59. parent_layer = getattr(parent_layer, sub_name)
  60. last_idx = idx + 1
  61. idx += 1
  62. sub_name = name[last_idx:idx]
  63. return parent_layer, sub_name
  64. class Identity(nn.Layer):
  65. '''a layer to replace bn or relu layers'''
  66. def __init__(self, *args, **kwargs):
  67. super(Identity, self).__init__()
  68. def forward(self, input):
  69. return input
  70. def fuse_layers(model, layers_to_fuse, inplace=False):
  71. '''
  72. fuse layers in layers_to_fuse
  73. Args:
  74. model(nn.Layer): The model to be fused.
  75. layers_to_fuse(list): The layers' names to be fused. For
  76. example,"fuse_list = [["conv1", "bn1"], ["conv2", "bn2"]]".
  77. A TypeError would be raised if "fuse" was set as
  78. True but "fuse_list" was None.
  79. Default: None.
  80. inplace(bool): Whether apply fusing to the input model.
  81. Default: False.
  82. Return
  83. fused_model(paddle.nn.Layer): The fused model.
  84. '''
  85. if not inplace:
  86. model = copy.deepcopy(model)
  87. for layers_list in layers_to_fuse:
  88. layer_list = []
  89. for layer_name in layers_list:
  90. parent_layer, sub_name = find_parent_layer_and_sub_name(model,
  91. layer_name)
  92. layer_list.append(getattr(parent_layer, sub_name))
  93. new_layers = _fuse_func(layer_list)
  94. for i, item in enumerate(layers_list):
  95. parent_layer, sub_name = find_parent_layer_and_sub_name(model, item)
  96. setattr(parent_layer, sub_name, new_layers[i])
  97. return model
  98. def _fuse_func(layer_list):
  99. '''choose the fuser method and fuse layers'''
  100. types = tuple(type(m) for m in layer_list)
  101. fusion_method = types_to_fusion_method.get(types, None)
  102. new_layers = [None] * len(layer_list)
  103. fused_layer = fusion_method(*layer_list)
  104. for handle_id, pre_hook_fn in layer_list[0]._forward_pre_hooks.items():
  105. fused_layer.register_forward_pre_hook(pre_hook_fn)
  106. del layer_list[0]._forward_pre_hooks[handle_id]
  107. for handle_id, hook_fn in layer_list[-1]._forward_post_hooks.items():
  108. fused_layer.register_forward_post_hook(hook_fn)
  109. del layer_list[-1]._forward_post_hooks[handle_id]
  110. new_layers[0] = fused_layer
  111. for i in range(1, len(layer_list)):
  112. identity = Identity()
  113. identity.training = layer_list[0].training
  114. new_layers[i] = identity
  115. return new_layers
  116. def _fuse_conv_bn(conv, bn):
  117. '''fuse conv and bn for train or eval'''
  118. assert(conv.training == bn.training),\
  119. "Conv and BN both must be in the same mode (train or eval)."
  120. if conv.training:
  121. assert bn._num_features == conv._out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
  122. raise NotImplementedError
  123. else:
  124. return _fuse_conv_bn_eval(conv, bn)
  125. def _fuse_conv_bn_eval(conv, bn):
  126. '''fuse conv and bn for eval'''
  127. assert (not (conv.training or bn.training)), "Fusion only for eval!"
  128. fused_conv = copy.deepcopy(conv)
  129. fused_weight, fused_bias = _fuse_conv_bn_weights(
  130. fused_conv.weight, fused_conv.bias, bn._mean, bn._variance, bn._epsilon,
  131. bn.weight, bn.bias)
  132. fused_conv.weight.set_value(fused_weight)
  133. if fused_conv.bias is None:
  134. fused_conv.bias = paddle.create_parameter(
  135. shape=[fused_conv._out_channels], is_bias=True, dtype=bn.bias.dtype)
  136. fused_conv.bias.set_value(fused_bias)
  137. return fused_conv
  138. def _fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
  139. '''fuse weights and bias of conv and bn'''
  140. if conv_b is None:
  141. conv_b = paddle.zeros_like(bn_rm)
  142. if bn_w is None:
  143. bn_w = paddle.ones_like(bn_rm)
  144. if bn_b is None:
  145. bn_b = paddle.zeros_like(bn_rm)
  146. bn_var_rsqrt = paddle.rsqrt(bn_rv + bn_eps)
  147. conv_w = conv_w * \
  148. (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
  149. conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
  150. return conv_w, conv_b
  151. types_to_fusion_method = {(nn.Conv2D, nn.BatchNorm2D): _fuse_conv_bn, }