train_and_val_divide_copy.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import os
  2. import json
  3. import argparse
  4. import numpy as np
  5. import glob
  6. import cv2
  7. from sklearn.model_selection import train_test_split
  8. from labelme import utils
  9. from tqdm import tqdm
  10. import shutil
  11. np.random.seed(5)
  12. if __name__ == '__main__':
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument('--anno_dirs', type=str, nargs="+", default=[])
  15. parser.add_argument('--train_ratio', type=float, default=0.9)
  16. parser.add_argument('--train_dir', type=str)
  17. parser.add_argument('--val_dir', type=str)
  18. args = parser.parse_args()
  19. labelme_folds = args.anno_dirs
  20. json_list_path = []
  21. train_path = []
  22. val_path = []
  23. train_val_path_list = []
  24. train_val_folds = [args.train_dir, args.val_dir]
  25. # 每个文件夹按照比例划分,遍历完,最后合并
  26. for labelme_path in labelme_folds:
  27. list_path = glob.glob(labelme_path + "/*.json")
  28. json_list_path.extend(list_path)
  29. train_path1, val_path1 = train_test_split(list_path, test_size=1-args.train_ratio, train_size=args.train_ratio)
  30. train_path.extend(train_path1)
  31. val_path.extend(val_path1)
  32. train_val_path_list.append(train_path)
  33. train_val_path_list.append(val_path)
  34. print("train_n:", len(train_path), 'val_n:', len(val_path))
  35. # print("train images: %d" % len(train_path))
  36. for index, data_path in enumerate(train_val_path_list):
  37. for file in tqdm(data_path):
  38. json_path = file
  39. temp_json = {}
  40. img_name = ''
  41. with open(json_path, 'r', encoding='utf-8')as jp:
  42. info = json.load(jp)
  43. info['imagePath'] = os.path.basename(info['imagePath'])
  44. img_name = info['imagePath']
  45. info['imageData'] = None
  46. temp_json = info
  47. jp.close()
  48. img_path = file.replace('json', str(img_name).split('.')[1])
  49. try:
  50. shutil.copy(img_path, train_val_folds[index])
  51. json.dump(temp_json,
  52. open(os.path.join(train_val_folds[index], os.path.basename(json_path)), 'w', encoding='utf-8'),
  53. ensure_ascii=False, indent=4)
  54. except Exception as e:
  55. print(e)
  56. print('Wrong Image:', img_name)
  57. continue
  58. # print(img_name + '-->', img_name.replace('jpg', 'jpg'))
  59. # print("eval images: %d" % len(val_path))
  60. # for file in tqdm(val_path):
  61. # json_path = file
  62. # img_name = ''
  63. # with open(json_path, 'r', encoding='utf-8') as jp:
  64. # info = json.load(jp)
  65. # img_name = os.path.basename(info['imagePath'])
  66. # jp.close()
  67. # img_path = file.replace('json', str(img_name).split('.')[1])
  68. # try:
  69. # shutil.copy(img_path, args.val_path)
  70. # shutil.copy(json_path, args.val_path)
  71. # except Exception as e:
  72. # print(e)
  73. # print('Wrong Image:', img_name)
  74. # continue