123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from . import prune
- from . import quant
- from . import distill
- from . import unstructured_prune
- from .prune import *
- from .quant import *
- from .distill import *
- from .unstructured_prune import *
- from .ofa import *
- import yaml
- from ppdet.core.workspace import load_config
- from ppdet.utils.checkpoint import load_pretrain_weight
- def build_slim_model(cfg, slim_cfg, mode='train'):
- with open(slim_cfg) as f:
- slim_load_cfg = yaml.load(f, Loader=yaml.Loader)
- if mode != 'train' and slim_load_cfg['slim'] == 'Distill':
- return cfg
- if slim_load_cfg['slim'] == 'Distill':
- if "slim_method" in slim_load_cfg and slim_load_cfg[
- 'slim_method'] == "FGD":
- model = FGDDistillModel(cfg, slim_cfg)
- elif "slim_method" in slim_load_cfg and slim_load_cfg[
- 'slim_method'] == "LD":
- model = LDDistillModel(cfg, slim_cfg)
- elif "slim_method" in slim_load_cfg and slim_load_cfg[
- 'slim_method'] == "CWD":
- model = CWDDistillModel(cfg, slim_cfg)
- else:
- model = DistillModel(cfg, slim_cfg)
- cfg['model'] = model
- cfg['slim_type'] = cfg.slim
- elif slim_load_cfg['slim'] == 'OFA':
- load_config(slim_cfg)
- model = create(cfg.architecture)
- load_pretrain_weight(model, cfg.weights)
- slim = create(cfg.slim)
- cfg['slim'] = slim
- cfg['model'] = slim(model, model.state_dict())
- cfg['slim_type'] = cfg.slim
- elif slim_load_cfg['slim'] == 'DistillPrune':
- if mode == 'train':
- model = DistillModel(cfg, slim_cfg)
- pruner = create(cfg.pruner)
- pruner(model.student_model)
- else:
- model = create(cfg.architecture)
- weights = cfg.weights
- load_config(slim_cfg)
- pruner = create(cfg.pruner)
- model = pruner(model)
- load_pretrain_weight(model, weights)
- cfg['model'] = model
- cfg['slim_type'] = cfg.slim
- elif slim_load_cfg['slim'] == 'PTQ':
- model = create(cfg.architecture)
- load_config(slim_cfg)
- load_pretrain_weight(model, cfg.weights)
- slim = create(cfg.slim)
- cfg['slim'] = slim
- cfg['model'] = slim(model)
- cfg['slim_type'] = cfg.slim
- elif slim_load_cfg['slim'] == 'UnstructuredPruner':
- load_config(slim_cfg)
- slim = create(cfg.slim)
- cfg['slim_type'] = cfg.slim
- cfg['slim'] = slim
- cfg['unstructured_prune'] = True
- else:
- load_config(slim_cfg)
- model = create(cfg.architecture)
- if mode == 'train':
- load_pretrain_weight(model, cfg.pretrain_weights)
- slim = create(cfg.slim)
- cfg['slim_type'] = cfg.slim
- # TODO: fix quant export model in framework.
- if mode == 'test' and 'QAT' in slim_load_cfg['slim']:
- slim.quant_config['activation_preprocess_type'] = None
- cfg['model'] = slim(model)
- cfg['slim'] = slim
- if mode != 'train':
- load_pretrain_weight(cfg['model'], cfg.weights)
- return cfg
|