prune.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle.utils import try_import
  19. from ppdet.core.workspace import register, serializable
  20. from ppdet.utils.logger import setup_logger
  21. logger = setup_logger(__name__)
  22. def print_prune_params(model):
  23. model_dict = model.state_dict()
  24. for key in model_dict.keys():
  25. weight_name = model_dict[key].name
  26. logger.info('Parameter name: {}, shape: {}'.format(
  27. weight_name, model_dict[key].shape))
  28. @register
  29. @serializable
  30. class Pruner(object):
  31. def __init__(self,
  32. criterion,
  33. pruned_params,
  34. pruned_ratios,
  35. print_params=False):
  36. super(Pruner, self).__init__()
  37. assert criterion in ['l1_norm', 'fpgm'], \
  38. "unsupported prune criterion: {}".format(criterion)
  39. self.criterion = criterion
  40. self.pruned_params = pruned_params
  41. self.pruned_ratios = pruned_ratios
  42. self.print_params = print_params
  43. def __call__(self, model):
  44. # FIXME: adapt to network graph when Training and inference are
  45. # inconsistent, now only supports prune inference network graph.
  46. model.eval()
  47. paddleslim = try_import('paddleslim')
  48. from paddleslim.analysis import dygraph_flops as flops
  49. input_spec = [{
  50. "image": paddle.ones(
  51. shape=[1, 3, 640, 640], dtype='float32'),
  52. "im_shape": paddle.full(
  53. [1, 2], 640, dtype='float32'),
  54. "scale_factor": paddle.ones(
  55. shape=[1, 2], dtype='float32')
  56. }]
  57. if self.print_params:
  58. print_prune_params(model)
  59. ori_flops = flops(model, input_spec) / (1000**3)
  60. logger.info("FLOPs before pruning: {}GFLOPs".format(ori_flops))
  61. if self.criterion == 'fpgm':
  62. pruner = paddleslim.dygraph.FPGMFilterPruner(model, input_spec)
  63. elif self.criterion == 'l1_norm':
  64. pruner = paddleslim.dygraph.L1NormFilterPruner(model, input_spec)
  65. logger.info("pruned params: {}".format(self.pruned_params))
  66. pruned_ratios = [float(n) for n in self.pruned_ratios]
  67. ratios = {}
  68. for i, param in enumerate(self.pruned_params):
  69. ratios[param] = pruned_ratios[i]
  70. pruner.prune_vars(ratios, [0])
  71. pruned_flops = flops(model, input_spec) / (1000**3)
  72. logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
  73. pruned_flops, (ori_flops - pruned_flops) / ori_flops))
  74. return model
  75. @register
  76. @serializable
  77. class PrunerQAT(object):
  78. def __init__(self, criterion, pruned_params, pruned_ratios,
  79. print_prune_params, quant_config, print_qat_model):
  80. super(PrunerQAT, self).__init__()
  81. assert criterion in ['l1_norm', 'fpgm'], \
  82. "unsupported prune criterion: {}".format(criterion)
  83. # Pruner hyperparameter
  84. self.criterion = criterion
  85. self.pruned_params = pruned_params
  86. self.pruned_ratios = pruned_ratios
  87. self.print_prune_params = print_prune_params
  88. # QAT hyperparameter
  89. self.quant_config = quant_config
  90. self.print_qat_model = print_qat_model
  91. def __call__(self, model):
  92. # FIXME: adapt to network graph when Training and inference are
  93. # inconsistent, now only supports prune inference network graph.
  94. model.eval()
  95. paddleslim = try_import('paddleslim')
  96. from paddleslim.analysis import dygraph_flops as flops
  97. input_spec = [{
  98. "image": paddle.ones(
  99. shape=[1, 3, 640, 640], dtype='float32'),
  100. "im_shape": paddle.full(
  101. [1, 2], 640, dtype='float32'),
  102. "scale_factor": paddle.ones(
  103. shape=[1, 2], dtype='float32')
  104. }]
  105. if self.print_prune_params:
  106. print_prune_params(model)
  107. ori_flops = flops(model, input_spec) / 1000
  108. logger.info("FLOPs before pruning: {}GFLOPs".format(ori_flops))
  109. if self.criterion == 'fpgm':
  110. pruner = paddleslim.dygraph.FPGMFilterPruner(model, input_spec)
  111. elif self.criterion == 'l1_norm':
  112. pruner = paddleslim.dygraph.L1NormFilterPruner(model, input_spec)
  113. logger.info("pruned params: {}".format(self.pruned_params))
  114. pruned_ratios = [float(n) for n in self.pruned_ratios]
  115. ratios = {}
  116. for i, param in enumerate(self.pruned_params):
  117. ratios[param] = pruned_ratios[i]
  118. pruner.prune_vars(ratios, [0])
  119. pruned_flops = flops(model, input_spec) / 1000
  120. logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
  121. pruned_flops, (ori_flops - pruned_flops) / ori_flops))
  122. self.quanter = paddleslim.dygraph.quant.QAT(config=self.quant_config)
  123. self.quanter.quantize(model)
  124. if self.print_qat_model:
  125. logger.info("Quantized model:")
  126. logger.info(model)
  127. return model
  128. def save_quantized_model(self, layer, path, input_spec=None, **config):
  129. self.quanter.save_quantized_model(
  130. model=layer, path=path, input_spec=input_spec, **config)