Explorar o código

[feat] bug 修复

WangChao hai 4 meses
pai
achega
a7556b98e3
Modificáronse 2 ficheiros con 61 adicións e 2 borrados
  1. 55 0
      data_collection/clean_label.py
  2. 6 2
      data_collection/spilit_data.py

+ 55 - 0
data_collection/clean_label.py

@@ -0,0 +1,55 @@
+import json
+import os
+import random
+import shutil
+import argparse
+from tqdm import tqdm
+
+
+def clean_label_files(labels_dir, output_labels_dir, log_file):
+    if not os.path.exists(output_labels_dir):
+        os.makedirs(output_labels_dir)
+
+    processed_files = []
+    label_files = os.listdir(labels_dir)
+
+    with tqdm(total=len(label_files), desc="Processing files", unit="file") as pbar:
+        for label_file in label_files:  # 遍历每一个txt文件名
+            file_path = os.path.join(labels_dir, label_file)  # 拼接得到旧的txt文件地址
+            output_file_path = os.path.join(output_labels_dir, label_file)  # 新的txt文件地址
+            with open(file_path, 'r') as f:  # 打开旧的txt文件
+                lines = f.readlines()  # 得到文件内容
+            new_lines = []
+            file_needs_logging = False
+            for line in lines:  # 遍历处理文件内容
+                parts = line.strip().split()
+                if parts:
+                    first_num = int(parts[0])
+                    other_nums = [float(num) for num in parts[1:]]
+                    if first_num not in [4, 20] and all(
+                            0 <= num <= 1 for num in other_nums):  # 只有当一行内容的第一个数字不是4或20,其他数字都在[0,1]之间时这一行内容才能被保留
+                        new_lines.append(line)
+                    elif not all(0 <= num <= 1 for num in other_nums):  # 如果不满足上诉条件且有数字不在【0,1】的,要记录文件名
+                        file_needs_logging = True
+            if file_needs_logging:
+                processed_files.append(label_file)
+            with open(output_file_path, 'w') as f:  # 将文本内容写入新文件夹中
+                f.writelines(new_lines)
+
+            pbar.update(1)  # 更新进度条
+
+    with open(log_file, 'w') as log:
+        for file_name in processed_files:
+            log.write(f"{file_name}\n")
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser("删除特定标注信息\n")
+    parser.add_argument("image_dir", help="图片文件夹路径")
+    parser.add_argument("--log_file", default="processed_files.log", help="记录被处理文件的日志文件路径")
+
+    args = parser.parse_args()
+
+    output_labels_dir = os.path.join(args.image_dir, 'processed_labels')  # 生成的新的labels文件夹
+    log_file_dir = os.path.join(args.image_dir, args.log_file)  # 日志文件
+    clean_label_files(os.path.join(args.image_dir, 'labels'), output_labels_dir, log_file_dir)

+ 6 - 2
data_collection/spilit_data.py

@@ -117,12 +117,16 @@ def process_yolo(image_dir, train_ratio, output):
 
     for file_name in train_files:
         shutil.copy2(os.path.join(images_dir, file_name), train_images_dir)
-        label_file_name = file_name.replace('.jpg', '.txt')
+        # label_file_name = file_name.replace('.jpg', '.txt')
+        prefix = file_name[:-4]
+        label_file_name = prefix + ".txt"
         shutil.copy2(os.path.join(labels_dir, label_file_name), train_labels_dir)
 
     for file_name in val_files:
         shutil.copy2(os.path.join(images_dir, file_name), val_images_dir)
-        label_file_name = file_name.replace('.jpg', '.txt')
+        # label_file_name = file_name.replace('.jpg', '.txt')
+        prefix = file_name[:-4]
+        label_file_name = prefix + ".txt"
         shutil.copy2(os.path.join(labels_dir, label_file_name), val_labels_dir)
 
     # 复制 image_dir 下的其他文件到 output