在模型訓練完畢后,我們通常希望將其部署到推理平臺中,比如 TensorRT、ONNX Runtime 或移動端框架。而 ONNX(Open Neural Network Exchange)正是 PyTorch 與這些平臺之間的橋梁。
本文將以一個圖像去噪模型 SimpleDenoiser
為例,手把手帶你完成 PyTorch 模型導出為 ONNX 格式的全過程,并解析每一行代碼背后的邏輯。
準備工作
我們假設你已經訓練好一個圖像去噪模型并保存為 .pth
文件,模型結構自編碼器實現如下(略):
class SimpleDenoiser(nn.Module):def __init__(self):super(SimpleDenoiser, self).__init__()self.encoder = nn.Sequential(nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(),nn.Conv2d(64, 64, 3, padding=1), nn.ReLU())self.decoder = nn.Sequential(nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(),nn.Conv2d(64, 3, 3, padding=1))def forward(self, x):x = self.encoder(x)x = self.decoder(x)return x
導出代碼分解
我們現在來看導出腳本的核心邏輯,并分塊解釋它的每一部分。
1. 導入模塊 & 設置路徑
//torch:核心框架//train.SimpleDenoiser:從訓練腳本復用模型結構//os:用于創建輸出目錄import torch
from train import SimpleDenoiser # 模型結構
import os
2. 導出函數定義
//這個函數接收三個參數://pth_path: 訓練得到的模型參數文件路徑//onnx_path: 導出的 ONNX 文件保存路徑//input_size: 模擬推理輸入的尺寸(默認 1×3×256×256)
def export_model_to_onnx(pth_path, onnx_path, input_size=(1, 3, 256, 256)):
3. 加載模型和權重
//自動檢測 CUDA 可用性,加載模型到對應設備;//使用 load_state_dict() 加載訓練好的參數;//model.eval() 讓模型切換到推理模式(關閉 Dropout/BatchNorm 更新);
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = SimpleDenoiser().to(device)
model.load_state_dict(torch.load(pth_path, map_location=device))
model.eval()
4. 構造假輸入(Dummy Input)
//ONNX 導出需要一個具體的輸入樣本,我們這里用 torch.randn 生成一個形狀為 (1, 3, 256, 256) 的隨機圖//像;//輸入必須放在同一個設備上(GPU 或 CPU);
dummy_input = torch.randn(*input_size).to(device)
5. 導出為 ONNX
torch.onnx.export(model, //要導出的模型dummy_input, //示例輸入張量onnx_path, // 導出路徑export_params=True, //是否導出權重opset_version=11, //ONNX 的算子集版本,通常推薦 11 或 13do_constant_folding=True, //優化常量表達式,減小模型體積input_names=['input'], //自定義輸入輸出張量的名稱output_names=['output'], //聲明哪些維度可以變動,比如 batch size、圖像大小等(部署時更靈活)dynamic_axes={'input': {0: 'batch_size', 2: 'height', 3: 'width'},'output': {0: 'batch_size', 2: 'height', 3: 'width'}}
)
6. 創建目錄并調用函數
//確保輸出文件夾存在,并調用導出函數生成最終模型。
if __name__ == "__main__":os.makedirs("onnx", exist_ok=True)export_model_to_onnx("weights/denoiser.pth", "onnx/denoiser.onnx")
導出后如何驗證?
pip install onnxruntime
import onnxruntime
import numpy as npsess = onnxruntime.InferenceSession("onnx/denoiser.onnx")
input = np.random.randn(1, 3, 256, 256).astype(np.float32)
output = sess.run(None, {"input": input})
print("輸出 shape:", output[0].shape)
?模型預覽:

總結
導出 ONNX 模型的流程主要包括:
-
加載模型結構 + 權重
-
準備 dummy 輸入張量
-
調用
torch.onnx.export()
進行導出 -
設置
dynamic_axes
可變尺寸以增強部署適配性
這套流程適用于大部分視覺模型(分類、去噪、分割等),也是后續進行 TensorRT 推理或移動端部署的基礎。