dataloaders.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. ''' Document Localization using Recursive CNN
  2. Maintainer : Khurram Javed
  3. Email : kjaved@ualberta.ca '''
  4. import logging
  5. import PIL
  6. import torch.utils.data as td
  7. import tqdm
  8. from PIL import Image
  9. logger = logging.getLogger('iCARL')
  10. class HddLoader(td.Dataset):
  11. def __init__(self, data, transform=None, cuda=False):
  12. self.data = data
  13. self.transform = transform
  14. self.cuda = cuda
  15. self.len = len(data[0])
  16. def __len__(self):
  17. return self.len
  18. def __getitem__(self, index):
  19. '''
  20. Replacing this with a more efficient implemnetation selection; removing c
  21. :param index:
  22. :return:
  23. '''
  24. assert (index < len(self.data[0]))
  25. assert (index < self.len)
  26. img = Image.open(self.data[0][index])
  27. target = self.data[1][index]
  28. if self.transform is not None:
  29. img = self.transform(img)
  30. return img, target
  31. class RamLoader(td.Dataset):
  32. def __init__(self, data, transform=None, cuda=False):
  33. self.data = data
  34. self.transform = transform
  35. self.cuda = cuda
  36. self.len = len(data[0])
  37. self.loadInRam()
  38. def loadInRam(self):
  39. self.loaded_data = []
  40. logger.info("Loading data in RAM")
  41. for i in tqdm.tqdm(self.data[0]):
  42. img = Image.open(i)
  43. if self.transform is not None:
  44. img = self.transform(img)
  45. self.loaded_data.append(img)
  46. def __len__(self):
  47. return self.len
  48. def __getitem__(self, index):
  49. '''
  50. Replacing this with a more efficient implemnetation selection; removing c
  51. :param index:
  52. :return:
  53. '''
  54. assert (index < len(self.data[0]))
  55. assert (index < self.len)
  56. target = self.data[1][index]
  57. img = self.loaded_data[index]
  58. return img, target
  59. class SingleFolderLoaderResized(td.Dataset):
  60. '''
  61. This loader class decodes all the images into tensors; this removes the decoding time.
  62. '''
  63. def __init__(self, data, transform=None, cuda=False):
  64. self.data = data
  65. self.transform = transform
  66. self.cuda = cuda
  67. self.len = len(data)
  68. self.decodeImages()
  69. def decodeImages(self):
  70. self.loaded_data = []
  71. logger.info("Resizing Images")
  72. for i in tqdm.tqdm(self.data):
  73. i = i[0]
  74. img = Image.open(i)
  75. img = img.resize((32, 32), PIL.Image.ANTIALIAS)
  76. img.save(i)
  77. def __len__(self):
  78. return self.len
  79. def __getitem__(self, index):
  80. '''
  81. Replacing this with a more efficient implemnetation selection; removing c
  82. :param index:
  83. :return:
  84. '''
  85. assert (index < len(self.data))
  86. assert (index < self.len)
  87. img = Image.open(self.data[index][0])
  88. target = self.data[index][1]
  89. if self.transform is not None:
  90. img = self.transform(img)
  91. return img, target