123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import importlib
- import inspect
- import yaml
- from .schema import SharedConfig
- __all__ = ['serializable', 'Callable']
- def represent_dictionary_order(self, dict_data):
- return self.represent_mapping('tag:yaml.org,2002:map', dict_data.items())
- def setup_orderdict():
- from collections import OrderedDict
- yaml.add_representer(OrderedDict, represent_dictionary_order)
- def _make_python_constructor(cls):
- def python_constructor(loader, node):
- if isinstance(node, yaml.SequenceNode):
- args = loader.construct_sequence(node, deep=True)
- return cls(*args)
- else:
- kwargs = loader.construct_mapping(node, deep=True)
- try:
- return cls(**kwargs)
- except Exception as ex:
- print("Error when construct {} instance from yaml config".
- format(cls.__name__))
- raise ex
- return python_constructor
- def _make_python_representer(cls):
- # python 2 compatibility
- if hasattr(inspect, 'getfullargspec'):
- argspec = inspect.getfullargspec(cls)
- else:
- argspec = inspect.getfullargspec(cls.__init__)
- argnames = [arg for arg in argspec.args if arg != 'self']
- def python_representer(dumper, obj):
- if argnames:
- data = {name: getattr(obj, name) for name in argnames}
- else:
- data = obj.__dict__
- if '_id' in data:
- del data['_id']
- return dumper.represent_mapping(u'!{}'.format(cls.__name__), data)
- return python_representer
- def serializable(cls):
- """
- Add loader and dumper for given class, which must be
- "trivially serializable"
- Args:
- cls: class to be serialized
- Returns: cls
- """
- yaml.add_constructor(u'!{}'.format(cls.__name__),
- _make_python_constructor(cls))
- yaml.add_representer(cls, _make_python_representer(cls))
- return cls
- yaml.add_representer(SharedConfig,
- lambda d, o: d.represent_data(o.default_value))
- @serializable
- class Callable(object):
- """
- Helper to be used in Yaml for creating arbitrary class objects
- Args:
- full_type (str): the full module path to target function
- """
- def __init__(self, full_type, args=[], kwargs={}):
- super(Callable, self).__init__()
- self.full_type = full_type
- self.args = args
- self.kwargs = kwargs
- def __call__(self):
- if '.' in self.full_type:
- idx = self.full_type.rfind('.')
- module = importlib.import_module(self.full_type[:idx])
- func_name = self.full_type[idx + 1:]
- else:
- try:
- module = importlib.import_module('builtins')
- except Exception:
- module = importlib.import_module('__builtin__')
- func_name = self.full_type
- func = getattr(module, func_name)
- return func(*self.args, **self.kwargs)
|