onnx 的導出和分析
- 一、PyTorch 導出 ONNX 的方法
- 1.1、一個簡單的例子 -- 將線性模型轉成 onnx
- 1.2、導出多個輸出頭的模型
- 1.3、導出含有動態維度的模型
- 二、pytorch 導出 onnx 不成功的時候如何解決
- 2.1、修改 opset 的版本
- 2.2、替換 pytorch 中的算子組合
- 2.3、在 pytorch 登記( 注冊 ) onnx 中某些算子
- 2.3.1、注冊方法一
- 2.3.2、注冊方法二
- 2.4、直接修改 onnx,創建 plugin
一、PyTorch 導出 ONNX 的方法
1.1、一個簡單的例子 – 將線性模型轉成 onnx
首先我們用 pytorch 定義一個線性模型,nn.Linear : 線性層執行的操作是 y = x * W^T + b
,其中 x 是輸入,W 是權重,b 是偏置。(實際上就是一個矩陣乘法)
class Model(torch.nn.Module):def __init__(self, in_features, out_features, weights, bias=False):super().__init__()self.linear = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear.weight.copy_(weights)def forward(self, x):x = self.linear(x)return x
然后我們再定義一個函數,用于導出 onnx
def export_onnx():input = torch.zeros(1, 1, 1, 4)weights = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)model = Model(4, 3, weights)model.eval() #添加eval防止權重繼續更新torch.onnx.export(model = model, args = (input,),f = "model.onnx",input_names = ["input0"],output_names = ["output0"],opset_version = 12)print("Finished onnx export")
可以看到,這里面的關鍵在函數 torch.onnx.export()
,這是 pytorch 導出 onnx 的基本方式,這個函數的參數有很多,但只要一些基本的參數即可導出模型,下面是一些基本參數的定義:
- model (torch.nn.Module): 需要導出的PyTorch模型
- args (tuple or Tensor): 一個元組,其中包含傳遞給模型的輸入張量
- f (str): 要保存導出模型的文件路徑。
- input_names (list of str): 輸入節點的名字的列表
- output_names (list of str): 輸出節點的名字的列表
- opset_version (int): 用于導出模型的 ONNX 操作集版本
最后我們完整的運行一下代碼:
import torch
import torch.nn as nn
import torch.onnxclass Model(torch.nn.Module):def __init__(self, in_features, out_features, weights, bias=False):super().__init__()self.linear = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear.weight.copy_(weights)def forward(self, x):x = self.linear(x)return xdef export_onnx():input = torch.zeros(1, 1, 1, 4)weights = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)model = Model(4, 3, weights)model.eval() #添加eval防止權重繼續更新torch.onnx.export(model = model, args = (input,),f = "model.onnx",input_names = ["input0"],output_names = ["output0"],opset_version = 12)print("Finished onnx export")if __name__ == "__main__":export_onnx()
導出模型后,我們用 netron 查看模型,在終端輸入
netron model.onnx
1.2、導出多個輸出頭的模型
第一步:定義一個多輸出的模型:
class Model(torch.nn.Module):def __init__(self, in_features, out_features, weights1, weights2, bias=False):super().__init__()self.linear1 = nn.Linear(in_features, out_features, bias)self.linear2 = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear1.weight.copy_(weights1)self.linear2.weight.copy_(weights2)def forward(self, x):x1 = self.linear1(x)x2 = self.linear2(x)return x1, x2
第二步:編寫導出 onnx 的函數
def export_onnx():input = torch.zeros(1, 1, 1, 4)weights1 = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)weights2 = torch.tensor([[2, 3, 4, 5],[3, 4, 5, 6],[4, 5, 6, 7]],dtype=torch.float32)model = Model(4, 3, weights1, weights2)model.eval() #添加eval防止權重繼續更新torch.onnx.export(model = model, args = (input,),f = "model.onnx",input_names = ["input0"],output_names = ["output0", "output1"],opset_version = 12)print("Finished onnx export")
可以看到,和例 1.1 不一樣的地方是 torch.onnx.export 的 output_names
例1.1:output_names = [“output0”]
例1.2:output_names = [“output0”, “output1”]
運行一下完整代碼:
import torch
import torch.nn as nn
import torch.onnxclass Model(torch.nn.Module):def __init__(self, in_features, out_features, weights1, weights2, bias=False):super().__init__()self.linear1 = nn.Linear(in_features, out_features, bias)self.linear2 = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear1.weight.copy_(weights1)self.linear2.weight.copy_(weights2)def forward(self, x):x1 = self.linear1(x)x2 = self.linear2(x)return x1, x2def export_onnx():input = torch.zeros(1, 1, 1, 4)weights1 = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)weights2 = torch.tensor([[2, 3, 4, 5],[3, 4, 5, 6],[4, 5, 6, 7]],dtype=torch.float32)model = Model(4, 3, weights1, weights2)model.eval() #添加eval防止權重繼續更新torch.onnx.export(model = model, args = (input,),f = "model.onnx",input_names = ["input0"],output_names = ["output0", "output1"],opset_version = 12)print("Finished onnx export")if __name__ == "__main__":export_onnx()
用 netron 查看模型,結果如下,模型多出了一個輸出結果
1.3、導出含有動態維度的模型
完整運行代碼如下:
import torch
import torch.nn as nn
import torch.onnxclass Model(torch.nn.Module):def __init__(self, in_features, out_features, weights, bias=False):super().__init__()self.linear = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear.weight.copy_(weights)def forward(self, x):x = self.linear(x)return xdef export_onnx():input = torch.zeros(1, 1, 1, 4)weights = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)model = Model(4, 3, weights)model.eval() #添加eval防止權重繼續更新torch.onnx.export(model = model, args = (input,),f = "model.onnx",input_names = ["input0"],output_names = ["output0"],dynamic_axes = {'input0': {0: 'batch'},'output0': {0: 'batch'}},opset_version = 12)print("Finished onnx export")if __name__ == "__main__":export_onnx()
可以看到,比例 1.1 多了一行 torch.onnx.export 的 dynamic_axes
。我們可以用 dynamic_axes
來指定動態維度,其中 'input0': {0: 'batch'}
中的 0 表示在第 0 維度上的元素是動態的,這里取名為 ‘batch’
用 netron 查看模型:
可以看到相對于例1.1,他的維度 0 變成了動態的,并且名為 ‘batch’
二、pytorch 導出 onnx 不成功的時候如何解決
上面是 onnx 可以直接被導出的情況,是因為對應的 pytorch 和 onnx 版本都有相應支持的算子在里面。但是有些時候,我們不能順利的導出 onnx,下面記錄一下常見的解決思路 。
2.1、修改 opset 的版本
這是首先應該考慮的思路,因為有可能只是版本過低然后有些算子還不支持,所以考慮提高 opset 的版本
。
比如下面的這個報錯,提示當前 onnx 的 opset 版本不支持這個算子,那我們可以去官方手冊搜索一下是否在高的版本支持了這個算子
官方手冊地址:https://github.com/onnx/onnx/blob/main/docs/Operators.md
又比如說 Acosh
這個算子,在 since version 9
才開始支持,那我們用 7 的時候就是不合適的,升級 opset 版本即可
2.2、替換 pytorch 中的算子組合
有些時候 pytorch 中的一些算子操作在 onnx 中并沒有,那我們可以把這些算子替換成 onnx 支持的算子
2.3、在 pytorch 登記( 注冊 ) onnx 中某些算子
有些算子在 onnx 中是有的,但是在 pytorch 中沒被登記,則需要注冊一下
比如下面這個案例,我們想要導出 asinh 這個算子的模型
import torch
import torch.onnxclass Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = torch.asinh(x)return xdef export_norm_onnx():input = torch.rand(1, 5)model = Model()model.eval()file = "asinh.onnx"torch.onnx.export(model = model, args = (input,),f = file,input_names = ["input0"],output_names = ["output0"],opset_version = 9)print("Finished normal onnx export")if __name__ == "__main__":export_norm_onnx()
但是報錯,提示 opset_version = 9 不支持這個算子
但是我們打開官方手冊去搜索發現 asinh 在 version 9 又是支持的
這里的問題是 PyTorch 與 onnx 之間沒有建立 asinh 的映射
(沒有搭建橋梁),所以我們編寫一個注冊代碼,來手動注冊一下這個算子
2.3.1、注冊方法一
完整代碼如下:
import torch
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolicdef asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)
register_custom_op_symbolic('aten::asinh', asinh_symbolic, 12)class Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = torch.asinh(x)return xdef validate_onnx():input = torch.rand(1, 5)# PyTorch的推理model = Model()x = model(input)print("result from Pytorch is :", x)# onnxruntime的推理sess = onnxruntime.InferenceSession('asinh.onnx')x = sess.run(None, {'input0': input.numpy()})print("result from onnx is: ", x)def export_norm_onnx():input = torch.rand(1, 5)model = Model()model.eval()file = "asinh.onnx"torch.onnx.export(model = model, args = (input,),f = file,input_names = ["input0"],output_names = ["output0"],opset_version = 12)print("Finished normal onnx export")if __name__ == "__main__":export_norm_onnx()# 自定義完onnx以后必須要進行一下驗證validate_onnx()
這段代碼的關鍵在于 算子的注冊:
1、定義 asinh_symbolic 函數
def asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)
- 函數必須是 asinh_symbolic 這個名字
- g: 就是 graph,計算圖 (在計算圖中添加onnx算子)
- input :symblic的參數需要與Pytorch的asinh接口函數的參數對齊
(def asinh( input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: … )- 符號函數內部調用 g.op, 為 onnx 計算圖添加 Asinh 算子
- g.op中的第一個參數是onnx中的算子名字: Asinh
2、使用 register_custom_op_symbolic 函數
register_custom_op_symbolic('aten::asinh', asinh_symbolic, 12)
- aten 是"a Tensor Library"的縮寫,是一個實現張量運算的C++庫
- asinh 是在名為 aten 的一個c++命名空間下進行實現的
- 將 asinh_symbolic 這個符號函數,與PyTorch的 asinh 算子綁定
- register_op 中的第一個參數是PyTorch中的算子名字: aten::asinh
- 最后一個參數表示從第幾個 opset 開始支持(可自己設置)
3、自定義完 onnx 以后必須要進行一下驗證
,可使用 onnxruntime
2.3.2、注冊方法二
import torch
import torch.onnx
import onnxruntime
import functools
from torch.onnx import register_custom_op_symbolic
from torch.onnx._internal import registration_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)@_onnx_symbolic('aten::asinh')
def asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)class Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = torch.asinh(x)return xdef validate_onnx():input = torch.rand(1, 5)# PyTorch的推理model = Model()x = model(input)print("result from Pytorch is :", x)# onnxruntime的推理sess = onnxruntime.InferenceSession('asinh2.onnx')x = sess.run(None, {'input0': input.numpy()})print("result from onnx is: ", x)def export_norm_onnx():input = torch.rand(1, 5)model = Model()model.eval()file = "asinh2.onnx"torch.onnx.export(model = model, args = (input,),f = file,input_names = ["input0"],output_names = ["output0"],opset_version = 12)print("Finished normal onnx export")if __name__ == "__main__":export_norm_onnx()# 自定義完onnx以后必須要進行一下驗證validate_onnx()
與上面例子不同的是,這個注冊方式跟底層文件的寫法是一樣的(文件在虛擬環境中的 torch/onnx/symbolic_opset*.py )
通過torch._internal 中的 registration 來注冊這個算子,讓這個算子可以與底層C++實現的 aten::asinh 綁定
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)
@_onnx_symbolic('aten::asinh')
def asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)
2.4、直接修改 onnx,創建 plugin
直接手動創建一個 onnx
(這是一個思路,會在后續博客進行總結記錄)
參考鏈接