1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- import os
- import json
- import argparse
- import numpy as np
- import glob
- import cv2
- from sklearn.model_selection import train_test_split
- from labelme import utils
- from tqdm import tqdm
- import shutil
- np.random.seed(5)
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--anno_dirs', type=str, nargs="+", default=[])
- parser.add_argument('--train_ratio', type=float, default=0.9)
- parser.add_argument('--train_dir', type=str)
- parser.add_argument('--val_dir', type=str)
- args = parser.parse_args()
- labelme_folds = args.anno_dirs
- json_list_path = []
- train_path = []
- val_path = []
- train_val_path_list = []
- train_val_folds = [args.train_dir, args.val_dir]
- # 每个文件夹按照比例划分,遍历完,最后合并
- for labelme_path in labelme_folds:
- list_path = glob.glob(labelme_path + "/*.json")
- json_list_path.extend(list_path)
- train_path1, val_path1 = train_test_split(list_path, test_size=1-args.train_ratio, train_size=args.train_ratio)
- train_path.extend(train_path1)
- val_path.extend(val_path1)
- train_val_path_list.append(train_path)
- train_val_path_list.append(val_path)
- print("train_n:", len(train_path), 'val_n:', len(val_path))
- # print("train images: %d" % len(train_path))
- for index, data_path in enumerate(train_val_path_list):
- for file in tqdm(data_path):
- json_path = file
- temp_json = {}
- img_name = ''
- with open(json_path, 'r', encoding='utf-8')as jp:
- info = json.load(jp)
- info['imagePath'] = os.path.basename(info['imagePath'])
- img_name = info['imagePath']
- info['imageData'] = None
- temp_json = info
- jp.close()
- img_path = file.replace('json', str(img_name).split('.')[1])
- try:
- shutil.copy(img_path, train_val_folds[index])
- json.dump(temp_json,
- open(os.path.join(train_val_folds[index], os.path.basename(json_path)), 'w', encoding='utf-8'),
- ensure_ascii=False, indent=4)
- except Exception as e:
- print(e)
- print('Wrong Image:', img_name)
- continue
- # print(img_name + '-->', img_name.replace('jpg', 'jpg'))
- # print("eval images: %d" % len(val_path))
- # for file in tqdm(val_path):
- # json_path = file
- # img_name = ''
- # with open(json_path, 'r', encoding='utf-8') as jp:
- # info = json.load(jp)
- # img_name = os.path.basename(info['imagePath'])
- # jp.close()
- # img_path = file.replace('json', str(img_name).split('.')[1])
- # try:
- # shutil.copy(img_path, args.val_path)
- # shutil.copy(json_path, args.val_path)
- # except Exception as e:
- # print(e)
- # print('Wrong Image:', img_name)
- # continue
|