從 PyTorch 到 TensorFlow Lite:模型訓練與推理

一、方案介紹

  1. 研發階段:利用 PyTorch 的動態圖特性進行快速原型驗證,快速迭代模型設計。
    • 靈活性與易用性:PyTorch 是一個非常靈活且易于使用的深度學習框架,特別適合研究和實驗。其動態計算圖特性使得模型的構建和調試變得更加直觀,開發者可以在運行時修改模型結構。
    • 快速原型開發:許多研究人員和開發者選擇 PyTorch 進行模型訓練,因為它支持快速原型開發和靈活的模型設計,能夠快速驗證新想法并進行迭代。
  2. 轉換階段:將訓練好的模型通過 TorchScript 導出為 ONNX 格式,再轉換為 TensorFlow 格式,最后生成 TFLite 模型。
    • 專為移動和嵌入式設備優化:TensorFlow Lite 是專為移動和嵌入式設備設計的推理框架,能夠在資源有限的環境中高效運行模型,確保在各種設備上實現實時推理。
    • 支持模型量化和優化:TFLite 支持模型量化和優化,能夠顯著減小模型大小并提高推理速度,適合在手機、邊緣設備等場景中使用。這使得開發者能夠在不犧牲準確度的情況下,提升模型的運行效率。
  3. 部署階段:將 TFLite 模型集成到 Android、iOS 或嵌入式系統中,確保模型能夠在目標設備上高效運行。
    • 內存和計算資源的優化:在推理階段,使用 TFLite 可以減少內存占用和計算資源消耗,尤其是在移動設備和嵌入式系統上。這對于需要長時間運行的應用尤為重要,可以延長設備的電池壽命。
    • 多種優化技術:TFLite 提供了多種優化技術,如模型量化(將浮點數轉換為整數),可以進一步提高推理速度并降低功耗。這使得在實時應用中能夠實現更快的響應時間,提升用戶體驗。
      在這里插入圖片描述

二、實例1:CNN模型的轉換

注:python 版本為3.10

2.1 pytorch模型訓練

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 檢查是否支持 MPS
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")# 定義 CNN 模型
class CNNModel(nn.Module):def __init__(self):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = nn.functional.relu(self.conv1(x))x = nn.functional.max_pool2d(x, 2)x = nn.functional.relu(self.conv2(x))x = nn.functional.max_pool2d(x, 2)x = x.view(-1, 64 * 7 * 7)x = nn.functional.relu(self.fc1(x))x = self.fc2(x)return x# 數據預處理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加載 MNIST 數據集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)# 初始化模型、損失函數和優化器
model = CNNModel().to(device)  # 將模型移動到 MPS 設備
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 訓練模型
for epoch in range(20):for images, labels in train_loader:images, labels = images.to(device), labels.to(device)  # 將數據移動到 MPS 設備optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch [{epoch + 1}/20], Loss: {loss.item():.6f}')# 保存模型
torch.save(model.state_dict(), 'cnn_mnist.pth')
print("Model saved as cnn_mnist.pth")

2.2 pth模型轉onnx 并驗證一致性

