用 PyTorch 搭建 CNN 實現 MNIST 手寫數字識別

在圖像識別領域,卷積神經網絡(CNN)?憑借其對空間特征的高效提取能力,成為手寫數字識別、人臉識別等任務的首選模型。而 MNIST(手寫數字數據集)作為入門級數據集,幾乎是每個深度學習學習者的 “第一個項目”。

本文將帶大家從零開始,用 PyTorch 搭建一個 CNN 模型完成 MNIST 手寫數字識別任務,不僅會貼出完整代碼,還會逐行解析核心邏輯,幫你搞懂 “每個參數為什么這么設”“每一層的作用是什么”,即使是剛接觸 PyTorch 的新手也能輕松跟上。

一、前置知識與環境準備

在開始前,我們需要先明確兩個核心背景,以及搭建好運行環境:

1. 核心背景速覽

  • MNIST 數據集:包含 70000 張 28×28 像素的灰度手寫數字圖片(0-9),其中 60000 張為訓練集,10000 張為測試集,每張圖片對應一個 “數字類別” 標簽(0-9)。
  • CNN 為什么適合?:相比全連接神經網絡,CNN 通過 “卷積層提取局部特征(邊緣、紋理)+ 池化層下采樣”,能大幅減少參數數量、避免過擬合,同時更好地保留圖像的空間結構信息。

2. 環境準備

需要安裝 PyTorch 和 TorchVision(PyTorch 官方的計算機視覺庫,內置 MNIST 數據集):

# pip安裝命令(根據系統自動匹配版本,若需指定CUDA版本可參考PyTorch官網)
pip install torch torchvision

驗證環境是否安裝成功:

import torch
print(torch.__version__)  # 輸出PyTorch版本,如2.0.1
print(torch.cuda.is_available())  # 輸出True表示支持GPU加速(需NVIDIA顯卡)

二、完整代碼先行(可直接運行)

先貼出完整可運行的代碼,后面會逐段拆解解析:

注意:nn.Sequential()是將網絡層組合在一起,內部不能寫函數

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# 1. 加載MNIST數據集
train_data = datasets.MNIST(root='data',          # 數據保存路徑train=True,           # 加載訓練集download=True,        # 若路徑下無數據則自動下載transform=ToTensor()  # 將圖像轉為Tensor(0-1歸一化+維度調整:(H,W,C)→(C,H,W))
)
test_data = datasets.MNIST(root='data',train=False,          # 加載測試集download=True,transform=ToTensor()
)# 2. 數據加載器(分批處理數據)
train_loader = DataLoader(train_data, batch_size=64)  # 每批64個樣本
test_loader = DataLoader(test_data, batch_size=64)# 3. 設備配置(優先GPU,其次CPU)
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')  # 打印當前使用的設備# 4. 定義CNN模型
class CNN(nn.Module):def __init__(self):super().__init__()# 卷積塊1:輸入(1,28,28) → 輸出(8,14,14)self.conv1 = nn.Sequential(# 卷積層:1個輸入通道→8個輸出通道,卷積核5×5,步長1,填充2nn.Conv2d(in_channels=1, out_channels=8, kernel_size=5, stride=1, padding=2),nn.ReLU(),  # 激活函數(引入非線性)nn.MaxPool2d(kernel_size=2)  # 池化層:2×2下采樣,尺寸減半)# 卷積塊2:輸入(8,14,14) → 輸出(32,7,7)self.conv2 = nn.Sequential(nn.Conv2d(8, 16, 5, 1, 2),  # 8→16通道,其他參數同上nn.ReLU(),nn.Conv2d(16, 32, 5, 1, 2),  # 16→32通道nn.ReLU(),nn.MaxPool2d(2)  # 下采樣后尺寸14→7)# 卷積塊3:輸入(32,7,7) → 輸出(64,7,7)(無池化,保留尺寸)self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),  # 32→64通道nn.ReLU(),nn.Conv2d(64, 64, 5, 1, 2),  # 64→64通道(加深特征提取)nn.ReLU())# 全連接層:輸入(64×7×7) → 輸出10(對應10個數字類別)self.out = nn.Linear(64 * 7 * 7, 10)# 前向傳播(定義數據在模型中的流動路徑)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)  # 展平:(batch_size, 64,7,7) → (batch_size, 64×7×7)output = self.out(x)return output# 5. 初始化模型并移至指定設備
model = CNN().to(device)
print(model)  # 打印模型結構,驗證是否正確# 6. 定義訓練函數
def train(dataloader, model, loss_fn, optimizer):model.train()  # 啟用訓練模式(如BatchNorm、Dropout會生效)batch_count = 1  # 計數批次,用于打印日志for X, y in dataloader:# 將數據移至指定設備(GPU/CPU)X, y = X.to(device), y.to(device)# 前向傳播:計算模型預測值pred = model(X)# 計算損失(多分類任務用CrossEntropyLoss)loss = loss_fn(pred, y)# 反向傳播:更新模型參數optimizer.zero_grad()  # 清空上一輪梯度(避免累積)loss.backward()        # 計算梯度(反向傳播)optimizer.step()       # 根據梯度更新參數(優化器執行)# 每100個批次打印一次損失(監控訓練進度)if batch_count % 100 == 0:loss_value = loss.item()  # 取出損失值(脫離計算圖)print(f'Batch: {batch_count:>4} | Loss: {loss_value:>6.4f}')batch_count += 1# 7. 定義測試函數
def test(dataloader, model, loss_fn):model.eval()  # 啟用評估模式(關閉BatchNorm、Dropout)total_samples = len(dataloader.dataset)  # 測試集總樣本數correct = 0  # 正確預測的樣本數total_loss = 0  # 總損失# 禁用梯度計算(測試階段無需更新參數,節省內存)with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)# 累積損失和正確數total_loss += loss_fn(pred, y).item()# pred.argmax(1):取每行最大概率的索引(即預測類別),與y比較correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 計算平均損失和準確率avg_loss = total_loss / len(dataloader)  # len(dataloader) = 總批次accuracy = (correct / total_samples) * 100  # 準確率(百分比)print(f'\nTest Result | Accuracy: {accuracy:>5.2f}% | Avg Loss: {avg_loss:>6.4f}\n')# 8. 配置訓練參數并執行
loss_fn = nn.CrossEntropyLoss()  # 多分類交叉熵損失(內置Softmax)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Adam優化器,學習率0.001
epochs = 10  # 訓練輪次(整個訓練集遍歷10次)# 循環訓練+測試
for epoch in range(epochs):print(f'=================== Epoch {epoch + 1}/{epochs} ===================')train(train_loader, model, loss_fn, optimizer)  # 訓練一輪test(test_loader, model, loss_fn)  # 測試一輪print("Training Finished!")

