123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- from dataset import DocCleanDataset
- from model import M64ColorNet
- import torch
- import argparse
- import glob
- from torchvision import transforms
- from PIL import Image
- from skimage import io
- from skimage.filters.rank import mean_bilateral
- from skimage import morphology
- import cv2
- import numpy as np
- import torchvision.transforms as T
- import os
- import shutil
- from generate_dataset import GetOverlappingBlocks, CombineToImage
- parser = argparse.ArgumentParser()
- parser.add_argument('--ckpt_path',
- type=str,
- help='This is the path where to store the ckpt file',
- default="output/model.pt")
- parser.add_argument('--img_dir',
- type=str,
- help='This is a folder where to store images to infer',
- default="infer_imgs")
- def padCropImg(img):
-
- H = img.shape[0]
- W = img.shape[1]
- patchRes = 128
- pH = patchRes
- pW = patchRes
- ovlp = int(patchRes * 0.125)
- padH = (int((H - patchRes)/(patchRes - ovlp) + 1) * (patchRes - ovlp) + patchRes) - H
- padW = (int((W - patchRes)/(patchRes - ovlp) + 1) * (patchRes - ovlp) + patchRes) - W
- padImg = cv2.copyMakeBorder(img, 0, padH, 0, padW, cv2.BORDER_REPLICATE)
- ynum = int((padImg.shape[0] - pH)/(pH - ovlp)) + 1
- xnum = int((padImg.shape[1] - pW)/(pW - ovlp)) + 1
- totalPatch = np.zeros((ynum, xnum, patchRes, patchRes, 3), dtype=np.uint8)
- for j in range(0, ynum):
- for i in range(0, xnum):
- x = int(i * (pW - ovlp))
- y = int(j * (pH - ovlp))
- totalPatch[j, i] = padImg[y:int(y + patchRes), x:int(x + patchRes)]
- return totalPatch
- def preProcess(img):
- img[:,:,0] = mean_bilateral(img[:,:,0], morphology.disk(20), s0=10, s1=10)
- img[:,:,1] = mean_bilateral(img[:,:,1], morphology.disk(20), s0=10, s1=10)
- img[:,:,2] = mean_bilateral(img[:,:,2], morphology.disk(20), s0=10, s1=10)
-
- return img
- def infer(model, img_path, output, transform, device, block_size=(256,256)):
- out_name = os.path.basename(img_path).split(".")[0]
- in_clr = cv2.imread(img_path,1)
- start_time = cv2.getTickCount()
- M = block_size[0]
- N = block_size[1]
- rgb_img = cv2.cvtColor(in_clr, cv2.COLOR_BGR2RGB)
- part = 8
- patches = GetOverlappingBlocks(rgb_img.copy(),M,N,part)
- preds = []
- with torch.no_grad():
- for idx, patch in enumerate(patches):
- input = transform(patch).to(device)
- pred = model(input.unsqueeze(0))
- pred = pred.cpu().detach().numpy()[0].transpose(1,2,0)
- # cv2.imwrite(f"{output}/{out_name}_{idx}.png", cv2.cvtColor(patch, cv2.COLOR_BGR2RGB))
- # cv2.imwrite(f"{output}/{out_name}_{idx}_pred.png", cv2.cvtColor(pred * 255, cv2.COLOR_BGR2RGB))
- preds.append(pred)
- # print(f"pred idx={idx}")
- # h, w, c = preds[0].shape
- h, w, c = in_clr.shape
- image = CombineToImage(preds, h, w, c)
- c_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR,part)
- print(f"image_name:{os.path.basename(img_path)} doc clean time:{(cv2.getTickCount() - start_time)/ cv2.getTickFrequency()}")
- cv2.imwrite(f"{output}/{out_name}.png", c_image)
- def infer_test(output:str, img_dir:str, ckpt_path:str, model_cls):
- shutil.rmtree(output, ignore_errors=True)
- os.makedirs(output)
- model = model_cls()
- ckpt_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))
- mean = ckpt_dict["mean"]
- std = ckpt_dict["std"]
- transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
- model.load_state_dict(ckpt_dict["model_state_dict"])
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
- model.to(device)
- model.eval()
- # img_paths = glob.glob(f"{img_dir}/*.jpg") + glob.glob(f"{img_dir}/*.png") +glob.glob(f"{img_dir}/*.JPG")
- img_paths = glob.glob(f"{img_dir}/002.JPG") + glob.glob(f"{img_dir}/21654792158_.pic.jpg")
- for img_path in img_paths:
- infer(model, img_path, output, transform, device)
- if __name__ == "__main__":
- args = parser.parse_args()
- ckpt_name = os.path.basename(args.ckpt_path).split(".")[0]
- output = f"{args.img_dir}/output_{ckpt_name}"
- infer_test(output, args.img_dir, args.ckpt_path, M64ColorNet)
-
|