模型部署 - onnx 的導出和分析 -(1) - PyTorch 導出 ONNX - 學習記錄

onnx 的導出和分析

  • 一、PyTorch 導出 ONNX 的方法
    • 1.1、一個簡單的例子 -- 將線性模型轉成 onnx
    • 1.2、導出多個輸出頭的模型
    • 1.3、導出含有動態維度的模型
  • 二、pytorch 導出 onnx 不成功的時候如何解決
    • 2.1、修改 opset 的版本
    • 2.2、替換 pytorch 中的算子組合
    • 2.3、在 pytorch 登記( 注冊 ) onnx 中某些算子
      • 2.3.1、注冊方法一
      • 2.3.2、注冊方法二
    • 2.4、直接修改 onnx,創建 plugin

一、PyTorch 導出 ONNX 的方法

1.1、一個簡單的例子 – 將線性模型轉成 onnx

首先我們用 pytorch 定義一個線性模型,nn.Linear : 線性層執行的操作是 y = x * W^T + b,其中 x 是輸入,W 是權重,b 是偏置。(實際上就是一個矩陣乘法)

class Model(torch.nn.Module):def __init__(self, in_features, out_features, weights, bias=False):super().__init__()self.linear = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear.weight.copy_(weights)def forward(self, x):x = self.linear(x)return x

然后我們再定義一個函數,用于導出 onnx

def export_onnx():input   = torch.zeros(1, 1, 1, 4)weights = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)model   = Model(4, 3, weights)model.eval() #添加eval防止權重繼續更新torch.onnx.export(model         = model, args          = (input,),f             = "model.onnx",input_names   = ["input0"],output_names  = ["output0"],opset_version = 12)print("Finished onnx export")

可以看到,這里面的關鍵在函數 torch.onnx.export(),這是 pytorch 導出 onnx 的基本方式,這個函數的參數有很多,但只要一些基本的參數即可導出模型,下面是一些基本參數的定義:

  • model (torch.nn.Module): 需要導出的PyTorch模型
  • args (tuple or Tensor): 一個元組,其中包含傳遞給模型的輸入張量
  • f (str): 要保存導出模型的文件路徑。
  • input_names (list of str): 輸入節點的名字的列表
  • output_names (list of str): 輸出節點的名字的列表
  • opset_version (int): 用于導出模型的 ONNX 操作集版本

最后我們完整的運行一下代碼:

import torch
import torch.nn as nn
import torch.onnxclass Model(torch.nn.Module):def __init__(self, in_features, out_features, weights, bias=False):super().__init__()self.linear = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear.weight.copy_(weights)def forward(self, x):x = self.linear(x)return xdef export_onnx():input   = torch.zeros(1, 1, 1, 4)weights = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)model   = Model(4, 3, weights)model.eval() #添加eval防止權重繼續更新torch.onnx.export(model         = model, args          = (input,),f             = "model.onnx",input_names   = ["input0"],output_names  = ["output0"],opset_version = 12)print("Finished onnx export")if __name__ == "__main__":export_onnx()

導出模型后,我們用 netron 查看模型,在終端輸入

netron model.onnx

在這里插入圖片描述

1.2、導出多個輸出頭的模型

第一步:定義一個多輸出的模型:

class Model(torch.nn.Module):def __init__(self, in_features, out_features, weights1, weights2, bias=False):super().__init__()self.linear1 = nn.Linear(in_features, out_features, bias)self.linear2 = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear1.weight.copy_(weights1)self.linear2.weight.copy_(weights2)def forward(self, x):x1 = self.linear1(x)x2 = self.linear2(x)return x1, x2

第二步:編寫導出 onnx 的函數

def export_onnx():input    = torch.zeros(1, 1, 1, 4)weights1 = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)weights2 = torch.tensor([[2, 3, 4, 5],[3, 4, 5, 6],[4, 5, 6, 7]],dtype=torch.float32)model   = Model(4, 3, weights1, weights2)model.eval() #添加eval防止權重繼續更新torch.onnx.export(model         = model, args          = (input,),f             = "model.onnx",input_names   = ["input0"],output_names  = ["output0", "output1"],opset_version = 12)print("Finished onnx export")

