2 Commits ed26257706 ... fb6c1e8b9b

Author SHA1 Message Date
  kangtan fb6c1e8b9b Merge branch 'tools/da_kit' of http://git.kdan.cc:8865/Others/DocumentAIKit into tools/da_kit 1 year ago
  kangtan fd869fef9e add model_convert 1 year ago
2 changed files with 29 additions and 0 deletions
  1. 10 0
      model_convert/pytorch2onnx.py
  2. 19 0
      model_convert/pytorch2paddle.py

+ 10 - 0
model_convert/pytorch2onnx.py

@@ -0,0 +1,10 @@
+import torch
+import model
+
+
+myModel = model.ModelFactory.get_model('resnet', 'document')
+myModel.load_state_dict(torch.load('outputs/doc552023/doc_0505_0/doc_0505document_resnet.pth'))
+myModel.eval()
+
+dummy_input = torch.randn(1, 3, 32, 32)
+torch.onnx.export(myModel, dummy_input, "document_1.0.0.onnx", do_constant_folding=False)

+ 19 - 0
model_convert/pytorch2paddle.py

@@ -0,0 +1,19 @@
+import torch
+import numpy as np
+
+# 构建输入
+import model
+
+input_data = np.random.rand(1, 3, 32, 32).astype("float32")
+# 获取PyTorch Module
+
+myModel = model.ModelFactory.get_model('resnet', 'corner')
+myModel.load_state_dict(torch.load('outputs/corner552023/corner_0505_1/corner_0505corner_resnet.pth'))
+# 设置为eval模式
+myModel.eval()
+# 进行转换
+from x2paddle.convert import pytorch2paddle
+pytorch2paddle(myModel,
+               save_dir="pd_model_trace1",
+               jit_type="trace",
+               input_examples=[torch.tensor(input_data)])