|
@@ -0,0 +1,71 @@
|
|
|
+import argparse
|
|
|
+import ijson
|
|
|
+import os
|
|
|
+import shutil
|
|
|
+from tqdm import tqdm
|
|
|
+
|
|
|
+
|
|
|
+def copy_file_to_directory(source_dir, target_dir, filename):
|
|
|
+ # 构建源文件路径
|
|
|
+ source_file = os.path.join(source_dir, filename)
|
|
|
+
|
|
|
+ # 检查文件是否存在
|
|
|
+ if os.path.isfile(source_file):
|
|
|
+ # 构建目标文件路径
|
|
|
+ target_file = os.path.join(target_dir, filename)
|
|
|
+
|
|
|
+ # 确保目标目录存在
|
|
|
+ os.makedirs(target_dir, exist_ok=True)
|
|
|
+
|
|
|
+ # 拷贝文件
|
|
|
+ shutil.copy(source_file, target_file)
|
|
|
+ else:
|
|
|
+ print(f"文件 {source_file} 不存在")
|
|
|
+
|
|
|
+
|
|
|
+def find_filenames_for_category(json_file, category_id, image_dir, output_dir):
|
|
|
+ # Step 1: 找出 category_id 为 3 的 image_id 列表
|
|
|
+ matching_image_ids = set()
|
|
|
+
|
|
|
+ # 首先计算 annotations 的总项数,方便在 tqdm 进度条中使用
|
|
|
+ print('load annotations...')
|
|
|
+ with open(json_file, 'r', encoding='utf-8') as f:
|
|
|
+ total_annotations = sum(1 for _ in ijson.items(f, 'annotations.item'))
|
|
|
+ print('annotations count : ', total_annotations)
|
|
|
+
|
|
|
+ with open(json_file, 'r', encoding='utf-8') as f:
|
|
|
+ # 使用 tqdm 包装 ijson.items,显示解析 annotations 的进度
|
|
|
+ for annotation in tqdm(ijson.items(f, 'annotations.item'), total=total_annotations,
|
|
|
+ desc="Processing annotations"):
|
|
|
+ if annotation['category_id'] == category_id:
|
|
|
+ matching_image_ids.add(annotation['image_id'])
|
|
|
+
|
|
|
+ # Step 2: 根据 image_id 找到对应的 file_name
|
|
|
+ matching_filenames = []
|
|
|
+
|
|
|
+ print('load images...')
|
|
|
+ # 计算 images 的总项数,方便在 tqdm 进度条中使用
|
|
|
+ with open(json_file, 'r', encoding='utf-8') as f:
|
|
|
+ total_images = sum(1 for _ in ijson.items(f, 'images.item'))
|
|
|
+ print('images count : ', total_images)
|
|
|
+
|
|
|
+ count = 0
|
|
|
+ with open(json_file, 'r', encoding='utf-8') as f:
|
|
|
+ # 使用 tqdm 包装 ijson.items,显示解析 images 的进度
|
|
|
+ for image in tqdm(ijson.items(f, 'images.item'), total=total_images, desc="Processing images"):
|
|
|
+ if image['id'] in matching_image_ids:
|
|
|
+ count += 1
|
|
|
+ copy_file_to_directory(image_dir, output_dir, image['file_name'])
|
|
|
+
|
|
|
+ print(count, 'files matched')
|
|
|
+
|
|
|
+
|
|
|
+# 主程序入口,处理命令行参数
|
|
|
+if __name__ == '__main__':
|
|
|
+ parser = argparse.ArgumentParser("处理数据集\n")
|
|
|
+ parser.add_argument("json_file", help="publaynet json 文件路径")
|
|
|
+ parser.add_argument('image_dir', help='图片目录')
|
|
|
+ parser.add_argument('output', help='输出目录')
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
+ find_filenames_for_category(args.json_file, 3, args.image_dir, args.output)
|