1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- import argparse
- import ijson
- import os
- import shutil
- from tqdm import tqdm
- def copy_file_to_directory(source_dir, target_dir, filename):
-
- source_file = os.path.join(source_dir, filename)
-
- if os.path.isfile(source_file):
-
- target_file = os.path.join(target_dir, filename)
-
- os.makedirs(target_dir, exist_ok=True)
-
- shutil.copy(source_file, target_file)
- else:
- print(f"文件 {source_file} 不存在")
- def find_filenames_for_category(json_file, category_id, image_dir, output_dir):
-
- matching_image_ids = set()
-
- print('load annotations...')
- with open(json_file, 'r', encoding='utf-8') as f:
- total_annotations = sum(1 for _ in ijson.items(f, 'annotations.item'))
- print('annotations count : ', total_annotations)
- with open(json_file, 'r', encoding='utf-8') as f:
-
- for annotation in tqdm(ijson.items(f, 'annotations.item'), total=total_annotations,
- desc="Processing annotations"):
- if annotation['category_id'] == category_id:
- matching_image_ids.add(annotation['image_id'])
-
- matching_filenames = []
- print('load images...')
-
- with open(json_file, 'r', encoding='utf-8') as f:
- total_images = sum(1 for _ in ijson.items(f, 'images.item'))
- print('images count : ', total_images)
- count = 0
- with open(json_file, 'r', encoding='utf-8') as f:
-
- for image in tqdm(ijson.items(f, 'images.item'), total=total_images, desc="Processing images"):
- if image['id'] in matching_image_ids:
- count += 1
- copy_file_to_directory(image_dir, output_dir, image['file_name'])
- print(count, 'files matched')
- if __name__ == '__main__':
- parser = argparse.ArgumentParser("处理数据集\n")
- parser.add_argument("json_file", help="publaynet json 文件路径")
- parser.add_argument('image_dir', help='图片目录')
- parser.add_argument('output', help='输出目录')
- args = parser.parse_args()
- find_filenames_for_category(args.json_file, 3, args.image_dir, args.output)
|