123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- #!/usr/bin/env python
- import argparse
- import os
- import shutil
- import onnx
- import torch
- import torch.backends._nnapi.prepare
- import torch.utils.bundled_inputs
- import torch.utils.mobile_optimizer
- from onnx_tf.backend import prepare
- import tensorflow as tf
- from model import M64ColorNet
- from torch.nn.utils import prune
- parser = argparse.ArgumentParser()
- parser.add_argument('--ckpt_path',
- type=str,
- help='This is the path where to store the ckpt file',
- default="output/model.pt")
- def convert_to_tflite(out_dir: str, model: torch.nn.Module):
- dummy_input = torch.randn((1, 3, 256, 256))
- onnx_path = f"{out_dir}/converted.onnx"
- torch.onnx.export(model, dummy_input, onnx_path, verbose=True,
- input_names=['input'], output_names=['output'])
- tf_path = f"{out_dir}/tf_model"
- onnx_model = onnx.load(onnx_path)
- # prepare function converts an ONNX model to an internel representation
- # of the computational graph called TensorflowRep and returns
- # the converted representation.
- tf_rep = prepare(onnx_model) # creating TensorflowRep object
- # export_graph function obtains the graph proto corresponding to the ONNX
- # model associated with the backend representation and serializes
- # to a protobuf file.
- tf_rep.export_graph(tf_path)
- converter = tf.lite.TFLiteConverter.from_saved_model(tf_path)
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- tf_lite_model = converter.convert()
- tflite_path = f"{out_dir}/doc_clean.tflite"
- with open(tflite_path, 'wb') as f:
- f.write(tf_lite_model)
- def convert_to_tflite_with_tiny(out_dir: str, fileName:str, model: torch.nn.Module):
- from tinynn.converter import TFLiteConverter
- dummy_input = torch.rand((1, 3, 256, 256))
- # output_path = os.path.join(out_dir, 'out', 'mbv1_224.tflite')
- tflite_path = f"{out_dir}/{fileName}"
- # When converting quantized models, please ensure the quantization backend is set.
- # torch.backends.quantized.engine = 'qnnpack'
- # The code section below is used to convert the model to the TFLite format
- # If you want perform dynamic quantization on the float models,
- # you may pass the following arguments.
- # `quantize_target_type='int8', hybrid_quantization_from_float=True, hybrid_per_channel=False`
- # As for static quantization (e.g. quantization-aware training and post-training quantization),
- # please refer to the code examples in the `examples/quantization` folder.
- converter = TFLiteConverter(model, dummy_input, tflite_path)
- converter.convert()
- if __name__ == "__main__":
- out_dir = "output_tflite"
- shutil.rmtree(out_dir, ignore_errors=True)
- os.mkdir(out_dir)
- args = parser.parse_args()
- model, _, _, _, ssim, psnr = M64ColorNet.load_trained_model(args.ckpt_path)
- name = os.path.basename(args.ckpt_path).split(".")[0]
- fileName = f"ssim_{round(ssim, 2)}_psnr_{round(psnr, 2)}_{name}.tflite"
- # convert_to_tflite(out_dir, model)
- convert_to_tflite_with_tiny(out_dir, fileName, model)
|