三、核心代碼逐段解析

上面的代碼看似長,但邏輯很清晰,我們按 “數據→模型→訓練→測試” 的流程拆解核心部分。

1. 數據加載與預處理

MNIST 數據集的加載全靠torchvision.datasets.MNIST,無需手動下載和解析,非常方便。關鍵參數解析:

  • root='data':數據會保存在當前目錄的data文件夾下(自動創建);
  • train=True/FalseTrue加載 6 萬張訓練集,False加載 1 萬張測試集;
  • transform=ToTensor():這是核心預處理步驟,作用有兩個:
    1. 將圖像從 “PIL 格式(0-255 像素值)” 轉為 “Tensor 格式(0-1 歸一化值)”,避免大數值導致梯度爆炸;
    2. 調整維度:從圖像默認的(高度H, 寬度W, 通道C)轉為 PyTorch 要求的(通道C, 高度H, 寬度W)(MNIST 是灰度圖,C=1)。

然后用DataLoader將數據集分批:

  • batch_size=64:每次訓練取 64 個樣本計算梯度(batch_size 越大,訓練越穩定,但內存占用越高);
  • DataLoader會自動打亂訓練集(默認shuffle=True),避免模型學習到 “樣本順序” 的無關特征。

2. 設備配置:GPU 加速有多重要?

代碼中這行是 “硬件適配” 的關鍵:

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
  • cuda:NVIDIA 顯卡的 GPU 加速(訓練 10 輪可能只需 1-2 分鐘);
  • mps:蘋果芯片(M1/M2)的 GPU 加速;
  • cpu:默認選項(訓練 10 輪可能需要 10-20 分鐘,速度慢很多)。

后續通過model.to(device)X.to(device),將模型和數據都移到指定設備上,確保計算在同一設備進行(否則會報錯)。

3. CNN 模型搭建(核心中的核心)

我們定義的CNN類繼承自nn.Module(PyTorch 所有模型的基類),核心是__init__(定義層)和forward(定義數據流動)。

先看模型結構總覽

