divide_rec_train_val.py 991 B

123456789101112131415161718192021222324252627282930313233
  1. import argparse
  2. import numpy as np
  3. def divide(lines, train_ratio):
  4. fp_val = open('./rec_gt_val.txt', 'w+', encoding='utf-8')
  5. fp_train = open('./rec_gt_train.txt', 'w+', encoding='utf-8')
  6. length = len(lines)
  7. shuffled_indices = np.random.permutation(length)
  8. train_size = int(length * train_ratio)
  9. train_list = shuffled_indices[:train_size]
  10. val_list = shuffled_indices[train_size:]
  11. for i in val_list:
  12. fp_val.write(lines[i])
  13. for i in train_list:
  14. fp_train.write(lines[i])
  15. print('train images ', train_size)
  16. print('val images ', length - train_size)
  17. fp_val.close()
  18. fp_train.close()
  19. if __name__ == '__main__':
  20. parser = argparse.ArgumentParser()
  21. parser.add_argument('--path', type=str, default='./file.txt')
  22. parser.add_argument('--train_ratio', type=float, default=0.8)
  23. args = parser.parse_args()
  24. fp = open(args.path, 'r', encoding='utf-8')
  25. lines = fp.readlines()
  26. divide(lines, args.train_ratio)