dataset.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. """ Document Localization using Recursive CNN
  2. Maintainer : Khurram Javed
  3. Email : kjaved@ualberta.ca """
  4. import imgaug.augmenters as iaa
  5. import csv
  6. import logging
  7. import os
  8. import xml.etree.ElementTree as ET
  9. import numpy as np
  10. from torchvision import transforms
  11. import utils.utils as utils
  12. # To incdude a new Dataset, inherit from Dataset and add all the Dataset specific parameters here.
  13. # Goal : Remove any data specific parameters from the rest of the code
  14. logger = logging.getLogger("iCARL")
  15. class Dataset:
  16. """
  17. Base class to reprenent a Dataset
  18. """
  19. def __init__(self, name):
  20. self.name = name
  21. self.data = []
  22. self.labels = []
  23. def getTransformsByImgaug():
  24. return iaa.Sequential(
  25. [
  26. iaa.Resize(32),
  27. # Add blur
  28. iaa.Sometimes(
  29. 0.05,
  30. iaa.OneOf(
  31. [
  32. iaa.GaussianBlur(
  33. (0, 3.0)
  34. ), # blur images with a sigma between 0 and 3.0
  35. iaa.AverageBlur(
  36. k=(2, 11)
  37. ), # blur image using local means with kernel sizes between 2 and 7
  38. iaa.MedianBlur(
  39. k=(3, 11)
  40. ), # blur image using local medians with kernel sizes between 2 and 7
  41. iaa.MotionBlur(k=15, angle=[-45, 45]),
  42. ]
  43. ),
  44. ),
  45. # Add color
  46. iaa.Sometimes(
  47. 0.05,
  48. iaa.OneOf(
  49. [
  50. iaa.WithHueAndSaturation(iaa.WithChannels(0, iaa.Add((0, 50)))),
  51. iaa.AddToBrightness((-30, 30)),
  52. iaa.MultiplyBrightness((0.5, 1.5)),
  53. iaa.AddToHueAndSaturation((-50, 50), per_channel=True),
  54. iaa.Grayscale(alpha=(0.0, 1.0)),
  55. iaa.ChangeColorTemperature((1100, 10000)),
  56. iaa.KMeansColorQuantization(),
  57. ]
  58. ),
  59. ),
  60. # Add wether
  61. iaa.Sometimes(
  62. 0.05,
  63. iaa.OneOf(
  64. [
  65. iaa.Clouds(),
  66. iaa.Fog(),
  67. iaa.Snowflakes(flake_size=(0.1, 0.4), speed=(0.01, 0.05)),
  68. iaa.Rain(speed=(0.1, 0.3)),
  69. ]
  70. ),
  71. ),
  72. # Add contrast
  73. iaa.Sometimes(
  74. 0.05,
  75. iaa.OneOf(
  76. [
  77. iaa.GammaContrast((0.5, 2.0)),
  78. iaa.GammaContrast((0.5, 2.0), per_channel=True),
  79. iaa.SigmoidContrast(gain=(3, 10), cutoff=(0.4, 0.6)),
  80. iaa.SigmoidContrast(
  81. gain=(3, 10), cutoff=(0.4, 0.6), per_channel=True
  82. ),
  83. iaa.LogContrast(gain=(0.6, 1.4)),
  84. iaa.LogContrast(gain=(0.6, 1.4), per_channel=True),
  85. iaa.LinearContrast((0.4, 1.6)),
  86. iaa.LinearContrast((0.4, 1.6), per_channel=True),
  87. iaa.AllChannelsCLAHE(),
  88. iaa.AllChannelsCLAHE(clip_limit=(1, 10)),
  89. iaa.AllChannelsCLAHE(clip_limit=(1, 10), per_channel=True),
  90. iaa.Alpha((0.0, 1.0), iaa.HistogramEqualization()),
  91. iaa.Alpha((0.0, 1.0), iaa.AllChannelsHistogramEqualization()),
  92. iaa.AllChannelsHistogramEqualization(),
  93. ]
  94. ),
  95. )
  96. ]
  97. ).augment_image
  98. class SmartDoc(Dataset):
  99. """
  100. Class to include MNIST specific details
  101. """
  102. def __init__(self, directory="data"):
  103. super().__init__("smartdoc")
  104. self.data = []
  105. self.labels = []
  106. for d in directory:
  107. self.directory = d
  108. self.train_transform = transforms.Compose(
  109. [
  110. getTransformsByImgaug(),
  111. # transforms.Resize([32, 32]),
  112. # transforms.ColorJitter(1.5, 1.5, 0.9, 0.5),
  113. transforms.ToTensor(),
  114. ]
  115. )
  116. self.test_transform = transforms.Compose(
  117. [
  118. iaa.Sequential(
  119. [
  120. iaa.Resize(32),
  121. ]
  122. ).augment_image,
  123. transforms.ToTensor(),
  124. ]
  125. )
  126. logger.info("Pass train/test data paths here")
  127. self.classes_list = {}
  128. file_names = []
  129. print(self.directory, "gt.csv")
  130. with open(os.path.join(self.directory, "gt.csv"), "r") as csvfile:
  131. spamreader = csv.reader(
  132. csvfile, delimiter=",", quotechar="|", quoting=csv.QUOTE_MINIMAL
  133. )
  134. import ast
  135. for row in spamreader:
  136. file_names.append(row[0])
  137. self.data.append(os.path.join(self.directory, row[0]))
  138. test = row[1].replace("array", "")
  139. self.labels.append((ast.literal_eval(test)))
  140. self.labels = np.array(self.labels)
  141. self.labels = np.reshape(self.labels, (-1, 8))
  142. logger.debug("Ground Truth Shape: %s", str(self.labels.shape))
  143. logger.debug("Data shape %s", str(len(self.data)))
  144. self.myData = [self.data, self.labels]
  145. class SmartDocDirectories(Dataset):
  146. """
  147. Class to include MNIST specific details
  148. """
  149. def __init__(self, directory="data"):
  150. super().__init__("smartdoc")
  151. self.data = []
  152. self.labels = []
  153. for folder in os.listdir(directory):
  154. if os.path.isdir(directory + "/" + folder):
  155. for file in os.listdir(directory + "/" + folder):
  156. images_dir = directory + "/" + folder + "/" + file
  157. if os.path.isdir(images_dir):
  158. list_gt = []
  159. tree = ET.parse(images_dir + "/" + file + ".gt")
  160. root = tree.getroot()
  161. for a in root.iter("frame"):
  162. list_gt.append(a)
  163. im_no = 0
  164. for image in os.listdir(images_dir):
  165. if image.endswith(".jpg"):
  166. # print(im_no)
  167. im_no += 1
  168. # Now we have opened the file and GT. Write code to create multiple files and scale gt
  169. list_of_points = {}
  170. # img = cv2.imread(images_dir + "/" + image)
  171. self.data.append(os.path.join(images_dir, image))
  172. for point in list_gt[int(float(image[0:-4])) - 1].iter(
  173. "point"
  174. ):
  175. myDict = point.attrib
  176. list_of_points[myDict["name"]] = (
  177. int(float(myDict["x"])),
  178. int(float(myDict["y"])),
  179. )
  180. ground_truth = np.asarray(
  181. (
  182. list_of_points["tl"],
  183. list_of_points["tr"],
  184. list_of_points["br"],
  185. list_of_points["bl"],
  186. )
  187. )
  188. ground_truth = utils.sort_gt(ground_truth)
  189. self.labels.append(ground_truth)
  190. self.labels = np.array(self.labels)
  191. self.labels = np.reshape(self.labels, (-1, 8))
  192. logger.debug("Ground Truth Shape: %s", str(self.labels.shape))
  193. logger.debug("Data shape %s", str(len(self.data)))
  194. self.myData = []
  195. for a in range(len(self.data)):
  196. self.myData.append([self.data[a], self.labels[a]])
  197. class SelfCollectedDataset(Dataset):
  198. """
  199. Class to include MNIST specific details
  200. """
  201. def __init__(self, directory="data"):
  202. super().__init__("smartdoc")
  203. self.data = []
  204. self.labels = []
  205. for image in os.listdir(directory):
  206. # print (image)
  207. if image.endswith("jpg") or image.endswith("JPG"):
  208. if os.path.isfile(os.path.join(directory, image + ".csv")):
  209. with open(os.path.join(directory, image + ".csv"), "r") as csvfile:
  210. spamwriter = csv.reader(
  211. csvfile,
  212. delimiter=" ",
  213. quotechar="|",
  214. quoting=csv.QUOTE_MINIMAL,
  215. )
  216. img_path = os.path.join(directory, image)
  217. gt = []
  218. for row in spamwriter:
  219. gt.append(row)
  220. gt = np.array(gt).astype(np.float32)
  221. ground_truth = utils.sort_gt(gt)
  222. self.labels.append(ground_truth)
  223. self.data.append(img_path)
  224. self.labels = np.array(self.labels)
  225. self.labels = np.reshape(self.labels, (-1, 8))
  226. logger.debug("Ground Truth Shape: %s", str(self.labels.shape))
  227. logger.debug("Data shape %s", str(len(self.data)))
  228. self.myData = []
  229. for a in range(len(self.data)):
  230. self.myData.append([self.data[a], self.labels[a]])
  231. class SmartDocCorner(Dataset):
  232. """
  233. Class to include MNIST specific details
  234. """
  235. def __init__(self, directory="data"):
  236. super().__init__("smartdoc")
  237. self.data = []
  238. self.labels = []
  239. for d in directory:
  240. self.directory = d
  241. self.train_transform = transforms.Compose(
  242. [
  243. getTransformsByImgaug(),
  244. transforms.ToTensor(),
  245. ]
  246. )
  247. self.test_transform = transforms.Compose(
  248. [
  249. iaa.Sequential(
  250. [
  251. iaa.Resize(32),
  252. ]
  253. ).augment_image,
  254. transforms.ToTensor(),
  255. ]
  256. )
  257. logger.info("Pass train/test data paths here")
  258. self.classes_list = {}
  259. file_names = []
  260. with open(os.path.join(self.directory, "gt.csv"), "r") as csvfile:
  261. spamreader = csv.reader(
  262. csvfile, delimiter=",", quotechar="|", quoting=csv.QUOTE_MINIMAL
  263. )
  264. import ast
  265. for row in spamreader:
  266. file_names.append(row[0])
  267. self.data.append(os.path.join(self.directory, row[0]))
  268. test = row[1].replace("array", "")
  269. self.labels.append((ast.literal_eval(test)))
  270. self.labels = np.array(self.labels)
  271. self.labels = np.reshape(self.labels, (-1, 2))
  272. logger.debug("Ground Truth Shape: %s", str(self.labels.shape))
  273. logger.debug("Data shape %s", str(len(self.data)))
  274. self.myData = [self.data, self.labels]