👋 你好!這里有實用干貨與深度分享?? 若有幫助,歡迎:?
👍 點贊 | ? 收藏 | 💬 評論 | ? 關注 ,解鎖更多精彩!?
📁 收藏專欄即可第一時間獲取最新推送🔔。?
📖后續我將持續帶來更多優質內容,期待與你一同探索知識,攜手前行,共同進步🚀。?
?
模型導出
本文介紹如何將深度學習模型導出為不同的部署格式,包括ONNX、TorchScript等,并對各種格式的優缺點和最佳實踐進行總結,幫助你高效完成模型部署準備。
1. 導出格式對比
格式 | 優點 | 缺點 | 適用場景 |
---|---|---|---|
ONNX | - 跨平臺跨框架 - 生態豐富 - 標準統一 - 廣泛支持 | - 可能存在算子兼容問題 - 部分高級特性支持有限 | - 跨平臺部署 - 使用標準推理引擎 - 需要廣泛兼容性 |
TorchScript | - 與PyTorch無縫集成 - 支持動態圖結構 - 調試方便 - 性能優化 | - 僅限PyTorch生態 - 文件體積較大 | - PyTorch生產環境 - 需要動態特性 - 性能要求高 |
TensorRT | - 極致優化性能 - 支持GPU加速 - 低延遲推理 | - 僅支持NVIDIA GPU - 配置復雜 | - 高性能推理場景 - 實時應用 - 邊緣計算 |
TensorFlow SavedModel | - TensorFlow生態完整支持 - 部署便捷 | - 跨框架兼容性差 | - TensorFlow生產環境 |
2. ONNX格式導出
2.1 基本導出
ONNX格式適用于跨平臺部署,支持多種推理引擎(如ONNXRuntime、TensorRT、OpenVINO等)。
import torch
import torch.onnxdef export_to_onnx(model, input_shape, save_path):# 設置模型為評估模式model.eval()# 創建示例輸入dummy_input = torch.randn(input_shape)# 導出模型torch.onnx.export(model, # 要導出的模型dummy_input, # 模型輸入save_path, # 保存路徑export_params=True, # 導出模型參數opset_version=11, # ONNX算子集版本do_constant_folding=True, # 常量折疊優化input_names=['input'], # 輸入名稱output_names=['output'], # 輸出名稱dynamic_axes={'input': {0: 'batch_size'}, # 動態批次大小'output': {0: 'batch_size'}})print(f"Model exported to {save_path}")# 使用示例
model = YourModel()
model.load_state_dict(torch.load('model.pth'))
export_to_onnx(model, (1, 3, 224, 224), 'model.onnx')
2.2 驗證導出模型
導出后必須進行全面驗證,包括結構檢查和數值對比:
- 結構驗證
import onnx
import onnxruntime
import numpy as npdef verify_onnx_structure(onnx_path):# 加載并檢查模型結構onnx_model = onnx.load(onnx_path)onnx.checker.check_model(onnx_model)# 打印模型信息print("模型輸入:")for input in onnx_model.graph.input:print(f"- {input.name}: {input.type.tensor_type.shape}")print("\n模型輸出:")for output in onnx_model.graph.output:print(f"- {output.name}: {output.type.tensor_type.shape}")
- 數值精度對比
def compare_outputs(model, onnx_path, input_data):# PyTorch結果model.eval()with torch.no_grad():torch_output = model(torch.from_numpy(input_data))# ONNX結果ort_output = verify_onnx_model(onnx_path, input_data)# 比較差異diff = np.abs(torch_output.numpy() - ort_output).max()print(f"最大誤差: {diff}")return diff < 1e-5
- 驗證 ONNX 模型
import onnx
import onnxruntime
import numpy as npdef verify_onnx_model(onnx_path, input_data):# 加載ONNX模型onnx_model = onnx.load(onnx_path)onnx.checker.check_model(onnx_model)# 創建推理會話ort_session = onnxruntime.InferenceSession(onnx_path)# 準備輸入數據ort_inputs = {ort_session.get_inputs()[0].name: input_data}# 運行推理ort_outputs = ort_session.run(None, ort_inputs)return ort_outputs[0]input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
output = verify_onnx_model('model.onnx', input_data)
2.3 ONNX模型優化
使用ONNX Runtime提供的優化工具進一步提升性能:
import onnxruntime as ort
from onnxruntime.transformers import optimizerdef optimize_onnx_model(onnx_path, optimized_path):# 創建優化器配置opt_options = optimizer.OptimizationConfig(optimization_level=99, # 最高優化級別enable_gelu_approximation=True,enable_layer_norm_optimization=True,enable_attention_fusion=True)# 優化模型optimized_model = optimizer.optimize_model(onnx_path, 'cpu', # 或 'gpu'opt_options)# 保存優化后的模型optimized_model.save_model_to_file(optimized_path)print(f"優化后的模型已保存至 {optimized_path}")
optimizer.optimize_model()
第二個參數是優化目標設備,支持 ‘cpu’ 或 ‘gpu’。- 優化目標設備:指定模型優化時的目標硬件平臺。例如:
- ‘cpu’:針對 CPU 進行優化(如調整算子、量化參數等)。
- ‘gpu’:針對 GPU 進行優化(如使用 CUDA 內核、張量核心等)。
*運行時設備:優化后的模型可以在其他設備上運行,但性能可能受影響。例如: - 針對 CPU 優化的模型可以在 GPU 上運行,但可能無法充分利用 GPU 特性。
- 針對 GPU 優化的模型在 CPU 上運行可能會報錯或性能下降。
建議保持優化目標與運行設備一致以獲得最佳性能。
- 優化目標設備:指定模型優化時的目標硬件平臺。例如:
3. TorchScript格式導出
3.1 trace導出
適用于前向計算圖結構固定的模型。
import torchdef export_torchscript_trace(model, input_shape, save_path):model.eval()example_input = torch.randn(input_shape)# 使用跟蹤法導出traced_model = torch.jit.trace(model, example_input)traced_model.save(save_path)print(f"Traced model exported to {save_path}")return traced_model
3.2 script導出
適用于包含條件分支、循環等動態結構的模型。
import torch
import torch.nn as nn@torch.jit.script
class ScriptableModel(nn.Module):def __init__(self):super().__init__()self.conv = nn.Conv2d(3, 64, 3)self.relu = nn.ReLU()def forward(self, x):x = self.conv(x)x = self.relu(x)return xdef export_torchscript_script(model, save_path):scripted_model = torch.jit.script(model)scripted_model.save(save_path)print(f"Scripted model exported to {save_path}")return scripted_model
3.3 TorchScript模型驗證
驗證TorchScript模型的正確性:
def verify_torchscript_model(original_model, ts_model_path, input_data):# 原始模型輸出original_model.eval()with torch.no_grad():original_output = original_model(input_data)# 加載TorchScript模型ts_model = torch.jit.load(ts_model_path)ts_model.eval()# TorchScript模型輸出with torch.no_grad():ts_output = ts_model(input_data)# 比較差異diff = torch.abs(original_output - ts_output).max().item()print(f"最大誤差: {diff}")return diff < 1e-5
4. 自定義算子處理
4.1 ONNX自定義算子
如需導出自定義算子,可通過ONNX擴展機制實現。
from onnx import helperdef create_custom_op():# 定義自定義算子custom_op = helper.make_node('CustomOp', # 算子名稱inputs=['input'], # 輸入outputs=['output'], # 輸出domain='custom.domain')return custom_opdef register_custom_op():# 注冊自定義算子from onnxruntime.capi import _pybind_state as CC.register_custom_op('CustomOp', 'custom.domain')
4.2 TorchScript自定義算子
可通過C++擴展自定義TorchScript算子。
from torch.utils.cpp_extension import load# 編譯自定義C++算子
custom_op = load(name="custom_op",sources=["custom_op.cpp"],verbose=True
)# 在模型中使用自定義算子
class ModelWithCustomOp(nn.Module):def forward(self, x):return custom_op.forward(x)
4.3 自定義算子示例
下面是一個完整的自定義算子實現示例:
// custom_op.cpp
#include <torch/extension.h>torch::Tensor custom_forward(torch::Tensor input) {return input.sigmoid().mul(2.0);
}PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {m.def("forward", &custom_forward, "Custom forward function");
}
# 在Python中使用
import torch
from torch.utils.cpp_extension import load# 編譯自定義算子
custom_op = load(name="custom_op",sources=["custom_op.cpp"],verbose=True
)# 測試自定義算子
input_tensor = torch.randn(2, 3)
output = custom_op.forward(input_tensor)
print(output)
5. 模型部署示例
5.1 ONNXRuntime部署
import onnxruntime as ort
import numpy as np
from PIL import Image
import torchvision.transforms as transformsdef preprocess_image(image_path, input_shape):# 圖像預處理transform = transforms.Compose([transforms.Resize((input_shape[2], input_shape[3])),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])image = Image.open(image_path).convert('RGB')image_tensor = transform(image).unsqueeze(0).numpy()return image_tensordef onnx_inference(onnx_path, image_path, input_shape=(1, 3, 224, 224)):# 加載ONNX模型session = ort.InferenceSession(onnx_path)# 預處理圖像input_data = preprocess_image(image_path, input_shape)# 獲取輸入輸出名稱input_name = session.get_inputs()[0].nameoutput_name = session.get_outputs()[0].name# 執行推理result = session.run([output_name], {input_name: input_data})return result[0]
5.2 TorchScript部署
import torch
from PIL import Image
import torchvision.transforms as transformsdef torchscript_inference(model_path, image_path):# 加載TorchScript模型model = torch.jit.load(model_path)model.eval()# 圖像預處理transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 加載并處理圖像image = Image.open(image_path).convert('RGB')input_tensor = transform(image).unsqueeze(0)# 執行推理with torch.no_grad():output = model(input_tensor)return output
6. 常見問題與解決方案
6.1 ONNX導出失敗
問題: 導出ONNX時出現算子不支持錯誤
解決方案:
# 嘗試使用更高版本的opset
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=13)# 或替換不支持的操作
class ModelWrapper(nn.Module):def __init__(self, model):super().__init__()self.model = modeldef forward(self, x):# 替換不支持的操作為等效操作return self.model(x)
6.2 TorchScript跟蹤失敗
問題: 動態控制流導致trace失敗
解決方案:
# 使用script而非trace
scripted_model = torch.jit.script(model)# 或修改模型結構避免動態控制流
class TraceFriendlyModel(nn.Module):def __init__(self, original_model):super().__init__()self.model = original_modeldef forward(self, x):# 移除動態控制流return self.model.forward_fixed(x)
6.3 推理性能問題
問題: 導出模型推理速度慢
解決方案:
# 1. 使用量化
from torch.quantization import quantize_dynamic
quantized_model = quantize_dynamic(model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8)# 2. 使用TensorRT優化ONNX
import tensorrt as trt
# TensorRT優化代碼...# 3. 使用ONNX Runtime優化
import onnxruntime as ort
session = ort.InferenceSession("model.onnx", providers=['CUDAExecutionProvider'])
7. 最佳實踐
-
選擇合適的導出格式
- ONNX:適合跨平臺、跨框架部署,兼容性強
- TorchScript:適合PyTorch生態內部署,支持靈活性高
- 根據目標平臺和性能需求選擇
-
優化導出模型
- 使用合適的opset版本(建議11及以上)
- 啟用常量折疊等優化選項
- 導出后務必驗證模型正確性
- 考慮使用量化和剪枝優化模型大小
-
處理動態輸入
- 設置動態維度(如batch_size)
- 測試不同輸入大小,確保模型魯棒性
- 記錄支持的輸入范圍和約束
-
文檔和版本控制
- 記錄導出配置和依賴版本
- 保存模型元數據(如輸入輸出規格)
- 對模型文件進行版本化管理
- 維護模型卡片(Model Card)記錄關鍵信息
-
調試技巧
- 使用ONNX Graph Viewer等可視化工具分析模型結構
- 使用Netron查看計算圖和參數分布
- 比較原始與導出模型輸出,檢查數值精度差異
- 遇到兼容性問題時查閱官方文檔和社區經驗
8. 參考資源
- ONNX官方文檔
- PyTorch TorchScript教程
- ONNX Runtime文檔
- TensorRT開發者指南
- Netron模型可視化工具
?
?
📌 感謝閱讀!若文章對你有用,別吝嗇互動~?
👍 點個贊 | ? 收藏備用 | 💬 留下你的想法 ,關注我,更多干貨持續更新!