mobile_model_converter.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import argparse
  2. import shutil
  3. from pathlib import Path
  4. import os
  5. import torch
  6. from tinynn.converter import TFLiteConverter
  7. import model
  8. parser = argparse.ArgumentParser()
  9. parser.add_argument("-cm", "--cornerModel", help="Model for corner point refinement",
  10. default="../cornerModelWell")
  11. parser.add_argument("-dm", "--documentModel", help="Model for document corners detection",
  12. default="../documentModelWell")
  13. def load_doc_model(checkpoint_dir, dataset):
  14. _model = model.ModelFactory.get_model("resnet", dataset)
  15. _model.load_state_dict(torch.load(checkpoint_dir, map_location="cpu"))
  16. return _model
  17. if __name__ == "__main__":
  18. args = parser.parse_args()
  19. models = [
  20. {
  21. "name": "corner_model",
  22. "model": load_doc_model(
  23. args.cornerModel,
  24. "corner",
  25. ),
  26. },
  27. {
  28. "name": "doc_model",
  29. "model": load_doc_model(
  30. args.documentModel,
  31. "document",
  32. ),
  33. },
  34. ]
  35. out_dir = "output_tflite"
  36. shutil.rmtree(out_dir, ignore_errors=True)
  37. os.mkdir(out_dir)
  38. for item in models:
  39. _model = item["model"]
  40. _model.eval()
  41. dummy_input = torch.rand((1, 3, 32, 32))
  42. modelPath = f'{out_dir}/{item["name"]}.tflite'
  43. converter = TFLiteConverter(_model, dummy_input, modelPath)
  44. converter.convert()
  45. # scripted = torch.jit.script(_model)
  46. # optimized_model = optimize_for_mobile(scripted, backend='metal')
  47. # print(torch.jit.export_opnames(optimized_model))
  48. # optimized_model._save_for_lite_interpreter(f'{output}/{item["name"]}_metal.ptl')
  49. # scripted_model = torch.jit.script(_model)
  50. # optimized_model = optimize_for_mobile(scripted_model, backend='metal')
  51. # print(torch.jit.export_opnames(optimized_model))
  52. # optimized_model._save_for_lite_interpreter(f'{output}/{item["name"]}_metal.pt')
  53. # torch.save(_model, f'{output}/{item["name"]}.pth')