dataset.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. ''' Document Localization using Recursive CNN
  2. Maintainer : Khurram Javed
  3. Email : kjaved@ualberta.ca '''
  4. import csv
  5. import logging
  6. import os
  7. import xml.etree.ElementTree as ET
  8. import numpy as np
  9. from torchvision import transforms
  10. import utils.utils as utils
  11. # To incdude a new Dataset, inherit from Dataset and add all the Dataset specific parameters here.
  12. # Goal : Remove any data specific parameters from the rest of the code
  13. logger = logging.getLogger('iCARL')
  14. class Dataset():
  15. '''
  16. Base class to reprenent a Dataset
  17. '''
  18. def __init__(self, name):
  19. self.name = name
  20. self.data = []
  21. self.labels = []
  22. class SmartDoc(Dataset):
  23. '''
  24. Class to include MNIST specific details
  25. '''
  26. def __init__(self, directory="data"):
  27. super().__init__("smartdoc")
  28. self.data = []
  29. self.labels = []
  30. for d in directory:
  31. self.directory = d
  32. self.train_transform = transforms.Compose([transforms.Resize([32, 32]),
  33. transforms.ColorJitter(1.5, 1.5, 0.9, 0.5),
  34. transforms.ToTensor()])
  35. self.test_transform = transforms.Compose([transforms.Resize([32, 32]),
  36. transforms.ToTensor()])
  37. logger.info("Pass train/test data paths here")
  38. self.classes_list = {}
  39. file_names = []
  40. print (self.directory, "gt.csv")
  41. with open(os.path.join(self.directory, "gt.csv"), 'r') as csvfile:
  42. spamreader = csv.reader(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL)
  43. import ast
  44. for row in spamreader:
  45. file_names.append(row[0])
  46. self.data.append(os.path.join(self.directory, row[0]))
  47. test = row[1].replace("array", "")
  48. self.labels.append((ast.literal_eval(test)))
  49. self.labels = np.array(self.labels)
  50. self.labels = np.reshape(self.labels, (-1, 8))
  51. logger.debug("Ground Truth Shape: %s", str(self.labels.shape))
  52. logger.debug("Data shape %s", str(len(self.data)))
  53. self.myData = [self.data, self.labels]
  54. class SmartDocDirectories(Dataset):
  55. '''
  56. Class to include MNIST specific details
  57. '''
  58. def __init__(self, directory="data"):
  59. super().__init__("smartdoc")
  60. self.data = []
  61. self.labels = []
  62. for folder in os.listdir(directory):
  63. if (os.path.isdir(directory + "/" + folder)):
  64. for file in os.listdir(directory + "/" + folder):
  65. images_dir = directory + "/" + folder + "/" + file
  66. if (os.path.isdir(images_dir)):
  67. list_gt = []
  68. tree = ET.parse(images_dir + "/" + file + ".gt")
  69. root = tree.getroot()
  70. for a in root.iter("frame"):
  71. list_gt.append(a)
  72. im_no = 0
  73. for image in os.listdir(images_dir):
  74. if image.endswith(".jpg"):
  75. # print(im_no)
  76. im_no += 1
  77. # Now we have opened the file and GT. Write code to create multiple files and scale gt
  78. list_of_points = {}
  79. # img = cv2.imread(images_dir + "/" + image)
  80. self.data.append(os.path.join(images_dir, image))
  81. for point in list_gt[int(float(image[0:-4])) - 1].iter("point"):
  82. myDict = point.attrib
  83. list_of_points[myDict["name"]] = (
  84. int(float(myDict['x'])), int(float(myDict['y'])))
  85. ground_truth = np.asarray(
  86. (list_of_points["tl"], list_of_points["tr"], list_of_points["br"],
  87. list_of_points["bl"]))
  88. ground_truth = utils.sort_gt(ground_truth)
  89. self.labels.append(ground_truth)
  90. self.labels = np.array(self.labels)
  91. self.labels = np.reshape(self.labels, (-1, 8))
  92. logger.debug("Ground Truth Shape: %s", str(self.labels.shape))
  93. logger.debug("Data shape %s", str(len(self.data)))
  94. self.myData = []
  95. for a in range(len(self.data)):
  96. self.myData.append([self.data[a], self.labels[a]])
  97. class SelfCollectedDataset(Dataset):
  98. '''
  99. Class to include MNIST specific details
  100. '''
  101. def __init__(self, directory="data"):
  102. super().__init__("smartdoc")
  103. self.data = []
  104. self.labels = []
  105. for image in os.listdir(directory):
  106. # print (image)
  107. if image.endswith("jpg") or image.endswith("JPG"):
  108. if os.path.isfile(os.path.join(directory, image + ".csv")):
  109. with open(os.path.join(directory, image + ".csv"), 'r') as csvfile:
  110. spamwriter = csv.reader(csvfile, delimiter=' ',
  111. quotechar='|', quoting=csv.QUOTE_MINIMAL)
  112. img_path = os.path.join(directory, image)
  113. gt = []
  114. for row in spamwriter:
  115. gt.append(row)
  116. gt = np.array(gt).astype(np.float32)
  117. ground_truth = utils.sort_gt(gt)
  118. self.labels.append(ground_truth)
  119. self.data.append(img_path)
  120. self.labels = np.array(self.labels)
  121. self.labels = np.reshape(self.labels, (-1, 8))
  122. logger.debug("Ground Truth Shape: %s", str(self.labels.shape))
  123. logger.debug("Data shape %s", str(len(self.data)))
  124. self.myData = []
  125. for a in range(len(self.data)):
  126. self.myData.append([self.data[a], self.labels[a]])
  127. class SmartDocCorner(Dataset):
  128. '''
  129. Class to include MNIST specific details
  130. '''
  131. def __init__(self, directory="data"):
  132. super().__init__("smartdoc")
  133. self.data = []
  134. self.labels = []
  135. for d in directory:
  136. self.directory = d
  137. self.train_transform = transforms.Compose([transforms.Resize([32, 32]),
  138. transforms.ColorJitter(0.5, 0.5, 0.5, 0.5),
  139. transforms.ToTensor()])
  140. self.test_transform = transforms.Compose([transforms.Resize([32, 32]),
  141. transforms.ToTensor()])
  142. logger.info("Pass train/test data paths here")
  143. self.classes_list = {}
  144. file_names = []
  145. with open(os.path.join(self.directory, "gt.csv"), 'r') as csvfile:
  146. spamreader = csv.reader(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL)
  147. import ast
  148. for row in spamreader:
  149. file_names.append(row[0])
  150. self.data.append(os.path.join(self.directory, row[0]))
  151. test = row[1].replace("array", "")
  152. self.labels.append((ast.literal_eval(test)))
  153. self.labels = np.array(self.labels)
  154. self.labels = np.reshape(self.labels, (-1, 2))
  155. logger.debug("Ground Truth Shape: %s", str(self.labels.shape))
  156. logger.debug("Data shape %s", str(len(self.data)))
  157. self.myData = [self.data, self.labels]