Przeglądaj źródła

上传文件至 'Data_Copy'

复制图片和对应json文件
zhuzezhou 1 rok temu
rodzic
commit
b51722163a
2 zmienionych plików z 127 dodań i 0 usunięć
  1. 47 0
      Data_Copy/data_copy.py
  2. 80 0
      Data_Copy/train_and_val_divide_copy.py

+ 47 - 0
Data_Copy/data_copy.py

@@ -0,0 +1,47 @@
+import os
+import json
+import argparse
+import numpy as np
+import glob
+import cv2
+from sklearn.model_selection import train_test_split
+from labelme import utils
+from tqdm import tqdm
+import shutil
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--anno_dirs', type=str, nargs="+", default=[])
+    parser.add_argument('--data_dir', type=str, default='')
+    args = parser.parse_args()
+    labelme_folds = args.anno_dirs
+    json_list_path = []
+
+    for labelme_path in labelme_folds:
+        list_path = glob.glob(labelme_path + "/*.json")
+        json_list_path.extend(list_path)
+
+    print('reading...')
+    print("images: %d" % len(json_list_path))
+
+    for file in tqdm(json_list_path):
+        json_path = file
+        temp_json = {}
+        img_name = ''
+        with open(json_path, 'r', encoding='utf-8') as jp:
+            info = json.load(jp)
+            info['imagePath'] = os.path.basename(info['imagePath'])
+            img_name = info['imagePath']
+            info['imageData'] = None
+            temp_json = info
+        jp.close()
+        img_path = file.replace('json', str(img_name).split('.')[1])
+        try:
+            shutil.copy(img_path, args.data_dir)
+            json.dump(temp_json,
+                      open(os.path.join(args.data_dir, os.path.basename(json_path)), 'w', encoding='utf-8'),
+                      ensure_ascii=False, indent=4)
+        except Exception as e:
+            print(e)
+            print('Wrong Image:', img_name)
+            continue

+ 80 - 0
Data_Copy/train_and_val_divide_copy.py

@@ -0,0 +1,80 @@
+import os
+import json
+import argparse
+import numpy as np
+import glob
+import cv2
+from sklearn.model_selection import train_test_split
+from labelme import utils
+from tqdm import tqdm
+import shutil
+
+np.random.seed(5)
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--anno_dirs', type=str, nargs="+", default=[])
+    parser.add_argument('--train_ratio', type=float, default=0.9)
+    parser.add_argument('--train_dir', type=str)
+    parser.add_argument('--val_dir', type=str)
+    args = parser.parse_args()
+    labelme_folds = args.anno_dirs
+    json_list_path = []
+    train_path = []
+    val_path = []
+    train_val_path_list = []
+    train_val_folds = [args.train_dir, args.val_dir]
+
+    # 每个文件夹按照比例划分,遍历完,最后合并
+    for labelme_path in labelme_folds:
+        list_path = glob.glob(labelme_path + "/*.json")
+        json_list_path.extend(list_path)
+        train_path1, val_path1 = train_test_split(list_path, test_size=1-args.train_ratio, train_size=args.train_ratio)
+        train_path.extend(train_path1)
+        val_path.extend(val_path1)
+
+    train_val_path_list.append(train_path)
+    train_val_path_list.append(val_path)
+
+    print("train_n:", len(train_path), 'val_n:', len(val_path))
+
+    # print("train images: %d" % len(train_path))
+    for index, data_path in enumerate(train_val_path_list):
+        for file in tqdm(data_path):
+            json_path = file
+            temp_json = {}
+            img_name = ''
+            with open(json_path, 'r', encoding='utf-8')as jp:
+                info = json.load(jp)
+                info['imagePath'] = os.path.basename(info['imagePath'])
+                img_name = info['imagePath']
+                info['imageData'] = None
+                temp_json = info
+            jp.close()
+            img_path = file.replace('json', str(img_name).split('.')[1])
+            try:
+                shutil.copy(img_path, train_val_folds[index])
+                json.dump(temp_json,
+                          open(os.path.join(train_val_folds[index], os.path.basename(json_path)), 'w', encoding='utf-8'),
+                          ensure_ascii=False, indent=4)
+            except Exception as e:
+                print(e)
+                print('Wrong Image:', img_name)
+                continue
+        # print(img_name + '-->', img_name.replace('jpg', 'jpg'))
+    # print("eval images: %d" % len(val_path))
+    # for file in tqdm(val_path):
+    #     json_path = file
+    #     img_name = ''
+    #     with open(json_path, 'r', encoding='utf-8') as jp:
+    #         info = json.load(jp)
+    #         img_name = os.path.basename(info['imagePath'])
+    #     jp.close()
+    #     img_path = file.replace('json', str(img_name).split('.')[1])
+    #     try:
+    #         shutil.copy(img_path, args.val_path)
+    #         shutil.copy(json_path, args.val_path)
+    #     except Exception as e:
+    #         print(e)
+    #         print('Wrong Image:', img_name)
+    #         continue