utility.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import os
  16. import imghdr
  17. import cv2
  18. import random
  19. import numpy as np
  20. import paddle
  21. import importlib.util
  22. import sys
  23. import subprocess
  24. def print_dict(d, logger, delimiter=0):
  25. """
  26. Recursively visualize a dict and
  27. indenting acrrording by the relationship of keys.
  28. """
  29. for k, v in sorted(d.items()):
  30. if isinstance(v, dict):
  31. logger.info("{}{} : ".format(delimiter * " ", str(k)))
  32. print_dict(v, logger, delimiter + 4)
  33. elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
  34. logger.info("{}{} : ".format(delimiter * " ", str(k)))
  35. for value in v:
  36. print_dict(value, logger, delimiter + 4)
  37. else:
  38. logger.info("{}{} : {}".format(delimiter * " ", k, v))
  39. def get_check_global_params(mode):
  40. check_params = ['use_gpu', 'max_text_length', 'image_shape', \
  41. 'image_shape', 'character_type', 'loss_type']
  42. if mode == "train_eval":
  43. check_params = check_params + [ \
  44. 'train_batch_size_per_card', 'test_batch_size_per_card']
  45. elif mode == "test":
  46. check_params = check_params + ['test_batch_size_per_card']
  47. return check_params
  48. def _check_image_file(path):
  49. img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'pdf'}
  50. return any([path.lower().endswith(e) for e in img_end])
  51. def get_image_file_list(img_file):
  52. imgs_lists = []
  53. if img_file is None or not os.path.exists(img_file):
  54. raise Exception("not found any img file in {}".format(img_file))
  55. img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'pdf'}
  56. if os.path.isfile(img_file) and _check_image_file(img_file):
  57. imgs_lists.append(img_file)
  58. elif os.path.isdir(img_file):
  59. for single_file in os.listdir(img_file):
  60. file_path = os.path.join(img_file, single_file)
  61. if os.path.isfile(file_path) and _check_image_file(file_path):
  62. imgs_lists.append(file_path)
  63. if len(imgs_lists) == 0:
  64. raise Exception("not found any img file in {}".format(img_file))
  65. imgs_lists = sorted(imgs_lists)
  66. return imgs_lists
  67. def check_and_read(img_path):
  68. if os.path.basename(img_path)[-3:] in ['gif', 'GIF']:
  69. gif = cv2.VideoCapture(img_path)
  70. ret, frame = gif.read()
  71. if not ret:
  72. logger = logging.getLogger('ppocr')
  73. logger.info("Cannot read {}. This gif image maybe corrupted.")
  74. return None, False
  75. if len(frame.shape) == 2 or frame.shape[-1] == 1:
  76. frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
  77. imgvalue = frame[:, :, ::-1]
  78. return imgvalue, True, False
  79. elif os.path.basename(img_path)[-3:] in ['pdf']:
  80. import fitz
  81. from PIL import Image
  82. imgs = []
  83. with fitz.open(img_path) as pdf:
  84. for pg in range(0, pdf.pageCount):
  85. page = pdf[pg]
  86. mat = fitz.Matrix(2, 2)
  87. pm = page.getPixmap(matrix=mat, alpha=False)
  88. # if width or height > 2000 pixels, don't enlarge the image
  89. if pm.width > 2000 or pm.height > 2000:
  90. pm = page.getPixmap(matrix=fitz.Matrix(1, 1), alpha=False)
  91. img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples)
  92. img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
  93. imgs.append(img)
  94. return imgs, False, True
  95. return None, False, False
  96. def load_vqa_bio_label_maps(label_map_path):
  97. with open(label_map_path, "r", encoding='utf-8') as fin:
  98. lines = fin.readlines()
  99. old_lines = [line.strip() for line in lines]
  100. lines = ["O"]
  101. for line in old_lines:
  102. # "O" has already been in lines
  103. if line.upper() in ["OTHER", "OTHERS", "IGNORE"]:
  104. continue
  105. lines.append(line)
  106. labels = ["O"]
  107. for line in lines[1:]:
  108. labels.append("B-" + line)
  109. labels.append("I-" + line)
  110. label2id_map = {label.upper(): idx for idx, label in enumerate(labels)}
  111. id2label_map = {idx: label.upper() for idx, label in enumerate(labels)}
  112. return label2id_map, id2label_map
  113. def set_seed(seed=1024):
  114. random.seed(seed)
  115. np.random.seed(seed)
  116. paddle.seed(seed)
  117. def check_install(module_name, install_name):
  118. spec = importlib.util.find_spec(module_name)
  119. if spec is None:
  120. print(f'Warnning! The {module_name} module is NOT installed')
  121. print(
  122. f'Try install {module_name} module automatically. You can also try to install manually by pip install {install_name}.'
  123. )
  124. python = sys.executable
  125. try:
  126. subprocess.check_call(
  127. [python, '-m', 'pip', 'install', install_name],
  128. stdout=subprocess.DEVNULL)
  129. print(f'The {module_name} module is now installed')
  130. except subprocess.CalledProcessError as exc:
  131. raise Exception(
  132. f"Install {module_name} failed, please install manually")
  133. else:
  134. print(f"{module_name} has been installed.")
  135. class AverageMeter:
  136. def __init__(self):
  137. self.reset()
  138. def reset(self):
  139. """reset"""
  140. self.val = 0
  141. self.avg = 0
  142. self.sum = 0
  143. self.count = 0
  144. def update(self, val, n=1):
  145. """update"""
  146. self.val = val
  147. self.sum += val * n
  148. self.count += n
  149. self.avg = self.sum / self.count