__init__.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from . import prune
  15. from . import quant
  16. from . import distill
  17. from . import unstructured_prune
  18. from .prune import *
  19. from .quant import *
  20. from .distill import *
  21. from .unstructured_prune import *
  22. from .ofa import *
  23. import yaml
  24. from ppdet.core.workspace import load_config
  25. from ppdet.utils.checkpoint import load_pretrain_weight
  26. def build_slim_model(cfg, slim_cfg, mode='train'):
  27. with open(slim_cfg) as f:
  28. slim_load_cfg = yaml.load(f, Loader=yaml.Loader)
  29. if mode != 'train' and slim_load_cfg['slim'] == 'Distill':
  30. return cfg
  31. if slim_load_cfg['slim'] == 'Distill':
  32. if "slim_method" in slim_load_cfg and slim_load_cfg[
  33. 'slim_method'] == "FGD":
  34. model = FGDDistillModel(cfg, slim_cfg)
  35. elif "slim_method" in slim_load_cfg and slim_load_cfg[
  36. 'slim_method'] == "LD":
  37. model = LDDistillModel(cfg, slim_cfg)
  38. elif "slim_method" in slim_load_cfg and slim_load_cfg[
  39. 'slim_method'] == "CWD":
  40. model = CWDDistillModel(cfg, slim_cfg)
  41. else:
  42. model = DistillModel(cfg, slim_cfg)
  43. cfg['model'] = model
  44. cfg['slim_type'] = cfg.slim
  45. elif slim_load_cfg['slim'] == 'OFA':
  46. load_config(slim_cfg)
  47. model = create(cfg.architecture)
  48. load_pretrain_weight(model, cfg.weights)
  49. slim = create(cfg.slim)
  50. cfg['slim'] = slim
  51. cfg['model'] = slim(model, model.state_dict())
  52. cfg['slim_type'] = cfg.slim
  53. elif slim_load_cfg['slim'] == 'DistillPrune':
  54. if mode == 'train':
  55. model = DistillModel(cfg, slim_cfg)
  56. pruner = create(cfg.pruner)
  57. pruner(model.student_model)
  58. else:
  59. model = create(cfg.architecture)
  60. weights = cfg.weights
  61. load_config(slim_cfg)
  62. pruner = create(cfg.pruner)
  63. model = pruner(model)
  64. load_pretrain_weight(model, weights)
  65. cfg['model'] = model
  66. cfg['slim_type'] = cfg.slim
  67. elif slim_load_cfg['slim'] == 'PTQ':
  68. model = create(cfg.architecture)
  69. load_config(slim_cfg)
  70. load_pretrain_weight(model, cfg.weights)
  71. slim = create(cfg.slim)
  72. cfg['slim'] = slim
  73. cfg['model'] = slim(model)
  74. cfg['slim_type'] = cfg.slim
  75. elif slim_load_cfg['slim'] == 'UnstructuredPruner':
  76. load_config(slim_cfg)
  77. slim = create(cfg.slim)
  78. cfg['slim_type'] = cfg.slim
  79. cfg['slim'] = slim
  80. cfg['unstructured_prune'] = True
  81. else:
  82. load_config(slim_cfg)
  83. model = create(cfg.architecture)
  84. if mode == 'train':
  85. load_pretrain_weight(model, cfg.pretrain_weights)
  86. slim = create(cfg.slim)
  87. cfg['slim_type'] = cfg.slim
  88. # TODO: fix quant export model in framework.
  89. if mode == 'test' and 'QAT' in slim_load_cfg['slim']:
  90. slim.quant_config['activation_preprocess_type'] = None
  91. cfg['model'] = slim(model)
  92. cfg['slim'] = slim
  93. if mode != 'train':
  94. load_pretrain_weight(cfg['model'], cfg.weights)
  95. return cfg