可以看到,和例 1.1 不一樣的地方是 torch.onnx.export 的 output_names
例1.1:output_names = [“output0”]
例1.2:output_names = [“output0”, “output1”]

運行一下完整代碼:

import torch
import torch.nn as nn
import torch.onnxclass Model(torch.nn.Module):def __init__(self, in_features, out_features, weights1, weights2, bias=False):super().__init__()self.linear1 = nn.Linear(in_features, out_features, bias)self.linear2 = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear1.weight.copy_(weights1)self.linear2.weight.copy_(weights2)def forward(self, x):x1 = self.linear1(x)x2 = self.linear2(x)return x1, x2def export_onnx():input    = torch.zeros(1, 1, 1, 4)weights1 = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)weights2 = torch.tensor([[2, 3, 4, 5],[3, 4, 5, 6],[4, 5, 6, 7]],dtype=torch.float32)model   = Model(4, 3, weights1, weights2)model.eval() #添加eval防止權重繼續更新torch.onnx.export(model         = model, args          = (input,),f             = "model.onnx",input_names   = ["input0"],output_names  = ["output0", "output1"],opset_version = 12)print("Finished onnx export")if __name__ == "__main__":export_onnx()

用 netron 查看模型,結果如下,模型多出了一個輸出結果
在這里插入圖片描述

1.3、導出含有動態維度的模型

完整運行代碼如下:

import torch
import torch.nn as nn
import torch.onnxclass Model(torch.nn.Module):def __init__(self, in_features, out_features, weights, bias=False):super().__init__()self.linear = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear.weight.copy_(weights)def forward(self, x):x = self.linear(x)return xdef export_onnx():input   = torch.zeros(1, 1, 1, 4)weights = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)model   = Model(4, 3, weights)model.eval() #添加eval防止權重繼續更新torch.onnx.export(model         = model, args          = (input,),f             = "model.onnx",input_names   = ["input0"],output_names  = ["output0"],dynamic_axes  = {'input0':  {0: 'batch'},'output0': {0: 'batch'}},opset_version = 12)print("Finished onnx export")if __name__ == "__main__":export_onnx()

可以看到,比例 1.1 多了一行 torch.onnx.export 的 dynamic_axes 。我們可以用 dynamic_axes 來指定動態維度,其中 'input0': {0: 'batch'} 中的 0 表示在第 0 維度上的元素是動態的,這里取名為 ‘batch’

用 netron 查看模型:
在這里插入圖片描述
可以看到相對于例1.1,他的維度 0 變成了動態的,并且名為 ‘batch’

二、pytorch 導出 onnx 不成功的時候如何解決

上面是 onnx 可以直接被導出的情況,是因為對應的 pytorch 和 onnx 版本都有相應支持的算子在里面。但是有些時候,我們不能順利的導出 onnx,下面記錄一下常見的解決思路 。

2.1、修改 opset 的版本

這是首先應該考慮的思路,因為有可能只是版本過低然后有些算子還不支持,所以考慮提高 opset 的版本

比如下面的這個報錯,提示當前 onnx 的 opset 版本不支持這個算子,那我們可以去官方手冊搜索一下是否在高的版本支持了這個算子
在這里插入圖片描述

官方手冊地址:https://github.com/onnx/onnx/blob/main/docs/Operators.md

在這里插入圖片描述
又比如說 Acosh 這個算子,在 since version 9 才開始支持,那我們用 7 的時候就是不合適的,升級 opset 版本即可

2.2、替換 pytorch 中的算子組合

有些時候 pytorch 中的一些算子操作在 onnx 中并沒有,那我們可以把這些算子替換成 onnx 支持的算子

2.3、在 pytorch 登記( 注冊 ) onnx 中某些算子

有些算子在 onnx 中是有的,但是在 pytorch 中沒被登記,則需要注冊一下
比如下面這個案例,我們想要導出 asinh 這個算子的模型

