cli.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from argparse import ArgumentParser, RawDescriptionHelpFormatter
  15. import yaml
  16. import re
  17. from ppdet.core.workspace import get_registered_modules, dump_value
  18. __all__ = ['ColorTTY', 'ArgsParser']
  19. class ColorTTY(object):
  20. def __init__(self):
  21. super(ColorTTY, self).__init__()
  22. self.colors = ['red', 'green', 'yellow', 'blue', 'magenta', 'cyan']
  23. def __getattr__(self, attr):
  24. if attr in self.colors:
  25. color = self.colors.index(attr) + 31
  26. def color_message(message):
  27. return "[{}m{}".format(color, message)
  28. setattr(self, attr, color_message)
  29. return color_message
  30. def bold(self, message):
  31. return self.with_code('01', message)
  32. def with_code(self, code, message):
  33. return "[{}m{}".format(code, message)
  34. class ArgsParser(ArgumentParser):
  35. def __init__(self):
  36. super(ArgsParser, self).__init__(
  37. formatter_class=RawDescriptionHelpFormatter)
  38. self.add_argument("-c", "--config", help="configuration file to use")
  39. self.add_argument(
  40. "-o", "--opt", nargs='*', help="set configuration options")
  41. def parse_args(self, argv=None):
  42. args = super(ArgsParser, self).parse_args(argv)
  43. assert args.config is not None, \
  44. "Please specify --config=configure_file_path."
  45. args.opt = self._parse_opt(args.opt)
  46. return args
  47. def _parse_opt(self, opts):
  48. config = {}
  49. if not opts:
  50. return config
  51. for s in opts:
  52. s = s.strip()
  53. k, v = s.split('=', 1)
  54. if '.' not in k:
  55. config[k] = yaml.load(v, Loader=yaml.Loader)
  56. else:
  57. keys = k.split('.')
  58. if keys[0] not in config:
  59. config[keys[0]] = {}
  60. cur = config[keys[0]]
  61. for idx, key in enumerate(keys[1:]):
  62. if idx == len(keys) - 2:
  63. cur[key] = yaml.load(v, Loader=yaml.Loader)
  64. else:
  65. cur[key] = {}
  66. cur = cur[key]
  67. return config
  68. def merge_args(config, args, exclude_args=['config', 'opt', 'slim_config']):
  69. for k, v in vars(args).items():
  70. if k not in exclude_args:
  71. config[k] = v
  72. return config
  73. def print_total_cfg(config):
  74. modules = get_registered_modules()
  75. color_tty = ColorTTY()
  76. green = '___{}___'.format(color_tty.colors.index('green') + 31)
  77. styled = {}
  78. for key in config.keys():
  79. if not config[key]: # empty schema
  80. continue
  81. if key not in modules and not hasattr(config[key], '__dict__'):
  82. styled[key] = config[key]
  83. continue
  84. elif key in modules:
  85. module = modules[key]
  86. else:
  87. type_name = type(config[key]).__name__
  88. if type_name in modules:
  89. module = modules[type_name].copy()
  90. module.update({
  91. k: v
  92. for k, v in config[key].__dict__.items()
  93. if k in module.schema
  94. })
  95. key += " ({})".format(type_name)
  96. default = module.find_default_keys()
  97. missing = module.find_missing_keys()
  98. mismatch = module.find_mismatch_keys()
  99. extra = module.find_extra_keys()
  100. dep_missing = []
  101. for dep in module.inject:
  102. if isinstance(module[dep], str) and module[dep] != '<value>':
  103. if module[dep] not in modules: # not a valid module
  104. dep_missing.append(dep)
  105. else:
  106. dep_mod = modules[module[dep]]
  107. # empty dict but mandatory
  108. if not dep_mod and dep_mod.mandatory():
  109. dep_missing.append(dep)
  110. override = list(
  111. set(module.keys()) - set(default) - set(extra) - set(dep_missing))
  112. replacement = {}
  113. for name in set(override + default + extra + mismatch + missing):
  114. new_name = name
  115. if name in missing:
  116. value = "<missing>"
  117. else:
  118. value = module[name]
  119. if name in extra:
  120. value = dump_value(value) + " <extraneous>"
  121. elif name in mismatch:
  122. value = dump_value(value) + " <type mismatch>"
  123. elif name in dep_missing:
  124. value = dump_value(value) + " <module config missing>"
  125. elif name in override and value != '<missing>':
  126. mark = green
  127. new_name = mark + name
  128. replacement[new_name] = value
  129. styled[key] = replacement
  130. buffer = yaml.dump(styled, default_flow_style=False, default_style='')
  131. buffer = (re.sub(r"<missing>", r"[31m<missing>[0m", buffer))
  132. buffer = (re.sub(r"<extraneous>", r"[33m<extraneous>[0m", buffer))
  133. buffer = (re.sub(r"<type mismatch>", r"[31m<type mismatch>[0m", buffer))
  134. buffer = (re.sub(r"<module config missing>",
  135. r"[31m<module config missing>[0m", buffer))
  136. buffer = re.sub(r"___(\d+)___(.*?):", r"[\1m\2[0m:", buffer)
  137. print(buffer)