[模型部署] 1. 模型導出

👋 你好!這里有實用干貨與深度分享?? 若有幫助,歡迎:?
👍 點贊 | ? 收藏 | 💬 評論 | ? 關注 ,解鎖更多精彩!?
📁 收藏專欄即可第一時間獲取最新推送🔔。?
📖后續我將持續帶來更多優質內容,期待與你一同探索知識,攜手前行,共同進步🚀。?



?人工智能

模型導出

本文介紹如何將深度學習模型導出為不同的部署格式,包括ONNX、TorchScript等,并對各種格式的優缺點和最佳實踐進行總結,幫助你高效完成模型部署準備。


1. 導出格式對比

格式優點缺點適用場景
ONNX- 跨平臺跨框架
- 生態豐富
- 標準統一
- 廣泛支持
- 可能存在算子兼容問題
- 部分高級特性支持有限
- 跨平臺部署
- 使用標準推理引擎
- 需要廣泛兼容性
TorchScript- 與PyTorch無縫集成
- 支持動態圖結構
- 調試方便
- 性能優化
- 僅限PyTorch生態
- 文件體積較大
- PyTorch生產環境
- 需要動態特性
- 性能要求高
TensorRT- 極致優化性能
- 支持GPU加速
- 低延遲推理
- 僅支持NVIDIA GPU
- 配置復雜
- 高性能推理場景
- 實時應用
- 邊緣計算
TensorFlow SavedModel- TensorFlow生態完整支持
- 部署便捷
- 跨框架兼容性差- TensorFlow生產環境

2. ONNX格式導出

2.1 基本導出

ONNX格式適用于跨平臺部署,支持多種推理引擎(如ONNXRuntime、TensorRT、OpenVINO等)。

import torch
import torch.onnxdef export_to_onnx(model, input_shape, save_path):# 設置模型為評估模式model.eval()# 創建示例輸入dummy_input = torch.randn(input_shape)# 導出模型torch.onnx.export(model,               # 要導出的模型dummy_input,        # 模型輸入save_path,          # 保存路徑export_params=True, # 導出模型參數opset_version=11,   # ONNX算子集版本do_constant_folding=True,  # 常量折疊優化input_names=['input'],     # 輸入名稱output_names=['output'],   # 輸出名稱dynamic_axes={'input': {0: 'batch_size'},  # 動態批次大小'output': {0: 'batch_size'}})print(f"Model exported to {save_path}")# 使用示例
model = YourModel()
model.load_state_dict(torch.load('model.pth'))
export_to_onnx(model, (1, 3, 224, 224), 'model.onnx')

2.2 驗證導出模型

導出后必須進行全面驗證,包括結構檢查和數值對比:

  1. 結構驗證
import onnx
import onnxruntime
import numpy as npdef verify_onnx_structure(onnx_path):# 加載并檢查模型結構onnx_model = onnx.load(onnx_path)onnx.checker.check_model(onnx_model)# 打印模型信息print("模型輸入:")for input in onnx_model.graph.input:print(f"- {input.name}: {input.type.tensor_type.shape}")print("\n模型輸出:")for output in onnx_model.graph.output:print(f"- {output.name}: {output.type.tensor_type.shape}")
  1. 數值精度對比
def compare_outputs(model, onnx_path, input_data):# PyTorch結果model.eval()with torch.no_grad():torch_output = model(torch.from_numpy(input_data))# ONNX結果ort_output = verify_onnx_model(onnx_path, input_data)# 比較差異diff = np.abs(torch_output.numpy() - ort_output).max()print(f"最大誤差: {diff}")return diff < 1e-5
  1. 驗證 ONNX 模型
import onnx
import onnxruntime
import numpy as npdef verify_onnx_model(onnx_path, input_data):# 加載ONNX模型onnx_model = onnx.load(onnx_path)onnx.checker.check_model(onnx_model)# 創建推理會話ort_session = onnxruntime.InferenceSession(onnx_path)# 準備輸入數據ort_inputs = {ort_session.get_inputs()[0].name: input_data}# 運行推理ort_outputs = ort_session.run(None, ort_inputs)return ort_outputs[0]input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
output = verify_onnx_model('model.onnx', input_data)

2.3 ONNX模型優化

使用ONNX Runtime提供的優化工具進一步提升性能:

