dataset.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from torch.utils.data import Dataset
  2. from PIL import Image
  3. from torchvision import transforms
  4. from typing import List, Tuple
  5. import imgaug.augmenters as iaa
  6. import numpy as np
  7. from sklearn.model_selection import train_test_split
  8. class UnNormalize(object):
  9. def __init__(self, mean, std):
  10. self.mean = mean
  11. self.std = std
  12. def __call__(self, tensor):
  13. """
  14. Args:
  15. tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
  16. Returns:
  17. Tensor: Normalized image.
  18. """
  19. for t, m, s in zip(tensor, self.mean, self.std):
  20. t.mul_(s).add_(m)
  21. # The normalize code -> t.sub_(m).div_(s)
  22. return tensor
  23. class DocCleanDataset(Dataset):
  24. @staticmethod
  25. def prepareDataset(dataset:str, shuffle=True):
  26. # imgs_dir = "dataset/raw_data/imgs_Trainblocks"
  27. with open(f"{dataset}/train_block_names.txt") as train_block_names_file:
  28. image_names = train_block_names_file.read().splitlines()
  29. train_img_names, eval_img_names, _, _ = train_test_split(
  30. image_names, image_names, test_size=0.2, random_state=1, shuffle=shuffle)
  31. return train_img_names, eval_img_names, dataset
  32. def __init__(self, img_names: List[str], imgs_dir: str, normalized_tuple: Tuple[List[float], List[float]] = None, dev=False, img_aug=False):
  33. if dev:
  34. num = int(len(img_names) * 0.01)
  35. img_names = img_names[0:num]
  36. self.img_names = img_names
  37. self.imgs_dir = imgs_dir
  38. if normalized_tuple:
  39. mean, std = normalized_tuple
  40. self.normalized = transforms.Compose([
  41. transforms.ToTensor(),
  42. transforms.Normalize(mean=mean, std=std)
  43. ])
  44. self.aug_seq = iaa.Sometimes(0.7, iaa.OneOf([
  45. iaa.SaltAndPepper(p=(0.0, 0.05)),
  46. iaa.imgcorruptlike.MotionBlur(severity=2),
  47. iaa.SigmoidContrast(gain=(3, 10), cutoff=(0.4, 0.6)),
  48. iaa.imgcorruptlike.JpegCompression(severity=2),
  49. iaa.GammaContrast((0.5, 2.0)),
  50. iaa.LogContrast(gain=(0.5, 0.9)),
  51. iaa.GaussianBlur(sigma=(0, 1)),
  52. iaa.imgcorruptlike.SpeckleNoise(severity=1),
  53. iaa.AdditiveGaussianNoise(scale=(0.03*255, 0.2*255), per_channel=True),
  54. iaa.Add((-20, 20), per_channel=0.5),
  55. iaa.AddToBrightness((-30, 30))
  56. ]))
  57. self.img_aug = img_aug
  58. self.toTensor = transforms.ToTensor()
  59. def __len__(self):
  60. return len(self.img_names)
  61. def __getitem__(self, index):
  62. img = Image.open(f"{self.imgs_dir}/{self.img_names[index]}")
  63. gt = Image.open(f"{self.imgs_dir}/gt{self.img_names[index]}")
  64. if hasattr(self, 'normalized'):
  65. img_np = np.array(img)
  66. if self.img_aug == True:
  67. img_np = self.aug_seq.augment_images([np.array(img)])[0]
  68. normalized_img = self.normalized(img_np)
  69. img = self.toTensor(img_np)
  70. else:
  71. img = self.toTensor(img)
  72. normalized_img = img
  73. return img, normalized_img, self.toTensor(gt)