輸入(1,28,28) → 卷積塊1 → 輸出(8,14,14) → 卷積塊2 → 輸出(32,7,7) → 卷積塊3 → 輸出(64,7,7) → 展平 → 全連接層 → 輸出(10)
(1)卷積層參數解析

conv1的第一個卷積層為例:

nn.Conv2d(in_channels=1, out_channels=8, kernel_size=5, stride=1, padding=2)
  • in_channels=1:輸入通道數(MNIST 是灰度圖,所以 1);
  • out_channels=8:輸出通道數 = 卷積核數量(8 個卷積核,提取 8 種不同特征);
  • kernel_size=5:卷積核大小(5×5 的窗口,比 3×3 能提取更復雜的局部特征);
  • stride=1:卷積核每次滑動 1 個像素(步長越小,特征保留越完整);
  • padding=2:填充(在圖像邊緣補 2 個像素),目的是讓卷積后圖像尺寸不變:
    👉 尺寸計算公式:輸出尺寸 = (輸入尺寸 - 卷積核尺寸 + 2×padding) / stride + 1
    👉 代入:(28 - 5 + 2×2)/1 + 1 = 28,所以卷積后還是 28×28。
(2)激活函數 ReLU

每個卷積層后都加nn.ReLU(),作用是引入非線性

  • 沒有激活函數的話,多個卷積層疊加還是線性變換,無法擬合復雜數據;
  • ReLU 的公式:ReLU(x) = max(0, x),計算簡單、梯度不易消失,是目前最常用的激活函數。
(3)池化層 MaxPool2d

nn.MaxPool2d(kernel_size=2)是 2×2 最大池化,作用是下采樣

  • 尺寸減半:28×28→14×14,14×14→7×7,大幅減少后續計算量;
  • 保留關鍵特征:取 2×2 窗口的最大值,相當于 “強化局部最顯著的特征”,提高模型魯棒性。
(4)展平與全連接層

卷積塊 3 輸出的是(batch_size, 64, 7, 7)的張量(batch_size 是每批樣本數),需要用x.view(x.size(0), -1)展平為(batch_size, 64×7×7)的一維向量,才能輸入全連接層:

  • x.size(0):獲取 batch_size(確保展平后每一行對應一個樣本);
  • -1:讓 PyTorch 自動計算剩余維度(64×7×7=3136);
  • 全連接層nn.Linear(3136, 10):將 3136 維特征映射到 10 維(對應 0-9 的 10 個類別)。

4. 訓練函數:模型如何 “學習”?

訓練的核心是 “前向傳播算損失→反向傳播求梯度→優化器更新參數” 的循環:

  1. model.train():啟用訓練模式(比如如果模型有 BatchNorm,會計算當前批次的均值和方差);
  2. 前向傳播:pred = model(X),用當前模型參數計算預測值;
  3. 計算損失:loss = loss_fn(pred, y),用CrossEntropyLoss(多分類任務專用,內置了 Softmax,無需手動在模型輸出加 Softmax);
  4. 反向傳播:
    • optimizer.zero_grad():清空上一輪的梯度(如果不清空,梯度會累積,導致參數更新錯誤);
    • loss.backward():自動計算所有可訓練參數的梯度(PyTorch 的自動微分機制);
    • optimizer.step():用計算出的梯度更新參數(Adam 優化器會自適應調整學習率,比 SGD 收斂更快)。

5. 測試函數:模型學得怎么樣?

測試階段不需要更新參數,核心是計算 “準確率” 和 “平均損失”:

  1. model.eval():啟用評估模式(關閉 BatchNorm 的批次統計更新、關閉 Dropout);
  2. with torch.no_grad():禁用梯度計算(節省內存,加速測試);
  3. 準確率計算:pred.argmax(1) == y,比較預測類別和真實類別,求和后除以總樣本數。

四、預期結果與優化方向

1. 預期訓練結果

在 GPU 上訓練 10 輪后,通常能達到:

  • 測試準確率:98.5% 以上(甚至 99%);
  • 測試平均損失:0.04 以下。

訓練過程中,損失會逐漸下降,準確率會逐漸上升(如果出現損失不下降或準確率波動,可能是學習率太大或 batch_size 太小)。

2. 模型優化方向

如果想進一步提升性能,可以嘗試這些改進:

  1. 增加 Dropout 層:在卷積層或全連接層后加nn.Dropout(0.2),隨機 “關閉” 20% 的神經元,防止過擬合;
  2. 使用學習率調度:比如torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5),每 5 輪將學習率減半,后期精細調整;
  3. 加深網絡:增加卷積塊數量(比如再加一個 conv4),或增加每個卷積層的輸出通道數;
  4. 數據增強:用torchvision.transforms添加旋轉、平移、縮放等操作,比如:
    transform = transforms.Compose([transforms.RandomRotation(5),  # 隨機旋轉±5度transforms.ToTensor()
    ])

