ソースを参照

add model_convert

kangtan 1 年間 前
コミット
fd869fef9e
3 ファイル変更29 行追加0 行削除
  1. 0 0
      README.txt
  2. 10 0
      model_convert/pytorch2onnx.py
  3. 19 0
      model_convert/pytorch2paddle.py

+ 0 - 0
README.txt


+ 10 - 0
model_convert/pytorch2onnx.py

@@ -0,0 +1,10 @@
+import torch
+import model
+
+
+myModel = model.ModelFactory.get_model('resnet', 'document')
+myModel.load_state_dict(torch.load('outputs/doc552023/doc_0505_0/doc_0505document_resnet.pth'))
+myModel.eval()
+
+dummy_input = torch.randn(1, 3, 32, 32)
+torch.onnx.export(myModel, dummy_input, "document_1.0.0.onnx", do_constant_folding=False)

+ 19 - 0
model_convert/pytorch2paddle.py

@@ -0,0 +1,19 @@
+import torch
+import numpy as np
+
+# 构建输入
+import model
+
+input_data = np.random.rand(1, 3, 32, 32).astype("float32")
+# 获取PyTorch Module
+
+myModel = model.ModelFactory.get_model('resnet', 'corner')
+myModel.load_state_dict(torch.load('outputs/corner552023/corner_0505_1/corner_0505corner_resnet.pth'))
+# 设置为eval模式
+myModel.eval()
+# 进行转换
+from x2paddle.convert import pytorch2paddle
+pytorch2paddle(myModel,
+               save_dir="pd_model_trace1",
+               jit_type="trace",
+               input_examples=[torch.tensor(input_data)])