|
@@ -0,0 +1,80 @@
|
|
|
+import os
|
|
|
+import json
|
|
|
+import argparse
|
|
|
+import numpy as np
|
|
|
+import glob
|
|
|
+import cv2
|
|
|
+from sklearn.model_selection import train_test_split
|
|
|
+from labelme import utils
|
|
|
+from tqdm import tqdm
|
|
|
+import shutil
|
|
|
+
|
|
|
+np.random.seed(5)
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ parser = argparse.ArgumentParser()
|
|
|
+ parser.add_argument('--anno_dirs', type=str, nargs="+", default=[])
|
|
|
+ parser.add_argument('--train_ratio', type=float, default=0.9)
|
|
|
+ parser.add_argument('--train_dir', type=str)
|
|
|
+ parser.add_argument('--val_dir', type=str)
|
|
|
+ args = parser.parse_args()
|
|
|
+ labelme_folds = args.anno_dirs
|
|
|
+ json_list_path = []
|
|
|
+ train_path = []
|
|
|
+ val_path = []
|
|
|
+ train_val_path_list = []
|
|
|
+ train_val_folds = [args.train_dir, args.val_dir]
|
|
|
+
|
|
|
+ # 每个文件夹按照比例划分,遍历完,最后合并
|
|
|
+ for labelme_path in labelme_folds:
|
|
|
+ list_path = glob.glob(labelme_path + "/*.json")
|
|
|
+ json_list_path.extend(list_path)
|
|
|
+ train_path1, val_path1 = train_test_split(list_path, test_size=1-args.train_ratio, train_size=args.train_ratio)
|
|
|
+ train_path.extend(train_path1)
|
|
|
+ val_path.extend(val_path1)
|
|
|
+
|
|
|
+ train_val_path_list.append(train_path)
|
|
|
+ train_val_path_list.append(val_path)
|
|
|
+
|
|
|
+ print("train_n:", len(train_path), 'val_n:', len(val_path))
|
|
|
+
|
|
|
+ # print("train images: %d" % len(train_path))
|
|
|
+ for index, data_path in enumerate(train_val_path_list):
|
|
|
+ for file in tqdm(data_path):
|
|
|
+ json_path = file
|
|
|
+ temp_json = {}
|
|
|
+ img_name = ''
|
|
|
+ with open(json_path, 'r', encoding='utf-8')as jp:
|
|
|
+ info = json.load(jp)
|
|
|
+ info['imagePath'] = os.path.basename(info['imagePath'])
|
|
|
+ img_name = info['imagePath']
|
|
|
+ info['imageData'] = None
|
|
|
+ temp_json = info
|
|
|
+ jp.close()
|
|
|
+ img_path = file.replace('json', str(img_name).split('.')[1])
|
|
|
+ try:
|
|
|
+ shutil.copy(img_path, train_val_folds[index])
|
|
|
+ json.dump(temp_json,
|
|
|
+ open(os.path.join(train_val_folds[index], os.path.basename(json_path)), 'w', encoding='utf-8'),
|
|
|
+ ensure_ascii=False, indent=4)
|
|
|
+ except Exception as e:
|
|
|
+ print(e)
|
|
|
+ print('Wrong Image:', img_name)
|
|
|
+ continue
|
|
|
+ # print(img_name + '-->', img_name.replace('jpg', 'jpg'))
|
|
|
+ # print("eval images: %d" % len(val_path))
|
|
|
+ # for file in tqdm(val_path):
|
|
|
+ # json_path = file
|
|
|
+ # img_name = ''
|
|
|
+ # with open(json_path, 'r', encoding='utf-8') as jp:
|
|
|
+ # info = json.load(jp)
|
|
|
+ # img_name = os.path.basename(info['imagePath'])
|
|
|
+ # jp.close()
|
|
|
+ # img_path = file.replace('json', str(img_name).split('.')[1])
|
|
|
+ # try:
|
|
|
+ # shutil.copy(img_path, args.val_path)
|
|
|
+ # shutil.copy(json_path, args.val_path)
|
|
|
+ # except Exception as e:
|
|
|
+ # print(e)
|
|
|
+ # print('Wrong Image:', img_name)
|
|
|
+ # continue
|