publaynet.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import argparse
  2. import ijson
  3. import os
  4. import shutil
  5. from tqdm import tqdm
  6. def copy_file_to_directory(source_dir, target_dir, filename):
  7. # 构建源文件路径
  8. source_file = os.path.join(source_dir, filename)
  9. # 检查文件是否存在
  10. if os.path.isfile(source_file):
  11. # 构建目标文件路径
  12. target_file = os.path.join(target_dir, filename)
  13. # 确保目标目录存在
  14. os.makedirs(target_dir, exist_ok=True)
  15. # 拷贝文件
  16. shutil.copy(source_file, target_file)
  17. else:
  18. print(f"文件 {source_file} 不存在")
  19. def find_filenames_for_category(json_file, category_id, image_dir, output_dir):
  20. # Step 1: 找出 category_id 为 3 的 image_id 列表
  21. matching_image_ids = set()
  22. # 首先计算 annotations 的总项数,方便在 tqdm 进度条中使用
  23. print('load annotations...')
  24. with open(json_file, 'r', encoding='utf-8') as f:
  25. total_annotations = sum(1 for _ in ijson.items(f, 'annotations.item'))
  26. print('annotations count : ', total_annotations)
  27. with open(json_file, 'r', encoding='utf-8') as f:
  28. # 使用 tqdm 包装 ijson.items,显示解析 annotations 的进度
  29. for annotation in tqdm(ijson.items(f, 'annotations.item'), total=total_annotations,
  30. desc="Processing annotations"):
  31. if annotation['category_id'] == category_id:
  32. matching_image_ids.add(annotation['image_id'])
  33. # Step 2: 根据 image_id 找到对应的 file_name
  34. matching_filenames = []
  35. print('load images...')
  36. # 计算 images 的总项数,方便在 tqdm 进度条中使用
  37. with open(json_file, 'r', encoding='utf-8') as f:
  38. total_images = sum(1 for _ in ijson.items(f, 'images.item'))
  39. print('images count : ', total_images)
  40. count = 0
  41. with open(json_file, 'r', encoding='utf-8') as f:
  42. # 使用 tqdm 包装 ijson.items,显示解析 images 的进度
  43. for image in tqdm(ijson.items(f, 'images.item'), total=total_images, desc="Processing images"):
  44. if image['id'] in matching_image_ids:
  45. count += 1
  46. copy_file_to_directory(image_dir, output_dir, image['file_name'])
  47. print(count, 'files matched')
  48. # 主程序入口,处理命令行参数
  49. if __name__ == '__main__':
  50. parser = argparse.ArgumentParser("处理数据集\n")
  51. parser.add_argument("json_file", help="publaynet json 文件路径")
  52. parser.add_argument('image_dir', help='图片目录')
  53. parser.add_argument('output', help='输出目录')
  54. args = parser.parse_args()
  55. find_filenames_for_category(args.json_file, 3, args.image_dir, args.output)