ONNX(Open Neural Network Exchange)
ONNX 是一種用于表示深度學習模型的開放格式,使得不同深度學習框架(如 PyTorch、TensorFlow、Caffe2 等)之間的模型能夠相互交換。
需安裝:
pip install --upgrade onnx onnxscript onnxruntime
Pytorch張量
可使用torch.rand()方法創建0~1均勻分布的隨機數,使用torch.randn()方法創建標準正態分布隨機數,使用torch.zeros()和torch.ones()方法創建全0和全1的張量。
在構造張量時使用dtype明確其類型。
PyTorch針對torch.float32和torch.int64類型有專門這樣的簡寫形式是因為,這兩種類型特別重要,模型的輸入類型一般都是torch.float32,而模型分類問題的標簽類型一般為torch.int64。
torch.onnx.export 是 PyTorch 自帶的把模型轉換成 ONNX 格式的函數。前三個參數分別是要轉換的模型、模型的任意一組輸入、導出的 ONNX 文件的文件名。
簡單示例
import torch
import torchvision.models as models # 加載一個預訓練的 PyTorch 模型
model = models.resnet18(pretrained=True)
model.eval() # 創建一個虛擬輸入張量(這里使用隨機數據)
dummy_input = torch.randn(1, 3, 224, 224) # 假設輸入是一張 224x224 的 RGB 圖像 # 導出模型為 ONNX 格式
torch.onnx.export(model, dummy_input, "resnet18.onnx", verbose=True, input_names=["input_0"], output_names=["output_0"])
注意點:
確保你的 PyTorch 模型在導出之前已經處于評估模式(
model.eval()
)示例輸入(dummy input)應該與你的模型訓練時使用的輸入具有相同的形狀和數據類型。
在將輸入數據傳遞給 ONNX Runtime 之前,請確保它們已經轉換為 NumPy 數組,并且位于 CPU 上