import os import json import argparse import numpy as np import glob import cv2 from sklearn.model_selection import train_test_split from labelme import utils from tqdm import tqdm import shutil np.random.seed(5) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--anno_dirs', type=str, nargs="+", default=[]) parser.add_argument('--train_ratio', type=float, default=0.9) parser.add_argument('--train_dir', type=str) parser.add_argument('--val_dir', type=str) args = parser.parse_args() labelme_folds = args.anno_dirs json_list_path = [] train_path = [] val_path = [] train_val_path_list = [] train_val_folds = [args.train_dir, args.val_dir] # 每个文件夹按照比例划分,遍历完,最后合并 for labelme_path in labelme_folds: list_path = glob.glob(labelme_path + "/*.json") json_list_path.extend(list_path) train_path1, val_path1 = train_test_split(list_path, test_size=1-args.train_ratio, train_size=args.train_ratio) train_path.extend(train_path1) val_path.extend(val_path1) train_val_path_list.append(train_path) train_val_path_list.append(val_path) print("train_n:", len(train_path), 'val_n:', len(val_path)) # print("train images: %d" % len(train_path)) for index, data_path in enumerate(train_val_path_list): for file in tqdm(data_path): json_path = file temp_json = {} img_name = '' with open(json_path, 'r', encoding='utf-8')as jp: info = json.load(jp) info['imagePath'] = os.path.basename(info['imagePath']) img_name = info['imagePath'] info['imageData'] = None temp_json = info jp.close() img_path = file.replace('json', str(img_name).split('.')[1]) try: shutil.copy(img_path, train_val_folds[index]) json.dump(temp_json, open(os.path.join(train_val_folds[index], os.path.basename(json_path)), 'w', encoding='utf-8'), ensure_ascii=False, indent=4) except Exception as e: print(e) print('Wrong Image:', img_name) continue # print(img_name + '-->', img_name.replace('jpg', 'jpg')) # print("eval images: %d" % len(val_path)) # for file in tqdm(val_path): # json_path = file # img_name = '' # with open(json_path, 'r', encoding='utf-8') as jp: # info = json.load(jp) # img_name = os.path.basename(info['imagePath']) # jp.close() # img_path = file.replace('json', str(img_name).split('.')[1]) # try: # shutil.copy(img_path, args.val_path) # shutil.copy(json_path, args.val_path) # except Exception as e: # print(e) # print('Wrong Image:', img_name) # continue