corpus_generators.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import random
  15. from utils.logging import get_logger
  16. class FileCorpus(object):
  17. def __init__(self, config):
  18. self.logger = get_logger()
  19. self.logger.info("using FileCorpus")
  20. self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
  21. corpus_file = config["CorpusGenerator"]["corpus_file"]
  22. self.language = config["CorpusGenerator"]["language"]
  23. with open(corpus_file, 'r') as f:
  24. corpus_raw = f.read()
  25. self.corpus_list = corpus_raw.split("\n")[:-1]
  26. assert len(self.corpus_list) > 0
  27. random.shuffle(self.corpus_list)
  28. self.index = 0
  29. def generate(self, corpus_length=0):
  30. if self.index >= len(self.corpus_list):
  31. self.index = 0
  32. random.shuffle(self.corpus_list)
  33. corpus = self.corpus_list[self.index]
  34. if corpus_length != 0:
  35. corpus = corpus[0:corpus_length]
  36. if corpus_length > len(corpus):
  37. self.logger.warning("generated corpus is shorter than expected.")
  38. self.index += 1
  39. return self.language, corpus
  40. class EnNumCorpus(object):
  41. def __init__(self, config):
  42. self.logger = get_logger()
  43. self.logger.info("using NumberCorpus")
  44. self.num_list = "0123456789"
  45. self.en_char_list = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
  46. self.height = config["Global"]["image_height"]
  47. self.max_width = config["Global"]["image_width"]
  48. def generate(self, corpus_length=0):
  49. corpus = ""
  50. if corpus_length == 0:
  51. corpus_length = random.randint(5, 15)
  52. for i in range(corpus_length):
  53. if random.random() < 0.2:
  54. corpus += "{}".format(random.choice(self.en_char_list))
  55. else:
  56. corpus += "{}".format(random.choice(self.num_list))
  57. return "en", corpus