一、方案介紹
- 研發階段:利用 PyTorch 的動態圖特性進行快速原型驗證,快速迭代模型設計。
- 靈活性與易用性:PyTorch 是一個非常靈活且易于使用的深度學習框架,特別適合研究和實驗。其動態計算圖特性使得模型的構建和調試變得更加直觀,開發者可以在運行時修改模型結構。
- 快速原型開發:許多研究人員和開發者選擇 PyTorch 進行模型訓練,因為它支持快速原型開發和靈活的模型設計,能夠快速驗證新想法并進行迭代。
- 轉換階段:將訓練好的模型通過 TorchScript 導出為 ONNX 格式,再轉換為 TensorFlow 格式,最后生成 TFLite 模型。
- 專為移動和嵌入式設備優化:TensorFlow Lite 是專為移動和嵌入式設備設計的推理框架,能夠在資源有限的環境中高效運行模型,確保在各種設備上實現實時推理。
- 支持模型量化和優化:TFLite 支持模型量化和優化,能夠顯著減小模型大小并提高推理速度,適合在手機、邊緣設備等場景中使用。這使得開發者能夠在不犧牲準確度的情況下,提升模型的運行效率。
- 部署階段:將 TFLite 模型集成到 Android、iOS 或嵌入式系統中,確保模型能夠在目標設備上高效運行。
- 內存和計算資源的優化:在推理階段,使用 TFLite 可以減少內存占用和計算資源消耗,尤其是在移動設備和嵌入式系統上。這對于需要長時間運行的應用尤為重要,可以延長設備的電池壽命。
- 多種優化技術:TFLite 提供了多種優化技術,如模型量化(將浮點數轉換為整數),可以進一步提高推理速度并降低功耗。這使得在實時應用中能夠實現更快的響應時間,提升用戶體驗。
二、實例1:CNN模型的轉換
注:python 版本為3.10
2.1 pytorch模型訓練
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 檢查是否支持 MPS
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")# 定義 CNN 模型
class CNNModel(nn.Module):def __init__(self):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = nn.functional.relu(self.conv1(x))x = nn.functional.max_pool2d(x, 2)x = nn.functional.relu(self.conv2(x))x = nn.functional.max_pool2d(x, 2)x = x.view(-1, 64 * 7 * 7)x = nn.functional.relu(self.fc1(x))x = self.fc2(x)return x# 數據預處理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加載 MNIST 數據集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)# 初始化模型、損失函數和優化器
model = CNNModel().to(device) # 將模型移動到 MPS 設備
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 訓練模型
for epoch in range(20):for images, labels in train_loader:images, labels = images.to(device), labels.to(device) # 將數據移動到 MPS 設備optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch [{epoch + 1}/20], Loss: {loss.item():.6f}')# 保存模型
torch.save(model.state_dict(), 'cnn_mnist.pth')
print("Model saved as cnn_mnist.pth")
2.2 pth模型轉onnx 并驗證一致性
import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.nn as nn# 定義 CNN 模型
class CNNModel(nn.Module):def __init__(self):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = nn.functional.relu(self.conv1(x))x = nn.functional.max_pool2d(x, 2)x = nn.functional.relu(self.conv2(x))x = nn.functional.max_pool2d(x, 2)x = x.view(-1, 64 * 7 * 7)x = nn.functional.relu(self.fc1(x))x = self.fc2(x)return x# 加載模型并進行推理
model = CNNModel()
model.load_state_dict(torch.load('cnn_mnist.pth', weights_only=True)) # 加載保存的模型權重
model.eval() # 設置為評估模式# 創建一個示例輸入
dummy_input = torch.randn(1, 1, 28, 28) # MNIST 圖像的形狀# 使用 PyTorch 進行推理
with torch.no_grad():pytorch_output = model(dummy_input)# 導出模型為 ONNX 格式
torch.onnx.export(model, dummy_input, 'cnn_mnist.onnx', export_params=True, opset_version=11)
print("Model exported to cnn_mnist.onnx")# 使用 ONNX 進行推理
onnx_model = onnx.load('cnn_mnist.onnx')
ort_session = ort.InferenceSession('cnn_mnist.onnx')# 準備輸入數據
onnx_input = dummy_input.numpy() # 將 PyTorch 張量轉換為 NumPy 數組
onnx_input = onnx_input.astype(np.float32) # 確保數據類型為 float32# 使用 ONNX 進行推理
onnx_output = ort_session.run(None, {ort_session.get_inputs()[0].name: onnx_input})# 比較輸出
pytorch_output_np = pytorch_output.numpy() # 將 PyTorch 輸出轉換為 NumPy 數組
onnx_output_np = onnx_output[0] # ONNX 輸出是一個列表,取第一個元素# 檢查輸出是否一致
if np.allclose(pytorch_output_np, onnx_output_np, atol=1e-5):print("The outputs are consistent between PyTorch and ONNX.")
else:print("The outputs are NOT consistent between PyTorch and ONNX.")# 打印輸出結果
print("PyTorch output:", pytorch_output_np)
print("ONNX output:", onnx_output_np)
The outputs are consistent between PyTorch and ONNX.
PyTorch output: [[ -1.5153266 -11.934659 0.5428004 -16.058285 -3.6684208 -4.596178-14.53585 -3.3159208 -5.7872214 -5.3301578]]
ONNX output: [[ -1.5153263 -11.934658 0.5428015 -16.058285 -3.66842 -4.5961757-14.53585 -3.3159204 -5.787223 -5.3301597]]
2.3 onnx模型轉tflite
參考這個項目:onnx2tflite
git clone https://github.com/MPolaris/onnx2tflite.git
cd onnx2tflite
conda install tensorflow=2.11.0
pip install .
python -m onnx2tflite --weights ../pth2onnx/cnn_mnist.onnx
2.4 onnx模型和tflite一致性驗證
import numpy as np
import onnxruntime as ort
import tensorflow as tf# 1. 加載 ONNX 模型
onnx_model_path = 'cnn_mnist.onnx'
onnx_session = ort.InferenceSession(onnx_model_path)# 2. 加載 TFLite 模型
tflite_model_path = 'cnn_mnist.tflite'
tflite_interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
tflite_interpreter.allocate_tensors()# 3. 準備輸入數據
# 假設輸入數據是 MNIST 數據集的一部分,形狀為 (1, 28, 28, 1)
input_data = np.random.rand(1, 28, 28, 1).astype(np.float32) # Keras 輸入
input_data_onnx = input_data.transpose(0, 3, 1, 2) # 轉換為 ONNX 輸入格式 (1, 1, 28, 28)# 4. 使用相同的輸入數據進行推理# ONNX 模型推理
onnx_input_name = onnx_session.get_inputs()[0].name
onnx_output = onnx_session.run(None, {onnx_input_name: input_data_onnx})[0]
print("ONNX Output:", onnx_output)# TFLite 模型推理
tflite_input_details = tflite_interpreter.get_input_details()
tflite_output_details = tflite_interpreter.get_output_details()# 檢查 TFLite 輸入形狀
print("TFLite Input Shape:", tflite_input_details[0]['shape'])# 設置 TFLite 輸入
# 確保輸入數據的形狀與 TFLite 模型的輸入要求一致
tflite_interpreter.set_tensor(tflite_input_details[0]['index'], input_data)
tflite_interpreter.invoke()
tflite_output = tflite_interpreter.get_tensor(tflite_output_details[0]['index'])
print("TFLite Output:", tflite_output)# 5. 比較輸出結果
# 計算輸出的差異
onnx_difference = np.abs(onnx_output - tflite_output)# 輸出結果
print("Difference (ONNX vs TFLite):", onnx_difference)# 檢查是否一致
if np.all(onnx_difference < 1e-5): # 設定一個閾值print("The outputs are consistent between ONNX and TFLite models.")
else:print("The outputs are not consistent between ONNX and TFLite models.")
ONNX Output: [[ -3.7372704 -6.5073314 -1.1807165 -2.4232314 -10.638929 2.2660115-4.5868526 -2.7494073 -0.5609715 -6.331989 ]]
TFLite Input Shape: [ 1 28 28 1]
TFLite Output: [[ -3.7372704 -6.5073323 -1.180716 -2.4232314 -10.6389282.2660117 -4.5868545 -2.7494078 -0.56097114 -6.331988 ]]
Difference (ONNX vs TFLite): [[0.0000000e+00 9.5367432e-07 4.7683716e-07 0.0000000e+00 9.5367432e-072.3841858e-07 1.9073486e-06 4.7683716e-07 3.5762787e-07 9.5367432e-07]]
The outputs are consistent between ONNX and TFLite models.