utility.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import os
  16. import sys
  17. import platform
  18. import cv2
  19. import numpy as np
  20. import paddle
  21. from PIL import Image, ImageDraw, ImageFont
  22. import math
  23. from paddle import inference
  24. import time
  25. import random
  26. from ppocr.utils.logging import get_logger
  27. def str2bool(v):
  28. return v.lower() in ("true", "t", "1")
  29. def init_args():
  30. parser = argparse.ArgumentParser()
  31. # params for prediction engine
  32. parser.add_argument("--use_gpu", type=str2bool, default=True)
  33. parser.add_argument("--use_xpu", type=str2bool, default=False)
  34. parser.add_argument("--use_npu", type=str2bool, default=False)
  35. parser.add_argument("--ir_optim", type=str2bool, default=True)
  36. parser.add_argument("--use_tensorrt", type=str2bool, default=False)
  37. parser.add_argument("--min_subgraph_size", type=int, default=15)
  38. parser.add_argument("--precision", type=str, default="fp32")
  39. parser.add_argument("--gpu_mem", type=int, default=500)
  40. parser.add_argument("--gpu_id", type=int, default=0)
  41. # params for text detector
  42. parser.add_argument("--image_dir", type=str)
  43. parser.add_argument("--page_num", type=int, default=0)
  44. parser.add_argument("--det_algorithm", type=str, default='DB')
  45. parser.add_argument("--det_model_dir", type=str)
  46. parser.add_argument("--det_limit_side_len", type=float, default=960)
  47. parser.add_argument("--det_limit_type", type=str, default='max')
  48. parser.add_argument("--det_box_type", type=str, default='quad')
  49. # DB parmas
  50. parser.add_argument("--det_db_thresh", type=float, default=0.3)
  51. parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
  52. parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
  53. parser.add_argument("--max_batch_size", type=int, default=10)
  54. parser.add_argument("--use_dilation", type=str2bool, default=False)
  55. parser.add_argument("--det_db_score_mode", type=str, default="fast")
  56. # EAST parmas
  57. parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
  58. parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
  59. parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
  60. # SAST parmas
  61. parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
  62. parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
  63. # PSE parmas
  64. parser.add_argument("--det_pse_thresh", type=float, default=0)
  65. parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
  66. parser.add_argument("--det_pse_min_area", type=float, default=16)
  67. parser.add_argument("--det_pse_scale", type=int, default=1)
  68. # FCE parmas
  69. parser.add_argument("--scales", type=list, default=[8, 16, 32])
  70. parser.add_argument("--alpha", type=float, default=1.0)
  71. parser.add_argument("--beta", type=float, default=1.0)
  72. parser.add_argument("--fourier_degree", type=int, default=5)
  73. # params for text recognizer
  74. parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
  75. parser.add_argument("--rec_model_dir", type=str)
  76. parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
  77. parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
  78. parser.add_argument("--rec_batch_num", type=int, default=6)
  79. parser.add_argument("--max_text_length", type=int, default=25)
  80. parser.add_argument(
  81. "--rec_char_dict_path",
  82. type=str,
  83. default="./ppocr/utils/ppocr_keys_v1.txt")
  84. parser.add_argument("--use_space_char", type=str2bool, default=True)
  85. parser.add_argument(
  86. "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
  87. parser.add_argument("--drop_score", type=float, default=0.5)
  88. # params for e2e
  89. parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
  90. parser.add_argument("--e2e_model_dir", type=str)
  91. parser.add_argument("--e2e_limit_side_len", type=float, default=768)
  92. parser.add_argument("--e2e_limit_type", type=str, default='max')
  93. # PGNet parmas
  94. parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
  95. parser.add_argument(
  96. "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
  97. parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
  98. parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
  99. # params for text classifier
  100. parser.add_argument("--use_angle_cls", type=str2bool, default=False)
  101. parser.add_argument("--cls_model_dir", type=str)
  102. parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
  103. parser.add_argument("--label_list", type=list, default=['0', '180'])
  104. parser.add_argument("--cls_batch_num", type=int, default=6)
  105. parser.add_argument("--cls_thresh", type=float, default=0.9)
  106. parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
  107. parser.add_argument("--cpu_threads", type=int, default=10)
  108. parser.add_argument("--use_pdserving", type=str2bool, default=False)
  109. parser.add_argument("--warmup", type=str2bool, default=False)
  110. # SR parmas
  111. parser.add_argument("--sr_model_dir", type=str)
  112. parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128")
  113. parser.add_argument("--sr_batch_num", type=int, default=1)
  114. #
  115. parser.add_argument(
  116. "--draw_img_save_dir", type=str, default="./inference_results")
  117. parser.add_argument("--save_crop_res", type=str2bool, default=False)
  118. parser.add_argument("--crop_res_save_dir", type=str, default="./output")
  119. # multi-process
  120. parser.add_argument("--use_mp", type=str2bool, default=False)
  121. parser.add_argument("--total_process_num", type=int, default=1)
  122. parser.add_argument("--process_id", type=int, default=0)
  123. parser.add_argument("--benchmark", type=str2bool, default=False)
  124. parser.add_argument("--save_log_path", type=str, default="./log_output/")
  125. parser.add_argument("--show_log", type=str2bool, default=True)
  126. parser.add_argument("--use_onnx", type=str2bool, default=False)
  127. return parser
  128. def parse_args():
  129. parser = init_args()
  130. return parser.parse_args()
  131. def create_predictor(args, mode, logger):
  132. if mode == "det":
  133. model_dir = args.det_model_dir
  134. elif mode == 'cls':
  135. model_dir = args.cls_model_dir
  136. elif mode == 'rec':
  137. model_dir = args.rec_model_dir
  138. elif mode == 'table':
  139. model_dir = args.table_model_dir
  140. elif mode == 'ser':
  141. model_dir = args.ser_model_dir
  142. elif mode == 're':
  143. model_dir = args.re_model_dir
  144. elif mode == "sr":
  145. model_dir = args.sr_model_dir
  146. elif mode == 'layout':
  147. model_dir = args.layout_model_dir
  148. else:
  149. model_dir = args.e2e_model_dir
  150. if model_dir is None:
  151. logger.info("not find {} model file path {}".format(mode, model_dir))
  152. sys.exit(0)
  153. if args.use_onnx:
  154. import onnxruntime as ort
  155. model_file_path = model_dir
  156. if not os.path.exists(model_file_path):
  157. raise ValueError("not find model file path {}".format(
  158. model_file_path))
  159. sess = ort.InferenceSession(model_file_path)
  160. return sess, sess.get_inputs()[0], None, None
  161. else:
  162. file_names = ['model', 'inference']
  163. for file_name in file_names:
  164. model_file_path = '{}/{}.pdmodel'.format(model_dir, file_name)
  165. params_file_path = '{}/{}.pdiparams'.format(model_dir, file_name)
  166. if os.path.exists(model_file_path) and os.path.exists(
  167. params_file_path):
  168. break
  169. if not os.path.exists(model_file_path):
  170. raise ValueError(
  171. "not find model.pdmodel or inference.pdmodel in {}".format(
  172. model_dir))
  173. if not os.path.exists(params_file_path):
  174. raise ValueError(
  175. "not find model.pdiparams or inference.pdiparams in {}".format(
  176. model_dir))
  177. config = inference.Config(model_file_path, params_file_path)
  178. if hasattr(args, 'precision'):
  179. if args.precision == "fp16" and args.use_tensorrt:
  180. precision = inference.PrecisionType.Half
  181. elif args.precision == "int8":
  182. precision = inference.PrecisionType.Int8
  183. else:
  184. precision = inference.PrecisionType.Float32
  185. else:
  186. precision = inference.PrecisionType.Float32
  187. if args.use_gpu:
  188. gpu_id = get_infer_gpuid()
  189. if gpu_id is None:
  190. logger.warning(
  191. "GPU is not found in current device by nvidia-smi. Please check your device or ignore it if run on jetson."
  192. )
  193. config.enable_use_gpu(args.gpu_mem, args.gpu_id)
  194. if args.use_tensorrt:
  195. config.enable_tensorrt_engine(
  196. workspace_size=1 << 30,
  197. precision_mode=precision,
  198. max_batch_size=args.max_batch_size,
  199. min_subgraph_size=args.
  200. min_subgraph_size, # skip the minmum trt subgraph
  201. use_calib_mode=False)
  202. # collect shape
  203. trt_shape_f = os.path.join(model_dir,
  204. f"{mode}_trt_dynamic_shape.txt")
  205. if not os.path.exists(trt_shape_f):
  206. config.collect_shape_range_info(trt_shape_f)
  207. logger.info(
  208. f"collect dynamic shape info into : {trt_shape_f}")
  209. try:
  210. config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f,
  211. True)
  212. except Exception as E:
  213. logger.info(E)
  214. logger.info("Please keep your paddlepaddle-gpu >= 2.3.0!")
  215. elif args.use_npu:
  216. config.enable_npu()
  217. elif args.use_xpu:
  218. config.enable_xpu(10 * 1024 * 1024)
  219. else:
  220. config.disable_gpu()
  221. if args.enable_mkldnn:
  222. # cache 10 different shapes for mkldnn to avoid memory leak
  223. config.set_mkldnn_cache_capacity(10)
  224. config.enable_mkldnn()
  225. if args.precision == "fp16":
  226. config.enable_mkldnn_bfloat16()
  227. if hasattr(args, "cpu_threads"):
  228. config.set_cpu_math_library_num_threads(args.cpu_threads)
  229. else:
  230. # default cpu threads as 10
  231. config.set_cpu_math_library_num_threads(10)
  232. # enable memory optim
  233. config.enable_memory_optim()
  234. config.disable_glog_info()
  235. config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
  236. config.delete_pass("matmul_transpose_reshape_fuse_pass")
  237. if mode == 're':
  238. config.delete_pass("simplify_with_basic_ops_pass")
  239. if mode == 'table':
  240. config.delete_pass("fc_fuse_pass") # not supported for table
  241. config.switch_use_feed_fetch_ops(False)
  242. config.switch_ir_optim(True)
  243. # create predictor
  244. predictor = inference.create_predictor(config)
  245. input_names = predictor.get_input_names()
  246. if mode in ['ser', 're']:
  247. input_tensor = []
  248. for name in input_names:
  249. input_tensor.append(predictor.get_input_handle(name))
  250. else:
  251. for name in input_names:
  252. input_tensor = predictor.get_input_handle(name)
  253. output_tensors = get_output_tensors(args, mode, predictor)
  254. return predictor, input_tensor, output_tensors, config
  255. def get_output_tensors(args, mode, predictor):
  256. output_names = predictor.get_output_names()
  257. output_tensors = []
  258. if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet"]:
  259. output_name = 'softmax_0.tmp_0'
  260. if output_name in output_names:
  261. return [predictor.get_output_handle(output_name)]
  262. else:
  263. for output_name in output_names:
  264. output_tensor = predictor.get_output_handle(output_name)
  265. output_tensors.append(output_tensor)
  266. else:
  267. for output_name in output_names:
  268. output_tensor = predictor.get_output_handle(output_name)
  269. output_tensors.append(output_tensor)
  270. return output_tensors
  271. def get_infer_gpuid():
  272. sysstr = platform.system()
  273. if sysstr == "Windows":
  274. return 0
  275. if not paddle.fluid.core.is_compiled_with_rocm():
  276. cmd = "env | grep CUDA_VISIBLE_DEVICES"
  277. else:
  278. cmd = "env | grep HIP_VISIBLE_DEVICES"
  279. env_cuda = os.popen(cmd).readlines()
  280. if len(env_cuda) == 0:
  281. return 0
  282. else:
  283. gpu_id = env_cuda[0].strip().split("=")[1]
  284. return int(gpu_id[0])
  285. def draw_e2e_res(dt_boxes, strs, img_path):
  286. src_im = cv2.imread(img_path)
  287. for box, str in zip(dt_boxes, strs):
  288. box = box.astype(np.int32).reshape((-1, 1, 2))
  289. cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
  290. cv2.putText(
  291. src_im,
  292. str,
  293. org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
  294. fontFace=cv2.FONT_HERSHEY_COMPLEX,
  295. fontScale=0.7,
  296. color=(0, 255, 0),
  297. thickness=1)
  298. return src_im
  299. def draw_text_det_res(dt_boxes, img):
  300. for box in dt_boxes:
  301. box = np.array(box).astype(np.int32).reshape(-1, 2)
  302. cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
  303. return img
  304. def resize_img(img, input_size=600):
  305. """
  306. resize img and limit the longest side of the image to input_size
  307. """
  308. img = np.array(img)
  309. im_shape = img.shape
  310. im_size_max = np.max(im_shape[0:2])
  311. im_scale = float(input_size) / float(im_size_max)
  312. img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
  313. return img
  314. def draw_ocr(image,
  315. boxes,
  316. txts=None,
  317. scores=None,
  318. drop_score=0.5,
  319. font_path="./doc/fonts/simfang.ttf"):
  320. """
  321. Visualize the results of OCR detection and recognition
  322. args:
  323. image(Image|array): RGB image
  324. boxes(list): boxes with shape(N, 4, 2)
  325. txts(list): the texts
  326. scores(list): txxs corresponding scores
  327. drop_score(float): only scores greater than drop_threshold will be visualized
  328. font_path: the path of font which is used to draw text
  329. return(array):
  330. the visualized img
  331. """
  332. if scores is None:
  333. scores = [1] * len(boxes)
  334. box_num = len(boxes)
  335. for i in range(box_num):
  336. if scores is not None and (scores[i] < drop_score or
  337. math.isnan(scores[i])):
  338. continue
  339. box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
  340. image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  341. if txts is not None:
  342. img = np.array(resize_img(image, input_size=600))
  343. txt_img = text_visual(
  344. txts,
  345. scores,
  346. img_h=img.shape[0],
  347. img_w=600,
  348. threshold=drop_score,
  349. font_path=font_path)
  350. img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
  351. return img
  352. return image
  353. def draw_ocr_box_txt(image,
  354. boxes,
  355. txts=None,
  356. scores=None,
  357. drop_score=0.5,
  358. font_path="./doc/fonts/simfang.ttf"):
  359. h, w = image.height, image.width
  360. img_left = image.copy()
  361. img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
  362. random.seed(0)
  363. draw_left = ImageDraw.Draw(img_left)
  364. if txts is None or len(txts) != len(boxes):
  365. txts = [None] * len(boxes)
  366. for idx, (box, txt) in enumerate(zip(boxes, txts)):
  367. if scores is not None and scores[idx] < drop_score:
  368. continue
  369. color = (random.randint(0, 255), random.randint(0, 255),
  370. random.randint(0, 255))
  371. draw_left.polygon(box, fill=color)
  372. img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
  373. pts = np.array(box, np.int32).reshape((-1, 1, 2))
  374. cv2.polylines(img_right_text, [pts], True, color, 1)
  375. img_right = cv2.bitwise_and(img_right, img_right_text)
  376. img_left = Image.blend(image, img_left, 0.5)
  377. img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
  378. img_show.paste(img_left, (0, 0, w, h))
  379. img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
  380. return np.array(img_show)
  381. def draw_box_txt_fine(img_size, box, txt, font_path="./doc/fonts/simfang.ttf"):
  382. box_height = int(
  383. math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][1])**2))
  384. box_width = int(
  385. math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][1])**2))
  386. if box_height > 2 * box_width and box_height > 30:
  387. img_text = Image.new('RGB', (box_height, box_width), (255, 255, 255))
  388. draw_text = ImageDraw.Draw(img_text)
  389. if txt:
  390. font = create_font(txt, (box_height, box_width), font_path)
  391. draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
  392. img_text = img_text.transpose(Image.ROTATE_270)
  393. else:
  394. img_text = Image.new('RGB', (box_width, box_height), (255, 255, 255))
  395. draw_text = ImageDraw.Draw(img_text)
  396. if txt:
  397. font = create_font(txt, (box_width, box_height), font_path)
  398. draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
  399. pts1 = np.float32(
  400. [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]])
  401. pts2 = np.array(box, dtype=np.float32)
  402. M = cv2.getPerspectiveTransform(pts1, pts2)
  403. img_text = np.array(img_text, dtype=np.uint8)
  404. img_right_text = cv2.warpPerspective(
  405. img_text,
  406. M,
  407. img_size,
  408. flags=cv2.INTER_NEAREST,
  409. borderMode=cv2.BORDER_CONSTANT,
  410. borderValue=(255, 255, 255))
  411. return img_right_text
  412. def create_font(txt, sz, font_path="./doc/fonts/simfang.ttf"):
  413. font_size = int(sz[1] * 0.99)
  414. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  415. length = font.getsize(txt)[0]
  416. if length > sz[0]:
  417. font_size = int(font_size * sz[0] / length)
  418. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  419. return font
  420. def str_count(s):
  421. """
  422. Count the number of Chinese characters,
  423. a single English character and a single number
  424. equal to half the length of Chinese characters.
  425. args:
  426. s(string): the input of string
  427. return(int):
  428. the number of Chinese characters
  429. """
  430. import string
  431. count_zh = count_pu = 0
  432. s_len = len(s)
  433. en_dg_count = 0
  434. for c in s:
  435. if c in string.ascii_letters or c.isdigit() or c.isspace():
  436. en_dg_count += 1
  437. elif c.isalpha():
  438. count_zh += 1
  439. else:
  440. count_pu += 1
  441. return s_len - math.ceil(en_dg_count / 2)
  442. def text_visual(texts,
  443. scores,
  444. img_h=400,
  445. img_w=600,
  446. threshold=0.,
  447. font_path="./doc/simfang.ttf"):
  448. """
  449. create new blank img and draw txt on it
  450. args:
  451. texts(list): the text will be draw
  452. scores(list|None): corresponding score of each txt
  453. img_h(int): the height of blank img
  454. img_w(int): the width of blank img
  455. font_path: the path of font which is used to draw text
  456. return(array):
  457. """
  458. if scores is not None:
  459. assert len(texts) == len(
  460. scores), "The number of txts and corresponding scores must match"
  461. def create_blank_img():
  462. blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
  463. blank_img[:, img_w - 1:] = 0
  464. blank_img = Image.fromarray(blank_img).convert("RGB")
  465. draw_txt = ImageDraw.Draw(blank_img)
  466. return blank_img, draw_txt
  467. blank_img, draw_txt = create_blank_img()
  468. font_size = 20
  469. txt_color = (0, 0, 0)
  470. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  471. gap = font_size + 5
  472. txt_img_list = []
  473. count, index = 1, 0
  474. for idx, txt in enumerate(texts):
  475. index += 1
  476. if scores[idx] < threshold or math.isnan(scores[idx]):
  477. index -= 1
  478. continue
  479. first_line = True
  480. while str_count(txt) >= img_w // font_size - 4:
  481. tmp = txt
  482. txt = tmp[:img_w // font_size - 4]
  483. if first_line:
  484. new_txt = str(index) + ': ' + txt
  485. first_line = False
  486. else:
  487. new_txt = ' ' + txt
  488. draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
  489. txt = tmp[img_w // font_size - 4:]
  490. if count >= img_h // gap - 1:
  491. txt_img_list.append(np.array(blank_img))
  492. blank_img, draw_txt = create_blank_img()
  493. count = 0
  494. count += 1
  495. if first_line:
  496. new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx])
  497. else:
  498. new_txt = " " + txt + " " + '%.3f' % (scores[idx])
  499. draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
  500. # whether add new blank img or not
  501. if count >= img_h // gap - 1 and idx + 1 < len(texts):
  502. txt_img_list.append(np.array(blank_img))
  503. blank_img, draw_txt = create_blank_img()
  504. count = 0
  505. count += 1
  506. txt_img_list.append(np.array(blank_img))
  507. if len(txt_img_list) == 1:
  508. blank_img = np.array(txt_img_list[0])
  509. else:
  510. blank_img = np.concatenate(txt_img_list, axis=1)
  511. return np.array(blank_img)
  512. def base64_to_cv2(b64str):
  513. import base64
  514. data = base64.b64decode(b64str.encode('utf8'))
  515. data = np.frombuffer(data, np.uint8)
  516. data = cv2.imdecode(data, cv2.IMREAD_COLOR)
  517. return data
  518. def draw_boxes(image, boxes, scores=None, drop_score=0.5):
  519. if scores is None:
  520. scores = [1] * len(boxes)
  521. for (box, score) in zip(boxes, scores):
  522. if score < drop_score:
  523. continue
  524. box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
  525. image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  526. return image
  527. def get_rotate_crop_image(img, points):
  528. '''
  529. img_height, img_width = img.shape[0:2]
  530. left = int(np.min(points[:, 0]))
  531. right = int(np.max(points[:, 0]))
  532. top = int(np.min(points[:, 1]))
  533. bottom = int(np.max(points[:, 1]))
  534. img_crop = img[top:bottom, left:right, :].copy()
  535. points[:, 0] = points[:, 0] - left
  536. points[:, 1] = points[:, 1] - top
  537. '''
  538. assert len(points) == 4, "shape of points must be 4*2"
  539. img_crop_width = int(
  540. max(
  541. np.linalg.norm(points[0] - points[1]),
  542. np.linalg.norm(points[2] - points[3])))
  543. img_crop_height = int(
  544. max(
  545. np.linalg.norm(points[0] - points[3]),
  546. np.linalg.norm(points[1] - points[2])))
  547. pts_std = np.float32([[0, 0], [img_crop_width, 0],
  548. [img_crop_width, img_crop_height],
  549. [0, img_crop_height]])
  550. M = cv2.getPerspectiveTransform(points, pts_std)
  551. dst_img = cv2.warpPerspective(
  552. img,
  553. M, (img_crop_width, img_crop_height),
  554. borderMode=cv2.BORDER_REPLICATE,
  555. flags=cv2.INTER_CUBIC)
  556. dst_img_height, dst_img_width = dst_img.shape[0:2]
  557. if dst_img_height * 1.0 / dst_img_width >= 1.5:
  558. dst_img = np.rot90(dst_img)
  559. return dst_img
  560. def get_minarea_rect_crop(img, points):
  561. bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
  562. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  563. index_a, index_b, index_c, index_d = 0, 1, 2, 3
  564. if points[1][1] > points[0][1]:
  565. index_a = 0
  566. index_d = 1
  567. else:
  568. index_a = 1
  569. index_d = 0
  570. if points[3][1] > points[2][1]:
  571. index_b = 2
  572. index_c = 3
  573. else:
  574. index_b = 3
  575. index_c = 2
  576. box = [points[index_a], points[index_b], points[index_c], points[index_d]]
  577. crop_img = get_rotate_crop_image(img, np.array(box))
  578. return crop_img
  579. def check_gpu(use_gpu):
  580. if use_gpu and not paddle.is_compiled_with_cuda():
  581. use_gpu = False
  582. return use_gpu
  583. if __name__ == '__main__':
  584. pass