quant.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from paddle.utils import try_import
  18. from ppdet.core.workspace import register, serializable
  19. from ppdet.utils.logger import setup_logger
  20. logger = setup_logger(__name__)
  21. @register
  22. @serializable
  23. class QAT(object):
  24. def __init__(self, quant_config, print_model):
  25. super(QAT, self).__init__()
  26. self.quant_config = quant_config
  27. self.print_model = print_model
  28. def __call__(self, model):
  29. paddleslim = try_import('paddleslim')
  30. self.quanter = paddleslim.dygraph.quant.QAT(config=self.quant_config)
  31. if self.print_model:
  32. logger.info("Model before quant:")
  33. logger.info(model)
  34. # For PP-YOLOE, convert model to deploy firstly.
  35. for layer in model.sublayers():
  36. if hasattr(layer, 'convert_to_deploy'):
  37. layer.convert_to_deploy()
  38. self.quanter.quantize(model)
  39. if self.print_model:
  40. logger.info("Quantized model:")
  41. logger.info(model)
  42. return model
  43. def save_quantized_model(self, layer, path, input_spec=None, **config):
  44. self.quanter.save_quantized_model(
  45. model=layer, path=path, input_spec=input_spec, **config)
  46. @register
  47. @serializable
  48. class PTQ(object):
  49. def __init__(self,
  50. ptq_config,
  51. quant_batch_num=10,
  52. output_dir='output_inference',
  53. fuse=True,
  54. fuse_list=None):
  55. super(PTQ, self).__init__()
  56. self.ptq_config = ptq_config
  57. self.quant_batch_num = quant_batch_num
  58. self.output_dir = output_dir
  59. self.fuse = fuse
  60. self.fuse_list = fuse_list
  61. def __call__(self, model):
  62. paddleslim = try_import('paddleslim')
  63. self.ptq = paddleslim.PTQ(**self.ptq_config)
  64. model.eval()
  65. quant_model = self.ptq.quantize(
  66. model, fuse=self.fuse, fuse_list=self.fuse_list)
  67. return quant_model
  68. def save_quantized_model(self,
  69. quant_model,
  70. quantize_model_path,
  71. input_spec=None):
  72. self.ptq.save_quantized_model(quant_model, quantize_model_path,
  73. input_spec)