end2end.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import argparse
  2. import onnx
  3. import onnx_graphsurgeon as gs
  4. import numpy as np
  5. from pathlib import Path
  6. from paddle2onnx.legacy.command import program2onnx
  7. from collections import OrderedDict
  8. def main(opt):
  9. model_dir = Path(opt.model_dir)
  10. save_file = Path(opt.save_file)
  11. assert model_dir.exists() and model_dir.is_dir()
  12. if save_file.is_dir():
  13. save_file = (save_file / model_dir.stem).with_suffix('.onnx')
  14. elif save_file.is_file() and save_file.suffix != '.onnx':
  15. save_file = save_file.with_suffix('.onnx')
  16. input_shape_dict = {'image': [opt.batch_size, 3, *opt.img_size],
  17. 'scale_factor': [opt.batch_size, 2]}
  18. program2onnx(str(model_dir), str(save_file),
  19. 'model.pdmodel', 'model.pdiparams',
  20. opt.opset, input_shape_dict=input_shape_dict)
  21. onnx_model = onnx.load(save_file)
  22. try:
  23. import onnxsim
  24. onnx_model, check = onnxsim.simplify(onnx_model)
  25. assert check, 'assert check failed'
  26. except Exception as e:
  27. print(f'Simplifier failure: {e}')
  28. onnx.checker.check_model(onnx_model)
  29. graph = gs.import_onnx(onnx_model)
  30. graph.fold_constants()
  31. graph.cleanup().toposort()
  32. mul = concat = None
  33. for node in graph.nodes:
  34. if node.op == 'Div' and node.i(0).op == 'Mul':
  35. mul = node.i(0)
  36. if node.op == 'Concat' and node.o().op == 'Reshape' and node.o().o().op == 'ReduceSum':
  37. concat = node
  38. assert mul.outputs[0].shape[1] == concat.outputs[0].shape[2], 'Something wrong in outputs shape'
  39. anchors = mul.outputs[0].shape[1]
  40. classes = concat.outputs[0].shape[1]
  41. scores = gs.Variable(name='scores', shape=[opt.batch_size, anchors, classes], dtype=np.float32)
  42. graph.layer(op='Transpose', name='lastTranspose',
  43. inputs=[concat.outputs[0]],
  44. outputs=[scores],
  45. attrs=OrderedDict(perm=[0, 2, 1]))
  46. graph.inputs = [graph.inputs[0]]
  47. attrs = OrderedDict(
  48. plugin_version="1",
  49. background_class=-1,
  50. max_output_boxes=opt.topk_all,
  51. score_threshold=opt.conf_thres,
  52. iou_threshold=opt.iou_thres,
  53. score_activation=False,
  54. box_coding=0, )
  55. outputs = [gs.Variable("num_dets", np.int32, [opt.batch_size, 1]),
  56. gs.Variable("det_boxes", np.float32, [opt.batch_size, opt.topk_all, 4]),
  57. gs.Variable("det_scores", np.float32, [opt.batch_size, opt.topk_all]),
  58. gs.Variable("det_classes", np.int32, [opt.batch_size, opt.topk_all])]
  59. graph.layer(op='EfficientNMS_TRT', name="batched_nms",
  60. inputs=[mul.outputs[0], scores],
  61. outputs=outputs,
  62. attrs=attrs)
  63. graph.outputs = outputs
  64. graph.cleanup().toposort()
  65. onnx.save(gs.export_onnx(graph), save_file)
  66. def parse_opt():
  67. parser = argparse.ArgumentParser()
  68. parser.add_argument('--model-dir', type=str,
  69. default=None,
  70. help='paddle static model')
  71. parser.add_argument('--save-file', type=str,
  72. default=None,
  73. help='onnx model save path')
  74. parser.add_argument('--opset', type=int, default=11, help='opset version')
  75. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size')
  76. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  77. parser.add_argument('--topk-all', type=int, default=100, help='topk objects for every images')
  78. parser.add_argument('--iou-thres', type=float, default=0.45, help='iou threshold for NMS')
  79. parser.add_argument('--conf-thres', type=float, default=0.25, help='conf threshold for NMS')
  80. opt = parser.parse_args()
  81. opt.img_size *= 2 if len(opt.img_size) == 1 else 1
  82. return opt
  83. if __name__ == '__main__':
  84. opt = parse_opt()
  85. main(opt)