get_ocr_boxes.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import datetime
  2. import os.path
  3. import uuid
  4. import cv2
  5. from paddleocr import PaddleOCR, draw_ocr
  6. import argparse
  7. from PIL import Image
  8. # 例如`ch`, `en`, `fr`, `german`, `korean`, `japan`
  9. from tqdm import tqdm
  10. def get_boxes(img_path, language):
  11. ocr = PaddleOCR(use_angle_cls=True, lang=language) # need to run only once to download and load model into memory
  12. result = ocr.ocr(img_path, cls=True)
  13. # for line in idcard_imgs:
  14. # print(line)
  15. #
  16. for line in result:
  17. # 左上
  18. line[0][0][0] -= 5
  19. line[0][0][1] -= 5
  20. # 右上
  21. line[0][1][0] += 5
  22. line[0][1][1] -= 5
  23. # 右下
  24. line[0][2][0] += 5
  25. line[0][2][1] += 5
  26. # 左下
  27. line[0][3][0] -= 5
  28. line[0][3][1] += 5
  29. # 显示结果
  30. image = Image.open(img_path).convert('RGB')
  31. boxes = [line[0] for line in result]
  32. texts = [line[1][0] for line in result]
  33. scores = [line[1][1] for line in result]
  34. im_show = draw_ocr(image, boxes, texts, scores, font_path='./fonts/STSONG.TTF')
  35. im_show = Image.fromarray(im_show)
  36. im_show.save('result_lisa_ch.jpg')
  37. return [line[0] for line in result]
  38. def image_cut_save(img_path, boxes, save_dir, cnt):
  39. if not os.path.exists(save_dir):
  40. os.makedirs(save_dir)
  41. img = cv2.imread(img_path) # 打开图像
  42. for box in tqdm(boxes):
  43. upper = int(box[0][1])
  44. lower = int(box[2][1])
  45. left = int(box[0][0])
  46. right = int(box[2][0])
  47. cropped = img[upper:lower, left:right]
  48. time = datetime.datetime.now()
  49. t = str(time.year).zfill(4) + '-' + str(time.month).zfill(2) + '-' + str(time.day).zfill(2)
  50. save_path = save_dir + '/' + 'kdan' + '_' + t + '_' + str(cnt).zfill(4) + '_' + str(uuid.uuid1())[0:8] + '.jpg'
  51. cv2.imwrite(save_path, cropped)
  52. cnt += 1
  53. def get_all_img(img_dir):
  54. result = []
  55. img_paths = os.listdir(img_dir)
  56. for it in img_paths:
  57. temp = os.path.join(img_dir, it)
  58. result.append(temp)
  59. return result
  60. if __name__ == '__main__':
  61. parser = argparse.ArgumentParser()
  62. parser.add_argument('--img_path', type=str, default='idcard_lisa.png')
  63. parser.add_argument('--img_dir', type=str, default='')
  64. parser.add_argument('--language', type=str, default='chinese_cht')
  65. parser.add_argument('--save_dir', type=str, default='./idcard_imgs')
  66. args = parser.parse_args()
  67. cnt = 1
  68. if args.img_dir == '':
  69. boxes = get_boxes(args.img_path, args.language)
  70. # image_cut_save(args.img_path, boxes, args.save_dir, cnt)
  71. else:
  72. img_paths = get_all_img(args.img_dir)
  73. for img_path in tqdm(img_paths):
  74. boxes = get_boxes(img_path, args.language)
  75. image_cut_save(img_path, boxes, args.save_dir, cnt)