#!/usr/bin/env python import argparse import os import shutil import onnx import torch import torch.backends._nnapi.prepare import torch.utils.bundled_inputs import torch.utils.mobile_optimizer from onnx_tf.backend import prepare import tensorflow as tf from model import M64ColorNet from torch.nn.utils import prune parser = argparse.ArgumentParser() parser.add_argument('--ckpt_path', type=str, help='This is the path where to store the ckpt file', default="output/model.pt") def convert_to_tflite(out_dir: str, model: torch.nn.Module): dummy_input = torch.randn((1, 3, 256, 256)) onnx_path = f"{out_dir}/converted.onnx" torch.onnx.export(model, dummy_input, onnx_path, verbose=True, input_names=['input'], output_names=['output']) tf_path = f"{out_dir}/tf_model" onnx_model = onnx.load(onnx_path) # prepare function converts an ONNX model to an internel representation # of the computational graph called TensorflowRep and returns # the converted representation. tf_rep = prepare(onnx_model) # creating TensorflowRep object # export_graph function obtains the graph proto corresponding to the ONNX # model associated with the backend representation and serializes # to a protobuf file. tf_rep.export_graph(tf_path) converter = tf.lite.TFLiteConverter.from_saved_model(tf_path) converter.optimizations = [tf.lite.Optimize.DEFAULT] tf_lite_model = converter.convert() tflite_path = f"{out_dir}/doc_clean.tflite" with open(tflite_path, 'wb') as f: f.write(tf_lite_model) def convert_to_tflite_with_tiny(out_dir: str, fileName:str, model: torch.nn.Module): from tinynn.converter import TFLiteConverter dummy_input = torch.rand((1, 3, 256, 256)) # output_path = os.path.join(out_dir, 'out', 'mbv1_224.tflite') tflite_path = f"{out_dir}/{fileName}" # When converting quantized models, please ensure the quantization backend is set. # torch.backends.quantized.engine = 'qnnpack' # The code section below is used to convert the model to the TFLite format # If you want perform dynamic quantization on the float models, # you may pass the following arguments. # `quantize_target_type='int8', hybrid_quantization_from_float=True, hybrid_per_channel=False` # As for static quantization (e.g. quantization-aware training and post-training quantization), # please refer to the code examples in the `examples/quantization` folder. converter = TFLiteConverter(model, dummy_input, tflite_path) converter.convert() if __name__ == "__main__": out_dir = "output_tflite" shutil.rmtree(out_dir, ignore_errors=True) os.mkdir(out_dir) args = parser.parse_args() model, _, _, _, ssim, psnr = M64ColorNet.load_trained_model(args.ckpt_path) name = os.path.basename(args.ckpt_path).split(".")[0] fileName = f"ssim_{round(ssim, 2)}_psnr_{round(psnr, 2)}_{name}.tflite" # convert_to_tflite(out_dir, model) convert_to_tflite_with_tiny(out_dir, fileName, model)