123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283 |
- # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import print_function
- from __future__ import division
- import importlib
- import os
- import sys
- import yaml
- import collections
- try:
- collectionsAbc = collections.abc
- except AttributeError:
- collectionsAbc = collections
- from .config.schema import SchemaDict, SharedConfig, extract_schema
- from .config.yaml_helpers import serializable
- __all__ = [
- 'global_config',
- 'load_config',
- 'merge_config',
- 'get_registered_modules',
- 'create',
- 'register',
- 'serializable',
- 'dump_value',
- ]
- def dump_value(value):
- # XXX this is hackish, but collections.abc is not available in python 2
- if hasattr(value, '__dict__') or isinstance(value, (dict, tuple, list)):
- value = yaml.dump(value, default_flow_style=True)
- value = value.replace('\n', '')
- value = value.replace('...', '')
- return "'{}'".format(value)
- else:
- # primitive types
- return str(value)
- class AttrDict(dict):
- """Single level attribute dict, NOT recursive"""
- def __init__(self, **kwargs):
- super(AttrDict, self).__init__()
- super(AttrDict, self).update(kwargs)
- def __getattr__(self, key):
- if key in self:
- return self[key]
- raise AttributeError("object has no attribute '{}'".format(key))
- global_config = AttrDict()
- BASE_KEY = '_BASE_'
- # parse and load _BASE_ recursively
- def _load_config_with_base(file_path):
- with open(file_path) as f:
- file_cfg = yaml.load(f, Loader=yaml.Loader)
- # NOTE: cfgs outside have higher priority than cfgs in _BASE_
- if BASE_KEY in file_cfg:
- all_base_cfg = AttrDict()
- base_ymls = list(file_cfg[BASE_KEY])
- for base_yml in base_ymls:
- if base_yml.startswith("~"):
- base_yml = os.path.expanduser(base_yml)
- if not base_yml.startswith('/'):
- base_yml = os.path.join(os.path.dirname(file_path), base_yml)
- with open(base_yml) as f:
- base_cfg = _load_config_with_base(base_yml)
- all_base_cfg = merge_config(base_cfg, all_base_cfg)
- del file_cfg[BASE_KEY]
- return merge_config(file_cfg, all_base_cfg)
- return file_cfg
- def load_config(file_path):
- """
- Load config from file.
- Args:
- file_path (str): Path of the config file to be loaded.
- Returns: global config
- """
- _, ext = os.path.splitext(file_path)
- assert ext in ['.yml', '.yaml'], "only support yaml files for now"
- # load config from file and merge into global config
- cfg = _load_config_with_base(file_path)
- cfg['filename'] = os.path.splitext(os.path.split(file_path)[-1])[0]
- merge_config(cfg)
- return global_config
- def dict_merge(dct, merge_dct):
- """ Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
- updating only top-level keys, dict_merge recurses down into dicts nested
- to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
- ``dct``.
- Args:
- dct: dict onto which the merge is executed
- merge_dct: dct merged into dct
- Returns: dct
- """
- for k, v in merge_dct.items():
- if (k in dct and isinstance(dct[k], dict) and
- isinstance(merge_dct[k], collectionsAbc.Mapping)):
- dict_merge(dct[k], merge_dct[k])
- else:
- dct[k] = merge_dct[k]
- return dct
- def merge_config(config, another_cfg=None):
- """
- Merge config into global config or another_cfg.
- Args:
- config (dict): Config to be merged.
- Returns: global config
- """
- global global_config
- dct = another_cfg or global_config
- return dict_merge(dct, config)
- def get_registered_modules():
- return {k: v for k, v in global_config.items() if isinstance(v, SchemaDict)}
- def make_partial(cls):
- op_module = importlib.import_module(cls.__op__.__module__)
- op = getattr(op_module, cls.__op__.__name__)
- cls.__category__ = getattr(cls, '__category__', None) or 'op'
- def partial_apply(self, *args, **kwargs):
- kwargs_ = self.__dict__.copy()
- kwargs_.update(kwargs)
- return op(*args, **kwargs_)
- if getattr(cls, '__append_doc__', True): # XXX should default to True?
- if sys.version_info[0] > 2:
- cls.__doc__ = "Wrapper for `{}` OP".format(op.__name__)
- cls.__init__.__doc__ = op.__doc__
- cls.__call__ = partial_apply
- cls.__call__.__doc__ = op.__doc__
- else:
- # XXX work around for python 2
- partial_apply.__doc__ = op.__doc__
- cls.__call__ = partial_apply
- return cls
- def register(cls):
- """
- Register a given module class.
- Args:
- cls (type): Module class to be registered.
- Returns: cls
- """
- if cls.__name__ in global_config:
- raise ValueError("Module class already registered: {}".format(
- cls.__name__))
- if hasattr(cls, '__op__'):
- cls = make_partial(cls)
- global_config[cls.__name__] = extract_schema(cls)
- return cls
- def create(cls_or_name, **kwargs):
- """
- Create an instance of given module class.
- Args:
- cls_or_name (type or str): Class of which to create instance.
- Returns: instance of type `cls_or_name`
- """
- assert type(cls_or_name) in [type, str
- ], "should be a class or name of a class"
- name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
- if name in global_config:
- if isinstance(global_config[name], SchemaDict):
- pass
- elif hasattr(global_config[name], "__dict__"):
- # support instance return directly
- return global_config[name]
- else:
- raise ValueError("The module {} is not registered".format(name))
- else:
- raise ValueError("The module {} is not registered".format(name))
- config = global_config[name]
- cls = getattr(config.pymodule, name)
- cls_kwargs = {}
- cls_kwargs.update(global_config[name])
- # parse `shared` annoation of registered modules
- if getattr(config, 'shared', None):
- for k in config.shared:
- target_key = config[k]
- shared_conf = config.schema[k].default
- assert isinstance(shared_conf, SharedConfig)
- if target_key is not None and not isinstance(target_key,
- SharedConfig):
- continue # value is given for the module
- elif shared_conf.key in global_config:
- # `key` is present in config
- cls_kwargs[k] = global_config[shared_conf.key]
- else:
- cls_kwargs[k] = shared_conf.default_value
- # parse `inject` annoation of registered modules
- if getattr(cls, 'from_config', None):
- cls_kwargs.update(cls.from_config(config, **kwargs))
- if getattr(config, 'inject', None):
- for k in config.inject:
- target_key = config[k]
- # optional dependency
- if target_key is None:
- continue
- if isinstance(target_key, dict) or hasattr(target_key, '__dict__'):
- if 'name' not in target_key.keys():
- continue
- inject_name = str(target_key['name'])
- if inject_name not in global_config:
- raise ValueError(
- "Missing injection name {} and check it's name in cfg file".
- format(k))
- target = global_config[inject_name]
- for i, v in target_key.items():
- if i == 'name':
- continue
- target[i] = v
- if isinstance(target, SchemaDict):
- cls_kwargs[k] = create(inject_name)
- elif isinstance(target_key, str):
- if target_key not in global_config:
- raise ValueError("Missing injection config:", target_key)
- target = global_config[target_key]
- if isinstance(target, SchemaDict):
- cls_kwargs[k] = create(target_key)
- elif hasattr(target, '__dict__'): # serialized object
- cls_kwargs[k] = target
- else:
- raise ValueError("Unsupported injection type:", target_key)
- # prevent modification of global config values of reference types
- # (e.g., list, dict) from within the created module instances
- #kwargs = copy.deepcopy(kwargs)
- return cls(**cls_kwargs)
|