PyTorch ONNX导出与跨平台部署
请解释PyTorch模型转ONNX的流程。torch.onnx.export的关键参数(export_params/dynamic_axes/opset_version)是什么?如何调试ONNX导出错误?导出的ONNX如何用ONNX Runtime/TensorRT部署?
回答
小字辈
PyTorch转ONNX流程:
import torch.onnx
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
export_params=True, # 导出模型参数
opset_version=17, # ONNX算子集版本
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size"}, # 动态batch
"output": {0: "batch_size"}
}
)
关键参数:
- export_params:是否包含权重参数
- dynamic_axes:指定动态维度(如batch/seq_len)
- opset_version:越高支持越多算子,但兼容性可能下降
调试ONNX:
- 简化模型:onnx-simplifier
- 验证模型:onnxruntime python API
- 调试:逐个算子检查
部署:
- ONNX Runtime (CPU/GPU):跨平台推荐
- TensorRT (NVIDIA GPU):最高性能
- OpenVINO (Intel CPU/GPU)
import onnxruntime as ort
session = ort.InferenceSession("model.onnx")
result = session.run(None, {"input": np_input})