auto_label.py 4.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import da_python as dap
  2. import cv2
  3. import argparse
  4. import os
  5. import glob
  6. from PIL import Image
  7. import copy
  8. import json
  9. from tqdm import tqdm
  10. def traverse_folder(path, image_paths):
  11. # 获取当前文件夹下的所有文件和子文件夹
  12. for file in os.listdir(path):
  13. # 获取文件路径
  14. file_path = os.path.join(path, file)
  15. # 判断文件类型,如果是图片则加入列表
  16. if os.path.isfile(file_path) and os.path.splitext(file_path)[1].lower() in allowed_extensions:
  17. image_paths.append(file_path)
  18. # 如果是文件夹,则递归调用本函数
  19. elif os.path.isdir(file_path):
  20. traverse_folder(file_path, image_paths)
  21. def labeling(model, key, image_paths, set_score, view_result_dir):
  22. da, create_result = dap.create(key, model)
  23. if create_result != dap.E_DA_SUCCESS:
  24. print('create document ai failed:{}'.format(create_result))
  25. exit()
  26. detect_engine, a = da.detection()
  27. # print(a)
  28. # 存放json信息
  29. labelme_json = {}
  30. shape_dic = {}
  31. for image_path in tqdm(image_paths):
  32. image = cv2.imread(image_path)
  33. layout_analysis_result, layout_analysis_result_code = detect_engine.layout_analysis(image)
  34. if layout_analysis_result_code == dap.E_DA_SUCCESS:
  35. # 判断是否检测到目标
  36. if len(layout_analysis_result.boxes):
  37. # print(dir(layout_analysis_result))
  38. # print(layout_analysis_result.boxes)
  39. # print(layout_analysis_result.labels)
  40. # print(layout_analysis_result.scores)
  41. # labels_list = layout_analysis_result.labels
  42. # box_list = layout_analysis_result.boxes
  43. # score_list = layout_analysis_result.scores
  44. shapes_list = []
  45. labelme_json['version'] = '5.0.1'
  46. labelme_json['flags'] = {}
  47. # 根据分数进行过滤
  48. for index, score in enumerate(layout_analysis_result.scores):
  49. # print(index, score)
  50. if score < set_score:
  51. continue
  52. # print('检测到')
  53. shape_dic['label'] = layout_analysis_result.labels[index]
  54. box = layout_analysis_result.boxes[index]
  55. x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
  56. shape_dic['points'] = [[x1, y1], [x2, y2]]
  57. shape_dic['group_id'] = None
  58. shape_dic['shape_type'] = 'rectangle' # 根据需求进行填写
  59. shape_dic['flags'] = {}
  60. shapes_list.append(copy.deepcopy(shape_dic))
  61. labelme_json['shapes'] = shapes_list
  62. labelme_json['imagePath'] = os.path.basename(image_path)
  63. labelme_json['imageData'] = None
  64. labelme_json['imageHeight'], labelme_json['imageWidth'] = image.shape[0], image.shape[1]
  65. json_path = image_path.replace(os.path.splitext(image_path)[1], '.json')
  66. json.dump(labelme_json, open(json_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=2)
  67. visualize_im = dap.visualize.detection(image, layout_analysis_result)
  68. cv2.imwrite(os.path.join(view_result_dir, os.path.basename(image_path)), visualize_im)
  69. else:
  70. print('magic color infer failed:{}'.format(layout_analysis_result_code))
  71. if __name__ == '__main__':
  72. parser = argparse.ArgumentParser()
  73. parser.add_argument('--model', type=str, default='', help='')
  74. parser.add_argument('--model_licence', type=str, default='', help='')
  75. parser.add_argument('--score', type=int, default=0.8, help='')
  76. parser.add_argument('--image_dir', type=str, default='', help='')
  77. parser.add_argument('--view_result_dir', type=str, default='', help='')
  78. args = parser.parse_args()
  79. # 定义允许的图片格式
  80. allowed_extensions = ['.jpg', '.jpeg', '.png', '.gif']
  81. # 初始化图片路径列表
  82. image_paths = []
  83. # 调用遍历函数
  84. traverse_folder(args.image_dir, image_paths)
  85. labeling(args.model, args.model_key, image_paths, args.score, args.view_result_dir)