使用PyTorch訓練VGG11模型:Fashion-MNIST圖像分類實戰

本文將通過代碼實戰,詳細講解如何使用?PyTorch?和?VGG11?模型在?Fashion-MNIST?數據集上進行圖像分類任務。代碼包含數據預處理、模型定義、訓練與評估全流程,并附上訓練結果的可視化圖表。所有代碼可直接復現,適合深度學習初學者和進階開發者參考。


1. 環境準備

確保已安裝以下庫:

pip install torch torchvision d2l
2. 代碼實現
2.1 導入依賴庫
from d2l import torch as d2l
from torchvision import models, transforms
import torch
2.2 數據預處理

由于VGG11默認接受RGB三通道輸入,需將Fashion-MNIST的灰度圖轉換為3通道:

# 定義數據預處理流程
transform = transforms.Compose([transforms.Resize(224),                # 調整圖像尺寸為224x224transforms.Grayscale(num_output_channels=3),  # 單通道轉三通道transforms.ToTensor()                   # 轉為Tensor格式
])
2.3 加載數據集
# 加載Fashion-MNIST數據集并應用預處理
batch_size = 64 * 3  # 增大批大小以利用GPU并行計算
train_data, test_data = d2l.load_data_fashion_mnist(batch_size, resize=224)# 替換原始數據集的數據增強方法
train_data.dataset.transform = transform
test_data.dataset.transform = transform
2.4 定義模型

使用PyTorch內置的VGG11模型(從頭訓練,不使用預訓練權重):

# 初始化VGG11模型(輸入通道為3,輸出類別為10)
net = models.vgg11(pretrained=False, num_classes=10)
2.5 模型訓練

調用D2L庫的封裝函數進行訓練(支持GPU加速):

# 設置超參數并啟動訓練
num_epochs = 10
lr = 0.01
device = d2l.try_gpu()  # 自動檢測GPU# 開始訓練
d2l.train_ch6(net, train_data, test_data, num_epochs, lr, device)
3. 訓練結果分析

下圖為訓練過程中的損失和準確率變化曲線:

關鍵指標
EpochTrain LossTrain AccTest AccSpeed (examples/sec)
10.8570.2%78.5%112.3
30.31288.6%88.1%117.7
50.3287.6%84.3%118.5
100.2191.8%85.7%119.0
  • 訓練損失(Train Loss):隨著訓練輪次增加,損失快速下降并趨于穩定。例如,第3輪時損失降至?0.312,表明模型快速收斂。

  • 訓練準確率(Train Acc):第3輪時達到?88.6%,說明模型對訓練數據的學習效果顯著。

  • 測試準確率(Test Acc):第3輪測試準確率?88.1%,與訓練準確率接近,表明模型泛化能力優秀,未出現明顯過擬合。

  • 訓練速度:在?cuda:0?設備上達到?117.7 examples/sec,充分利用GPU加速,適合大規模數據訓練。

4. 完整代碼?
from d2l import torch as d2l
from torchvision import models, transforms
import torch# 數據預處理
transform = transforms.Compose([transforms.Resize(224),transforms.Grayscale(num_output_channels=3),transforms.ToTensor()
])# 加載數據集
batch_size = 64 * 3
train_data, test_data = d2l.load_data_fashion_mnist(batch_size, resize=224)
train_data.dataset.transform = transform
test_data.dataset.transform = transform# 定義模型
net = models.vgg11(pretrained=False, num_classes=10)# 訓練配置
num_epochs = 10
lr = 0.01
device = d2l.try_gpu()# 啟動訓練
d2l.train_ch6(net, train_data, test_data, num_epochs, lr, device)
5. 常見問題
Q1:為什么將灰度圖轉為三通道?

VGG系列模型設計時默認接受RGB輸入(3通道)。盡管Fashion-MNIST為單通道,需通過復制通道數適配模型。

Q2:如何進一步提升準確率?
  • 增加訓練輪次(如?num_epochs=20)。

  • 使用更復雜模型(如VGG16、ResNet)。

  • 添加數據增強(隨機旋轉、亮度調整)。