import torch
import torch.onnxclass Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = torch.asinh(x)return xdef export_norm_onnx():input   = torch.rand(1, 5)model   = Model()model.eval()file    = "asinh.onnx"torch.onnx.export(model         = model, args          = (input,),f             = file,input_names   = ["input0"],output_names  = ["output0"],opset_version = 9)print("Finished normal onnx export")if __name__ == "__main__":export_norm_onnx()

但是報錯,提示 opset_version = 9 不支持這個算子
在這里插入圖片描述

但是我們打開官方手冊去搜索發現 asinh 在 version 9 又是支持的
在這里插入圖片描述
這里的問題是 PyTorch 與 onnx 之間沒有建立 asinh 的映射 (沒有搭建橋梁),所以我們編寫一個注冊代碼,來手動注冊一下這個算子

2.3.1、注冊方法一

完整代碼如下:

import torch
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolicdef asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)
register_custom_op_symbolic('aten::asinh', asinh_symbolic, 12)class Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = torch.asinh(x)return xdef validate_onnx():input = torch.rand(1, 5)# PyTorch的推理model = Model()x     = model(input)print("result from Pytorch is :", x)# onnxruntime的推理sess  = onnxruntime.InferenceSession('asinh.onnx')x     = sess.run(None, {'input0': input.numpy()})print("result from onnx is:    ", x)def export_norm_onnx():input   = torch.rand(1, 5)model   = Model()model.eval()file    = "asinh.onnx"torch.onnx.export(model         = model, args          = (input,),f             = file,input_names   = ["input0"],output_names  = ["output0"],opset_version = 12)print("Finished normal onnx export")if __name__ == "__main__":export_norm_onnx()# 自定義完onnx以后必須要進行一下驗證validate_onnx()

這段代碼的關鍵在于 算子的注冊:

1、定義 asinh_symbolic 函數

def asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)
  1. 函數必須是 asinh_symbolic 這個名字
  2. g: 就是 graph,計算圖 (在計算圖中添加onnx算子)
  3. input :symblic的參數需要與Pytorch的asinh接口函數的參數對齊
    (def asinh( input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: … )
  4. 符號函數內部調用 g.op, 為 onnx 計算圖添加 Asinh 算子
  5. g.op中的第一個參數是onnx中的算子名字: Asinh

2、使用 register_custom_op_symbolic 函數

register_custom_op_symbolic('aten::asinh', asinh_symbolic, 12)
  1. aten 是"a Tensor Library"的縮寫,是一個實現張量運算的C++庫
  2. asinh 是在名為 aten 的一個c++命名空間下進行實現的
  3. 將 asinh_symbolic 這個符號函數,與PyTorch的 asinh 算子綁定
  4. register_op 中的第一個參數是PyTorch中的算子名字: aten::asinh
  5. 最后一個參數表示從第幾個 opset 開始支持(可自己設置)

3、自定義完 onnx 以后必須要進行一下驗證,可使用 onnxruntime

2.3.2、注冊方法二

import torch
import torch.onnx
import onnxruntime
import functools
from torch.onnx import register_custom_op_symbolic
from torch.onnx._internal import registration_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)@_onnx_symbolic('aten::asinh')
def asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)class Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = torch.asinh(x)return xdef validate_onnx():input = torch.rand(1, 5)# PyTorch的推理model = Model()x     = model(input)print("result from Pytorch is :", x)# onnxruntime的推理sess  = onnxruntime.InferenceSession('asinh2.onnx')x     = sess.run(None, {'input0': input.numpy()})print("result from onnx is:    ", x)def export_norm_onnx():input   = torch.rand(1, 5)model   = Model()model.eval()file    = "asinh2.onnx"torch.onnx.export(model         = model, args          = (input,),f             = file,input_names   = ["input0"],output_names  = ["output0"],opset_version = 12)print("Finished normal onnx export")if __name__ == "__main__":export_norm_onnx()# 自定義完onnx以后必須要進行一下驗證validate_onnx()

與上面例子不同的是,這個注冊方式跟底層文件的寫法是一樣的(文件在虛擬環境中的 torch/onnx/symbolic_opset*.py )

