123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- import glob
- from skimage import io
- from skimage.filters.rank import mean_bilateral
- from skimage import morphology
- import os
- import shutil
- from generate_dataset import GetOverlappingBlocks, CombineToImage
- import paddle.inference as paddle_infer
- import argparse
- import numpy as np
- import cv2
- from PIL import Image
- from paddle.vision.transforms import functional as F
- mean=[0.4805, 0.4633, 0.4085]
- std=[0.2987, 0.2875, 0.2526]
- def padCropImg(img):
-
- H = img.shape[0]
- W = img.shape[1]
- patchRes = 128
- pH = patchRes
- pW = patchRes
- ovlp = int(patchRes * 0.125)
- padH = (int((H - patchRes)/(patchRes - ovlp) + 1) * (patchRes - ovlp) + patchRes) - H
- padW = (int((W - patchRes)/(patchRes - ovlp) + 1) * (patchRes - ovlp) + patchRes) - W
- padImg = cv2.copyMakeBorder(img, 0, padH, 0, padW, cv2.BORDER_REPLICATE)
- ynum = int((padImg.shape[0] - pH)/(pH - ovlp)) + 1
- xnum = int((padImg.shape[1] - pW)/(pW - ovlp)) + 1
- totalPatch = np.zeros((ynum, xnum, patchRes, patchRes, 3), dtype=np.uint8)
- for j in range(0, ynum):
- for i in range(0, xnum):
- x = int(i * (pW - ovlp))
- y = int(j * (pH - ovlp))
- totalPatch[j, i] = padImg[y:int(y + patchRes), x:int(x + patchRes)]
- return totalPatch
- def preProcess(img):
- img[:,:,0] = mean_bilateral(img[:,:,0], morphology.disk(20), s0=10, s1=10)
- img[:,:,1] = mean_bilateral(img[:,:,1], morphology.disk(20), s0=10, s1=10)
- img[:,:,2] = mean_bilateral(img[:,:,2], morphology.disk(20), s0=10, s1=10)
- return img
- def main():
- args = parse_args()
- # 创建 config
- config = paddle_infer.Config(args.model_file, args.params_file)
- # 根据 config 创建 predictor
- predictor = paddle_infer.create_predictor(config)
- # 获取输入的名称
- input_names = predictor.get_input_names()
- input_handle = predictor.get_input_handle(input_names[0])
- # 设置输入
- block_size=(256,256)
- fake_input = cv2.imread(args.image_file, 1)
- rgb_img = cv2.cvtColor(fake_input, cv2.COLOR_BGR2RGB)
- part = 8
- M = block_size[0]
- N = block_size[1]
- patches = GetOverlappingBlocks(rgb_img.copy(),M,N,part)
- preds = []
- for idx, patch in enumerate(patches):
- patch = patch / 255.
- normalized_image = F.normalize(patch, mean, std, data_format='HWC')
- normalized_image = normalized_image.transpose(2,0,1)
- normalized_image = normalized_image.astype("float32")
- real_input = np.expand_dims(normalized_image, 0)
- input_handle.reshape([args.batch_size, 3, 256, 256])
- input_handle.copy_from_cpu(real_input)
- # 运行predictor
- predictor.run()
- # 获取输出
- output_names = predictor.get_output_names()
- output_handle = predictor.get_output_handle(output_names[0])
- output_data = output_handle.copy_to_cpu() # numpy.ndarray类型
- output = np.squeeze(output_data, 0)
- output = output.transpose(1,2,0)
- preds.append(output)
- h, w, c = fake_input.shape
- image = CombineToImage(preds, h, w, c)
- c_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR, part)
- cv2.imwrite(args.output_dir + 'res.jpg', c_image)
- def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument("--model_file", type=str, help="model filename")
- parser.add_argument("--params_file", type=str, help="parameter filename")
- parser.add_argument("--image_file", type=str, help="image filename")
- parser.add_argument("--output_dir", type=str, help="output dir path")
- parser.add_argument("--batch_size", type=int, default=1, help="batch size")
- return parser.parse_args()
- if __name__ == "__main__":
- main()
|