Q3:訓練時顯存不足怎么辦?
  • 減小?batch_size(如設為64)。

  • 啟用混合精度訓練(添加?torch.cuda.amp)。


6. 總結

本文使用PyTorch實現了VGG11模型在Fashion-MNIST數據集上的分類任務,最終測試準確率達?85.7%,并在第3輪即達到?88.1%?的測試準確率,訓練速度高達?117.7 examples/sec,展現了優秀的性能與效率。通過代碼解析與結果分析,讀者可快速掌握從數據預處理到模型訓練的完整流程,并根據實際需求調整模型或超參數進一步優化性能。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/77223.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/77223.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/77223.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

汽車BMS技術分享及其HIL測試方案

一、BMS技術簡介 在全球碳中和目標的戰略驅動下,新能源汽車產業正以指數級速度重塑交通出行格局。動力電池作為電動汽車的"心臟",其性能與安全性不僅直接決定了車輛的續航里程、使用壽命等關鍵指標,更深刻影響著消費者對電動汽車的…

打造船岸“5G+AI”智能慧眼 智驅力賦能客船數智管理

項目介紹 船舶在航行、作業過程中有著嚴格的規范要求,但在實際航行與作業中往往會因為人為的疏忽,發生事故,導致人員重大傷亡和財產損失; 為推動安全治理模式向事前預防轉型,實現不安全狀態和行為智能預警&#xff0c…

C#二叉樹

C#二叉樹 二叉樹是一種常見的數據結構,它是由節點組成的一種樹形結構,其中每個節點最多有兩個子節點。二叉樹的一個節點通常包含三部分:存儲數據的變量、指向左子節點的指針和指向右子節點的指針。二叉樹可以用于多種算法和操作,…

WinForm真入門(11)——ComboBox控件詳解

WinForm中 ComboBox 控件詳解? ComboBox 是 WinForms 中一個集文本框與下拉列表于一體的控件,支持用戶從預定義選項中選擇或直接輸入內容。以下從核心屬性、事件、使用場景到高級技巧的全面解析: 一、ComboBox 核心屬性? 屬性說明示例?Items?下拉…

超詳細解讀:數據庫MVCC機制

之前文章:Mysql鎖_exclusivelock for update寫鎖-CSDN博客 中有提到通過MVCC來實現快照讀,從而解決幻讀問題,這里詳細介紹下MVCC。 一、前言 表1:實例表t idk1122 表2:事務A、B、C的執行流程 事務A事務B事務Cstart …

【SpringCloud】從入門到精通【上】

今天主播我把黑馬新版微服務課程MQ高級之前的內容都看完了,雖然在看視頻的時候也記了筆記,但是看完之后還是忘得差不多了,所以打算寫一篇博客再溫習一下內容。 課程坐標:黑馬程序員SpringCloud微服務開發與實戰 微服務 認識單體架構 單體架…

力扣hot100_回溯(2)_python版本

一、39. 組合總和(中等) 代碼: class Solution:def combinationSum(self, candidates: List[int], target: int) -> List[List[int]]:ans []path []def dfs(i: int, left: int) -> None:if left 0:# 找到一個合法組合ans.append(pa…

AI平臺如何實現推理?數算島是一個開源的AI平臺(主要用于管理和調度分布式AI訓練和推理任務。)

數算島是一個開源的AI平臺,主要用于管理和調度分布式AI訓練和推理任務。它基于Kubernetes構建,支持多種深度學習框架(如TensorFlow、PyTorch等)。以下是數算島實現模型推理的核心原理、架構及具體實現步驟: 一、數算島…

cesium項目之cesiumlab地形數據加載

之前的文章我們有提到,使用cesiumlab加載地形出現了一些錯誤,沒有解決,今天作者終于找到了解決方法,下面描述一下具體步驟,首先在地理數據云下載dem數據,在cesiumlab中使用地形切片,得到terrain…

[Vue]App.vue講解

