schema.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import print_function
  16. from __future__ import division
  17. import inspect
  18. import importlib
  19. import re
  20. try:
  21. from docstring_parser import parse as doc_parse
  22. except Exception:
  23. def doc_parse(*args):
  24. pass
  25. try:
  26. from typeguard import check_type
  27. except Exception:
  28. def check_type(*args):
  29. pass
  30. __all__ = ['SchemaValue', 'SchemaDict', 'SharedConfig', 'extract_schema']
  31. class SchemaValue(object):
  32. def __init__(self, name, doc='', type=None):
  33. super(SchemaValue, self).__init__()
  34. self.name = name
  35. self.doc = doc
  36. self.type = type
  37. def set_default(self, value):
  38. self.default = value
  39. def has_default(self):
  40. return hasattr(self, 'default')
  41. class SchemaDict(dict):
  42. def __init__(self, **kwargs):
  43. super(SchemaDict, self).__init__()
  44. self.schema = {}
  45. self.strict = False
  46. self.doc = ""
  47. self.update(kwargs)
  48. def __setitem__(self, key, value):
  49. # XXX also update regular dict to SchemaDict??
  50. if isinstance(value, dict) and key in self and isinstance(self[key],
  51. SchemaDict):
  52. self[key].update(value)
  53. else:
  54. super(SchemaDict, self).__setitem__(key, value)
  55. def __missing__(self, key):
  56. if self.has_default(key):
  57. return self.schema[key].default
  58. elif key in self.schema:
  59. return self.schema[key]
  60. else:
  61. raise KeyError(key)
  62. def copy(self):
  63. newone = SchemaDict()
  64. newone.__dict__.update(self.__dict__)
  65. newone.update(self)
  66. return newone
  67. def set_schema(self, key, value):
  68. assert isinstance(value, SchemaValue)
  69. self.schema[key] = value
  70. def set_strict(self, strict):
  71. self.strict = strict
  72. def has_default(self, key):
  73. return key in self.schema and self.schema[key].has_default()
  74. def is_default(self, key):
  75. if not self.has_default(key):
  76. return False
  77. if hasattr(self[key], '__dict__'):
  78. return True
  79. else:
  80. return key not in self or self[key] == self.schema[key].default
  81. def find_default_keys(self):
  82. return [
  83. k for k in list(self.keys()) + list(self.schema.keys())
  84. if self.is_default(k)
  85. ]
  86. def mandatory(self):
  87. return any([k for k in self.schema.keys() if not self.has_default(k)])
  88. def find_missing_keys(self):
  89. missing = [
  90. k for k in self.schema.keys()
  91. if k not in self and not self.has_default(k)
  92. ]
  93. placeholders = [k for k in self if self[k] in ('<missing>', '<value>')]
  94. return missing + placeholders
  95. def find_extra_keys(self):
  96. return list(set(self.keys()) - set(self.schema.keys()))
  97. def find_mismatch_keys(self):
  98. mismatch_keys = []
  99. for arg in self.schema.values():
  100. if arg.type is not None:
  101. try:
  102. check_type("{}.{}".format(self.name, arg.name),
  103. self[arg.name], arg.type)
  104. except Exception:
  105. mismatch_keys.append(arg.name)
  106. return mismatch_keys
  107. def validate(self):
  108. missing_keys = self.find_missing_keys()
  109. if missing_keys:
  110. raise ValueError("Missing param for class<{}>: {}".format(
  111. self.name, ", ".join(missing_keys)))
  112. extra_keys = self.find_extra_keys()
  113. if extra_keys and self.strict:
  114. raise ValueError("Extraneous param for class<{}>: {}".format(
  115. self.name, ", ".join(extra_keys)))
  116. mismatch_keys = self.find_mismatch_keys()
  117. if mismatch_keys:
  118. raise TypeError("Wrong param type for class<{}>: {}".format(
  119. self.name, ", ".join(mismatch_keys)))
  120. class SharedConfig(object):
  121. """
  122. Representation class for `__shared__` annotations, which work as follows:
  123. - if `key` is set for the module in config file, its value will take
  124. precedence
  125. - if `key` is not set for the module but present in the config file, its
  126. value will be used
  127. - otherwise, use the provided `default_value` as fallback
  128. Args:
  129. key: config[key] will be injected
  130. default_value: fallback value
  131. """
  132. def __init__(self, key, default_value=None):
  133. super(SharedConfig, self).__init__()
  134. self.key = key
  135. self.default_value = default_value
  136. def extract_schema(cls):
  137. """
  138. Extract schema from a given class
  139. Args:
  140. cls (type): Class from which to extract.
  141. Returns:
  142. schema (SchemaDict): Extracted schema.
  143. """
  144. ctor = cls.__init__
  145. # python 2 compatibility
  146. if hasattr(inspect, 'getfullargspec'):
  147. argspec = inspect.getfullargspec(ctor)
  148. annotations = argspec.annotations
  149. has_kwargs = argspec.varkw is not None
  150. else:
  151. argspec = inspect.getfullargspec(ctor)
  152. # python 2 type hinting workaround, see pep-3107
  153. # however, since `typeguard` does not support python 2, type checking
  154. # is still python 3 only for now
  155. annotations = getattr(ctor, '__annotations__', {})
  156. has_kwargs = argspec.varkw is not None
  157. names = [arg for arg in argspec.args if arg != 'self']
  158. defaults = argspec.defaults
  159. num_defaults = argspec.defaults is not None and len(argspec.defaults) or 0
  160. num_required = len(names) - num_defaults
  161. docs = cls.__doc__
  162. if docs is None and getattr(cls, '__category__', None) == 'op':
  163. docs = cls.__call__.__doc__
  164. try:
  165. docstring = doc_parse(docs)
  166. except Exception:
  167. docstring = None
  168. if docstring is None:
  169. comments = {}
  170. else:
  171. comments = {}
  172. for p in docstring.params:
  173. match_obj = re.match('^([a-zA-Z_]+[a-zA-Z_0-9]*).*', p.arg_name)
  174. if match_obj is not None:
  175. comments[match_obj.group(1)] = p.description
  176. schema = SchemaDict()
  177. schema.name = cls.__name__
  178. schema.doc = ""
  179. if docs is not None:
  180. start_pos = docs[0] == '\n' and 1 or 0
  181. schema.doc = docs[start_pos:].split("\n")[0].strip()
  182. # XXX handle paddle's weird doc convention
  183. if '**' == schema.doc[:2] and '**' == schema.doc[-2:]:
  184. schema.doc = schema.doc[2:-2].strip()
  185. schema.category = hasattr(cls, '__category__') and getattr(
  186. cls, '__category__') or 'module'
  187. schema.strict = not has_kwargs
  188. schema.pymodule = importlib.import_module(cls.__module__)
  189. schema.inject = getattr(cls, '__inject__', [])
  190. schema.shared = getattr(cls, '__shared__', [])
  191. for idx, name in enumerate(names):
  192. comment = name in comments and comments[name] or name
  193. if name in schema.inject:
  194. type_ = None
  195. else:
  196. type_ = name in annotations and annotations[name] or None
  197. value_schema = SchemaValue(name, comment, type_)
  198. if name in schema.shared:
  199. assert idx >= num_required, "shared config must have default value"
  200. default = defaults[idx - num_required]
  201. value_schema.set_default(SharedConfig(name, default))
  202. elif idx >= num_required:
  203. default = defaults[idx - num_required]
  204. value_schema.set_default(default)
  205. schema.set_schema(name, value_schema)
  206. return schema