123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248 |
- # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import print_function
- from __future__ import division
- import inspect
- import importlib
- import re
- try:
- from docstring_parser import parse as doc_parse
- except Exception:
- def doc_parse(*args):
- pass
- try:
- from typeguard import check_type
- except Exception:
- def check_type(*args):
- pass
- __all__ = ['SchemaValue', 'SchemaDict', 'SharedConfig', 'extract_schema']
- class SchemaValue(object):
- def __init__(self, name, doc='', type=None):
- super(SchemaValue, self).__init__()
- self.name = name
- self.doc = doc
- self.type = type
- def set_default(self, value):
- self.default = value
- def has_default(self):
- return hasattr(self, 'default')
- class SchemaDict(dict):
- def __init__(self, **kwargs):
- super(SchemaDict, self).__init__()
- self.schema = {}
- self.strict = False
- self.doc = ""
- self.update(kwargs)
- def __setitem__(self, key, value):
- # XXX also update regular dict to SchemaDict??
- if isinstance(value, dict) and key in self and isinstance(self[key],
- SchemaDict):
- self[key].update(value)
- else:
- super(SchemaDict, self).__setitem__(key, value)
- def __missing__(self, key):
- if self.has_default(key):
- return self.schema[key].default
- elif key in self.schema:
- return self.schema[key]
- else:
- raise KeyError(key)
- def copy(self):
- newone = SchemaDict()
- newone.__dict__.update(self.__dict__)
- newone.update(self)
- return newone
- def set_schema(self, key, value):
- assert isinstance(value, SchemaValue)
- self.schema[key] = value
- def set_strict(self, strict):
- self.strict = strict
- def has_default(self, key):
- return key in self.schema and self.schema[key].has_default()
- def is_default(self, key):
- if not self.has_default(key):
- return False
- if hasattr(self[key], '__dict__'):
- return True
- else:
- return key not in self or self[key] == self.schema[key].default
- def find_default_keys(self):
- return [
- k for k in list(self.keys()) + list(self.schema.keys())
- if self.is_default(k)
- ]
- def mandatory(self):
- return any([k for k in self.schema.keys() if not self.has_default(k)])
- def find_missing_keys(self):
- missing = [
- k for k in self.schema.keys()
- if k not in self and not self.has_default(k)
- ]
- placeholders = [k for k in self if self[k] in ('<missing>', '<value>')]
- return missing + placeholders
- def find_extra_keys(self):
- return list(set(self.keys()) - set(self.schema.keys()))
- def find_mismatch_keys(self):
- mismatch_keys = []
- for arg in self.schema.values():
- if arg.type is not None:
- try:
- check_type("{}.{}".format(self.name, arg.name),
- self[arg.name], arg.type)
- except Exception:
- mismatch_keys.append(arg.name)
- return mismatch_keys
- def validate(self):
- missing_keys = self.find_missing_keys()
- if missing_keys:
- raise ValueError("Missing param for class<{}>: {}".format(
- self.name, ", ".join(missing_keys)))
- extra_keys = self.find_extra_keys()
- if extra_keys and self.strict:
- raise ValueError("Extraneous param for class<{}>: {}".format(
- self.name, ", ".join(extra_keys)))
- mismatch_keys = self.find_mismatch_keys()
- if mismatch_keys:
- raise TypeError("Wrong param type for class<{}>: {}".format(
- self.name, ", ".join(mismatch_keys)))
- class SharedConfig(object):
- """
- Representation class for `__shared__` annotations, which work as follows:
- - if `key` is set for the module in config file, its value will take
- precedence
- - if `key` is not set for the module but present in the config file, its
- value will be used
- - otherwise, use the provided `default_value` as fallback
- Args:
- key: config[key] will be injected
- default_value: fallback value
- """
- def __init__(self, key, default_value=None):
- super(SharedConfig, self).__init__()
- self.key = key
- self.default_value = default_value
- def extract_schema(cls):
- """
- Extract schema from a given class
- Args:
- cls (type): Class from which to extract.
- Returns:
- schema (SchemaDict): Extracted schema.
- """
- ctor = cls.__init__
- # python 2 compatibility
- if hasattr(inspect, 'getfullargspec'):
- argspec = inspect.getfullargspec(ctor)
- annotations = argspec.annotations
- has_kwargs = argspec.varkw is not None
- else:
- argspec = inspect.getfullargspec(ctor)
- # python 2 type hinting workaround, see pep-3107
- # however, since `typeguard` does not support python 2, type checking
- # is still python 3 only for now
- annotations = getattr(ctor, '__annotations__', {})
- has_kwargs = argspec.varkw is not None
- names = [arg for arg in argspec.args if arg != 'self']
- defaults = argspec.defaults
- num_defaults = argspec.defaults is not None and len(argspec.defaults) or 0
- num_required = len(names) - num_defaults
- docs = cls.__doc__
- if docs is None and getattr(cls, '__category__', None) == 'op':
- docs = cls.__call__.__doc__
- try:
- docstring = doc_parse(docs)
- except Exception:
- docstring = None
- if docstring is None:
- comments = {}
- else:
- comments = {}
- for p in docstring.params:
- match_obj = re.match('^([a-zA-Z_]+[a-zA-Z_0-9]*).*', p.arg_name)
- if match_obj is not None:
- comments[match_obj.group(1)] = p.description
- schema = SchemaDict()
- schema.name = cls.__name__
- schema.doc = ""
- if docs is not None:
- start_pos = docs[0] == '\n' and 1 or 0
- schema.doc = docs[start_pos:].split("\n")[0].strip()
- # XXX handle paddle's weird doc convention
- if '**' == schema.doc[:2] and '**' == schema.doc[-2:]:
- schema.doc = schema.doc[2:-2].strip()
- schema.category = hasattr(cls, '__category__') and getattr(
- cls, '__category__') or 'module'
- schema.strict = not has_kwargs
- schema.pymodule = importlib.import_module(cls.__module__)
- schema.inject = getattr(cls, '__inject__', [])
- schema.shared = getattr(cls, '__shared__', [])
- for idx, name in enumerate(names):
- comment = name in comments and comments[name] or name
- if name in schema.inject:
- type_ = None
- else:
- type_ = name in annotations and annotations[name] or None
- value_schema = SchemaValue(name, comment, type_)
- if name in schema.shared:
- assert idx >= num_required, "shared config must have default value"
- default = defaults[idx - num_required]
- value_schema.set_default(SharedConfig(name, default))
- elif idx >= num_required:
- default = defaults[idx - num_required]
- value_schema.set_default(default)
- schema.set_schema(name, value_schema)
- return schema
|