頁面中可以看見的內容不再在index.html中進行編輯,而是在App.vue中進行編輯。 組件化開發 在傳統的html開發中,一個頁面的資源往往都寫在同一個html文件中。這種模式在開發小規模、樣式簡單的項目時會相當便捷,但當項目規模越來越大&#xf…

sql-labs靶場 less-1

文章目錄 sqli-labs靶場less 1 聯合注入 sqli-labs靶場 每道題都從以下模板講解,并且每個步驟都有圖片,清晰明了,便于復盤。 sql注入的基本步驟 注入點注入類型 字符型:判斷閉合方式 (‘、"、’、“”&#xf…

藍橋杯-小明的彩燈(差分)

問題描述: 差分數組 1. 什么是差分數組? 差分數組 c 是原數組 a 的“差值表示”,其定義如下: c[0] a[0]c[i] a[i] - a[i-1] (i ≥ 1) 差分數組記錄了相鄰元素的差值。例如,原數組 a [1, …

精品可編輯PPT | 基于湖倉一體構建數據中臺架構大數據湖數據倉庫一體化中臺解決方案

本文介紹了基于湖倉一體構建數據中臺架構的技術創新與實踐。它詳細闡述了數據湖、數據倉庫和數據中臺的概念,分析了三者的區別與協作關系,指出數據湖可存儲大規模結構化和非結構化數據,數據倉庫用于高效存儲和快速查詢以支持決策,…

最近api.themoviedb.org無法連接的問題解決

修改NAS的host需要用到SSH終端連接工具,比如常見的Putty,XShell,或者FinalShell等都可以,我個人還是習慣Putty。 1.輸入命令“ sudo -i ”回車,提示輸入密碼,密碼就是我們NAS的登錄密碼,輸入的…

0.機器學習基礎

0.人工智能概述: (1)必備三要素: 數據算法計算力 CPU、GPU、TPUGPU和CPU對比: GPU主要適合計算密集型任務;CPU主要適合I/O密集型任務; 【筆試問題】什么類型程序適合在GPU上運行&#xff1…

多類型醫療自助終端智能化升級路徑(代碼版.下)

醫療人機交互層技術實施方案 一、多模態交互體系 1. 醫療語音識別引擎 # 基于Wav2Vec2的醫療ASR系統 from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC import torchaudioclass MedicalASR:def __init__(self):self.processor = Wav2Vec2Processor.from_pretrai…

前端基礎:React項目打包部署服務器教程

問題背景 我做了一個React框架的前端的Node項目,是一個單頁面應用。 頁面路由用的是,然后使用了React.lazy在路由層級對每一個不同頁面進行了懶加載,只有打開那個頁面才會加載對應資源。 然后現在我用了Webpack5對項目進行了打包&#xff…

【深度學習:理論篇】--Pytorch基礎入門

目錄 1.Pytorch--安裝 2.Pytorch--張量 3.Pytorch--定義 4.Pytorch--運算 4.1.Tensor數據類型 4.2.Tensor創建 4.3.Tensor運算 4.4.Tensor--Numpy轉換 4.5.Tensor--CUDA(GPU) 5.Pytorch--自動微分 (autograd) 5.1.back…

使用 Spring Boot 快速構建企業微信 JS-SDK 權限簽名后端服務

使用 Spring Boot 快速構建企業微信 JS-SDK 權限簽名后端服務 本篇文章將介紹如何使用 Spring Boot 快速構建一個用于支持企業微信 JS-SDK 權限校驗的后端接口,并提供一個簡單的 HTML 頁面進行功能測試。適用于需要在企業微信網頁端使用掃一掃、定位、錄音等接口的…

工程師 - FTDI SPI converter

中國網站:FTDIChip- 首頁 UMFT4222EV-D UMFT4222EV-D - FTDI 可以下載Datasheet。 UMFT4222EVUSB2.0 to QuadSPI/I2C Bridge Development Module Future Technology Devices International Ltd. The UMFT4222EV is a development module which uses FTDI’s FT4222H…