Parcourir la source

[feat] 划分数据集新增 yolo 格式

WangChao il y a 3 mois
Parent
commit
f9c72fca8d
1 fichiers modifiés avec 81 ajouts et 20 suppressions
  1. 81 20
      data_collection/spilit_data.py

+ 81 - 20
data_collection/spilit_data.py

@@ -6,7 +6,7 @@ import argparse
 import shutil
 
 
-def ensure_save_dir(output):
+def ensure_coco_save_dir(output):
     train_dir = os.path.join(output, 'train')
     if not os.path.exists(train_dir):
         os.makedirs(train_dir)
@@ -15,14 +15,40 @@ def ensure_save_dir(output):
     if not os.path.exists(train_image_dir):
         os.makedirs(train_image_dir)
 
-    test_dir = os.path.join(output, 'test')
+    val_dir = os.path.join(output, 'val')
+    if not os.path.exists(val_dir):
+        os.makedirs(val_dir)
+
+    val_image_dir = os.path.join(output, 'val', 'images')
+    if not os.path.exists(val_image_dir):
+        os.makedirs(val_image_dir)
+
+
+def ensure_yolo_save_dir(output):
+    train_dir = os.path.join(output, 'train')
+    if not os.path.exists(train_dir):
+        os.makedirs(train_dir)
+
+    train_image_dir = os.path.join(output, 'train', 'images')
+    if not os.path.exists(train_image_dir):
+        os.makedirs(train_image_dir)
+
+    train_label_dir = os.path.join(output, 'train', 'labels')
+    if not os.path.exists(train_label_dir):
+        os.makedirs(train_label_dir)
+
+    test_dir = os.path.join(output, 'val')
     if not os.path.exists(test_dir):
         os.makedirs(test_dir)
 
-    test_image_dir = os.path.join(output, 'test', 'images')
+    test_image_dir = os.path.join(output, 'val', 'images')
     if not os.path.exists(test_image_dir):
         os.makedirs(test_image_dir)
 
+    test_label_dir = os.path.join(output, 'val', 'labels')
+    if not os.path.exists(test_label_dir):
+        os.makedirs(test_label_dir)
+
 
 def process_coco(coco_json, image_dir, train_ratio, output):
     # 加载 COCO 数据集
@@ -33,14 +59,14 @@ def process_coco(coco_json, image_dir, train_ratio, output):
 
     # 随机划分训练集和测试集(例如,80% 为训练集,20% 为测试集)
     train_image_ids = random.sample(image_ids, int(len(image_ids) * train_ratio))
-    test_image_ids = list(set(image_ids) - set(train_image_ids))
+    val_image_ids = list(set(image_ids) - set(train_image_ids))
 
     # 为训练集和测试集创建新的 JSON 文件
     train_json = {'images': [], 'annotations': [], 'categories': coco.dataset['categories'], 'info': coco.dataset['info']}
-    test_json = {'images': [], 'annotations': [], 'categories': coco.dataset['categories'], 'info': coco.dataset['info']}
+    val_json = {'images': [], 'annotations': [], 'categories': coco.dataset['categories'], 'info': coco.dataset['info']}
 
     train_dir = os.path.join(output, 'train')
-    test_dir = os.path.join(output, 'test')
+    val_dir = os.path.join(output, 'val')
     # 遍历图像 ID,并将它们分配到训练集或测试集
     for image_id in image_ids:
         image_info = coco.loadImgs(image_id)[0]
@@ -58,34 +84,69 @@ def process_coco(coco_json, image_dir, train_ratio, output):
                 train_json['annotations'].extend(annot_list)
                 shutil.copy2(file_name, train_image_path)
             else:
-                test_json['images'].append(image_info)
-                test_image_path = os.path.join(test_dir, image_info['file_name'])
+                val_json['images'].append(image_info)
+                val_image_path = os.path.join(val_dir, image_info['file_name'])
 
