doc_clear_infer.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import glob
  2. from skimage import io
  3. from skimage.filters.rank import mean_bilateral
  4. from skimage import morphology
  5. import os
  6. import shutil
  7. from generate_dataset import GetOverlappingBlocks, CombineToImage
  8. import paddle.inference as paddle_infer
  9. import argparse
  10. import numpy as np
  11. import cv2
  12. from PIL import Image
  13. from paddle.vision.transforms import functional as F
  14. mean=[0.4805, 0.4633, 0.4085]
  15. std=[0.2987, 0.2875, 0.2526]
  16. def padCropImg(img):
  17. H = img.shape[0]
  18. W = img.shape[1]
  19. patchRes = 128
  20. pH = patchRes
  21. pW = patchRes
  22. ovlp = int(patchRes * 0.125)
  23. padH = (int((H - patchRes)/(patchRes - ovlp) + 1) * (patchRes - ovlp) + patchRes) - H
  24. padW = (int((W - patchRes)/(patchRes - ovlp) + 1) * (patchRes - ovlp) + patchRes) - W
  25. padImg = cv2.copyMakeBorder(img, 0, padH, 0, padW, cv2.BORDER_REPLICATE)
  26. ynum = int((padImg.shape[0] - pH)/(pH - ovlp)) + 1
  27. xnum = int((padImg.shape[1] - pW)/(pW - ovlp)) + 1
  28. totalPatch = np.zeros((ynum, xnum, patchRes, patchRes, 3), dtype=np.uint8)
  29. for j in range(0, ynum):
  30. for i in range(0, xnum):
  31. x = int(i * (pW - ovlp))
  32. y = int(j * (pH - ovlp))
  33. totalPatch[j, i] = padImg[y:int(y + patchRes), x:int(x + patchRes)]
  34. return totalPatch
  35. def preProcess(img):
  36. img[:,:,0] = mean_bilateral(img[:,:,0], morphology.disk(20), s0=10, s1=10)
  37. img[:,:,1] = mean_bilateral(img[:,:,1], morphology.disk(20), s0=10, s1=10)
  38. img[:,:,2] = mean_bilateral(img[:,:,2], morphology.disk(20), s0=10, s1=10)
  39. return img
  40. def main():
  41. args = parse_args()
  42. # 创建 config
  43. config = paddle_infer.Config(args.model_file, args.params_file)
  44. # 根据 config 创建 predictor
  45. predictor = paddle_infer.create_predictor(config)
  46. # 获取输入的名称
  47. input_names = predictor.get_input_names()
  48. input_handle = predictor.get_input_handle(input_names[0])
  49. # 设置输入
  50. block_size=(256,256)
  51. fake_input = cv2.imread(args.image_file, 1)
  52. rgb_img = cv2.cvtColor(fake_input, cv2.COLOR_BGR2RGB)
  53. part = 8
  54. M = block_size[0]
  55. N = block_size[1]
  56. patches = GetOverlappingBlocks(rgb_img.copy(),M,N,part)
  57. preds = []
  58. for idx, patch in enumerate(patches):
  59. patch = patch / 255.
  60. normalized_image = F.normalize(patch, mean, std, data_format='HWC')
  61. normalized_image = normalized_image.transpose(2,0,1)
  62. normalized_image = normalized_image.astype("float32")
  63. real_input = np.expand_dims(normalized_image, 0)
  64. input_handle.reshape([args.batch_size, 3, 256, 256])
  65. input_handle.copy_from_cpu(real_input)
  66. # 运行predictor
  67. predictor.run()
  68. # 获取输出
  69. output_names = predictor.get_output_names()
  70. output_handle = predictor.get_output_handle(output_names[0])
  71. output_data = output_handle.copy_to_cpu() # numpy.ndarray类型
  72. output = np.squeeze(output_data, 0)
  73. output = output.transpose(1,2,0)
  74. preds.append(output)
  75. h, w, c = fake_input.shape
  76. image = CombineToImage(preds, h, w, c)
  77. c_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR, part)
  78. cv2.imwrite(args.output_dir + 'res.jpg', c_image)
  79. def parse_args():
  80. parser = argparse.ArgumentParser()
  81. parser.add_argument("--model_file", type=str, help="model filename")
  82. parser.add_argument("--params_file", type=str, help="parameter filename")
  83. parser.add_argument("--image_file", type=str, help="image filename")
  84. parser.add_argument("--output_dir", type=str, help="output dir path")
  85. parser.add_argument("--batch_size", type=int, default=1, help="batch size")
  86. return parser.parse_args()
  87. if __name__ == "__main__":
  88. main()