import json import os import random from pycocotools.coco import COCO import argparse import shutil def ensure_coco_save_dir(output): train_dir = os.path.join(output, 'train') if not os.path.exists(train_dir): os.makedirs(train_dir) train_image_dir = os.path.join(output, 'train', 'images') if not os.path.exists(train_image_dir): os.makedirs(train_image_dir) val_dir = os.path.join(output, 'val') if not os.path.exists(val_dir): os.makedirs(val_dir) val_image_dir = os.path.join(output, 'val', 'images') if not os.path.exists(val_image_dir): os.makedirs(val_image_dir) def ensure_yolo_save_dir(output): train_dir = os.path.join(output, 'train') if not os.path.exists(train_dir): os.makedirs(train_dir) train_image_dir = os.path.join(output, 'train', 'images') if not os.path.exists(train_image_dir): os.makedirs(train_image_dir) train_label_dir = os.path.join(output, 'train', 'labels') if not os.path.exists(train_label_dir): os.makedirs(train_label_dir) test_dir = os.path.join(output, 'val') if not os.path.exists(test_dir): os.makedirs(test_dir) test_image_dir = os.path.join(output, 'val', 'images') if not os.path.exists(test_image_dir): os.makedirs(test_image_dir) test_label_dir = os.path.join(output, 'val', 'labels') if not os.path.exists(test_label_dir): os.makedirs(test_label_dir) def process_coco(coco_json, image_dir, train_ratio, output): # 加载 COCO 数据集 coco = COCO(coco_json) # 获取所有图像的 ID image_ids = coco.getImgIds() # 随机划分训练集和测试集(例如,80% 为训练集,20% 为测试集) train_image_ids = random.sample(image_ids, int(len(image_ids) * train_ratio)) val_image_ids = list(set(image_ids) - set(train_image_ids)) # 为训练集和测试集创建新的 JSON 文件 train_json = {'images': [], 'annotations': [], 'categories': coco.dataset['categories'], 'info': coco.dataset['info']} val_json = {'images': [], 'annotations': [], 'categories': coco.dataset['categories'], 'info': coco.dataset['info']} train_dir = os.path.join(output, 'train') val_dir = os.path.join(output, 'val') # 遍历图像 ID,并将它们分配到训练集或测试集 for image_id in image_ids: image_info = coco.loadImgs(image_id)[0] file_name = os.path.join(image_dir, image_info['file_name']) # 检查图像文件是否存在(可选) if os.path.exists(file_name): if image_id in train_image_ids: train_json['images'].append(image_info) train_image_path = os.path.join(train_dir, image_info['file_name']) # 还需要获取并添加相关的标注信息到 train_json['annotations'] annot_list = [annotation for annotation in coco.dataset['annotations'] if annotation['image_id'] == image_id] train_json['annotations'].extend(annot_list) shutil.copy2(file_name, train_image_path) else: val_json['images'].append(image_info) val_image_path = os.path.join(val_dir, image_info['file_name']) # 还需要获取并添加相关的标注信息到 val_json['annotations'] annot_list = [annotation for annotation in coco.dataset['annotations'] if annotation['image_id'] == image_id] val_json['annotations'].extend(annot_list) shutil.copy2(file_name, val_image_path) # 保存新的 JSON 文件(这里省略了保存 annotations 的部分) train_json_path = os.path.join(output, 'train', 'coco_annotations.json') val_json_path = os.path.join(output, 'val', 'coco_annotations.json') with open(train_json_path, 'w') as f: json.dump(train_json, f, indent=2) with open(val_json_path, 'w') as f: json.dump(val_json, f, indent=2) def process_yolo(image_dir, train_ratio, output): images_dir = os.path.join(image_dir, 'images') labels_dir = os.path.join(image_dir, 'labels') image_files = os.listdir(images_dir) train_files = random.sample(image_files, int(len(image_files) * train_ratio)) val_files = list(set(image_files) - set(train_files)) train_images_dir = os.path.join(output, 'train', 'images') train_labels_dir = os.path.join(output, 'train', 'labels') val_images_dir = os.path.join(output, 'val', 'images') val_labels_dir = os.path.join(output, 'val', 'labels') for file_name in train_files: shutil.copy2(os.path.join(images_dir, file_name), train_images_dir) label_file_name = file_name.replace('.jpg', '.txt') shutil.copy2(os.path.join(labels_dir, label_file_name), train_labels_dir) for file_name in val_files: shutil.copy2(os.path.join(images_dir, file_name), val_images_dir) label_file_name = file_name.replace('.jpg', '.txt') shutil.copy2(os.path.join(labels_dir, label_file_name), val_labels_dir) # 复制 image_dir 下的其他文件到 output for item in os.listdir(image_dir): item_path = os.path.join(image_dir, item) if os.path.isfile(item_path) and item not in ['images', 'labels']: shutil.copy2(item_path, output) # 主程序入口,处理命令行参数 if __name__ == '__main__': parser = argparse.ArgumentParser("处理数据集\n") parser.add_argument("--coco_json", help="COCO 格式的 JSON 文件路径") parser.add_argument("image_dir", help="图片文件夹路径") parser.add_argument("type", help="数据标注格式类型,支持 'coco' 或 'yolo'") parser.add_argument("--ratio", default=0.8, type=float, help="训练集的比例,默认是 0.8") parser.add_argument("output", help="输出路径") args = parser.parse_args() # 根据数据标注格式类型调用不同的处理函数 if args.type == 'coco': ensure_coco_save_dir(args.output) process_coco(args.coco_json, args.image_dir, args.ratio, args.output) elif args.type == 'yolo': ensure_yolo_save_dir(args.output) process_yolo(args.image_dir, args.ratio, args.output)