dataloaders.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. ''' Document Localization using Recursive CNN
  2. Maintainer : Khurram Javed
  3. Email : kjaved@ualberta.ca '''
  4. import logging
  5. import PIL
  6. import torch.utils.data as td
  7. import tqdm
  8. from PIL import Image
  9. import cv2
  10. logger = logging.getLogger('iCARL')
  11. class HddLoader(td.Dataset):
  12. def __init__(self, data, transform=None, cuda=False):
  13. self.data = data
  14. self.transform = transform
  15. self.cuda = cuda
  16. self.len = len(data[0])
  17. def __len__(self):
  18. return self.len
  19. def __getitem__(self, index):
  20. '''
  21. Replacing this with a more efficient implemnetation selection; removing c
  22. :param index:
  23. :return:
  24. '''
  25. assert (index < len(self.data[0]))
  26. assert (index < self.len)
  27. img = Image.open(self.data[0][index])
  28. target = self.data[1][index]
  29. if self.transform is not None:
  30. img = self.transform(img)
  31. return img, target
  32. class RamLoader(td.Dataset):
  33. def __init__(self, data, transform=None, cuda=False):
  34. self.data = data
  35. self.transform = transform
  36. self.cuda = cuda
  37. self.len = len(data[0])
  38. self.loadInRam()
  39. def loadInRam(self):
  40. self.loaded_data = []
  41. logger.info("Loading data in RAM")
  42. for i in tqdm.tqdm(self.data[0]):
  43. # img = Image.open(i)
  44. img = cv2.imread(i)
  45. if self.transform is not None:
  46. img = self.transform(img)
  47. self.loaded_data.append(img)
  48. def __len__(self):
  49. return self.len
  50. def __getitem__(self, index):
  51. '''
  52. Replacing this with a more efficient implemnetation selection; removing c
  53. :param index:
  54. :return:
  55. '''
  56. assert (index < len(self.data[0]))
  57. assert (index < self.len)
  58. target = self.data[1][index]
  59. img = self.loaded_data[index]
  60. return img, target
  61. class SingleFolderLoaderResized(td.Dataset):
  62. '''
  63. This loader class decodes all the images into tensors; this removes the decoding time.
  64. '''
  65. def __init__(self, data, transform=None, cuda=False):
  66. self.data = data
  67. self.transform = transform
  68. self.cuda = cuda
  69. self.len = len(data)
  70. self.decodeImages()
  71. def decodeImages(self):
  72. self.loaded_data = []
  73. logger.info("Resizing Images")
  74. for i in tqdm.tqdm(self.data):
  75. i = i[0]
  76. img = Image.open(i)
  77. img = img.resize((32, 32), PIL.Image.ANTIALIAS)
  78. img.save(i)
  79. def __len__(self):
  80. return self.len
  81. def __getitem__(self, index):
  82. '''
  83. Replacing this with a more efficient implemnetation selection; removing c
  84. :param index:
  85. :return:
  86. '''
  87. assert (index < len(self.data))
  88. assert (index < self.len)
  89. img = Image.open(self.data[index][0])
  90. target = self.data[index][1]
  91. if self.transform is not None:
  92. img = self.transform(img)
  93. return img, target