|
@@ -0,0 +1,19 @@
|
|
|
|
+import torch
|
|
|
|
+import numpy as np
|
|
|
|
+
|
|
|
|
+# 构建输入
|
|
|
|
+import model
|
|
|
|
+
|
|
|
|
+input_data = np.random.rand(1, 3, 32, 32).astype("float32")
|
|
|
|
+# 获取PyTorch Module
|
|
|
|
+
|
|
|
|
+myModel = model.ModelFactory.get_model('resnet', 'corner')
|
|
|
|
+myModel.load_state_dict(torch.load('outputs/corner552023/corner_0505_1/corner_0505corner_resnet.pth'))
|
|
|
|
+# 设置为eval模式
|
|
|
|
+myModel.eval()
|
|
|
|
+# 进行转换
|
|
|
|
+from x2paddle.convert import pytorch2paddle
|
|
|
|
+pytorch2paddle(myModel,
|
|
|
|
+ save_dir="pd_model_trace1",
|
|
|
|
+ jit_type="trace",
|
|
|
|
+ input_examples=[torch.tensor(input_data)])
|