Browse Source

上传文件至 ''

DocumentAI Python SDK半自动标注工具
zhuzezhou 1 year ago
parent
commit
1c33a64373
2 changed files with 207 additions and 0 deletions
  1. 112 0
      auto_add_label.py
  2. 95 0
      auto_label.py

+ 112 - 0
auto_add_label.py

@@ -0,0 +1,112 @@
+import da_python as dap
+import cv2
+import argparse
+import os
+import glob
+from PIL import Image
+import copy
+import json
+from tqdm import tqdm
+
+def traverse_folder(path, image_paths):
+    # 获取当前文件夹下的所有文件和子文件夹
+    for file in os.listdir(path):
+        # 获取文件路径
+        file_path = os.path.join(path, file)
+
+        # 判断文件类型,如果是图片则加入列表
+        if os.path.isfile(file_path) and os.path.splitext(file_path)[1].lower() in allowed_extensions:
+            image_paths.append(file_path)
+        # 如果是文件夹,则递归调用本函数
+        elif os.path.isdir(file_path):
+            traverse_folder(file_path, image_paths)
+
+def add_labeling(model, key, image_paths, set_score, view_result_dir):
+    da, create_result = dap.create(key, model)
+    if create_result != dap.E_DA_SUCCESS:
+        print('create document ai failed:{}'.format(create_result))
+        exit()
+    detect_engine, a = da.detection()
+    # print(a)
+
+    # 存放信息
+    labelme_json = {}
+    shape_dic = {}
+
+    for image_path in tqdm(image_paths):
+        image = cv2.imread(image_path)
+        layout_analysis_result, layout_analysis_result_code = detect_engine.layout_analysis(image)
+        if layout_analysis_result_code == dap.E_DA_SUCCESS:
+            # 判断是否检测到目标
+            if len(layout_analysis_result.boxes):
+                # print(dir(layout_analysis_result))
+                # print(layout_analysis_result.boxes)
+                # print(layout_analysis_result.labels)
+                # print(layout_analysis_result.scores)
+                # labels_list = layout_analysis_result.labels
+                # box_list = layout_analysis_result.boxes
+                # score_list = layout_analysis_result.scores
+                json_path = image_path.replace(os.path.splitext(image_path)[1], '.json')
+                add_json_dic = {}
+                # 当json文件不存在,但是模型又检测到目标时,创建json文件
+                if not os.path.isfile(json_path):
+                    labelme_json['version'] = '5.0.1'
+                    labelme_json['flags'] = {}
+                    labelme_json['shapes'] = []
+                    labelme_json['imagePath'] = os.path.basename(image_path)
+                    labelme_json['imageData'] = None
+                    labelme_json['imageHeight'], labelme_json['imageWidth'] = image.shape[0], image.shape[1]
+                    labelme_json['updated_by'] = None
+                    json.dump(labelme_json, open(json_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=2)
+                with open(json_path, 'r', encoding='utf-8') as jf:
+                    json_info = json.load(jf)
+                    # 根据需求(忽略人工修正后的json文件)
+                    if json_info['updated_by']:
+                        print('没有添加')
+                        continue
+                    # 过滤低置信度标签,添加目标标签
+                    for index, score in enumerate(layout_analysis_result.scores):
+                        # print(index, score)
+                        # 过滤低置信度标签
+                        if score < set_score:
+                            continue
+                        # print('检测到')
+                        # 添加目标标签
+                        if layout_analysis_result.labels[index] in ['Table_0', 'Table_std', 'Figure']:
+                            continue
+                        shape_dic['label'] = layout_analysis_result.labels[index]
+                        box = layout_analysis_result.boxes[index]
+                        x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
+                        shape_dic['points'] = [[x1, y1], [x2, y2]]
+                        shape_dic['group_id'] = None
+                        shape_dic['shape_type'] = 'rectangle'  # 根据需求进行填写
+                        shape_dic['flags'] = {}
+                        json_info['shapes'].append(copy.deepcopy(shape_dic))
+                    json_info['updated_by'] = 3
+                    add_json_dic = json_info
+                jf.close()
+                json.dump(add_json_dic, open(json_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=2)
+            visualize_im = dap.visualize.detection(image, layout_analysis_result)
+            cv2.imwrite(os.path.join(view_result_dir, os.path.basename(image_path)), visualize_im)
+
+        else:
+          print('magic color infer failed:{}'.format(layout_analysis_result_code))
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--model', type=str, default='', help='')
+    parser.add_argument('--model_licence', type=str, default='', help='')
+    parser.add_argument('--score', type=int, default=0.7, help='')
+    parser.add_argument('--image_dir', type=str, default='', help='')
+    parser.add_argument('--view_result_dir', type=str, default='', help='')
+    args = parser.parse_args()
+
+    # 定义允许的图片格式
+    allowed_extensions = ['.jpg', '.jpeg', '.png', '.gif']
+    # 初始化图片路径列表
+    image_paths = []
+    # 调用遍历函数
+    traverse_folder(args.image_dir, image_paths)
+
+    add_labeling(args.model, args.model_key, image_paths, args.score, args.view_result_dir)

+ 95 - 0
auto_label.py

@@ -0,0 +1,95 @@
+import da_python as dap
+import cv2
+import argparse
+import os
+import glob
+from PIL import Image
+import copy
+import json
+from tqdm import tqdm
+
+def traverse_folder(path, image_paths):
+    # 获取当前文件夹下的所有文件和子文件夹
+    for file in os.listdir(path):
+        # 获取文件路径
+        file_path = os.path.join(path, file)
+
+        # 判断文件类型,如果是图片则加入列表
+        if os.path.isfile(file_path) and os.path.splitext(file_path)[1].lower() in allowed_extensions:
+            image_paths.append(file_path)
+        # 如果是文件夹,则递归调用本函数
+        elif os.path.isdir(file_path):
+            traverse_folder(file_path, image_paths)
+
+def labeling(model, key, image_paths, set_score, view_result_dir):
+    da, create_result = dap.create(key, model)
+    if create_result != dap.E_DA_SUCCESS:
+        print('create document ai failed:{}'.format(create_result))
+        exit()
+    detect_engine, a = da.detection()
+    # print(a)
+
+    # 存放json信息
+    labelme_json = {}
+    shape_dic = {}
+
+    for image_path in tqdm(image_paths):
+        image = cv2.imread(image_path)
+        layout_analysis_result, layout_analysis_result_code = detect_engine.layout_analysis(image)
+        if layout_analysis_result_code == dap.E_DA_SUCCESS:
+            # 判断是否检测到目标
+            if len(layout_analysis_result.boxes):
+                # print(dir(layout_analysis_result))
+                # print(layout_analysis_result.boxes)
+                # print(layout_analysis_result.labels)
+                # print(layout_analysis_result.scores)
+                # labels_list = layout_analysis_result.labels
+                # box_list = layout_analysis_result.boxes
+                # score_list = layout_analysis_result.scores
+                shapes_list = []
+                labelme_json['version'] = '5.0.1'
+                labelme_json['flags'] = {}
+                # 根据分数进行过滤
+                for index, score in enumerate(layout_analysis_result.scores):
+                    # print(index, score)
+                    if score < set_score:
+                        continue
+                    # print('检测到')
+                    shape_dic['label'] = layout_analysis_result.labels[index]
+                    box = layout_analysis_result.boxes[index]
+                    x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
+                    shape_dic['points'] = [[x1, y1], [x2, y2]]
+                    shape_dic['group_id'] = None
+                    shape_dic['shape_type'] = 'rectangle'  # 根据需求进行填写
+                    shape_dic['flags'] = {}
+                    shapes_list.append(copy.deepcopy(shape_dic))
+                labelme_json['shapes'] = shapes_list
+                labelme_json['imagePath'] = os.path.basename(image_path)
+                labelme_json['imageData'] = None
+                labelme_json['imageHeight'], labelme_json['imageWidth'] = image.shape[0], image.shape[1]
+                json_path = image_path.replace(os.path.splitext(image_path)[1], '.json')
+                json.dump(labelme_json, open(json_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=2)
+            visualize_im = dap.visualize.detection(image, layout_analysis_result)
+            cv2.imwrite(os.path.join(view_result_dir, os.path.basename(image_path)), visualize_im)
+
+        else:
+          print('magic color infer failed:{}'.format(layout_analysis_result_code))
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--model', type=str, default='', help='')
+    parser.add_argument('--model_licence', type=str, default='', help='')
+    parser.add_argument('--score', type=int, default=0.8, help='')
+    parser.add_argument('--image_dir', type=str, default='', help='')
+    parser.add_argument('--view_result_dir', type=str, default='', help='')
+    args = parser.parse_args()
+
+    # 定义允许的图片格式
+    allowed_extensions = ['.jpg', '.jpeg', '.png', '.gif']
+    # 初始化图片路径列表
+    image_paths = []
+    # 调用遍历函数
+    traverse_folder(args.image_dir, image_paths)
+
+    labeling(args.model, args.model_key, image_paths, args.score, args.view_result_dir)