通過torch._internal 中的 registration 來注冊這個算子,讓這個算子可以與底層C++實現的 aten::asinh 綁定

_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)
@_onnx_symbolic('aten::asinh')
def asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)

2.4、直接修改 onnx,創建 plugin

直接手動創建一個 onnx (這是一個思路,會在后續博客進行總結記錄)

參考鏈接

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/715799.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/715799.shtml
英文地址,請注明出處:http://en.pswp.cn/news/715799.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

vscode+remote突然無法連接服務器以及ssh連接出問題時的排錯方法

文章目錄 設備描述狀況描述解決方法當ssh連接出問題時的排錯方法 設備描述 主機:win11,使用vscode的remote-ssh插件 服務器:阿里云的2C2GUbuntu 22.04 UFIE 狀況描述 之前一直使用的是vscode的remote服務,都是能夠正常連接服務…

【Qt】界面布局

Qt常用布局 除Qt Designer支持可視化設計和布局界面之外,Qt 提供了代碼方式來進行界面布局, 以下是幾種常用的界面布局方式: 水平布局(QHBoxLayout)和垂直布局(QVBoxLayout): QHBo…

Redis常用數據結構--Zset

Zset ZADDZCARDZCOUNTZRANGE/ZREVRANGEZRANGEBYSCOREZPOPMAX/ZPOPMINBZPOPMAX/BZPOPMINZRANK/ZREVRANKZSCOREZREMZREMRANGEBYRANKZREMRANGEBYSCOREZINCRBYZINTERSTORE/ZUNIONSTORE內部編碼使?場景 ZADD 添加或者更新指定的元素以及關聯的分數到 zset 中,分數應該符…

如何在 Angular 測試中使用 spy

簡介 Jasmine spy 用于跟蹤或存根函數或方法。spy 是一種檢查函數是否被調用或提供自定義返回值的方法。我們可以使用spy 來測試依賴于服務的組件,并避免實際調用服務的方法來獲取值。這有助于保持我們的單元測試專注于測試組件本身的內部而不是其依賴關系。 在本…

空調壓縮機補充潤滑油的方法

空調壓縮機補充潤滑油的方法有三種,從吸氣截止閥旁邊通孔吸入,從加油孔中加入,從曲軸箱下部加入,具體操作步驟如下: 1關閉吸氣截止閥,啟動壓縮機幾分鐘,將曲軸箱中制冷劑排入冷凝器&#xff0c…

vue2結合electron開發桌面端應用

一、Electron是什么? Electron是一個使用 JavaScript、HTML 和 CSS 構建桌面應用程序的框架。 嵌入 Chromium 和 Node.js 到 二進制的 Electron 。允許您保持一個 JavaScript 代碼代碼庫并創建可在Windows、macOS和Linux上運行的跨平臺應用 。 Electron 經常與 Ch…

scrapy 中間件

就是發送請求的時候,會經過,中間件。中間件會處理,你的請求 下面是代碼: # Define here the models for your spider middleware # # See documentation in: # https://docs.scrapy.org/en/latest/topics/spider-middleware.html…

【快速上手ProtoBuf】基本使用

文章目錄 1 :peach:初識 ProtoBuf:peach:1.1 :apple:序列化概念:apple:1.2 :apple:ProtoBuf 是什么:apple:1.3 :apple:ProtoBuf 的使用特點:apple: 2 :peach:創建 .proto ?件:peach:3 :peach:編譯 .proto 文件:peach:3 :peach:序列化與反序列化的使用:peach: 1 🍑初…

洛谷 2036.PERKET

采用遞歸法的方式進行題解。 思路:首先我們知道在n種材料當中,我們需要從中選擇至少有一種得配料才行。也就是說,我們選擇的配料數目是自己決定的,而不是那種組合型得對于你有要求的組合型遞歸方式。 所以我們會想到用指數型得遞…

(五)網絡優化與超參數選擇--九五小龐

