yaml_helpers.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib
  15. import inspect
  16. import yaml
  17. from .schema import SharedConfig
  18. __all__ = ['serializable', 'Callable']
  19. def represent_dictionary_order(self, dict_data):
  20. return self.represent_mapping('tag:yaml.org,2002:map', dict_data.items())
  21. def setup_orderdict():
  22. from collections import OrderedDict
  23. yaml.add_representer(OrderedDict, represent_dictionary_order)
  24. def _make_python_constructor(cls):
  25. def python_constructor(loader, node):
  26. if isinstance(node, yaml.SequenceNode):
  27. args = loader.construct_sequence(node, deep=True)
  28. return cls(*args)
  29. else:
  30. kwargs = loader.construct_mapping(node, deep=True)
  31. try:
  32. return cls(**kwargs)
  33. except Exception as ex:
  34. print("Error when construct {} instance from yaml config".
  35. format(cls.__name__))
  36. raise ex
  37. return python_constructor
  38. def _make_python_representer(cls):
  39. # python 2 compatibility
  40. if hasattr(inspect, 'getfullargspec'):
  41. argspec = inspect.getfullargspec(cls)
  42. else:
  43. argspec = inspect.getfullargspec(cls.__init__)
  44. argnames = [arg for arg in argspec.args if arg != 'self']
  45. def python_representer(dumper, obj):
  46. if argnames:
  47. data = {name: getattr(obj, name) for name in argnames}
  48. else:
  49. data = obj.__dict__
  50. if '_id' in data:
  51. del data['_id']
  52. return dumper.represent_mapping(u'!{}'.format(cls.__name__), data)
  53. return python_representer
  54. def serializable(cls):
  55. """
  56. Add loader and dumper for given class, which must be
  57. "trivially serializable"
  58. Args:
  59. cls: class to be serialized
  60. Returns: cls
  61. """
  62. yaml.add_constructor(u'!{}'.format(cls.__name__),
  63. _make_python_constructor(cls))
  64. yaml.add_representer(cls, _make_python_representer(cls))
  65. return cls
  66. yaml.add_representer(SharedConfig,
  67. lambda d, o: d.represent_data(o.default_value))
  68. @serializable
  69. class Callable(object):
  70. """
  71. Helper to be used in Yaml for creating arbitrary class objects
  72. Args:
  73. full_type (str): the full module path to target function
  74. """
  75. def __init__(self, full_type, args=[], kwargs={}):
  76. super(Callable, self).__init__()
  77. self.full_type = full_type
  78. self.args = args
  79. self.kwargs = kwargs
  80. def __call__(self):
  81. if '.' in self.full_type:
  82. idx = self.full_type.rfind('.')
  83. module = importlib.import_module(self.full_type[:idx])
  84. func_name = self.full_type[idx + 1:]
  85. else:
  86. try:
  87. module = importlib.import_module('builtins')
  88. except Exception:
  89. module = importlib.import_module('__builtin__')
  90. func_name = self.full_type
  91. func = getattr(module, func_name)
  92. return func(*self.args, **self.kwargs)