123456789101112131415161718192021222324252627282930313233 |
- import argparse
- import numpy as np
- def divide(lines, train_ratio):
- fp_val = open('./rec_gt_val.txt', 'w+', encoding='utf-8')
- fp_train = open('./rec_gt_train.txt', 'w+', encoding='utf-8')
- length = len(lines)
- shuffled_indices = np.random.permutation(length)
- train_size = int(length * train_ratio)
- train_list = shuffled_indices[:train_size]
- val_list = shuffled_indices[train_size:]
- for i in val_list:
- fp_val.write(lines[i])
- for i in train_list:
- fp_train.write(lines[i])
- print('train images ', train_size)
- print('val images ', length - train_size)
- fp_val.close()
- fp_train.close()
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--path', type=str, default='./file.txt')
- parser.add_argument('--train_ratio', type=float, default=0.8)
- args = parser.parse_args()
- fp = open(args.path, 'r', encoding='utf-8')
- lines = fp.readlines()
- divide(lines, args.train_ratio)
|