網絡容量 網絡中神經單元數越多,層數越多,神經網路的擬合能力越強。但是訓練速度,難度越大,越容易產生過擬合。 如何選擇超參數 所謂超參數,也就是搭建神經網路中,需要我們自己去選擇(不是通…

LeetCode 刷題 [C++] 第45題.跳躍游戲 II

題目描述 給定一個長度為 n 的 0 索引整數數組 nums。初始位置為 nums[0]。 每個元素 nums[i] 表示從索引 i 向前跳轉的最大長度。換句話說&#xff0c;如果你在 nums[i] 處&#xff0c;你可以跳轉到任意 nums[i j] 處: 0 < j < nums[i]i j < n 返回到達 nums[n …

遞歸函數(c++題解)

題目描述 對于一個遞歸函數w(a, b, c)。 如果a < 0 or b < 0 or c < 0就返回值1。 如果a > 20 or b > 20 or c > 20就返回W(20,20,20)。 如果a < b并且b < c 就返回w(a, b, c ? 1) w(a, b ? 1, c ? 1) ? w(a, b ? 1, c)&#xff0c; 其它別…

計算機網絡知多少-第1篇

一、 從輸入URL到頁面展示到底發生了什么&#xff1f; 1. 首先瀏覽器會查看電腦本地緩存&#xff0c;如果有直接返回&#xff0c;否則需要進行下一步的網絡請求。 2. 在網絡請求之前&#xff0c;需要先進行DNS解析&#xff0c;來找到請求域名的ip地址。如果是HTTPS請求&#…

【C語言】熟悉文件基礎知識

歡迎關注個人主頁&#xff1a;逸狼 創造不易&#xff0c;可以點點贊嗎~ 如有錯誤&#xff0c;歡迎指出~ 文件 為了數據持久化保存&#xff0c;使用文件&#xff0c;否則數據存儲在內存中&#xff0c;程序退出&#xff0c;內存回收&#xff0c;數據就會丟失。 程序設計中&…

微信小程序,h5端自適應登陸方式

微信小程序端只顯示登陸(獲取opid),h5端顯示通過賬戶密碼登陸 例如: 通過下面的變量控制: const isWeixin ref(false); // #ifdef MP-WEIXIN isWeixin.value true; // #endif

Git 查看提交歷史

命令說明git log查看歷史提交記錄git blame (file)以列表形式查看指定文件的歷史修改記錄 git log 在使用 Git 提交了若干更新之后&#xff0c;又或者克隆了某個項目&#xff0c;想回顧下提交歷史&#xff0c;我們可以使用 git log 命令查看。 git log 命令用于查看 Git 倉庫中…

LIN基礎:從LIN Frame開始

目錄&#xff1a; 1、LIN的網絡拓撲 2、LIN Frame 1&#xff09;Header 2&#xff09;Response 3、LIN的通信規則 1&#xff09;LIN的發送行為示例 2&#xff09;LIN的接收行為示例 雖然LIN總線的通信速率不高&#xff0c;工程中&#xff0c;最高的速率也就19200bps。…

c語言extern關鍵字

extern 是C和C中的關鍵字&#xff0c;用于聲明一個變量或函數的存在&#xff0c;但不進行定義。 它通常用于在一個源文件中引用另一個源文件中定義的變量或函數。 例如&#xff0c;extern int x; 表示 x 是一個整數變量&#xff0c;但它的實際定義將在其他文件中。在引用它的文…

StarRocks——Stream Load 事務接口實現原理

目錄 前言 一、StarRocks 數據導入 二、StarRocks 事務寫入原理 三、InLong 實時寫入StarRocks原理 3.1 InLong概述 3.2 基本原理 3.3 詳細流程 3.3.1 任務寫入數據 3.3.2 任務保存檢查點 3.3.3 任務如何確認保存點成功 3.3.4 任務如何初始化 3.4 Exactly Once 保證…

Leetcode - 周賽386

目錄 一&#xff0c;3046. 分割數組 二&#xff0c;3047. 求交集區域內的最大正方形面積 三&#xff0c;3048. 標記所有下標的最早秒數 I 四&#xff0c;3049. 標記所有下標的最早秒數 II 一&#xff0c;3046. 分割數組 將題目給的數組nums分成兩個數組&#xff0c;且這兩個…