????????數據增強能讓模型看到更多 “變種” 樣本,提升泛化能力。

五、總結

本文用 PyTorch 實現了一個基礎的 CNN 模型,完成了 MNIST 手寫數字識別任務,核心收獲包括:

  1. 掌握了 PyTorch 加載數據集、搭建 CNN 模型的基本流程;
  2. 理解了卷積層、池化層、激活函數的作用和參數意義;
  3. 熟悉了 “訓練 - 測試” 的循環邏輯,以及 GPU 加速的配置方法。

MNIST 是入門任務,但 CNN 的核心思想(特征提取 + 下采樣)可以遷移到更復雜的圖像任務(如 CIFAR-10、ImageNet)。建議大家動手修改代碼,比如調整卷積核大小、學習率、網絡層數,觀察結果變化,這樣才能真正理解每個參數的影響~

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/95105.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/95105.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/95105.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

CTFshow系列——命令執行web61-68

本篇文章介紹了不同了方法進行題目的解析以及原因講解。 文章目錄Web61嘗試了一下,被過濾的payload如下:所以,根據上述思路,這里嘗試過的payload為:Web62(同Web61)Web63(同Web62&…

.Net程序員就業現狀以及學習路線圖(二)

一、.NET程序員就業現狀分析 1. 市場需求與崗位分布 2025年.NET開發崗位全國招聘職位約1676個,占全國技術崗位的0.009%,主要集中在一線城市如深圳、上海等地。就業單位類型分布為:軟件公司占43.3%,研發機構占33.1%,物聯…

MTK Linux DRM分析(二十二)- MTK mtk_drm_crtc.c(Part1)

一、代碼分析 mtk_drm_crtc.c以mtk_crtc_comp_is_busy函數為界限進行拆分分析 static const struct drm_crtc_funcs mtk_crtc_funcs = {.set_config = drm_atomic_helper_set_config,.page_flip = drm_atomic_helper_page_flip,.destroy = mtk_drm_crtc_destroy,.reset = mtk…

stm32f103c8t6 led閃燈實驗

目錄 閃燈原理 2種接線方式控制閃燈 使用推挽接法 使用開漏接法 看原理圖 寫代碼 閃燈原理 LED燈有個2-10mA的電流就可以點亮 3.3/5100.006A6mA 2種接線方式控制閃燈 使用推挽接法 當設置推挽模式時,CPU控制寄存器寫0,IO引腳輸出低電壓&#xff0…

“我同意”按鈕別亂點——你的“職業EULA”漏洞掃描報告

尊敬的審核: 本人文章《“我同意”按鈕別亂點——你的“職業EULA”漏洞掃描報告》 1. 純屬技術交流,無任何違法內容 2. 所有法律引用均來自公開條文 3. 請依據《網絡安全法》第12條“不得無故刪除合法內容”處理 附:本文結構已通過區塊鏈存證…

Product Hunt 每日熱榜 | 2025-09-01

1. A01 標語:你個人的新聞助手 介紹:A01 是你的新聞助手,可以幫你關注你關心的任何話題。只需告訴它你想了解什么,它就能為你帶來最新的文章。 產品網站: 立即訪問 Product Hunt: View on Product Hunt…

【OpenFeign】基礎使用

【OpenFeign】基礎使用1. Feign介紹1.1 使用示例1.2 Feign與RPC對比1.3 SpringCloud Alibaba快速整合OpenFeign1.3.1 詳細代碼1. Feign介紹 1.什么是 Feign Feign 是 Netflix 開發的一個 聲明式的 HTTP 客戶端,在 Spring Cloud 中被廣泛使用。它的目標是&#xff…

訪問相同的url,相同入參的請求,Apifox/Postman可以正常響應結果,而本地調用不行(或結果不同)

文章目錄問題概述Apifox查看實際請求總結問題概述 開發中有一個需求需要去別的系統中拿數據,配置好相關參數后發起請求時發現響應結果和在Apifox上不同,Apifox上正常顯示數據,而本地調用后返回數據不存在。 這就很奇怪了,想了很多…

數據結構(C語言篇):(七)雙向鏈表

目錄 前言 一、概念與結構 二、雙向鏈表的實現 2.1 頭文件的準備 2.2 函數的實現 2.2.1 LTPushBack( )函數(尾插) (1)LTBuyNode( ) (2)LTInit( ) (3)LTPrint( ) &#x…

從拿起簡歷(resume)重新找工作開始聊起

經濟蕭條或經濟衰退在經濟相關學術上似乎有著嚴格的定義,我不知道我們的經濟是否已經走向了衰退或者蕭條,但有一點那是肯定的,那就現在我們的經濟肯定是不景氣的。經濟不景氣會怎么樣?是的,會有很多人失業,…

OS+MySQL+(其他)八股小記

魯迅先生曾經說過,每天進步一點點,媽媽夸我小天才。 依舊今日八股,這是我在多個文檔整合一起的,可能格式有些問題,請諒解。 操作系統 1.進程和線程的區別? 進程是代碼在數據集合的一次執行活動,…

Transformer的并行計算與長序列處理瓶頸總結

🌟 第0層:極簡版(30秒理解)一句話核心:Transformer像圓桌會議——所有人都能同時交流(并行優勢),但人越多會議越混亂(長序列瓶頸)。核心問題 并行優勢&#x…

Vue 3 useId 完全指南:生成唯一標識符的最佳實踐

📖 概述 useId() 是 Vue 3 中的一個組合式 API 函數,用于生成唯一的標識符。它確保在服務端渲染(SSR)和客戶端渲染之間生成一致的 ID,避免水合不匹配的問題。 🎯 基本概念 什么是 useId? useId…

CGroup 資源控制組 + Docker 網絡模式

1 CGroup 資源控制組1.1 為什么需要 CGroup - 容器本質 宿主機上一組進程 - 若無資源邊界,一個暴走容器即可拖垮整機 - CGroup 提供**內核級硬限制**,比 ulimit、nice 更可靠1.2 核心概念 3 件套 | 概念 | 一句話解釋 | 查看方式 | | Hierarchy | 樹…

【ArcGIS微課1000例】0150:如何根據地名獲取經緯度坐標

本文介紹了三種獲取地理坐標的方法:1)在ArcGIS Pro中通過搜索功能定位目標點(如月牙泉)并查看其WGS84坐標;2)使用ArcGIS內置工具獲取坐標;3)推薦三個在線工具(maplocation、地球在線、yanue)支持批量查詢和多地圖源坐標轉換。強調了使用WGS84坐標系以減少誤差,并展示…

