document_data_generator.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import os
  2. from tqdm import tqdm
  3. import cv2
  4. import numpy as np
  5. import utils
  6. import dataprocessor
  7. import argparse
  8. def str2bool(v):
  9. if isinstance(v, bool):
  10. return v
  11. if v.lower() in ('yes', 'true', 't', 'y', '1'):
  12. return True
  13. elif v.lower() in ('no', 'false', 'f', 'n', '0'):
  14. return False
  15. else:
  16. raise argparse.ArgumentTypeError('Boolean value expected.')
  17. def args_processor():
  18. parser = argparse.ArgumentParser()
  19. parser.add_argument("-i", "--input-dir", help="Path to data files (Extract images using video_to_image.py first")
  20. parser.add_argument("-o", "--output-dir", help="Directory to store results")
  21. parser.add_argument("-v", "--visualize", help="Draw the point on the corner", default=False, type=bool)
  22. parser.add_argument("-a", "--augment", type=str2bool, nargs='?',
  23. const=True, default=True,
  24. help="Augment image dataset")
  25. parser.add_argument("--dataset", default="smartdoc", help="'smartdoc' or 'selfcollected' dataset")
  26. return parser.parse_args()
  27. if __name__ == '__main__':
  28. if __name__ == '__main__':
  29. args = args_processor()
  30. input_directory = args.input_dir
  31. if not os.path.isdir(args.output_dir):
  32. os.mkdir(args.output_dir)
  33. import csv
  34. # Dataset iterator
  35. if args.dataset == "smartdoc":
  36. dataset_test = dataprocessor.dataset.SmartDocDirectories(input_directory)
  37. elif args.dataset == "selfcollected":
  38. dataset_test = dataprocessor.dataset.SelfCollectedDataset(input_directory)
  39. else:
  40. print("Incorrect dataset type; please choose between smartdoc or selfcollected")
  41. assert (False)
  42. with open(os.path.join(args.output_dir, 'gt.csv'), 'a') as csvfile:
  43. spamwriter = csv.writer(csvfile, delimiter=',',
  44. quotechar='|', quoting=csv.QUOTE_MINIMAL)
  45. # Counter for file naming
  46. counter = 0
  47. for data_elem in tqdm(dataset_test.myData):
  48. img_path = data_elem[0]
  49. target = data_elem[1].reshape((4, 2))
  50. img = cv2.imread(img_path)
  51. if args.dataset == "selfcollected":
  52. target = target / (img.shape[1], img.shape[0])
  53. target = target * (1920, 1920)
  54. img = cv2.resize(img, (1920, 1920))
  55. corner_cords = target
  56. angles = [0, 271, 90] if args.augment else [0]
  57. random_crops = [0, 16] if args.augment else [0]
  58. for angle in angles:
  59. img_rotate, gt_rotate = utils.utils.rotate(img, corner_cords, angle)
  60. for random_crop in random_crops:
  61. counter += 1
  62. f_name = str(counter).zfill(8)
  63. img_crop, gt_crop = utils.utils.random_crop(img_rotate, gt_rotate)
  64. mah_size = img_crop.shape
  65. img_crop = cv2.resize(img_crop, (64, 64))
  66. gt_crop = np.array(gt_crop)
  67. if (args.visualize):
  68. no=0
  69. for a in range(0,4):
  70. no+=1
  71. cv2.circle(img_crop, tuple(((gt_crop[a]*64).astype(int))), 2,(255-no*60,no*60,0),9)
  72. # # cv2.imwrite("asda.jpg", img)
  73. cv2.imwrite(os.path.join(args.output_dir, f_name+".jpg"), img_crop)
  74. spamwriter.writerow((f_name+".jpg", tuple(list(gt_crop))))