import onnxruntime as ort
from onnxruntime.transformers import optimizerdef optimize_onnx_model(onnx_path, optimized_path):# 創建優化器配置opt_options = optimizer.OptimizationConfig(optimization_level=99,  # 最高優化級別enable_gelu_approximation=True,enable_layer_norm_optimization=True,enable_attention_fusion=True)# 優化模型optimized_model = optimizer.optimize_model(onnx_path, 'cpu',  # 或 'gpu'opt_options)# 保存優化后的模型optimized_model.save_model_to_file(optimized_path)print(f"優化后的模型已保存至 {optimized_path}")
  • optimizer.optimize_model() 第二個參數是優化目標設備,支持 ‘cpu’ 或 ‘gpu’。
    • 優化目標設備:指定模型優化時的目標硬件平臺。例如:
      • ‘cpu’:針對 CPU 進行優化(如調整算子、量化參數等)。
      • ‘gpu’:針對 GPU 進行優化(如使用 CUDA 內核、張量核心等)。
        *運行時設備:優化后的模型可以在其他設備上運行,但性能可能受影響。例如:
      • 針對 CPU 優化的模型可以在 GPU 上運行,但可能無法充分利用 GPU 特性。
      • 針對 GPU 優化的模型在 CPU 上運行可能會報錯或性能下降。
        建議保持優化目標與運行設備一致以獲得最佳性能。

3. TorchScript格式導出

3.1 trace導出

適用于前向計算圖結構固定的模型。

import torchdef export_torchscript_trace(model, input_shape, save_path):model.eval()example_input = torch.randn(input_shape)# 使用跟蹤法導出traced_model = torch.jit.trace(model, example_input)traced_model.save(save_path)print(f"Traced model exported to {save_path}")return traced_model

3.2 script導出

適用于包含條件分支、循環等動態結構的模型。

import torch
import torch.nn as nn@torch.jit.script
class ScriptableModel(nn.Module):def __init__(self):super().__init__()self.conv = nn.Conv2d(3, 64, 3)self.relu = nn.ReLU()def forward(self, x):x = self.conv(x)x = self.relu(x)return xdef export_torchscript_script(model, save_path):scripted_model = torch.jit.script(model)scripted_model.save(save_path)print(f"Scripted model exported to {save_path}")return scripted_model

3.3 TorchScript模型驗證

驗證TorchScript模型的正確性:

def verify_torchscript_model(original_model, ts_model_path, input_data):# 原始模型輸出original_model.eval()with torch.no_grad():original_output = original_model(input_data)# 加載TorchScript模型ts_model = torch.jit.load(ts_model_path)ts_model.eval()# TorchScript模型輸出with torch.no_grad():ts_output = ts_model(input_data)# 比較差異diff = torch.abs(original_output - ts_output).max().item()print(f"最大誤差: {diff}")return diff < 1e-5

4. 自定義算子處理

4.1 ONNX自定義算子

如需導出自定義算子,可通過ONNX擴展機制實現。

from onnx import helperdef create_custom_op():# 定義自定義算子custom_op = helper.make_node('CustomOp',           # 算子名稱inputs=['input'],     # 輸入outputs=['output'],   # 輸出domain='custom.domain')return custom_opdef register_custom_op():# 注冊自定義算子from onnxruntime.capi import _pybind_state as CC.register_custom_op('CustomOp', 'custom.domain')

4.2 TorchScript自定義算子

可通過C++擴展自定義TorchScript算子。

from torch.utils.cpp_extension import load# 編譯自定義C++算子
custom_op = load(name="custom_op",sources=["custom_op.cpp"],verbose=True
)# 在模型中使用自定義算子
class ModelWithCustomOp(nn.Module):def forward(self, x):return custom_op.forward(x)

4.3 自定義算子示例

下面是一個完整的自定義算子實現示例:

// custom_op.cpp
#include <torch/extension.h>torch::Tensor custom_forward(torch::Tensor input) {return input.sigmoid().mul(2.0);
}PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {m.def("forward", &custom_forward, "Custom forward function");
}
# 在Python中使用
import torch
from torch.utils.cpp_extension import load# 編譯自定義算子
custom_op = load(name="custom_op",sources=["custom_op.cpp"],verbose=True
)# 測試自定義算子
input_tensor = torch.randn(2, 3)
output = custom_op.forward(input_tensor)
print(output)

