from dataset import DocCleanDataset from model import M64ColorNet import torch import argparse import glob from torchvision import transforms from PIL import Image from skimage import io from skimage.filters.rank import mean_bilateral from skimage import morphology import cv2 import numpy as np import torchvision.transforms as T import os import shutil from generate_dataset import GetOverlappingBlocks, CombineToImage parser = argparse.ArgumentParser() parser.add_argument('--ckpt_path', type=str, help='This is the path where to store the ckpt file', default="output/model.pt") parser.add_argument('--img_dir', type=str, help='This is a folder where to store images to infer', default="infer_imgs") def padCropImg(img): H = img.shape[0] W = img.shape[1] patchRes = 128 pH = patchRes pW = patchRes ovlp = int(patchRes * 0.125) padH = (int((H - patchRes)/(patchRes - ovlp) + 1) * (patchRes - ovlp) + patchRes) - H padW = (int((W - patchRes)/(patchRes - ovlp) + 1) * (patchRes - ovlp) + patchRes) - W padImg = cv2.copyMakeBorder(img, 0, padH, 0, padW, cv2.BORDER_REPLICATE) ynum = int((padImg.shape[0] - pH)/(pH - ovlp)) + 1 xnum = int((padImg.shape[1] - pW)/(pW - ovlp)) + 1 totalPatch = np.zeros((ynum, xnum, patchRes, patchRes, 3), dtype=np.uint8) for j in range(0, ynum): for i in range(0, xnum): x = int(i * (pW - ovlp)) y = int(j * (pH - ovlp)) totalPatch[j, i] = padImg[y:int(y + patchRes), x:int(x + patchRes)] return totalPatch def preProcess(img): img[:,:,0] = mean_bilateral(img[:,:,0], morphology.disk(20), s0=10, s1=10) img[:,:,1] = mean_bilateral(img[:,:,1], morphology.disk(20), s0=10, s1=10) img[:,:,2] = mean_bilateral(img[:,:,2], morphology.disk(20), s0=10, s1=10) return img def infer(model, img_path, output, transform, device, block_size=(256,256)): out_name = os.path.basename(img_path).split(".")[0] in_clr = cv2.imread(img_path,1) start_time = cv2.getTickCount() M = block_size[0] N = block_size[1] rgb_img = cv2.cvtColor(in_clr, cv2.COLOR_BGR2RGB) part = 8 patches = GetOverlappingBlocks(rgb_img.copy(),M,N,part) preds = [] with torch.no_grad(): for idx, patch in enumerate(patches): input = transform(patch).to(device) pred = model(input.unsqueeze(0)) pred = pred.cpu().detach().numpy()[0].transpose(1,2,0) # cv2.imwrite(f"{output}/{out_name}_{idx}.png", cv2.cvtColor(patch, cv2.COLOR_BGR2RGB)) # cv2.imwrite(f"{output}/{out_name}_{idx}_pred.png", cv2.cvtColor(pred * 255, cv2.COLOR_BGR2RGB)) preds.append(pred) # print(f"pred idx={idx}") # h, w, c = preds[0].shape h, w, c = in_clr.shape image = CombineToImage(preds, h, w, c) c_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR,part) print(f"image_name:{os.path.basename(img_path)} doc clean time:{(cv2.getTickCount() - start_time)/ cv2.getTickFrequency()}") cv2.imwrite(f"{output}/{out_name}.png", c_image) def infer_test(output:str, img_dir:str, ckpt_path:str, model_cls): shutil.rmtree(output, ignore_errors=True) os.makedirs(output) model = model_cls() ckpt_dict = torch.load(ckpt_path, map_location=torch.device('cpu')) mean = ckpt_dict["mean"] std = ckpt_dict["std"] transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) model.load_state_dict(ckpt_dict["model_state_dict"]) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model.to(device) model.eval() # img_paths = glob.glob(f"{img_dir}/*.jpg") + glob.glob(f"{img_dir}/*.png") +glob.glob(f"{img_dir}/*.JPG") img_paths = glob.glob(f"{img_dir}/002.JPG") + glob.glob(f"{img_dir}/21654792158_.pic.jpg") for img_path in img_paths: infer(model, img_path, output, transform, device) if __name__ == "__main__": args = parser.parse_args() ckpt_name = os.path.basename(args.ckpt_path).split(".")[0] output = f"{args.img_dir}/output_{ckpt_name}" infer_test(output, args.img_dir, args.ckpt_path, M64ColorNet)