import argparse import os import random def get_all_img(img_fold): # 获取字体文件夹列表 img_fold_list = os.listdir(img_fold) cnt = 0 img_path_list = [] for img_fold_in in img_fold_list: img_list1 = [] if img_fold_in.endswith('.txt'): continue img_list = os.listdir(os.path.join(img_fold, img_fold_in)) for img in img_list: if img.endswith('.txt'): continue img_path = str(img_fold_in) + '/' + str(img) img_list1.append(str(img_path) + ' ' + str(cnt)) cnt += 1 img_path_list.append(img_list1) return img_path_list def divide(lines, img_folds, train_ratio): fp_val = open(str(img_folds) + '/val_list.txt', 'a') fp_train = open(str(img_folds) + '/train_list.txt', 'a') train_size = 0 val_size = 0 for line in lines: length = len(line) trainList = random.sample(range(0, length), round(train_ratio * length)) train_size += len(trainList) for i in trainList: fp_train.write(line[i] + '\n') testList = [] for i in range(0, length): if i not in trainList: fp_val.write(line[i] + '\n') testList.append(i) val_size += len(testList) print('train images ', train_size) print('val images ', val_size) fp_val.close() fp_train.close() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--img_dir', type=str, default='./font_img_dataset/windows_1/chinese_gray') parser.add_argument('--train_ratio', type=float, default=0.8) args = parser.parse_args() list1 = get_all_img(args.img_dir) divide(list1, args.img_dir, args.train_ratio)