5. 模型部署示例

5.1 ONNXRuntime部署

import onnxruntime as ort
import numpy as np
from PIL import Image
import torchvision.transforms as transformsdef preprocess_image(image_path, input_shape):# 圖像預處理transform = transforms.Compose([transforms.Resize((input_shape[2], input_shape[3])),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])image = Image.open(image_path).convert('RGB')image_tensor = transform(image).unsqueeze(0).numpy()return image_tensordef onnx_inference(onnx_path, image_path, input_shape=(1, 3, 224, 224)):# 加載ONNX模型session = ort.InferenceSession(onnx_path)# 預處理圖像input_data = preprocess_image(image_path, input_shape)# 獲取輸入輸出名稱input_name = session.get_inputs()[0].nameoutput_name = session.get_outputs()[0].name# 執行推理result = session.run([output_name], {input_name: input_data})return result[0]

5.2 TorchScript部署

import torch
from PIL import Image
import torchvision.transforms as transformsdef torchscript_inference(model_path, image_path):# 加載TorchScript模型model = torch.jit.load(model_path)model.eval()# 圖像預處理transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 加載并處理圖像image = Image.open(image_path).convert('RGB')input_tensor = transform(image).unsqueeze(0)# 執行推理with torch.no_grad():output = model(input_tensor)return output

6. 常見問題與解決方案

6.1 ONNX導出失敗

問題: 導出ONNX時出現算子不支持錯誤

解決方案:

# 嘗試使用更高版本的opset
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=13)# 或替換不支持的操作
class ModelWrapper(nn.Module):def __init__(self, model):super().__init__()self.model = modeldef forward(self, x):# 替換不支持的操作為等效操作return self.model(x)

6.2 TorchScript跟蹤失敗

問題: 動態控制流導致trace失敗

解決方案:

# 使用script而非trace
scripted_model = torch.jit.script(model)# 或修改模型結構避免動態控制流
class TraceFriendlyModel(nn.Module):def __init__(self, original_model):super().__init__()self.model = original_modeldef forward(self, x):# 移除動態控制流return self.model.forward_fixed(x)

6.3 推理性能問題

問題: 導出模型推理速度慢

解決方案:

# 1. 使用量化
from torch.quantization import quantize_dynamic
quantized_model = quantize_dynamic(model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8)# 2. 使用TensorRT優化ONNX
import tensorrt as trt
# TensorRT優化代碼...# 3. 使用ONNX Runtime優化
import onnxruntime as ort
session = ort.InferenceSession("model.onnx", providers=['CUDAExecutionProvider'])

7. 最佳實踐

  1. 選擇合適的導出格式

    • ONNX:適合跨平臺、跨框架部署,兼容性強
    • TorchScript:適合PyTorch生態內部署,支持靈活性高
    • 根據目標平臺和性能需求選擇
  2. 優化導出模型

    • 使用合適的opset版本(建議11及以上)
    • 啟用常量折疊等優化選項
    • 導出后務必驗證模型正確性
    • 考慮使用量化和剪枝優化模型大小
  3. 處理動態輸入

    • 設置動態維度(如batch_size)
    • 測試不同輸入大小,確保模型魯棒性
    • 記錄支持的輸入范圍和約束
  4. 文檔和版本控制

    • 記錄導出配置和依賴版本
    • 保存模型元數據(如輸入輸出規格)
    • 對模型文件進行版本化管理
    • 維護模型卡片(Model Card)記錄關鍵信息
  5. 調試技巧

    • 使用ONNX Graph Viewer等可視化工具分析模型結構
    • 使用Netron查看計算圖和參數分布
    • 比較原始與導出模型輸出,檢查數值精度差異
    • 遇到兼容性問題時查閱官方文檔和社區經驗

8. 參考資源

  • ONNX官方文檔
  • PyTorch TorchScript教程
  • ONNX Runtime文檔
  • TensorRT開發者指南
  • Netron模型可視化工具

?
?



📌 感謝閱讀!若文章對你有用,別吝嗇互動~?
👍 點個贊 | ? 收藏備用 | 💬 留下你的想法 ,關注我,更多干貨持續更新!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/81122.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/81122.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/81122.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

mac的Cli為什么輸入python3才有用python --version顯示無效,pyenv入門筆記,如何查看mac自帶的標準庫模塊

