import glob from skimage import io from skimage.filters.rank import mean_bilateral from skimage import morphology import os import shutil from generate_dataset import GetOverlappingBlocks, CombineToImage import paddle.inference as paddle_infer import argparse import numpy as np import cv2 from PIL import Image from paddle.vision.transforms import functional as F mean=[0.4805, 0.4633, 0.4085] std=[0.2987, 0.2875, 0.2526] def padCropImg(img): H = img.shape[0] W = img.shape[1] patchRes = 128 pH = patchRes pW = patchRes ovlp = int(patchRes * 0.125) padH = (int((H - patchRes)/(patchRes - ovlp) + 1) * (patchRes - ovlp) + patchRes) - H padW = (int((W - patchRes)/(patchRes - ovlp) + 1) * (patchRes - ovlp) + patchRes) - W padImg = cv2.copyMakeBorder(img, 0, padH, 0, padW, cv2.BORDER_REPLICATE) ynum = int((padImg.shape[0] - pH)/(pH - ovlp)) + 1 xnum = int((padImg.shape[1] - pW)/(pW - ovlp)) + 1 totalPatch = np.zeros((ynum, xnum, patchRes, patchRes, 3), dtype=np.uint8) for j in range(0, ynum): for i in range(0, xnum): x = int(i * (pW - ovlp)) y = int(j * (pH - ovlp)) totalPatch[j, i] = padImg[y:int(y + patchRes), x:int(x + patchRes)] return totalPatch def preProcess(img): img[:,:,0] = mean_bilateral(img[:,:,0], morphology.disk(20), s0=10, s1=10) img[:,:,1] = mean_bilateral(img[:,:,1], morphology.disk(20), s0=10, s1=10) img[:,:,2] = mean_bilateral(img[:,:,2], morphology.disk(20), s0=10, s1=10) return img def main(): args = parse_args() # 创建 config config = paddle_infer.Config(args.model_file, args.params_file) # 根据 config 创建 predictor predictor = paddle_infer.create_predictor(config) # 获取输入的名称 input_names = predictor.get_input_names() input_handle = predictor.get_input_handle(input_names[0]) # 设置输入 block_size=(256,256) fake_input = cv2.imread(args.image_file, 1) rgb_img = cv2.cvtColor(fake_input, cv2.COLOR_BGR2RGB) part = 8 M = block_size[0] N = block_size[1] patches = GetOverlappingBlocks(rgb_img.copy(),M,N,part) preds = [] for idx, patch in enumerate(patches): patch = patch / 255. normalized_image = F.normalize(patch, mean, std, data_format='HWC') normalized_image = normalized_image.transpose(2,0,1) normalized_image = normalized_image.astype("float32") real_input = np.expand_dims(normalized_image, 0) input_handle.reshape([args.batch_size, 3, 256, 256]) input_handle.copy_from_cpu(real_input) # 运行predictor predictor.run() # 获取输出 output_names = predictor.get_output_names() output_handle = predictor.get_output_handle(output_names[0]) output_data = output_handle.copy_to_cpu() # numpy.ndarray类型 output = np.squeeze(output_data, 0) output = output.transpose(1,2,0) preds.append(output) h, w, c = fake_input.shape image = CombineToImage(preds, h, w, c) c_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR, part) cv2.imwrite(args.output_dir + 'res.jpg', c_image) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--model_file", type=str, help="model filename") parser.add_argument("--params_file", type=str, help="parameter filename") parser.add_argument("--image_file", type=str, help="image filename") parser.add_argument("--output_dir", type=str, help="output dir path") parser.add_argument("--batch_size", type=int, default=1, help="batch size") return parser.parse_args() if __name__ == "__main__": main()