workspace.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import print_function
  16. from __future__ import division
  17. import importlib
  18. import os
  19. import sys
  20. import yaml
  21. import collections
  22. try:
  23. collectionsAbc = collections.abc
  24. except AttributeError:
  25. collectionsAbc = collections
  26. from .config.schema import SchemaDict, SharedConfig, extract_schema
  27. from .config.yaml_helpers import serializable
  28. __all__ = [
  29. 'global_config',
  30. 'load_config',
  31. 'merge_config',
  32. 'get_registered_modules',
  33. 'create',
  34. 'register',
  35. 'serializable',
  36. 'dump_value',
  37. ]
  38. def dump_value(value):
  39. # XXX this is hackish, but collections.abc is not available in python 2
  40. if hasattr(value, '__dict__') or isinstance(value, (dict, tuple, list)):
  41. value = yaml.dump(value, default_flow_style=True)
  42. value = value.replace('\n', '')
  43. value = value.replace('...', '')
  44. return "'{}'".format(value)
  45. else:
  46. # primitive types
  47. return str(value)
  48. class AttrDict(dict):
  49. """Single level attribute dict, NOT recursive"""
  50. def __init__(self, **kwargs):
  51. super(AttrDict, self).__init__()
  52. super(AttrDict, self).update(kwargs)
  53. def __getattr__(self, key):
  54. if key in self:
  55. return self[key]
  56. raise AttributeError("object has no attribute '{}'".format(key))
  57. global_config = AttrDict()
  58. BASE_KEY = '_BASE_'
  59. # parse and load _BASE_ recursively
  60. def _load_config_with_base(file_path):
  61. with open(file_path) as f:
  62. file_cfg = yaml.load(f, Loader=yaml.Loader)
  63. # NOTE: cfgs outside have higher priority than cfgs in _BASE_
  64. if BASE_KEY in file_cfg:
  65. all_base_cfg = AttrDict()
  66. base_ymls = list(file_cfg[BASE_KEY])
  67. for base_yml in base_ymls:
  68. if base_yml.startswith("~"):
  69. base_yml = os.path.expanduser(base_yml)
  70. if not base_yml.startswith('/'):
  71. base_yml = os.path.join(os.path.dirname(file_path), base_yml)
  72. with open(base_yml) as f:
  73. base_cfg = _load_config_with_base(base_yml)
  74. all_base_cfg = merge_config(base_cfg, all_base_cfg)
  75. del file_cfg[BASE_KEY]
  76. return merge_config(file_cfg, all_base_cfg)
  77. return file_cfg
  78. def load_config(file_path):
  79. """
  80. Load config from file.
  81. Args:
  82. file_path (str): Path of the config file to be loaded.
  83. Returns: global config
  84. """
  85. _, ext = os.path.splitext(file_path)
  86. assert ext in ['.yml', '.yaml'], "only support yaml files for now"
  87. # load config from file and merge into global config
  88. cfg = _load_config_with_base(file_path)
  89. cfg['filename'] = os.path.splitext(os.path.split(file_path)[-1])[0]
  90. merge_config(cfg)
  91. return global_config
  92. def dict_merge(dct, merge_dct):
  93. """ Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
  94. updating only top-level keys, dict_merge recurses down into dicts nested
  95. to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
  96. ``dct``.
  97. Args:
  98. dct: dict onto which the merge is executed
  99. merge_dct: dct merged into dct
  100. Returns: dct
  101. """
  102. for k, v in merge_dct.items():
  103. if (k in dct and isinstance(dct[k], dict) and
  104. isinstance(merge_dct[k], collectionsAbc.Mapping)):
  105. dict_merge(dct[k], merge_dct[k])
  106. else:
  107. dct[k] = merge_dct[k]
  108. return dct
  109. def merge_config(config, another_cfg=None):
  110. """
  111. Merge config into global config or another_cfg.
  112. Args:
  113. config (dict): Config to be merged.
  114. Returns: global config
  115. """
  116. global global_config
  117. dct = another_cfg or global_config
  118. return dict_merge(dct, config)
  119. def get_registered_modules():
  120. return {k: v for k, v in global_config.items() if isinstance(v, SchemaDict)}
  121. def make_partial(cls):
  122. op_module = importlib.import_module(cls.__op__.__module__)
  123. op = getattr(op_module, cls.__op__.__name__)
  124. cls.__category__ = getattr(cls, '__category__', None) or 'op'
  125. def partial_apply(self, *args, **kwargs):
  126. kwargs_ = self.__dict__.copy()
  127. kwargs_.update(kwargs)
  128. return op(*args, **kwargs_)
  129. if getattr(cls, '__append_doc__', True): # XXX should default to True?
  130. if sys.version_info[0] > 2:
  131. cls.__doc__ = "Wrapper for `{}` OP".format(op.__name__)
  132. cls.__init__.__doc__ = op.__doc__
  133. cls.__call__ = partial_apply
  134. cls.__call__.__doc__ = op.__doc__
  135. else:
  136. # XXX work around for python 2
  137. partial_apply.__doc__ = op.__doc__
  138. cls.__call__ = partial_apply
  139. return cls
  140. def register(cls):
  141. """
  142. Register a given module class.
  143. Args:
  144. cls (type): Module class to be registered.
  145. Returns: cls
  146. """
  147. if cls.__name__ in global_config:
  148. raise ValueError("Module class already registered: {}".format(
  149. cls.__name__))
  150. if hasattr(cls, '__op__'):
  151. cls = make_partial(cls)
  152. global_config[cls.__name__] = extract_schema(cls)
  153. return cls
  154. def create(cls_or_name, **kwargs):
  155. """
  156. Create an instance of given module class.
  157. Args:
  158. cls_or_name (type or str): Class of which to create instance.
  159. Returns: instance of type `cls_or_name`
  160. """
  161. assert type(cls_or_name) in [type, str
  162. ], "should be a class or name of a class"
  163. name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
  164. if name in global_config:
  165. if isinstance(global_config[name], SchemaDict):
  166. pass
  167. elif hasattr(global_config[name], "__dict__"):
  168. # support instance return directly
  169. return global_config[name]
  170. else:
  171. raise ValueError("The module {} is not registered".format(name))
  172. else:
  173. raise ValueError("The module {} is not registered".format(name))
  174. config = global_config[name]
  175. cls = getattr(config.pymodule, name)
  176. cls_kwargs = {}
  177. cls_kwargs.update(global_config[name])
  178. # parse `shared` annoation of registered modules
  179. if getattr(config, 'shared', None):
  180. for k in config.shared:
  181. target_key = config[k]
  182. shared_conf = config.schema[k].default
  183. assert isinstance(shared_conf, SharedConfig)
  184. if target_key is not None and not isinstance(target_key,
  185. SharedConfig):
  186. continue # value is given for the module
  187. elif shared_conf.key in global_config:
  188. # `key` is present in config
  189. cls_kwargs[k] = global_config[shared_conf.key]
  190. else:
  191. cls_kwargs[k] = shared_conf.default_value
  192. # parse `inject` annoation of registered modules
  193. if getattr(cls, 'from_config', None):
  194. cls_kwargs.update(cls.from_config(config, **kwargs))
  195. if getattr(config, 'inject', None):
  196. for k in config.inject:
  197. target_key = config[k]
  198. # optional dependency
  199. if target_key is None:
  200. continue
  201. if isinstance(target_key, dict) or hasattr(target_key, '__dict__'):
  202. if 'name' not in target_key.keys():
  203. continue
  204. inject_name = str(target_key['name'])
  205. if inject_name not in global_config:
  206. raise ValueError(
  207. "Missing injection name {} and check it's name in cfg file".
  208. format(k))
  209. target = global_config[inject_name]
  210. for i, v in target_key.items():
  211. if i == 'name':
  212. continue
  213. target[i] = v
  214. if isinstance(target, SchemaDict):
  215. cls_kwargs[k] = create(inject_name)
  216. elif isinstance(target_key, str):
  217. if target_key not in global_config:
  218. raise ValueError("Missing injection config:", target_key)
  219. target = global_config[target_key]
  220. if isinstance(target, SchemaDict):
  221. cls_kwargs[k] = create(target_key)
  222. elif hasattr(target, '__dict__'): # serialized object
  223. cls_kwargs[k] = target
  224. else:
  225. raise ValueError("Unsupported injection type:", target_key)
  226. # prevent modification of global config values of reference types
  227. # (e.g., list, dict) from within the created module instances
  228. #kwargs = copy.deepcopy(kwargs)
  229. return cls(**cls_kwargs)