import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.nn as nn# 定義 CNN 模型
class CNNModel(nn.Module):def __init__(self):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = nn.functional.relu(self.conv1(x))x = nn.functional.max_pool2d(x, 2)x = nn.functional.relu(self.conv2(x))x = nn.functional.max_pool2d(x, 2)x = x.view(-1, 64 * 7 * 7)x = nn.functional.relu(self.fc1(x))x = self.fc2(x)return x# 加載模型并進行推理
model = CNNModel()
model.load_state_dict(torch.load('cnn_mnist.pth', weights_only=True))  # 加載保存的模型權重
model.eval()  # 設置為評估模式# 創建一個示例輸入
dummy_input = torch.randn(1, 1, 28, 28)  # MNIST 圖像的形狀# 使用 PyTorch 進行推理
with torch.no_grad():pytorch_output = model(dummy_input)# 導出模型為 ONNX 格式
torch.onnx.export(model, dummy_input, 'cnn_mnist.onnx', export_params=True, opset_version=11)
print("Model exported to cnn_mnist.onnx")# 使用 ONNX 進行推理
onnx_model = onnx.load('cnn_mnist.onnx')
ort_session = ort.InferenceSession('cnn_mnist.onnx')# 準備輸入數據
onnx_input = dummy_input.numpy()  # 將 PyTorch 張量轉換為 NumPy 數組
onnx_input = onnx_input.astype(np.float32)  # 確保數據類型為 float32# 使用 ONNX 進行推理
onnx_output = ort_session.run(None, {ort_session.get_inputs()[0].name: onnx_input})# 比較輸出
pytorch_output_np = pytorch_output.numpy()  # 將 PyTorch 輸出轉換為 NumPy 數組
onnx_output_np = onnx_output[0]  # ONNX 輸出是一個列表,取第一個元素# 檢查輸出是否一致
if np.allclose(pytorch_output_np, onnx_output_np, atol=1e-5):print("The outputs are consistent between PyTorch and ONNX.")
else:print("The outputs are NOT consistent between PyTorch and ONNX.")# 打印輸出結果
print("PyTorch output:", pytorch_output_np)
print("ONNX output:", onnx_output_np)
The outputs are consistent between PyTorch and ONNX.
PyTorch output: [[ -1.5153266 -11.934659    0.5428004 -16.058285   -3.6684208  -4.596178-14.53585    -3.3159208  -5.7872214  -5.3301578]]
ONNX output: [[ -1.5153263 -11.934658    0.5428015 -16.058285   -3.66842    -4.5961757-14.53585    -3.3159204  -5.787223   -5.3301597]]

2.3 onnx模型轉tflite

參考這個項目:onnx2tflite

git clone https://github.com/MPolaris/onnx2tflite.git
cd onnx2tflite
conda install tensorflow=2.11.0
pip install .
python -m onnx2tflite --weights ../pth2onnx/cnn_mnist.onnx

在這里插入圖片描述

2.4 onnx模型和tflite一致性驗證

import numpy as np
import onnxruntime as ort
import tensorflow as tf# 1. 加載 ONNX 模型
onnx_model_path = 'cnn_mnist.onnx'
onnx_session = ort.InferenceSession(onnx_model_path)# 2. 加載 TFLite 模型
tflite_model_path = 'cnn_mnist.tflite'
tflite_interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
tflite_interpreter.allocate_tensors()# 3. 準備輸入數據
# 假設輸入數據是 MNIST 數據集的一部分,形狀為 (1, 28, 28, 1)
input_data = np.random.rand(1, 28, 28, 1).astype(np.float32)  # Keras 輸入
input_data_onnx = input_data.transpose(0, 3, 1, 2)  # 轉換為 ONNX 輸入格式 (1, 1, 28, 28)# 4. 使用相同的輸入數據進行推理# ONNX 模型推理
onnx_input_name = onnx_session.get_inputs()[0].name
onnx_output = onnx_session.run(None, {onnx_input_name: input_data_onnx})[0]
print("ONNX Output:", onnx_output)# TFLite 模型推理
tflite_input_details = tflite_interpreter.get_input_details()
tflite_output_details = tflite_interpreter.get_output_details()# 檢查 TFLite 輸入形狀
print("TFLite Input Shape:", tflite_input_details[0]['shape'])# 設置 TFLite 輸入
# 確保輸入數據的形狀與 TFLite 模型的輸入要求一致
tflite_interpreter.set_tensor(tflite_input_details[0]['index'], input_data)
tflite_interpreter.invoke()
tflite_output = tflite_interpreter.get_tensor(tflite_output_details[0]['index'])
print("TFLite Output:", tflite_output)# 5. 比較輸出結果
# 計算輸出的差異
onnx_difference = np.abs(onnx_output - tflite_output)# 輸出結果
print("Difference (ONNX vs TFLite):", onnx_difference)# 檢查是否一致
if np.all(onnx_difference < 1e-5):  # 設定一個閾值print("The outputs are consistent between ONNX and TFLite models.")
else:print("The outputs are not consistent between ONNX and TFLite models.")
ONNX Output: [[ -3.7372704  -6.5073314  -1.1807165  -2.4232314 -10.638929    2.2660115-4.5868526  -2.7494073  -0.5609715  -6.331989 ]]
TFLite Input Shape: [ 1 28 28  1]
TFLite Output: [[ -3.7372704   -6.5073323   -1.180716    -2.4232314  -10.6389282.2660117   -4.5868545   -2.7494078   -0.56097114  -6.331988  ]]
Difference (ONNX vs TFLite): [[0.0000000e+00 9.5367432e-07 4.7683716e-07 0.0000000e+00 9.5367432e-072.3841858e-07 1.9073486e-06 4.7683716e-07 3.5762787e-07 9.5367432e-07]]
The outputs are consistent between ONNX and TFLite models.

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/85238.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/85238.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/85238.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

