divide_train_eval.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import argparse
  2. import os
  3. import random
  4. def get_all_img(img_fold):
  5. # 获取字体文件夹列表
  6. img_fold_list = os.listdir(img_fold)
  7. cnt = 0
  8. img_path_list = []
  9. for img_fold_in in img_fold_list:
  10. img_list1 = []
  11. if img_fold_in.endswith('.txt'):
  12. continue
  13. img_list = os.listdir(os.path.join(img_fold, img_fold_in))
  14. for img in img_list:
  15. if img.endswith('.txt'):
  16. continue
  17. img_path = str(img_fold_in) + '/' + str(img)
  18. img_list1.append(str(img_path) + ' ' + str(cnt))
  19. cnt += 1
  20. img_path_list.append(img_list1)
  21. return img_path_list
  22. def divide(lines, img_folds, train_ratio):
  23. fp_val = open(str(img_folds) + '/val_list.txt', 'a')
  24. fp_train = open(str(img_folds) + '/train_list.txt', 'a')
  25. train_size = 0
  26. val_size = 0
  27. for line in lines:
  28. length = len(line)
  29. trainList = random.sample(range(0, length), round(train_ratio * length))
  30. train_size += len(trainList)
  31. for i in trainList:
  32. fp_train.write(line[i] + '\n')
  33. testList = []
  34. for i in range(0, length):
  35. if i not in trainList:
  36. fp_val.write(line[i] + '\n')
  37. testList.append(i)
  38. val_size += len(testList)
  39. print('train images ', train_size)
  40. print('val images ', val_size)
  41. fp_val.close()
  42. fp_train.close()
  43. if __name__ == '__main__':
  44. parser = argparse.ArgumentParser()
  45. parser.add_argument('--img_dir', type=str, default='./font_img_dataset/windows_1/chinese_gray')
  46. parser.add_argument('--train_ratio', type=float, default=0.8)
  47. args = parser.parse_args()
  48. list1 = get_all_img(args.img_dir)
  49. divide(list1, args.img_dir, args.train_ratio)