Parcourir la source

[feat] 数据划分脚本

WangChao il y a 5 mois
Parent
commit
8b3b52d27a
1 fichiers modifiés avec 91 ajouts et 0 suppressions
  1. 91 0
      data_collection/spilit_data.py

+ 91 - 0
data_collection/spilit_data.py

@@ -0,0 +1,91 @@
+import json
+import os
+import random
+from pycocotools.coco import COCO
+import argparse
+import shutil
+
+
+def ensure_save_dir(output):
+    train_dir = os.path.join(output, 'train')
+    if not os.path.exists(train_dir):
+        os.makedirs(train_dir)
+
+    train_image_dir = os.path.join(output, 'train', 'images')
+    if not os.path.exists(train_image_dir):
+        os.makedirs(train_image_dir)
+
+    test_dir = os.path.join(output, 'test')
+    if not os.path.exists(test_dir):
+        os.makedirs(test_dir)
+
+    test_image_dir = os.path.join(output, 'test', 'images')
+    if not os.path.exists(test_image_dir):
+        os.makedirs(test_image_dir)
+
+
+def process_coco(coco_json, image_dir, train_ratio, output):
+    # 加载 COCO 数据集
+    coco = COCO(coco_json)
+
+    # 获取所有图像的 ID
+    image_ids = coco.getImgIds()
+
+    # 随机划分训练集和测试集(例如,80% 为训练集,20% 为测试集)
+    train_image_ids = random.sample(image_ids, int(len(image_ids) * train_ratio))
+    test_image_ids = list(set(image_ids) - set(train_image_ids))
+
+    # 为训练集和测试集创建新的 JSON 文件
+    train_json = {'images': [], 'annotations': [], 'categories': coco.dataset['categories'], 'info': coco.dataset['info']}
+    test_json = {'images': [], 'annotations': [], 'categories': coco.dataset['categories'], 'info': coco.dataset['info']}
+
+    train_dir = os.path.join(output, 'train')
+    test_dir = os.path.join(output, 'test')
+    # 遍历图像 ID,并将它们分配到训练集或测试集
+    for image_id in image_ids:
+        image_info = coco.loadImgs(image_id)[0]
+        file_name = os.path.join(image_dir, image_info['file_name'])
+
+        # 检查图像文件是否存在(可选)
+        if os.path.exists(file_name):
+            if image_id in train_image_ids:
+                train_json['images'].append(image_info)
+                train_image_path = os.path.join(train_dir, image_info['file_name'])
+
+                # 还需要获取并添加相关的标注信息到 train_json['annotations']
+                annot_list = [annotation for annotation in coco.dataset['annotations'] if
+                              annotation['image_id'] == image_id]
+                train_json['annotations'].extend(annot_list)
+                shutil.copy2(file_name, train_image_path)
+            else:
+                test_json['images'].append(image_info)
+                test_image_path = os.path.join(test_dir, image_info['file_name'])
+
+                # 还需要获取并添加相关的标注信息到 test_json['annotations']
+                annot_list = [annotation for annotation in coco.dataset['annotations'] if
+                              annotation['image_id'] == image_id]
+                test_json['annotations'].extend(annot_list)
+                shutil.copy2(file_name, test_image_path)
+
+    # 保存新的 JSON 文件(这里省略了保存 annotations 的部分)
+    train_json_path = os.path.join(output, 'train', 'coco_annotations.json')
+    test_json_path = os.path.join(output, 'test', 'coco_annotations.json')
+
+    with open(train_json_path, 'w') as f:
+        json.dump(train_json, f, indent=2)
+    with open(test_json_path, 'w') as f:
+        json.dump(test_json, f, indent=2)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser("删除特定标注信息\n")
+    parser.add_argument("coco_json", help="json文件路径")
+    parser.add_argument("image_dir", help="图片文件夹路径")
+    parser.add_argument("type", help="数据标注格式类型,目前支持coco")
+    parser.add_argument("--ratio", default=0.8, help="比例,默认是0.8")
+    parser.add_argument("output", help="输出路径")
+
+    args = parser.parse_args()
+    if args.type == 'coco':
+        ensure_save_dir(args.output)
+        process_coco(args.coco_json, args.image_dir, args.ratio, args.output)