卷積神經網絡實現mnist手寫數字集識別案例

手寫數字識別是計算機視覺領域的“Hello World”,也是深度學習入門的經典案例。它通過訓練模型識別0-9的手寫數字圖像(如MNIST數據集),幫助我們快速掌握神經網絡的核心流程。本文將以PyTorch框架為基礎,帶你從數據加載、模型構建到訓練評估,完整實現一個手寫數字識別系統。

二、數據加載與預處理:認識MNIST數據集

1. MNIST數據集簡介

MNIST是手寫數字的標準數據集,包含:

  • 訓練集:60,000張28x28的灰度圖(0-9數字)
  • 測試集:10,000張同尺寸圖片
  • 每張圖片已歸一化(像素值0-1),標簽為0-9的整數

2. 代碼實現:下載與加載數據

使用torchvision.datasets可直接下載MNIST,transforms.ToTensor()將圖片轉為PyTorch張量(通道優先格式:[1,28,28],1為灰度通道數)。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# 下載訓練集(60,000張)
train_data = datasets.MNIST(root="data",       # 數據存儲路徑train=True,        # 標記為訓練集download=True,     # 自動下載(首次運行時)transform=ToTensor()  # 轉為張量(shape: [1,28,28])
)# 下載測試集(10,000張)
test_data = datasets.MNIST(root="data",train=False,       # 標記為測試集download=True,transform=ToTensor()
)

3. 數據封裝:DataLoader批量加載

DataLoader將數據集打包為可迭代的批量數據,支持隨機打亂(訓練集)、多線程加載等。

device = "cuda" if torch.cuda.is_available() else "cpu"  # 自動選擇GPU/CPU
batch_size = 64  # 每批64張圖片(可根據顯存調整)# 訓練集DataLoader(打亂順序)
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
# 測試集DataLoader(不打亂順序)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

三、模型構建:設計卷積神經網絡(CNN)

1. 為什么選擇CNN?

手寫數字識別需要捕捉圖像的局部特征(如筆畫邊緣、拐點),而CNN的卷積層通過滑動窗口提取局部模式,池化層降低計算量,全連接層完成分類,非常適合處理圖像任務。

2. 模型結構詳解(附代碼注釋)

以下是我們定義的CNN模型,包含3個卷積塊和1個全連接輸出層:

class CNN(nn.Module):def __init__(self):super().__init__()  # 繼承PyTorch模塊基類# 卷積塊1:輸入1通道(灰度圖)→ 輸出8通道特征圖self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,    # 輸入通道數(灰度圖)out_channels=8,   # 輸出8個特征圖(8個卷積核)kernel_size=5,    # 卷積核尺寸5x5(覆蓋局部區域)stride=1,         # 滑動步長1(不跳躍)padding=2         # 邊緣填充2圈0(保持輸出尺寸不變)),nn.ReLU(),  # 非線性激活(引入復雜模式)nn.MaxPool2d(kernel_size=2)  # 最大池化(2x2窗口,尺寸減半))# 卷積塊2:特征抽象(8→16→32通道)self.conv2 = nn.Sequential(nn.Conv2d(8, 16, 5, 1, 2),  # 8→16通道,5x5卷積,填充2(尺寸不變)nn.ReLU(),nn.Conv2d(16, 32, 5, 1, 2), # 16→32通道,5x5卷積,填充2(尺寸不變)nn.ReLU(),nn.MaxPool2d(kernel_size=2)  # 尺寸減半(14→7))# 卷積塊3:特征精煉(32→256通道,保留空間信息)self.conv3 = nn.Sequential(nn.Conv2d(32, 256, 5, 1, 2),  # 32→256通道,5x5卷積,填充2(尺寸不變)nn.ReLU())# 全連接輸出層:256*7*7維特征→10類概率self.out = nn.Linear(256 * 7 * 7, 10)  # 10對應0-9數字類別def forward(self, x):"""前向傳播:定義數據流動路徑"""x = self.conv1(x)  # 輸入:[64,1,28,28] → 輸出:[64,8,14,14](池化后尺寸減半)x = self.conv2(x)  # 輸入:[64,8,14,14] → 輸出:[64,32,7,7](兩次卷積+池化)x = self.conv3(x)  # 輸入:[64,32,7,7] → 輸出:[64,256,7,7](僅卷積)x = x.view(x.size(0), -1)  # 展平:[64,256,7,7] → [64,256*7*7](全連接需要一維輸入)output = self.out(x)       # 輸出:[64,10](每個樣本對應10類的得分)return output