HTML應用指南:利用GET請求獲取MSN財經股價數據并可視化

隨著數字化金融服務的不斷深化,及時、準確的財經信息已成為投資者決策與市場分析的重要支撐。MSN財經股價數據服務作為廣受信賴的金融信息平臺,依托微軟強大的技術架構與數據整合能力,持續為全球用戶提供全面、可靠的證券市場數據。平臺不僅提…

雅思聽力第四課:配對題核心技巧與詞匯深化

現在,請拿出劍橋真題,開始你的刻意練習! 內容大綱 課程核心目標舊題回顧與基礎鞏固配對題/匹配題核心解題策略考點總結與精聽訓練表 一、課程核心目標 掌握第二部分配對題的解題策略攻克第三部分匹配題的改寫難點系統整理高頻場景詞匯與特…

SQL Server從入門到項目實踐(超值版)讀書筆記 25

第12章 存儲過程的應用 🎉學習指引 存儲過程(Stored Procedure)是在大型數據庫系統中,一組為了完成特定功能的SQL語句集,存儲過程時數據庫中的一個重要對象,它代替了傳統的逐條執行SQL語句的方式。本章就來…

20.29 QLoRA適配器實戰:24GB顯卡輕松微調650億參數大模型

QLoRA適配器實戰:24GB顯卡輕松微調650億參數大模型 QLoRA 適配器配置深度解析 一、QLoRA 適配器核心原理 QLoRA 作為當前大模型微調領域的前沿技術,通過量化與低秩適配的協同設計,在保證模型效果的前提下實現了顯存占用的革命性降低。其核心由三大技術支柱構成: 4位量化…

QMainWindow使用QTabWidget添加多個QWidget

QTabWidget添加其它Wdiget的2個函數如下&#xff1a; QTabWidget的介紹可參考官網QTabWidget Class | Qt Widgets | Qt 6.9.1 直接上代碼&#xff0c;代碼如下&#xff1a; #include <QMainWindow>#include <QApplication> #include <QVBoxLayout> #includ…