Sfoglia il codice sorgente

[feat] 功能更新

WangChao 3 mesi fa
parent
commit
9e31bcaadb
2 ha cambiato i file con 36 aggiunte e 20 eliminazioni
  1. 27 18
      data_collection/clean_label.py
  2. 9 2
      data_collection/images_fetching.py

+ 27 - 18
data_collection/clean_label.py

@@ -17,24 +17,33 @@ def clean_label_files(labels_dir, output_labels_dir, log_file):
         for label_file in label_files:  # 遍历每一个txt文件名
             file_path = os.path.join(labels_dir, label_file)  # 拼接得到旧的txt文件地址
             output_file_path = os.path.join(output_labels_dir, label_file)  # 新的txt文件地址
-            with open(file_path, 'r') as f:  # 打开旧的txt文件
-                lines = f.readlines()  # 得到文件内容
-            new_lines = []
-            file_needs_logging = False
-            for line in lines:  # 遍历处理文件内容
-                parts = line.strip().split()
-                if parts:
-                    first_num = int(parts[0])
-                    other_nums = [float(num) for num in parts[1:]]
-                    if first_num not in [4, 20] and all(
-                            0 <= num <= 1 for num in other_nums):  # 只有当一行内容的第一个数字不是4或20,其他数字都在[0,1]之间时这一行内容才能被保留
-                        new_lines.append(line)
-                    elif not all(0 <= num <= 1 for num in other_nums):  # 如果不满足上诉条件且有数字不在【0,1】的,要记录文件名
-                        file_needs_logging = True
-            if file_needs_logging:
-                processed_files.append(label_file)
-            with open(output_file_path, 'w') as f:  # 将文本内容写入新文件夹中
-                f.writelines(new_lines)
+
+            _, file_extension = os.path.splitext(file_path)
+            if not file_extension.lower() == ".txt":
+                continue
+
+            try:
+                with open(file_path, 'r') as f:  # 打开旧的txt文件
+                    lines = f.readlines()  # 得到文件内容
+                new_lines = []
+                file_needs_logging = False
+                for line in lines:  # 遍历处理文件内容
+                    parts = line.strip().split()
+                    if parts:
+                        first_num = int(parts[0])
+                        other_nums = [float(num) for num in parts[1:]]
+                        if first_num not in [4, 20] and all(
+                                0 <= num <= 1 for num in other_nums):  # 只有当一行内容的第一个数字不是4或20,其他数字都在[0,1]之间时这一行内容才能被保留
+                            new_lines.append(line)
+                        elif not all(0 <= num <= 1 for num in other_nums):  # 如果不满足上诉条件且有数字不在【0,1】的,要记录文件名
+                            file_needs_logging = True
+                if file_needs_logging:
+                    processed_files.append(label_file)
+                with open(output_file_path, 'w') as f:  # 将文本内容写入新文件夹中
+                    f.writelines(new_lines)
+
+            except Exception as e:
+                print(f"发生错误: {e}")
 
             pbar.update(1)  # 更新进度条
 

+ 9 - 2
data_collection/images_fetching.py

@@ -36,7 +36,14 @@ def pyMuPDF_fitz(pdfPath, imagePath, need):
         pdfDoc = fitz.open(pdfPath)
 
         file_name = os.path.splitext(os.path.basename(pdfPath))[0]
-        page_array = get_render_pages(pdfDoc.page_count, need)
+
+        page_array = []
+        if need == 0:
+            page_array = list(range(pdfDoc.page_count))
+        elif need >= pdfDoc.page_count:
+            page_array = list(range(pdfDoc.page_count))
+        else:
+            page_array = get_render_pages(pdfDoc.page_count, need)
 
         file_info = {
             "file_name" : file_name,
@@ -87,7 +94,7 @@ def main():
     parse = argparse.ArgumentParser("从PDF当中随机抽取图片\n")
     parse.add_argument("input", help="输入路径")
     parse.add_argument("output", help="输出路径")
-    parse.add_argument("--count", type=int, help="每份PDF需要提取的图片数量")
+    parse.add_argument("--count", type=int, default=0, help="每份PDF需要提取的图片数量")
 
     args = parse.parse_args()
     process_files(args.input, os.path.join(args.output, "images"), args.count)