Browse Source

删除 'auto_label.py'

zhuzezhou 1 year ago
parent
commit
7ae76c841f
1 changed files with 0 additions and 95 deletions
  1. 0 95
      auto_label.py

+ 0 - 95
auto_label.py

@@ -1,95 +0,0 @@
-import da_python as dap
-import cv2
-import argparse
-import os
-import glob
-from PIL import Image
-import copy
-import json
-from tqdm import tqdm
-
-def traverse_folder(path, image_paths):
-    # 获取当前文件夹下的所有文件和子文件夹
-    for file in os.listdir(path):
-        # 获取文件路径
-        file_path = os.path.join(path, file)
-
-        # 判断文件类型,如果是图片则加入列表
-        if os.path.isfile(file_path) and os.path.splitext(file_path)[1].lower() in allowed_extensions:
-            image_paths.append(file_path)
-        # 如果是文件夹,则递归调用本函数
-        elif os.path.isdir(file_path):
-            traverse_folder(file_path, image_paths)
-
-def labeling(model, key, image_paths, set_score, view_result_dir):
-    da, create_result = dap.create(key, model)
-    if create_result != dap.E_DA_SUCCESS:
-        print('create document ai failed:{}'.format(create_result))
-        exit()
-    detect_engine, a = da.detection()
-    # print(a)
-
-    # 存放json信息
-    labelme_json = {}
-    shape_dic = {}
-
-    for image_path in tqdm(image_paths):
-        image = cv2.imread(image_path)
-        layout_analysis_result, layout_analysis_result_code = detect_engine.layout_analysis(image)
-        if layout_analysis_result_code == dap.E_DA_SUCCESS:
-            # 判断是否检测到目标
-            if len(layout_analysis_result.boxes):
-                # print(dir(layout_analysis_result))
-                # print(layout_analysis_result.boxes)
-                # print(layout_analysis_result.labels)
-                # print(layout_analysis_result.scores)
-                # labels_list = layout_analysis_result.labels
-                # box_list = layout_analysis_result.boxes
-                # score_list = layout_analysis_result.scores
-                shapes_list = []
-                labelme_json['version'] = '5.0.1'
-                labelme_json['flags'] = {}
-                # 根据分数进行过滤
-                for index, score in enumerate(layout_analysis_result.scores):
-                    # print(index, score)
-                    if score < set_score:
-                        continue
-                    # print('检测到')
-                    shape_dic['label'] = layout_analysis_result.labels[index]
-                    box = layout_analysis_result.boxes[index]
-                    x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
-                    shape_dic['points'] = [[x1, y1], [x2, y2]]
-                    shape_dic['group_id'] = None
-                    shape_dic['shape_type'] = 'rectangle'  # 根据需求进行填写
-                    shape_dic['flags'] = {}
-                    shapes_list.append(copy.deepcopy(shape_dic))
-                labelme_json['shapes'] = shapes_list
-                labelme_json['imagePath'] = os.path.basename(image_path)
-                labelme_json['imageData'] = None
-                labelme_json['imageHeight'], labelme_json['imageWidth'] = image.shape[0], image.shape[1]
-                json_path = image_path.replace(os.path.splitext(image_path)[1], '.json')
-                json.dump(labelme_json, open(json_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=2)
-            visualize_im = dap.visualize.detection(image, layout_analysis_result)
-            cv2.imwrite(os.path.join(view_result_dir, os.path.basename(image_path)), visualize_im)
-
-        else:
-          print('magic color infer failed:{}'.format(layout_analysis_result_code))
-
-
-if __name__ == '__main__':
-    parser = argparse.ArgumentParser()
-    parser.add_argument('--model', type=str, default='', help='')
-    parser.add_argument('--model_licence', type=str, default='', help='')
-    parser.add_argument('--score', type=int, default=0.8, help='')
-    parser.add_argument('--image_dir', type=str, default='', help='')
-    parser.add_argument('--view_result_dir', type=str, default='', help='')
-    args = parser.parse_args()
-
-    # 定义允许的图片格式
-    allowed_extensions = ['.jpg', '.jpeg', '.png', '.gif']
-    # 初始化图片路径列表
-    image_paths = []
-    # 调用遍历函数
-    traverse_folder(args.image_dir, image_paths)
-
-    labeling(args.model, args.model_key, image_paths, args.score, args.view_result_dir)