document_data_generator.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import os
  2. from tqdm import tqdm
  3. import cv2
  4. import numpy as np
  5. import utils
  6. import dataprocessor
  7. def args_processor():
  8. import argparse
  9. parser = argparse.ArgumentParser()
  10. parser.add_argument("-i", "--input-dir", help="Path to data files (Extract images using video_to_image.py first")
  11. parser.add_argument("-o", "--output-dir", help="Directory to store results")
  12. parser.add_argument("--dataset", default="smartdoc", help="'smartdoc' or 'selfcollected' dataset")
  13. return parser.parse_args()
  14. if __name__ == '__main__':
  15. if __name__ == '__main__':
  16. args = args_processor()
  17. input_directory = args.input_dir
  18. if not os.path.isdir(args.output_dir):
  19. os.mkdir(args.output_dir)
  20. import csv
  21. # Dataset iterator
  22. if args.dataset == "smartdoc":
  23. dataset_test = dataprocessor.dataset.SmartDocDirectories(input_directory)
  24. elif args.dataset == "selfcollected":
  25. dataset_test = dataprocessor.dataset.SelfCollectedDataset(input_directory)
  26. else:
  27. print("Incorrect dataset type; please choose between smartdoc or selfcollected")
  28. assert (False)
  29. with open(os.path.join(args.output_dir, 'gt.csv'), 'a', newline='') as csvfile:
  30. spamwriter = csv.writer(csvfile, delimiter=',',
  31. quotechar='|')
  32. # Counter for file naming
  33. counter = 0
  34. for data_elem in tqdm(dataset_test.myData):
  35. img_path = data_elem[0]
  36. target = data_elem[1].reshape((4, 2))
  37. img = cv2.imread(img_path)
  38. if args.dataset == "selfcollected":
  39. target = target / (img.shape[1], img.shape[0])
  40. target = target * (1920, 1920)
  41. img = cv2.resize(img, (1920, 1920))
  42. corner_cords = target
  43. for angle in range(0, 271, 90):
  44. img_rotate, gt_rotate = utils.utils.rotate(img, corner_cords, angle)
  45. for random_crop in range(0, 16):
  46. counter += 1
  47. f_name = str(counter).zfill(8)
  48. img_crop, gt_crop = utils.utils.random_crop(img_rotate, gt_rotate)
  49. mah_size = img_crop.shape
  50. img_crop = cv2.resize(img_crop, (64, 64))
  51. gt_crop = np.array(gt_crop)
  52. # no=0
  53. # for a in range(0,4):
  54. # no+=1
  55. # cv2.circle(img_crop, tuple(((gt_crop[a]*64).astype(int))), 2,(255-no*60,no*60,0),9)
  56. # # cv2.imwrite("asda.jpg", img)
  57. cv2.imwrite(os.path.join(args.output_dir, f_name+".jpg"), img_crop)
  58. spamwriter.writerow((f_name+".jpg", tuple(list(gt_crop))))
  59. print(tuple(list(gt_crop)))