pytorch2onnx.py 330 B

12345678910
  1. import torch
  2. import model
  3. myModel = model.ModelFactory.get_model('resnet', 'document')
  4. myModel.load_state_dict(torch.load('outputs/doc552023/doc_0505_0/doc_0505document_resnet.pth'))
  5. myModel.eval()
  6. dummy_input = torch.randn(1, 3, 32, 32)
  7. torch.onnx.export(myModel, dummy_input, "document_1.0.0.onnx", do_constant_folding=False)