spilit_data.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import json
  2. import os
  3. import random
  4. from pycocotools.coco import COCO
  5. import argparse
  6. import shutil
  7. def ensure_coco_save_dir(output):
  8. train_dir = os.path.join(output, 'train')
  9. if not os.path.exists(train_dir):
  10. os.makedirs(train_dir)
  11. train_image_dir = os.path.join(output, 'train', 'images')
  12. if not os.path.exists(train_image_dir):
  13. os.makedirs(train_image_dir)
  14. val_dir = os.path.join(output, 'val')
  15. if not os.path.exists(val_dir):
  16. os.makedirs(val_dir)
  17. val_image_dir = os.path.join(output, 'val', 'images')
  18. if not os.path.exists(val_image_dir):
  19. os.makedirs(val_image_dir)
  20. def ensure_yolo_save_dir(output):
  21. train_dir = os.path.join(output, 'train')
  22. if not os.path.exists(train_dir):
  23. os.makedirs(train_dir)
  24. train_image_dir = os.path.join(output, 'train', 'images')
  25. if not os.path.exists(train_image_dir):
  26. os.makedirs(train_image_dir)
  27. train_label_dir = os.path.join(output, 'train', 'labels')
  28. if not os.path.exists(train_label_dir):
  29. os.makedirs(train_label_dir)
  30. test_dir = os.path.join(output, 'val')
  31. if not os.path.exists(test_dir):
  32. os.makedirs(test_dir)
  33. test_image_dir = os.path.join(output, 'val', 'images')
  34. if not os.path.exists(test_image_dir):
  35. os.makedirs(test_image_dir)
  36. test_label_dir = os.path.join(output, 'val', 'labels')
  37. if not os.path.exists(test_label_dir):
  38. os.makedirs(test_label_dir)
  39. def process_coco(coco_json, image_dir, train_ratio, output):
  40. # 加载 COCO 数据集
  41. coco = COCO(coco_json)
  42. # 获取所有图像的 ID
  43. image_ids = coco.getImgIds()
  44. # 随机划分训练集和测试集(例如,80% 为训练集,20% 为测试集)
  45. train_image_ids = random.sample(image_ids, int(len(image_ids) * train_ratio))
  46. val_image_ids = list(set(image_ids) - set(train_image_ids))
  47. # 为训练集和测试集创建新的 JSON 文件
  48. train_json = {'images': [], 'annotations': [], 'categories': coco.dataset['categories'], 'info': coco.dataset['info']}
  49. val_json = {'images': [], 'annotations': [], 'categories': coco.dataset['categories'], 'info': coco.dataset['info']}
  50. train_dir = os.path.join(output, 'train')
  51. val_dir = os.path.join(output, 'val')
  52. # 遍历图像 ID,并将它们分配到训练集或测试集
  53. for image_id in image_ids:
  54. image_info = coco.loadImgs(image_id)[0]
  55. file_name = os.path.join(image_dir, image_info['file_name'])
  56. # 检查图像文件是否存在(可选)
  57. if os.path.exists(file_name):
  58. if image_id in train_image_ids:
  59. train_json['images'].append(image_info)
  60. train_image_path = os.path.join(train_dir, image_info['file_name'])
  61. # 还需要获取并添加相关的标注信息到 train_json['annotations']
  62. annot_list = [annotation for annotation in coco.dataset['annotations'] if
  63. annotation['image_id'] == image_id]
  64. train_json['annotations'].extend(annot_list)
  65. shutil.copy2(file_name, train_image_path)
  66. else:
  67. val_json['images'].append(image_info)
  68. val_image_path = os.path.join(val_dir, image_info['file_name'])
  69. # 还需要获取并添加相关的标注信息到 val_json['annotations']
  70. annot_list = [annotation for annotation in coco.dataset['annotations'] if
  71. annotation['image_id'] == image_id]
  72. val_json['annotations'].extend(annot_list)
  73. shutil.copy2(file_name, val_image_path)
  74. # 保存新的 JSON 文件(这里省略了保存 annotations 的部分)
  75. train_json_path = os.path.join(output, 'train', 'coco_annotations.json')
  76. val_json_path = os.path.join(output, 'val', 'coco_annotations.json')
  77. with open(train_json_path, 'w') as f:
  78. json.dump(train_json, f, indent=2)
  79. with open(val_json_path, 'w') as f:
  80. json.dump(val_json, f, indent=2)
  81. def process_yolo(image_dir, train_ratio, output):
  82. images_dir = os.path.join(image_dir, 'images')
  83. labels_dir = os.path.join(image_dir, 'labels')
  84. image_files = os.listdir(images_dir)
  85. train_files = random.sample(image_files, int(len(image_files) * train_ratio))
  86. val_files = list(set(image_files) - set(train_files))
  87. train_images_dir = os.path.join(output, 'train', 'images')
  88. train_labels_dir = os.path.join(output, 'train', 'labels')
  89. val_images_dir = os.path.join(output, 'val', 'images')
  90. val_labels_dir = os.path.join(output, 'val', 'labels')
  91. for file_name in train_files:
  92. shutil.copy2(os.path.join(images_dir, file_name), train_images_dir)
  93. label_file_name = file_name.replace('.jpg', '.txt')
  94. shutil.copy2(os.path.join(labels_dir, label_file_name), train_labels_dir)
  95. for file_name in val_files:
  96. shutil.copy2(os.path.join(images_dir, file_name), val_images_dir)
  97. label_file_name = file_name.replace('.jpg', '.txt')
  98. shutil.copy2(os.path.join(labels_dir, label_file_name), val_labels_dir)
  99. # 复制 image_dir 下的其他文件到 output
  100. for item in os.listdir(image_dir):
  101. item_path = os.path.join(image_dir, item)
  102. if os.path.isfile(item_path) and item not in ['images', 'labels']:
  103. shutil.copy2(item_path, output)
  104. # 主程序入口,处理命令行参数
  105. if __name__ == '__main__':
  106. parser = argparse.ArgumentParser("处理数据集\n")
  107. parser.add_argument("--coco_json", help="COCO 格式的 JSON 文件路径")
  108. parser.add_argument("image_dir", help="图片文件夹路径")
  109. parser.add_argument("type", help="数据标注格式类型,支持 'coco' 或 'yolo'")
  110. parser.add_argument("--ratio", default=0.8, type=float, help="训练集的比例,默认是 0.8")
  111. parser.add_argument("output", help="输出路径")
  112. args = parser.parse_args()
  113. # 根据数据标注格式类型调用不同的处理函数
  114. if args.type == 'coco':
  115. ensure_coco_save_dir(args.output)
  116. process_coco(args.coco_json, args.image_dir, args.ratio, args.output)
  117. elif args.type == 'yolo':
  118. ensure_yolo_save_dir(args.output)
  119. process_yolo(args.image_dir, args.ratio, args.output)