generate_multi_language_configs.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import yaml
  15. from argparse import ArgumentParser, RawDescriptionHelpFormatter
  16. import os.path
  17. import logging
  18. logging.basicConfig(level=logging.INFO)
  19. support_list = {
  20. 'it': 'italian',
  21. 'xi': 'spanish',
  22. 'pu': 'portuguese',
  23. 'ru': 'russian',
  24. 'ar': 'arabic',
  25. 'ta': 'tamil',
  26. 'ug': 'uyghur',
  27. 'fa': 'persian',
  28. 'ur': 'urdu',
  29. 'rs': 'serbian latin',
  30. 'oc': 'occitan',
  31. 'rsc': 'serbian cyrillic',
  32. 'bg': 'bulgarian',
  33. 'uk': 'ukranian',
  34. 'be': 'belarusian',
  35. 'te': 'telugu',
  36. 'ka': 'kannada',
  37. 'chinese_cht': 'chinese tradition',
  38. 'hi': 'hindi',
  39. 'mr': 'marathi',
  40. 'ne': 'nepali',
  41. }
  42. latin_lang = [
  43. 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr',
  44. 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
  45. 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
  46. 'sw', 'tl', 'tr', 'uz', 'vi', 'latin'
  47. ]
  48. arabic_lang = ['ar', 'fa', 'ug', 'ur']
  49. cyrillic_lang = [
  50. 'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava',
  51. 'dar', 'inh', 'che', 'lbe', 'lez', 'tab', 'cyrillic'
  52. ]
  53. devanagari_lang = [
  54. 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom',
  55. 'sa', 'bgc', 'devanagari'
  56. ]
  57. multi_lang = latin_lang + arabic_lang + cyrillic_lang + devanagari_lang
  58. assert (os.path.isfile("./rec_multi_language_lite_train.yml")
  59. ), "Loss basic configuration file rec_multi_language_lite_train.yml.\
  60. You can download it from \
  61. https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/"
  62. global_config = yaml.load(
  63. open("./rec_multi_language_lite_train.yml", 'rb'), Loader=yaml.Loader)
  64. project_path = os.path.abspath(os.path.join(os.getcwd(), "../../../"))
  65. class ArgsParser(ArgumentParser):
  66. def __init__(self):
  67. super(ArgsParser, self).__init__(
  68. formatter_class=RawDescriptionHelpFormatter)
  69. self.add_argument(
  70. "-o", "--opt", nargs='+', help="set configuration options")
  71. self.add_argument(
  72. "-l",
  73. "--language",
  74. nargs='+',
  75. help="set language type, support {}".format(support_list))
  76. self.add_argument(
  77. "--train",
  78. type=str,
  79. help="you can use this command to change the train dataset default path"
  80. )
  81. self.add_argument(
  82. "--val",
  83. type=str,
  84. help="you can use this command to change the eval dataset default path"
  85. )
  86. self.add_argument(
  87. "--dict",
  88. type=str,
  89. help="you can use this command to change the dictionary default path"
  90. )
  91. self.add_argument(
  92. "--data_dir",
  93. type=str,
  94. help="you can use this command to change the dataset default root path"
  95. )
  96. def parse_args(self, argv=None):
  97. args = super(ArgsParser, self).parse_args(argv)
  98. args.opt = self._parse_opt(args.opt)
  99. args.language = self._set_language(args.language)
  100. return args
  101. def _parse_opt(self, opts):
  102. config = {}
  103. if not opts:
  104. return config
  105. for s in opts:
  106. s = s.strip()
  107. k, v = s.split('=')
  108. config[k] = yaml.load(v, Loader=yaml.Loader)
  109. return config
  110. def _set_language(self, type):
  111. lang = type[0]
  112. assert (type), "please use -l or --language to choose language type"
  113. assert(
  114. lang in support_list.keys() or lang in multi_lang
  115. ),"the sub_keys(-l or --language) can only be one of support list: \n{},\nbut get: {}, " \
  116. "please check your running command".format(multi_lang, type)
  117. if lang in latin_lang:
  118. lang = "latin"
  119. elif lang in arabic_lang:
  120. lang = "arabic"
  121. elif lang in cyrillic_lang:
  122. lang = "cyrillic"
  123. elif lang in devanagari_lang:
  124. lang = "devanagari"
  125. global_config['Global'][
  126. 'character_dict_path'] = 'ppocr/utils/dict/{}_dict.txt'.format(lang)
  127. global_config['Global'][
  128. 'save_model_dir'] = './output/rec_{}_lite'.format(lang)
  129. global_config['Train']['dataset'][
  130. 'label_file_list'] = ["train_data/{}_train.txt".format(lang)]
  131. global_config['Eval']['dataset'][
  132. 'label_file_list'] = ["train_data/{}_val.txt".format(lang)]
  133. global_config['Global']['character_type'] = lang
  134. assert (
  135. os.path.isfile(
  136. os.path.join(project_path, global_config['Global'][
  137. 'character_dict_path']))
  138. ), "Loss default dictionary file {}_dict.txt.You can download it from \
  139. https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/ppocr/utils/dict/".format(
  140. lang)
  141. return lang
  142. def merge_config(config):
  143. """
  144. Merge config into global config.
  145. Args:
  146. config (dict): Config to be merged.
  147. Returns: global config
  148. """
  149. for key, value in config.items():
  150. if "." not in key:
  151. if isinstance(value, dict) and key in global_config:
  152. global_config[key].update(value)
  153. else:
  154. global_config[key] = value
  155. else:
  156. sub_keys = key.split('.')
  157. assert (
  158. sub_keys[0] in global_config
  159. ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
  160. global_config.keys(), sub_keys[0])
  161. cur = global_config[sub_keys[0]]
  162. for idx, sub_key in enumerate(sub_keys[1:]):
  163. if idx == len(sub_keys) - 2:
  164. cur[sub_key] = value
  165. else:
  166. cur = cur[sub_key]
  167. def loss_file(path):
  168. assert (
  169. os.path.exists(path)
  170. ), "There is no such file:{},Please do not forget to put in the specified file".format(
  171. path)
  172. if __name__ == '__main__':
  173. FLAGS = ArgsParser().parse_args()
  174. merge_config(FLAGS.opt)
  175. save_file_path = 'rec_{}_lite_train.yml'.format(FLAGS.language)
  176. if os.path.isfile(save_file_path):
  177. os.remove(save_file_path)
  178. if FLAGS.train:
  179. global_config['Train']['dataset']['label_file_list'] = [FLAGS.train]
  180. train_label_path = os.path.join(project_path, FLAGS.train)
  181. loss_file(train_label_path)
  182. if FLAGS.val:
  183. global_config['Eval']['dataset']['label_file_list'] = [FLAGS.val]
  184. eval_label_path = os.path.join(project_path, FLAGS.val)
  185. loss_file(eval_label_path)
  186. if FLAGS.dict:
  187. global_config['Global']['character_dict_path'] = FLAGS.dict
  188. dict_path = os.path.join(project_path, FLAGS.dict)
  189. loss_file(dict_path)
  190. if FLAGS.data_dir:
  191. global_config['Eval']['dataset']['data_dir'] = FLAGS.data_dir
  192. global_config['Train']['dataset']['data_dir'] = FLAGS.data_dir
  193. data_dir = os.path.join(project_path, FLAGS.data_dir)
  194. loss_file(data_dir)
  195. with open(save_file_path, 'w') as f:
  196. yaml.dump(
  197. dict(global_config), f, default_flow_style=False, sort_keys=False)
  198. logging.info("Project path is :{}".format(project_path))
  199. logging.info("Train list path set to :{}".format(global_config['Train'][
  200. 'dataset']['label_file_list'][0]))
  201. logging.info("Eval list path set to :{}".format(global_config['Eval'][
  202. 'dataset']['label_file_list'][0]))
  203. logging.info("Dataset root path set to :{}".format(global_config['Eval'][
  204. 'dataset']['data_dir']))
  205. logging.info("Dict path set to :{}".format(global_config['Global'][
  206. 'character_dict_path']))
  207. logging.info("Config file set to :configs/rec/multi_language/{}".
  208. format(save_file_path))