12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- import argparse
- import shutil
- from pathlib import Path
- import os
- import torch
- from tinynn.converter import TFLiteConverter
- import model
- parser = argparse.ArgumentParser()
- parser.add_argument("-cm", "--cornerModel", help="Model for corner point refinement",
- default="../cornerModelWell")
- parser.add_argument("-dm", "--documentModel", help="Model for document corners detection",
- default="../documentModelWell")
- def load_doc_model(checkpoint_dir, dataset):
- _model = model.ModelFactory.get_model("resnet", dataset)
- _model.load_state_dict(torch.load(checkpoint_dir, map_location="cpu"))
- return _model
- if __name__ == "__main__":
- args = parser.parse_args()
- models = [
- {
- "name": "corner_model",
- "model": load_doc_model(
- args.cornerModel,
- "corner",
- ),
- },
- {
- "name": "doc_model",
- "model": load_doc_model(
- args.documentModel,
- "document",
- ),
- },
- ]
- out_dir = "output_tflite"
- shutil.rmtree(out_dir, ignore_errors=True)
- os.mkdir(out_dir)
- for item in models:
- _model = item["model"]
- _model.eval()
- dummy_input = torch.rand((1, 3, 32, 32))
- modelPath = f'{out_dir}/{item["name"]}.tflite'
- converter = TFLiteConverter(_model, dummy_input, modelPath)
- converter.convert()
- # scripted = torch.jit.script(_model)
- # optimized_model = optimize_for_mobile(scripted, backend='metal')
- # print(torch.jit.export_opnames(optimized_model))
- # optimized_model._save_for_lite_interpreter(f'{output}/{item["name"]}_metal.ptl')
-
- # scripted_model = torch.jit.script(_model)
- # optimized_model = optimize_for_mobile(scripted_model, backend='metal')
- # print(torch.jit.export_opnames(optimized_model))
- # optimized_model._save_for_lite_interpreter(f'{output}/{item["name"]}_metal.pt')
- # torch.save(_model, f'{output}/{item["name"]}.pth')
|