-                # 还需要获取并添加相关的标注信息到 test_json['annotations']
+                # 还需要获取并添加相关的标注信息到 val_json['annotations']
                 annot_list = [annotation for annotation in coco.dataset['annotations'] if
                               annotation['image_id'] == image_id]
-                test_json['annotations'].extend(annot_list)
-                shutil.copy2(file_name, test_image_path)
+                val_json['annotations'].extend(annot_list)
+                shutil.copy2(file_name, val_image_path)
 
     # 保存新的 JSON 文件(这里省略了保存 annotations 的部分)
     train_json_path = os.path.join(output, 'train', 'coco_annotations.json')
-    test_json_path = os.path.join(output, 'test', 'coco_annotations.json')
+    val_json_path = os.path.join(output, 'val', 'coco_annotations.json')
 
     with open(train_json_path, 'w') as f:
         json.dump(train_json, f, indent=2)
-    with open(test_json_path, 'w') as f:
-        json.dump(test_json, f, indent=2)
+    with open(val_json_path, 'w') as f:
+        json.dump(val_json, f, indent=2)
+
+
+def process_yolo(image_dir, train_ratio, output):
+    images_dir = os.path.join(image_dir, 'images')
+    labels_dir = os.path.join(image_dir, 'labels')
+    image_files = os.listdir(images_dir)
+    train_files = random.sample(image_files, int(len(image_files) * train_ratio))
+    val_files = list(set(image_files) - set(train_files))
 
+    train_images_dir = os.path.join(output, 'train', 'images')
+    train_labels_dir = os.path.join(output, 'train', 'labels')
+    val_images_dir = os.path.join(output, 'val', 'images')
+    val_labels_dir = os.path.join(output, 'val', 'labels')
 
+    for file_name in train_files:
+        shutil.copy2(os.path.join(images_dir, file_name), train_images_dir)
+        label_file_name = file_name.replace('.jpg', '.txt')
+        shutil.copy2(os.path.join(labels_dir, label_file_name), train_labels_dir)
+
+    for file_name in val_files:
+        shutil.copy2(os.path.join(images_dir, file_name), val_images_dir)
+        label_file_name = file_name.replace('.jpg', '.txt')
+        shutil.copy2(os.path.join(labels_dir, label_file_name), val_labels_dir)
+
+    # 复制 image_dir 下的其他文件到 output
+    for item in os.listdir(image_dir):
+        item_path = os.path.join(image_dir, item)
+        if os.path.isfile(item_path) and item not in ['images', 'labels']:
+            shutil.copy2(item_path, output)
+
+
+# 主程序入口,处理命令行参数
 if __name__ == '__main__':
-    parser = argparse.ArgumentParser("删除特定标注信息\n")
-    parser.add_argument("coco_json", help="json文件路径")
+    parser = argparse.ArgumentParser("处理数据集\n")
+    parser.add_argument("--coco_json", help="COCO 格式的 JSON 文件路径")
     parser.add_argument("image_dir", help="图片文件夹路径")
-    parser.add_argument("type", help="数据标注格式类型,目前支持coco")
-    parser.add_argument("--ratio", default=0.8, help="比例,默认是0.8")
+    parser.add_argument("type", help="数据标注格式类型,支持 'coco' 或 'yolo'")
+    parser.add_argument("--ratio", default=0.8, type=float, help="训练集的比例,默认是 0.8")
     parser.add_argument("output", help="输出路径")
 
     args = parser.parse_args()
+
+    # 根据数据标注格式类型调用不同的处理函数
     if args.type == 'coco':
-        ensure_save_dir(args.output)
-        process_coco(args.coco_json, args.image_dir, args.ratio, args.output)
+        ensure_coco_save_dir(args.output)
+        process_coco(args.coco_json, args.image_dir, args.ratio, args.output)
+    elif args.type == 'yolo':
+        ensure_yolo_save_dir(args.output)
+        process_yolo(args.image_dir, args.ratio, args.output)