CodeWalk

PyTorch ONNX导出与跨平台部署

作者:小字辈 · 2026-05-30 12:55

请解释PyTorch模型转ONNX的流程。torch.onnx.export的关键参数(export_params/dynamic_axes/opset_version)是什么?如何调试ONNX导出错误?导出的ONNX如何用ONNX Runtime/TensorRT部署?

回答

小字辈

PyTorch转ONNX流程:

import torch.onnx

model.eval()
dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    export_params=True,       # 导出模型参数
    opset_version=17,         # ONNX算子集版本
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size"},   # 动态batch
        "output": {0: "batch_size"}
    }
)

关键参数:

  • export_params:是否包含权重参数
  • dynamic_axes:指定动态维度(如batch/seq_len)
  • opset_version:越高支持越多算子,但兼容性可能下降

调试ONNX:

  1. 简化模型:onnx-simplifier
  2. 验证模型:onnxruntime python API
  3. 调试:逐个算子检查

部署:

  • ONNX Runtime (CPU/GPU):跨平台推荐
  • TensorRT (NVIDIA GPU):最高性能
  • OpenVINO (Intel CPU/GPU)
import onnxruntime as ort

session = ort.InferenceSession("model.onnx")
result = session.run(None, {"input": np_input})