4.2.5 Spark SQL 分區自動推斷

在本節實戰中&#xff0c;我們學習了Spark SQL的分區自動推斷功能&#xff0c;這是一種提升查詢性能的有效手段。通過創建具有不同分區的目錄結構&#xff0c;并在這些目錄中放置JSON文件&#xff0c;我們模擬了一個分區表的環境。使用Spark SQL讀取這些數據時&#xff0c;Spar…

數據結構:導論

目錄 什么是“第一性原理”&#xff1f; 什么是“數據結構”&#xff1f; 數據結構解決的根本問題是什么&#xff1f; 數據結構的兩大分類 數據結構的基本操作 數據結構與算法的關系 學習數據結構的底層目標 什么是“第一性原理”&#xff1f; 在正式進入數據結構之前&…

汽車制造場景下Profibus轉Profinet網關核心功能與應用解析

在當今工業自動化的浪潮中&#xff0c;各種通訊協議層出不窮&#xff0c;而其中PROFIBUS與PROFINET作為兩種主流的工業通信標準&#xff0c;它們之間的轉換需求日益增長。特別是對于那些希望實現老舊設備與現代化網絡無縫對接的企業來說&#xff0c;一個高效、穩定的網關產品顯…

qt ubuntu 20.04 交叉編譯

一、交叉編譯環境搭建 1.下載交叉編譯工具鏈&#xff1a;https://developer.arm.com/downloads/-/gnu-a 可以根據自己需要下載對應版本&#xff0c;當前最新版本是10.3, 筆者使用10.3編譯后的glibc.so版本太高&#xff08;glibc_2.3.3, glibc_2.3.4, glibc_2.3.5&#xff09;…

在Babylon.js中創建3D文字:簡單而強大的方法

引言 在3D場景中添加文字是許多WebGL項目的常見需求。Babylon.js提供了多種創建3D文字的方法&#xff0c;其中使用TextBlock結合平面網格是一種簡單而高效的方式。本文將介紹如何使用Babylon.js的GUI系統在3D空間中創建美觀的文字效果。 方法概述 Babylon.js的GUI系統允許我…

油桃TV v20250519 一款電視端應用網站聚合TV播放器 支持安卓4.1

油桃TV v20250519 一款電視端應用網站聚合TV播放器 支持安卓4.1 應用簡介&#xff1a; 油桃TV是一款開源電視端應用網站聚合瀏覽器&#xff0c;它把大家常見需求的一些網站都整合到了這個應用上&#xff0c;并進行了電視端…

Perl單元測試實戰指南:從Test::Class入門到精通的完整方案

閱讀原文 前言:為什么Perl開發者需要重視單元測試? "這段代碼昨天還能運行,今天就出問題了!"——這可能是每位Perl開發者都經歷過的噩夢。在沒有充分測試覆蓋的情況下,即使是微小的改動也可能導致系統崩潰。單元測試正是解決這一痛點的最佳實踐,它能幫助我們在…

OpenCv高階(十三)——人臉檢測

文章目錄 前言一、人臉檢測—haar特征二、人臉檢測---級聯分類器1、級聯分類器2、如何訓練級聯分類器3、已存在的級聯分類器 三、代碼分析1、人臉檢測的簡單使用2、人臉微笑檢測&#xff08;1&#xff09; 初始化視頻源&#xff08;2&#xff09;主循環處理每一幀&#xff08;3…

無線通信模塊簡介

QuecPython 是運行在無線通信模塊上的開發框架。對于首次接觸物聯網開發的用戶而言&#xff0c;無線通信模塊可能是一個相對陌生的概念。本文主要針對無線通信和蜂窩網絡本身&#xff0c;以及模塊的概念、特性和開發方式進行簡要的介紹。 無線通信和蜂窩網絡 物聯網對無線通信…

Unity 中實現首尾無限循環的 ListView