根據你的終端輸出&#xff0c;可以得出以下結論&#xff1a; 1. 你的 Mac 當前只有一個 Python 版本 系統默認的 Python 3 位于 /usr/bin/python3&#xff08;這是 macOS 自帶的 Python&#xff09;通過 which python3 確認當前使用的就是系統自帶的 Pythonbrew list python …

Java注解詳解:從入門到實戰應用篇

1. 引言 Java注解&#xff08;Annotation&#xff09;是JDK 5.0引入的一種元數據機制&#xff0c;用于為代碼提供附加信息。它廣泛應用于框架開發、代碼生成、編譯檢查等領域。本文將從基礎到實戰&#xff0c;全面解析Java注解的核心概念和使用場景。 2. 注解基礎概念 2.1 什…

前端方法的總結及記錄

個人簡介 &#x1f468;?&#x1f4bb;?個人主頁&#xff1a; 魔術師 &#x1f4d6;學習方向&#xff1a; 主攻前端方向&#xff0c;正逐漸往全棧發展 &#x1f6b4;個人狀態&#xff1a; 研發工程師&#xff0c;現效力于政務服務網事業 &#x1f1e8;&#x1f1f3;人生格言&…

組件導航 (HMRouter)+flutter項目搭建-混合開發+分欄效果

組件導航 (Navigation)flutter項目搭建 接上一章flutter項目的環境變量配置并運行flutter 1.flutter創建項目并運行 flutter create fluter_hmrouter 進入ohos目錄打開編輯器先自動簽名 編譯項目-生成簽名包 flutter build hap --debug 運行項目 HMRouter搭建安裝 1.安…

城市排水管網流量監測系統解決方案

一、方案背景 隨著工業的不斷發展和城市人口的急劇增加&#xff0c;工業廢水和城市污水的排放量也大量增加。目前&#xff0c;我國已成為世界上污水排放量大、增加速度快的國家之一。然而&#xff0c;總體而言污水處理能力較低&#xff0c;有相當部分未經處理的污水直接或間接排…

TCP/IP 知識體系

TCP/IP 知識體系 一、TCP/IP 定義 全稱&#xff1a;Transmission Control Protocol/Internet Protocol&#xff08;傳輸控制協議/網際協議&#xff09;核心概念&#xff1a; 跨網絡實現信息傳輸的協議簇&#xff08;包含 TCP、IP、FTP、SMTP、UDP 等協議&#xff09;因 TCP 和…

5G行業專網部署費用詳解:投資回報如何最大化?

隨著數字化轉型的加速&#xff0c;5G行業專網作為企業提升生產效率、保障業務安全和實現智能化管理的重要基礎設施&#xff0c;正受到越來越多行業客戶的關注。部署5G專網雖然前期投入較大&#xff0c;但通過合理規劃和技術選擇&#xff0c;能夠實現投資回報的最大化。 在5G行…

網頁工具-OTU/ASV表格物種分類匯總工具

AI輔助下開發了個工具&#xff0c;功能如下&#xff0c;分享給大家&#xff1a; 基于Shiny開發的用戶友好型網頁應用&#xff0c;專為微生物組數據分析設計。該工具能夠自動處理OTU/ASV_taxa表格&#xff08;支持XLS/XLSX/TSV/CSV格式&#xff09;&#xff0c;通過調用QIIME1&a…

【超分辨率專題】一種考量視頻編碼比特率優化能力的超分辨率基準

這是一個Benchmark&#xff0c;超分辨率視頻編碼&#xff08;2024&#xff09; 專題介紹一、研究背景二、相關工作2.1 SR的發展2.2 SR benchmark的發展 三、Benchmark細節3.1 數據集制作3.2 模型選擇3.3 編解碼器和壓縮標準選擇3.4 Benchmark pipeline3.5 質量評估和主觀評價研…

保姆教程-----安裝MySQL全過程

1.電腦從未安裝過mysql的&#xff0c;先找到mysql官網&#xff1a;MySQL :: Download MySQL Community Server 然后下載完成后&#xff0c;找到文件&#xff0c;然后雙擊打開 2. 選擇安裝的產品和功能 依次點開“MySQL Servers”、“MySQL Servers”、“MySQL Servers 5.7”、…

【React中函數組件和類組件區別】

