convert_model_to_tflite.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. #!/usr/bin/env python
  2. import argparse
  3. import os
  4. import shutil
  5. import onnx
  6. import torch
  7. import torch.backends._nnapi.prepare
  8. import torch.utils.bundled_inputs
  9. import torch.utils.mobile_optimizer
  10. from onnx_tf.backend import prepare
  11. import tensorflow as tf
  12. from model import M64ColorNet
  13. from torch.nn.utils import prune
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument('--ckpt_path',
  16. type=str,
  17. help='This is the path where to store the ckpt file',
  18. default="output/model.pt")
  19. def convert_to_tflite(out_dir: str, model: torch.nn.Module):
  20. dummy_input = torch.randn((1, 3, 256, 256))
  21. onnx_path = f"{out_dir}/converted.onnx"
  22. torch.onnx.export(model, dummy_input, onnx_path, verbose=True,
  23. input_names=['input'], output_names=['output'])
  24. tf_path = f"{out_dir}/tf_model"
  25. onnx_model = onnx.load(onnx_path)
  26. # prepare function converts an ONNX model to an internel representation
  27. # of the computational graph called TensorflowRep and returns
  28. # the converted representation.
  29. tf_rep = prepare(onnx_model) # creating TensorflowRep object
  30. # export_graph function obtains the graph proto corresponding to the ONNX
  31. # model associated with the backend representation and serializes
  32. # to a protobuf file.
  33. tf_rep.export_graph(tf_path)
  34. converter = tf.lite.TFLiteConverter.from_saved_model(tf_path)
  35. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  36. tf_lite_model = converter.convert()
  37. tflite_path = f"{out_dir}/doc_clean.tflite"
  38. with open(tflite_path, 'wb') as f:
  39. f.write(tf_lite_model)
  40. def convert_to_tflite_with_tiny(out_dir: str, fileName:str, model: torch.nn.Module):
  41. from tinynn.converter import TFLiteConverter
  42. dummy_input = torch.rand((1, 3, 256, 256))
  43. # output_path = os.path.join(out_dir, 'out', 'mbv1_224.tflite')
  44. tflite_path = f"{out_dir}/{fileName}"
  45. # When converting quantized models, please ensure the quantization backend is set.
  46. # torch.backends.quantized.engine = 'qnnpack'
  47. # The code section below is used to convert the model to the TFLite format
  48. # If you want perform dynamic quantization on the float models,
  49. # you may pass the following arguments.
  50. # `quantize_target_type='int8', hybrid_quantization_from_float=True, hybrid_per_channel=False`
  51. # As for static quantization (e.g. quantization-aware training and post-training quantization),
  52. # please refer to the code examples in the `examples/quantization` folder.
  53. converter = TFLiteConverter(model, dummy_input, tflite_path)
  54. converter.convert()
  55. if __name__ == "__main__":
  56. out_dir = "output_tflite"
  57. shutil.rmtree(out_dir, ignore_errors=True)
  58. os.mkdir(out_dir)
  59. args = parser.parse_args()
  60. model, _, _, _, ssim, psnr = M64ColorNet.load_trained_model(args.ckpt_path)
  61. name = os.path.basename(args.ckpt_path).split(".")[0]
  62. fileName = f"ssim_{round(ssim, 2)}_psnr_{round(psnr, 2)}_{name}.tflite"
  63. # convert_to_tflite(out_dir, model)
  64. convert_to_tflite_with_tiny(out_dir, fileName, model)