1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- import argparse
- import os
- import random
- def get_all_img(img_fold):
- # 获取字体文件夹列表
- img_fold_list = os.listdir(img_fold)
- cnt = 0
- img_path_list = []
- for img_fold_in in img_fold_list:
- img_list1 = []
- if img_fold_in.endswith('.txt'):
- continue
- img_list = os.listdir(os.path.join(img_fold, img_fold_in))
- for img in img_list:
- if img.endswith('.txt'):
- continue
- img_path = str(img_fold_in) + '/' + str(img)
- img_list1.append(str(img_path) + ' ' + str(cnt))
- cnt += 1
- img_path_list.append(img_list1)
- return img_path_list
- def divide(lines, img_folds, train_ratio):
- fp_val = open(str(img_folds) + '/val_list.txt', 'a')
- fp_train = open(str(img_folds) + '/train_list.txt', 'a')
- train_size = 0
- val_size = 0
- for line in lines:
- length = len(line)
- trainList = random.sample(range(0, length), round(train_ratio * length))
- train_size += len(trainList)
- for i in trainList:
- fp_train.write(line[i] + '\n')
- testList = []
- for i in range(0, length):
- if i not in trainList:
- fp_val.write(line[i] + '\n')
- testList.append(i)
- val_size += len(testList)
- print('train images ', train_size)
- print('val images ', val_size)
- fp_val.close()
- fp_train.close()
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--img_dir', type=str, default='./font_img_dataset/windows_1/english')
- parser.add_argument('--train_ratio', type=float, default=0.8)
- args = parser.parse_args()
- list1 = get_all_img(args.img_dir)
- divide(list1, args.img_dir, args.train_ratio)
|