3. 關鍵參數計算(以輸入28x28為例)

  • conv1后:卷積核5x5,填充2,輸出尺寸(28-5+2*2)/1 +1=28;池化后尺寸28/2=14 → 輸出[64,8,14,14]
  • conv2后:兩次卷積保持14x14,池化后14/2=7 → 輸出[64,32,7,7]
  • conv3后:卷積保持7x7 → 輸出[64,256,7,7]
  • 展平后256*7*7=12544維向量 → 全連接到10類

四、訓練配置:損失函數與優化器

1. 損失函數:交叉熵損失(CrossEntropyLoss)

手寫數字識別是多分類任務,交叉熵損失函數直接衡量模型輸出概率與真實標簽的差異。PyTorch的nn.CrossEntropyLoss已集成Softmax操作(無需手動添加)。

2. 優化器:隨機梯度下降(SGD)

優化器負責根據損失值更新模型參數。這里選擇SGD(學習率lr=0.1),簡單且對小數據集友好(也可嘗試Adam等更復雜的優化器)。

model = CNN().to(device)  # 模型加載到GPU/CPU
loss_fn = nn.CrossEntropyLoss()  # 交叉熵損失
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)  # SGD優化器

五、訓練循環:讓模型“學習”特征

1. 訓練邏輯概述

訓練過程的核心是“前向傳播→計算損失→反向傳播→更新參數”,重復直到模型收斂。具體步驟:

  1. 模型設為訓練模式(model.train());
  2. 遍歷訓練數據,按批輸入模型;
  3. 計算預測值與真實標簽的損失;
  4. 反向傳播計算梯度(loss.backward());
  5. 優化器更新參數(optimizer.step());
  6. 清空梯度(optimizer.zero_grad())避免累積。

2. 代碼實現:訓練函數

def train(dataloader, model, loss_fn, optimizer):model.train()  # 開啟訓練模式(影響Dropout/BatchNorm等層)total_loss = 0  # 記錄總損失for batch_idx, (x, y) in enumerate(dataloader):x, y = x.to(device), y.to(device)  # 數據加載到GPU/CPU# 1. 前向傳播:模型預測pred = model(x)# 2. 計算損失:預測值 vs 真實標簽loss = loss_fn(pred, y)total_loss += loss.item()  # 累加批次損失# 3. 反向傳播:計算梯度optimizer.zero_grad()  # 清空歷史梯度loss.backward()        # 反向傳播計算當前梯度# 4. 更新參數:根據梯度調整模型權重optimizer.step()# 每100個批次打印一次損失(監控訓練進度)if (batch_idx + 1) % 100 == 0:print(f"批次 {batch_idx+1}/{len(dataloader)}, 當前損失: {loss.item():.4f}")avg_loss = total_loss / len(dataloader)print(f"訓練完成,平均損失: {avg_loss:.4f}")

六、測試評估:驗證模型泛化能力

1. 測試邏輯概述

測試階段需關閉模型的隨機操作(如Dropout),用測試集評估模型的泛化能力。核心指標是準確率(正確預測的樣本比例)。

2. 代碼實現:測試函數

def test(dataloader, model):model.eval()  # 開啟評估模式(關閉Dropout等隨機層)correct = 0   # 記錄正確預測數total = 0     # 記錄總樣本數with torch.no_grad():  # 關閉梯度計算(節省內存)for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model(x)  # 模型預測# 統計正確數:pred.argmax(1)取預測概率最大的類別correct += (pred.argmax(1) == y).sum().item()total += y.size(0)  # 累加批次樣本數accuracy = correct / totalprint(f"測試準確率: {accuracy * 100:.2f}%")return accuracy

七、完整訓練與結果

1. 運行訓練循環

我們訓練10個epoch(遍歷整個訓練集10次):

