pytorch2paddle.py 561 B

12345678910111213141516171819
  1. import torch
  2. import numpy as np
  3. # 构建输入
  4. import model
  5. input_data = np.random.rand(1, 3, 32, 32).astype("float32")
  6. # 获取PyTorch Module
  7. myModel = model.ModelFactory.get_model('resnet', 'corner')
  8. myModel.load_state_dict(torch.load('outputs/corner552023/corner_0505_1/corner_0505corner_resnet.pth'))
  9. # 设置为eval模式
  10. myModel.eval()
  11. # 进行转换
  12. from x2paddle.convert import pytorch2paddle
  13. pytorch2paddle(myModel,
  14. save_dir="pd_model_trace1",
  15. jit_type="trace",
  16. input_examples=[torch.tensor(input_data)])