onnx_custom.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import os
  16. import onnx
  17. import onnx_graphsurgeon
  18. import numpy as np
  19. from collections import OrderedDict
  20. from paddle2onnx.command import program2onnx
  21. parser = argparse.ArgumentParser(description=__doc__)
  22. parser.add_argument(
  23. '--onnx_file', required=True, type=str, help='onnx model path')
  24. parser.add_argument(
  25. '--model_dir',
  26. type=str,
  27. default=None,
  28. help=("Directory include:'model.pdiparams', 'model.pdmodel', "
  29. "'infer_cfg.yml', created by tools/export_model.py."))
  30. parser.add_argument(
  31. "--opset_version",
  32. type=int,
  33. default=11,
  34. help="set onnx opset version to export")
  35. parser.add_argument(
  36. '--topk_all', type=int, default=300, help='topk objects for every images')
  37. parser.add_argument(
  38. '--iou_thres', type=float, default=0.7, help='iou threshold for NMS')
  39. parser.add_argument(
  40. '--conf_thres', type=float, default=0.01, help='conf threshold for NMS')
  41. def main(FLAGS):
  42. assert os.path.exists(FLAGS.onnx_file)
  43. onnx_model = onnx.load(FLAGS.onnx_file)
  44. graph = onnx_graphsurgeon.import_onnx(onnx_model)
  45. graph.toposort()
  46. graph.fold_constants()
  47. graph.cleanup()
  48. num_anchors = graph.outputs[1].shape[2]
  49. num_classes = graph.outputs[1].shape[1]
  50. scores = onnx_graphsurgeon.Variable(
  51. name='scores', shape=[-1, num_anchors, num_classes], dtype=np.float32)
  52. graph.layer(
  53. op='Transpose',
  54. name='lastTranspose',
  55. inputs=[graph.outputs[1]],
  56. outputs=[scores],
  57. attrs=OrderedDict(perm=[0, 2, 1]))
  58. attrs = OrderedDict(
  59. plugin_version="1",
  60. background_class=-1,
  61. max_output_boxes=FLAGS.topk_all,
  62. score_threshold=FLAGS.conf_thres,
  63. iou_threshold=FLAGS.iou_thres,
  64. score_activation=False,
  65. box_coding=0, )
  66. outputs = [
  67. onnx_graphsurgeon.Variable("num_dets", np.int32, [-1, 1]),
  68. onnx_graphsurgeon.Variable("det_boxes", np.float32,
  69. [-1, FLAGS.topk_all, 4]),
  70. onnx_graphsurgeon.Variable("det_scores", np.float32,
  71. [-1, FLAGS.topk_all]),
  72. onnx_graphsurgeon.Variable("det_classes", np.int32,
  73. [-1, FLAGS.topk_all])
  74. ]
  75. graph.layer(
  76. op='EfficientNMS_TRT',
  77. name="batched_nms",
  78. inputs=[graph.outputs[0], scores],
  79. outputs=outputs,
  80. attrs=attrs)
  81. graph.outputs = outputs
  82. graph.cleanup().toposort()
  83. onnx.save(onnx_graphsurgeon.export_onnx(graph), FLAGS.onnx_file)
  84. print(f"The modified onnx model is saved in {FLAGS.onnx_file}")
  85. if __name__ == '__main__':
  86. FLAGS = parser.parse_args()
  87. if FLAGS.model_dir is not None:
  88. assert os.path.exists(FLAGS.model_dir)
  89. program2onnx(
  90. model_dir=FLAGS.model_dir,
  91. save_file=FLAGS.onnx_file,
  92. model_filename="model.pdmodel",
  93. params_filename="model.pdiparams",
  94. opset_version=FLAGS.opset_version,
  95. enable_onnx_checker=True)
  96. main(FLAGS)