12345678910111213141516171819 |
- import torch
- import numpy as np
- # 构建输入
- import model
- input_data = np.random.rand(1, 3, 32, 32).astype("float32")
- # 获取PyTorch Module
- myModel = model.ModelFactory.get_model('resnet', 'corner')
- myModel.load_state_dict(torch.load('outputs/corner552023/corner_0505_1/corner_0505corner_resnet.pth'))
- # 设置为eval模式
- myModel.eval()
- # 进行转换
- from x2paddle.convert import pytorch2paddle
- pytorch2paddle(myModel,
- save_dir="pd_model_trace1",
- jit_type="trace",
- input_examples=[torch.tensor(input_data)])
|