auto_add_label.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import da_python as dap
  2. import cv2
  3. import argparse
  4. import os
  5. import glob
  6. from PIL import Image
  7. import copy
  8. import json
  9. from tqdm import tqdm
  10. def traverse_folder(path, image_paths):
  11. # 获取当前文件夹下的所有文件和子文件夹
  12. for file in os.listdir(path):
  13. # 获取文件路径
  14. file_path = os.path.join(path, file)
  15. # 判断文件类型,如果是图片则加入列表
  16. if os.path.isfile(file_path) and os.path.splitext(file_path)[1].lower() in allowed_extensions:
  17. image_paths.append(file_path)
  18. # 如果是文件夹,则递归调用本函数
  19. elif os.path.isdir(file_path):
  20. traverse_folder(file_path, image_paths)
  21. def add_labeling(model, key, image_paths, set_score, view_result_dir):
  22. da, create_result = dap.create(key, model)
  23. if create_result != dap.E_DA_SUCCESS:
  24. print('create document ai failed:{}'.format(create_result))
  25. exit()
  26. detect_engine, a = da.detection()
  27. # print(a)
  28. # 存放信息
  29. labelme_json = {}
  30. shape_dic = {}
  31. for image_path in tqdm(image_paths):
  32. image = cv2.imread(image_path)
  33. layout_analysis_result, layout_analysis_result_code = detect_engine.layout_analysis(image)
  34. if layout_analysis_result_code == dap.E_DA_SUCCESS:
  35. # 判断是否检测到目标
  36. if len(layout_analysis_result.boxes):
  37. # print(dir(layout_analysis_result))
  38. # print(layout_analysis_result.boxes)
  39. # print(layout_analysis_result.labels)
  40. # print(layout_analysis_result.scores)
  41. # labels_list = layout_analysis_result.labels
  42. # box_list = layout_analysis_result.boxes
  43. # score_list = layout_analysis_result.scores
  44. json_path = image_path.replace(os.path.splitext(image_path)[1], '.json')
  45. add_json_dic = {}
  46. # 当json文件不存在,但是模型又检测到目标时,创建json文件
  47. if not os.path.isfile(json_path):
  48. labelme_json['version'] = '5.0.1'
  49. labelme_json['flags'] = {}
  50. labelme_json['shapes'] = []
  51. labelme_json['imagePath'] = os.path.basename(image_path)
  52. labelme_json['imageData'] = None
  53. labelme_json['imageHeight'], labelme_json['imageWidth'] = image.shape[0], image.shape[1]
  54. labelme_json['updated_by'] = None
  55. json.dump(labelme_json, open(json_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=2)
  56. with open(json_path, 'r', encoding='utf-8') as jf:
  57. json_info = json.load(jf)
  58. # 根据需求(忽略人工修正后的json文件)
  59. if json_info['updated_by']:
  60. print('没有添加')
  61. continue
  62. # 过滤低置信度标签,添加目标标签
  63. for index, score in enumerate(layout_analysis_result.scores):
  64. # print(index, score)
  65. # 过滤低置信度标签
  66. if score < set_score:
  67. continue
  68. # print('检测到')
  69. # 添加目标标签
  70. if layout_analysis_result.labels[index] in ['Table_0', 'Table_std', 'Figure']:
  71. continue
  72. shape_dic['label'] = layout_analysis_result.labels[index]
  73. box = layout_analysis_result.boxes[index]
  74. x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
  75. shape_dic['points'] = [[x1, y1], [x2, y2]]
  76. shape_dic['group_id'] = None
  77. shape_dic['shape_type'] = 'rectangle' # 根据需求进行填写
  78. shape_dic['flags'] = {}
  79. json_info['shapes'].append(copy.deepcopy(shape_dic))
  80. json_info['updated_by'] = 3
  81. add_json_dic = json_info
  82. jf.close()
  83. json.dump(add_json_dic, open(json_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=2)
  84. visualize_im = dap.visualize.detection(image, layout_analysis_result)
  85. cv2.imwrite(os.path.join(view_result_dir, os.path.basename(image_path)), visualize_im)
  86. else:
  87. print('magic color infer failed:{}'.format(layout_analysis_result_code))
  88. if __name__ == '__main__':
  89. parser = argparse.ArgumentParser()
  90. parser.add_argument('--model', type=str, default='', help='')
  91. parser.add_argument('--model_licence', type=str, default='', help='')
  92. parser.add_argument('--score', type=int, default=0.7, help='')
  93. parser.add_argument('--image_dir', type=str, default='', help='')
  94. parser.add_argument('--view_result_dir', type=str, default='', help='')
  95. args = parser.parse_args()
  96. # 定义允许的图片格式
  97. allowed_extensions = ['.jpg', '.jpeg', '.png', '.gif']
  98. # 初始化图片路径列表
  99. image_paths = []
  100. # 调用遍历函数
  101. traverse_folder(args.image_dir, image_paths)
  102. add_labeling(args.model, args.model_key, image_paths, args.score, args.view_result_dir)