# 訓練10輪
for epoch in range(10):print(f"
====={epoch+1} 輪訓練 =====")train(train_dataloader, model, loss_fn, optimizer)# 測試最終效果
print("
===== 最終測試 =====")
test_acc = test(test_dataloader, model)

2. 典型輸出結果

假設訓練10輪后,測試準確率可能達到98.5%+(具體取決于超參數和硬件):

===== 第 1 輪訓練 =====
批次 100/938, 當前損失: 0.2145
...
訓練完成,平均損失: 0.1234===== 第 10 輪訓練 =====
批次 100/938, 當前損失: 0.0321
...
訓練完成,平均損失: 0.0189===== 最終測試 =====
測試準確率: 98.76%

八、改進方向:讓模型更強大

當前模型已能較好識別手寫數字,但仍有優化空間:

1. 調整超參數

  • 學習率:若損失下降緩慢,降低lr(如0.01);若波動大,增大lr
  • 批量大小:增大batch_size(如128)可加速訓練(需更大顯存)。
  • 訓練輪次:增加epoch(如20輪),但需防止過擬合(訓練損失持續下降,測試損失上升)。

2. 添加正則化

  • Batch Normalization:在卷積層后添加nn.BatchNorm2d(out_channels),加速收斂并穩定訓練。
    self.conv1 = nn.Sequential(nn.Conv2d(1,8,5,1,2),nn.BatchNorm2d(8),  # 新增nn.ReLU(),nn.MaxPool2d(2)
    )
    
  • Dropout:在全連接層前添加nn.Dropout(p=0.5),隨機斷開神經元,防止過擬合。
    self.out = nn.Sequential(nn.Dropout(0.5),  # 新增nn.Linear(256*7*7, 10)
    )
    

3. 使用更深的網絡

當前模型僅3個卷積塊,對于復雜任務(如ImageNet),可使用ResNet等殘差網絡,通過跳躍連接(Skip Connection)解決深層網絡的梯度消失問題。

九、總結

通過本文,你已完成從數據加載到模型訓練的全流程,掌握了:

  • 數據預處理:使用torchvision加載標準數據集,DataLoader批量管理數據;
  • 模型構建:設計CNN的核心組件(卷積層、激活函數、池化層);
  • 訓練與評估:理解損失函數、優化器的作用,掌握訓練循環和測試邏輯。

手寫數字識別是深度學習的起點,你可以嘗試修改模型結構(如增加卷積層)、更換數據集(如Fashion-MNIST)或調整超參數,進一步探索深度學習的魅力!

動手建議:運行代碼時,嘗試將device改為cpu(無GPU時),觀察訓練速度變化;或修改kernel_size(如3x3),對比模型性能差異。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/94607.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/94607.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/94607.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

實戰筆記——構建智能Agent:SpreadJS代碼助手

目錄 前言 解決思路 需求理解 MCP Server LangGraph 本教程目標 技術棧 第一部分:構建 MCP Server - 工具服務化的基礎架構 第二部分:Tools 實現 第三部分:基于 LangGraph 構建智能 Agent 第四部分:服務器和前端搭建 前…

【Word】用 Python 輕松實現 Word 文檔對比并生成可視化 HTML 報告

在日常工作和學習中,我們經常需要對兩個版本的文檔進行比對,比如合同修改、論文修訂、報告更新等。手動逐字檢查不僅耗時費力,還容易遺漏細節。 今天,我將帶你使用 Python python-docx difflib 實現一個自動化 Word 文檔對比工具…

從0開始搭建一個前端項目(vue + vite + typescript)

版本 node:v22.17.1 pnpm:v10.13.1 vue:^3.5.18 vite:^7.0.6 typescipt:~5.8.0腳手架初始化vue pnpm create vuelatest只選擇: TypeScript, JSX 3. 用vscode打開創建的項目,并刪除多余的代碼esl…

1.ImGui-環境安裝

免責聲明:內容僅供學習參考,請合法利用知識,禁止進行違法犯罪活動! 本次游戲沒法給 內容參考于:微塵網絡安全 IMGUI是一個被廣泛應用到逆向里面的,它可以用來做外部的繪制,比如登錄界面&…

基于springboot的二手車交易系統

博主介紹:java高級開發,從事互聯網行業六年,熟悉各種主流語言,精通java、python、php、爬蟲、web開發,已經做了六年的畢業設計程序開發,開發過上千套畢業設計程序,沒有什么華麗的語言&#xff0…

修改win11任務欄時間字體和小圖標顏色

1 打開運行提示框 在桌面按快捷鍵winR,然后如下圖所示輸入regedit2 查找路徑 1、在路徑處粘貼路徑計算機\HKEY_CURRENT_USER\Software\Microsoft\Windows\CurrentVersion\Themes\Personalize 2、如下圖所示,雙擊打開ColorPrevalence,將里面的…

第13集 當您的USB設備不在已實測支持列表,如何讓TOS-WLink支持您的USB設備--答案Wireshark USB抓包

問:當您的USB設備不在已實測支持列表,如何讓TOS-WLink支持您的USB設備? 答案:使用Wireshark USB抓包,日志發給我 為什么要抓包: USB設備種類繁多;TOS-WLink是單片機,內存緊張&#…

[靈動微電子 MM32BIN560CN MM32SPIN0280]讀懂電機MCU之比較器

作為剛接觸微控制器的初學者,在看到MM32SPIN0280用戶手冊中“比較器”相關內容時,是不是會感到困惑?比如“5個通用比較器”“輪詢功能”“遲滯電壓”這些術語,好像都和電機控制有關,但又不知道具體怎么用。別擔心&…

? 貳 ? ? 安全架構:數字銀行安全體系規劃

👍點「贊」📌收「藏」👀關「注」💬評「論」 🔥更多文章戳👉Whoami!-CSDN博客🚀 在金融科技深度融合的背景下,信息安全已從單純的技術攻防擴展至架構、合規、流程與創新的…

布隆過濾器完全指南:從原理到實戰

布隆過濾器完全指南:從原理到實戰 摘要:本文深入解析布隆過濾器的核心原理、實現細節和實際應用,提供完整的Java實現代碼,并探討性能優化策略。適合想要深入理解概率數據結構的開發者閱讀。 前言 在大數據時代,如何快速判斷一個元素是否存在于海量數據集合中?傳統的Hash…

?嵌入式Linux學習 - 網絡服務器實現與客戶端的通信

1.單循環服務器 2.并發服務器 1. 設置socket屬性 2. 進程 ?3. 線程 3.多路IO復用模型 - 提高并發程度 1. 區別 2. IO處理模型 1. 阻塞IO模型 2. 非阻塞IO模型 3. 信號驅動IO 4. IO多路復用 3. 特點 4. 函數接口 1. select 2. poll 3. epoll 半包 1.單循環服務…

Mybatis中緩存機制的理解以及優缺點

文章目錄一、MyBatis 緩存機制詳解1. 一級緩存(Local Cache)2. 二級緩存(Global Cache)3. 緩存執行順序二、MyBatis 緩存的優點三、MyBatis 緩存的缺點四、適用場景與最佳實踐總結MyBatis 提供了完善的緩存機制,用于減…

Rust 登堂 之 類型轉換(三)

Rust 是類型安全的語言,因此在Rust 中做類型轉換不是一件簡單的事,這一章節,我們將對Rust 中的類型轉換進行詳盡講解。 高能預警,本章節有些難,可以考慮學了進階后回頭再看 as 轉換 先來看一段代碼 fn main() {let a…

【MySQL 為什么默認會給 id 建索引? MySQL 主鍵索引 = 聚簇索引?】

MySQL 索引 MySQL 為什么默認會給 id 建索引? & MySQL 主鍵索引 聚簇索引? 結論:在 MySQL (InnoDB) 中,主鍵索引是自動創建的聚簇索引,不需要刪除,其他索引是補充優化。 1. MySQL 的id 索引是怎么來的…

[光學原理與應用-321]:皮秒深紫外激光器產品不同階段使用的工具軟件、對應的輸出文件

在皮秒深紫外激光器的開發過程中,不同階段使用的工具軟件及其對應的輸出文件如下:一、設計階段工具軟件:Zemax OpticStudio:用于光學系統的初步設計和仿真,包括光線追跡、像差分析、優化設計等。MATLAB:用于…

openEuler常用操作指令

openEuler常用操作指令 一、前言 1.簡介 openEuler是由開放原子開源基金會孵化的全場景開源操作系統項目,面向數字基礎設施四大核心場景(服務器、云計算、邊緣計算、嵌入式),全面支持ARM、x86、RISC-V、loongArch、PowerPC、SW…

Python爬蟲實戰:構建網易云音樂個性化音樂播放列表同步系統

1. 引言 1.1 研究背景 在數字音樂生態中,各大音樂平臺憑借獨家版權、個性化推薦等優勢占據不同市場份額。根據國際唱片業協會(IFPI)2024 年報告,全球流媒體音樂用戶已突破 50 億,其中超過 60% 的用戶同時使用 2 個及以上音樂平臺。用戶在不同平臺積累的播放列表包含大量…

vscode 配置 + androidStudio配置

插件代碼片段 餓了么 icon{"Print to console": {"prefix": "ii-ep-","body": ["i-ep-"],"description": "elementPlus Icon"} }Ts 初始化模版{"Print to console": {"prefix": &q…

DQN(深度Q網絡):深度強化學習的里程碑式突破

本文由「大千AI助手」原創發布,專注用真話講AI,回歸技術本質。拒絕神話或妖魔化。搜索「大千AI助手」關注我,一起撕掉過度包裝,學習真實的AI技術! ? 1. DQN概述:當深度學習遇見強化學習 DQN(D…

個人博客運行3個月記錄

個人博客 自推一波,目前我的Hexo個人博客已經優化的足夠好了, 已經足夠穩定的和簡單進行發布和管理,但還是有不少問題,總之先記下來再說 先總結下 關于評論系統方面,我從Waline (快速上手 | Waline) 更換成了&#x…