在 React 中,函數組件和類組件是兩種構建組件的方式,它們在多個方面存在區別,以下詳細介紹: 1. 語法和定義 類組件:使用 ES6 的類(class)語法定義,繼承自 React.Component。需要通過 this.props 來訪問傳遞給組件的屬性(props),并且通常要實現 render 方法返回 JSX…

[基礎] HPOP、SGP4與SDP4軌道傳播模型深度解析與對比

HPOP、SGP4與SDP4軌道傳播模型深度解析與對比 文章目錄 HPOP、SGP4與SDP4軌道傳播模型深度解析與對比第一章 引言第二章 模型基礎理論2.1 歷史演進脈絡2.2 動力學方程統一框架 第三章 數學推導與攝動機制3.1 SGP4核心推導3.1.1 J?攝動解析解3.1.2 大氣阻力建模改進 3.2 SDP4深…

搭建運行若依微服務版本ruoyi-cloud最新教程

搭建運行若依微服務版本ruoyi-cloud 一、環境準備 JDK > 1.8MySQL > 5.7Maven > 3.0Node > 12Redis > 3 二、后端 2.1數據庫準備 在navicat上創建數據庫ry-seata、ry-config、ry-cloud運行SQL文件ry_20250425.sql、ry_config_20250224.sql、ry_seata_2021012…

Google I/O 2025 觀看攻略一鍵收藏,開啟技術探索之旅!

AIGC開放社區https://lerhk.xetlk.com/sl/1SAwVJ創業邦https://weibo.com/1649252577/PrNjioJ7XCSDNhttps://live.csdn.net/room/csdnnews/OOFSCy2g/channel/collectiondetail?sid2941619DONEWShttps://www.donews.com/live/detail/958.html鳳凰科技https://flive.ifeng.com/l…

ORACLE 11.2.0.4 數據庫磁盤空間爆滿導致GAP產生

前言 昨天晚上深夜接到客戶電話&#xff0c;反應數據庫無法正常使用&#xff0c;想進入服務器檢查時&#xff0c;登錄響應非常慢。等兩分鐘后進入服務器且通過sqlplus進入數據庫也很慢。通過檢查服務器磁盤空間發現數據庫所在區已經爆滿&#xff0c;導致數據庫在運行期間新增審…

計算機視覺---目標追蹤(Object Tracking)概覽

一、核心定義與基礎概念 1. 目標追蹤的定義 定義&#xff1a;在視頻序列或連續圖像中&#xff0c;對一個或多個感興趣目標&#xff08;如人、車輛、物體等&#xff09;的位置、運動軌跡進行持續估計的過程。核心任務&#xff1a;跨幀關聯目標&#xff0c;解決“同一目標在不同…

windows系統中下載好node無法使用npm

原因是 Windows PowerShell禁用導致的npm無法正常使用 解決方法管理員打開Windows PowerShell 輸入Set-ExecutionPolicy -Scope CurrentUser RemoteSigned 按Y 確認就解決了

Nginx模塊配置與請求處理詳解

Nginx 作為模塊化設計的 Web 服務器,其核心功能通過不同模塊協同完成。以下是各模塊的詳細配置案例及數據流轉解析: 一、核心模塊配置案例 1. Handler 模塊(內容生成) 功能:直接生成響應內容(如靜態文件、重定向等) # 示例1:靜態文件處理(ngx_http_static_module)…

Elasticsearch 學習(一)如何在Linux 系統中下載、安裝

目錄 一、Elasticsearch 下載二、使用 yum、dnf、zypper 命令下載安裝三、使用 Docker 本地快速啟動安裝&#xff08;ESKibana&#xff09;【測試推薦】3.1 介紹3.2 下載、安裝、啟動3.3 訪問3.4 修改配置&#xff0c;支持ip訪問 官網地址&#xff1a; https://www.elastic.co/…

Java Map雙列集合深度解析:HashMap、LinkedHashMap、TreeMap底層原理與實戰應用

Java Map雙列集合深度解析&#xff1a;HashMap、LinkedHashMap、TreeMap底層原理與實戰應用 一、Map雙列集合概述 1. 核心特點 鍵值對結構&#xff1a;每個元素由鍵&#xff08;Key&#xff09;和值&#xff08;Value&#xff09;組成。鍵唯一性&#xff1a;鍵不可重復&#…