infer.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. from dataset import DocCleanDataset
  2. from model import M64ColorNet
  3. import torch
  4. import argparse
  5. import glob
  6. from torchvision import transforms
  7. from PIL import Image
  8. from skimage import io
  9. from skimage.filters.rank import mean_bilateral
  10. from skimage import morphology
  11. import cv2
  12. import numpy as np
  13. import torchvision.transforms as T
  14. import os
  15. import shutil
  16. from generate_dataset import GetOverlappingBlocks, CombineToImage
  17. parser = argparse.ArgumentParser()
  18. parser.add_argument('--ckpt_path',
  19. type=str,
  20. help='This is the path where to store the ckpt file',
  21. default="output/model.pt")
  22. parser.add_argument('--img_dir',
  23. type=str,
  24. help='This is a folder where to store images to infer',
  25. default="infer_imgs")
  26. def padCropImg(img):
  27. H = img.shape[0]
  28. W = img.shape[1]
  29. patchRes = 128
  30. pH = patchRes
  31. pW = patchRes
  32. ovlp = int(patchRes * 0.125)
  33. padH = (int((H - patchRes)/(patchRes - ovlp) + 1) * (patchRes - ovlp) + patchRes) - H
  34. padW = (int((W - patchRes)/(patchRes - ovlp) + 1) * (patchRes - ovlp) + patchRes) - W
  35. padImg = cv2.copyMakeBorder(img, 0, padH, 0, padW, cv2.BORDER_REPLICATE)
  36. ynum = int((padImg.shape[0] - pH)/(pH - ovlp)) + 1
  37. xnum = int((padImg.shape[1] - pW)/(pW - ovlp)) + 1
  38. totalPatch = np.zeros((ynum, xnum, patchRes, patchRes, 3), dtype=np.uint8)
  39. for j in range(0, ynum):
  40. for i in range(0, xnum):
  41. x = int(i * (pW - ovlp))
  42. y = int(j * (pH - ovlp))
  43. totalPatch[j, i] = padImg[y:int(y + patchRes), x:int(x + patchRes)]
  44. return totalPatch
  45. def preProcess(img):
  46. img[:,:,0] = mean_bilateral(img[:,:,0], morphology.disk(20), s0=10, s1=10)
  47. img[:,:,1] = mean_bilateral(img[:,:,1], morphology.disk(20), s0=10, s1=10)
  48. img[:,:,2] = mean_bilateral(img[:,:,2], morphology.disk(20), s0=10, s1=10)
  49. return img
  50. def infer(model, img_path, output, transform, device, block_size=(256,256)):
  51. out_name = os.path.basename(img_path).split(".")[0]
  52. in_clr = cv2.imread(img_path,1)
  53. start_time = cv2.getTickCount()
  54. M = block_size[0]
  55. N = block_size[1]
  56. rgb_img = cv2.cvtColor(in_clr, cv2.COLOR_BGR2RGB)
  57. part = 8
  58. patches = GetOverlappingBlocks(rgb_img.copy(),M,N,part)
  59. preds = []
  60. with torch.no_grad():
  61. for idx, patch in enumerate(patches):
  62. input = transform(patch).to(device)
  63. pred = model(input.unsqueeze(0))
  64. pred = pred.cpu().detach().numpy()[0].transpose(1,2,0)
  65. # cv2.imwrite(f"{output}/{out_name}_{idx}.png", cv2.cvtColor(patch, cv2.COLOR_BGR2RGB))
  66. # cv2.imwrite(f"{output}/{out_name}_{idx}_pred.png", cv2.cvtColor(pred * 255, cv2.COLOR_BGR2RGB))
  67. preds.append(pred)
  68. # print(f"pred idx={idx}")
  69. # h, w, c = preds[0].shape
  70. h, w, c = in_clr.shape
  71. image = CombineToImage(preds, h, w, c)
  72. c_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR,part)
  73. print(f"image_name:{os.path.basename(img_path)} doc clean time:{(cv2.getTickCount() - start_time)/ cv2.getTickFrequency()}")
  74. cv2.imwrite(f"{output}/{out_name}.png", c_image)
  75. def infer_test(output:str, img_dir:str, ckpt_path:str, model_cls):
  76. shutil.rmtree(output, ignore_errors=True)
  77. os.makedirs(output)
  78. model = model_cls()
  79. ckpt_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))
  80. mean = ckpt_dict["mean"]
  81. std = ckpt_dict["std"]
  82. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
  83. model.load_state_dict(ckpt_dict["model_state_dict"])
  84. device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  85. model.to(device)
  86. model.eval()
  87. # img_paths = glob.glob(f"{img_dir}/*.jpg") + glob.glob(f"{img_dir}/*.png") +glob.glob(f"{img_dir}/*.JPG")
  88. img_paths = glob.glob(f"{img_dir}/002.JPG") + glob.glob(f"{img_dir}/21654792158_.pic.jpg")
  89. for img_path in img_paths:
  90. infer(model, img_path, output, transform, device)
  91. if __name__ == "__main__":
  92. args = parser.parse_args()
  93. ckpt_name = os.path.basename(args.ckpt_path).split(".")[0]
  94. output = f"{args.img_dir}/output_{ckpt_name}"
  95. infer_test(output, args.img_dir, args.ckpt_path, M64ColorNet)