import argparse import ijson import os import shutil from tqdm import tqdm def copy_file_to_directory(source_dir, target_dir, filename): # 构建源文件路径 source_file = os.path.join(source_dir, filename) # 检查文件是否存在 if os.path.isfile(source_file): # 构建目标文件路径 target_file = os.path.join(target_dir, filename) # 确保目标目录存在 os.makedirs(target_dir, exist_ok=True) # 拷贝文件 shutil.copy(source_file, target_file) else: print(f"文件 {source_file} 不存在") def find_filenames_for_category(json_file, category_id, image_dir, output_dir): # Step 1: 找出 category_id 为 3 的 image_id 列表 matching_image_ids = set() # 首先计算 annotations 的总项数,方便在 tqdm 进度条中使用 print('load annotations...') with open(json_file, 'r', encoding='utf-8') as f: total_annotations = sum(1 for _ in ijson.items(f, 'annotations.item')) print('annotations count : ', total_annotations) with open(json_file, 'r', encoding='utf-8') as f: # 使用 tqdm 包装 ijson.items,显示解析 annotations 的进度 for annotation in tqdm(ijson.items(f, 'annotations.item'), total=total_annotations, desc="Processing annotations"): if annotation['category_id'] == category_id: matching_image_ids.add(annotation['image_id']) # Step 2: 根据 image_id 找到对应的 file_name matching_filenames = [] print('load images...') # 计算 images 的总项数,方便在 tqdm 进度条中使用 with open(json_file, 'r', encoding='utf-8') as f: total_images = sum(1 for _ in ijson.items(f, 'images.item')) print('images count : ', total_images) count = 0 with open(json_file, 'r', encoding='utf-8') as f: # 使用 tqdm 包装 ijson.items,显示解析 images 的进度 for image in tqdm(ijson.items(f, 'images.item'), total=total_images, desc="Processing images"): if image['id'] in matching_image_ids: count += 1 copy_file_to_directory(image_dir, output_dir, image['file_name']) print(count, 'files matched') # 主程序入口,处理命令行参数 if __name__ == '__main__': parser = argparse.ArgumentParser("处理数据集\n") parser.add_argument("json_file", help="publaynet json 文件路径") parser.add_argument('image_dir', help='图片目录') parser.add_argument('output', help='输出目录') args = parser.parse_args() find_filenames_for_category(args.json_file, 3, args.image_dir, args.output)