之前已經實現過&#xff1a; Unity 中實現可復用的 ListView-CSDN博客文章瀏覽閱讀5.6k次&#xff0c;點贊2次&#xff0c;收藏27次。源碼已放入我的 github&#xff0c;地址&#xff1a;Unity-ListView前言實現一個列表組件&#xff0c;表現方面最核心的部分就是重寫布局&…

【C++】 類和對象(上)

1.類的定義 1.1類的定義格式 ? class為定義類的關鍵字&#xff0c;后跟一個類的名字&#xff0c;{}中為類的主體&#xff0c;注意類定義結束時后?分號不能省 略。類體中內容稱為類的成員&#xff1a;類中的變量稱為類的屬性或成員變量;類中的函數稱為類的?法或 者成員函數。…

Transformer架構詳解:從Attention到ChatGPT

Transformer架構詳解&#xff1a;從Attention到ChatGPT 系統化學習人工智能網站&#xff08;收藏&#xff09;&#xff1a;https://www.captainbed.cn/flu 文章目錄 Transformer架構詳解&#xff1a;從Attention到ChatGPT摘要引言一、Attention機制&#xff1a;Transformer的…

Rock9.x(Linux)安裝Redis7

&#x1f49a;提醒&#xff1a;1&#xff09;注意權限問題 &#x1f49a; 查是否已經安裝了gcc gcc 是C語言編譯器&#xff0c;Redis是用C語言開發的&#xff0c;我們需要編譯它。 gcc --version如果沒有安裝gcc&#xff0c;那么我們手動安裝 安裝GCC sudo dnf -y install…

EasyExcel使用導出模版后設置 CellStyle失效問題解決

EasyExcel使用導出模版后在CellWriteHandler的afterCellDispose方法設置 CellStyle失效問題解決方法 問題描述&#xff1a;excel 模版塞入數據后&#xff0c;需要設置單元格的個性化設置時失效&#xff0c;本文以設置數據格式為例&#xff08;設置列的數據展示時需要加上千分位…

【Day41】

DAY 41 簡單CNN 知識回顧 數據增強卷積神經網絡定義的寫法batch歸一化&#xff1a;調整一個批次的分布&#xff0c;常用與圖像數據特征圖&#xff1a;只有卷積操作輸出的才叫特征圖調度器&#xff1a;直接修改基礎學習率 卷積操作常見流程如下&#xff1a; 1. 輸入 → 卷積層 →…

Express教程【002】:Express監聽GET和POST請求

文章目錄 2、監聽post和get請求2.1 監聽GET請求2.2 監聽POST請求 2、監聽post和get請求 創建02-app.js文件。 2.1 監聽GET請求 1??通過app.get()方法&#xff0c;可以監聽客戶端的GET請求&#xff0c;具體的語法格式如下&#xff1a; // 1、導入express const express req…

C# 文件 I/O 操作詳解:從基礎到高級應用

在軟件開發中&#xff0c;文件操作&#xff08;I/O&#xff09;是一項基本且重要的功能。無論是讀取配置文件、存儲用戶數據&#xff0c;還是處理日志文件&#xff0c;C# 都提供了豐富的 API 來高效地進行文件讀寫操作。本文將全面介紹 C# 中的文件 I/O 操作&#xff0c;涵蓋基…

Vue-Router簡版手寫實現

1. 路由庫工程設計 首先&#xff0c;我們需要創建幾個核心文件來組織我們的路由庫&#xff1a; src/router/index.tsRouterView.tsRouterLink.tsuseRouter.tsinjectionsymbols.tshistory.ts 2. injectionSymbols.ts 定義一些注入符號來在應用中共享狀態&#xff1a; import…

Electron-vite【實戰】MD 編輯器 -- 文件列表(含右鍵快捷菜單,重命名文件,刪除本地文件,打開本地目錄等)

最終效果 頁面 src/renderer/src/App.vue <div class"dirPanel"><div class"panelTitle">文件列表</div><div class"searchFileBox"><Icon class"searchFileInputIcon" icon"material-symbols-light:…

Remote Sensing投稿記錄(投稿郵箱寫錯、申請大修延期...)風雨波折投稿路

歷時近一個半月&#xff0c;我中啦&#xff01; RS是中科院二區&#xff0c;2023-2024影響因子4.2&#xff0c;五年影響因子4.9。 投稿前特意查了下預警&#xff0c;發現近五年都不在預警名單中&#xff0c;甚至最新中科院SCI分區&#xff08;2025年3月&#xff09;在各小類上…