手寫數字識別是計算機視覺領域的“Hello World”,也是深度學習入門的經典案例。它通過訓練模型識別0-9的手寫數字圖像(如MNIST數據集),幫助我們快速掌握神經網絡的核心流程。本文將以PyTorch框架為基礎,帶你從數據加載、模型構建到訓練評估,完整實現一個手寫數字識別系統。
二、數據加載與預處理:認識MNIST數據集
1. MNIST數據集簡介
MNIST是手寫數字的標準數據集,包含:
- 訓練集:60,000張28x28的灰度圖(0-9數字)
- 測試集:10,000張同尺寸圖片
- 每張圖片已歸一化(像素值0-1),標簽為0-9的整數
2. 代碼實現:下載與加載數據
使用torchvision.datasets
可直接下載MNIST,transforms.ToTensor()
將圖片轉為PyTorch張量(通道優先格式:[1,28,28]
,1為灰度通道數)。
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# 下載訓練集(60,000張)
train_data = datasets.MNIST(root="data", # 數據存儲路徑train=True, # 標記為訓練集download=True, # 自動下載(首次運行時)transform=ToTensor() # 轉為張量(shape: [1,28,28])
)# 下載測試集(10,000張)
test_data = datasets.MNIST(root="data",train=False, # 標記為測試集download=True,transform=ToTensor()
)
3. 數據封裝:DataLoader批量加載
DataLoader
將數據集打包為可迭代的批量數據,支持隨機打亂(訓練集)、多線程加載等。
device = "cuda" if torch.cuda.is_available() else "cpu" # 自動選擇GPU/CPU
batch_size = 64 # 每批64張圖片(可根據顯存調整)# 訓練集DataLoader(打亂順序)
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
# 測試集DataLoader(不打亂順序)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
三、模型構建:設計卷積神經網絡(CNN)
1. 為什么選擇CNN?
手寫數字識別需要捕捉圖像的局部特征(如筆畫邊緣、拐點),而CNN的卷積層通過滑動窗口提取局部模式,池化層降低計算量,全連接層完成分類,非常適合處理圖像任務。
2. 模型結構詳解(附代碼注釋)
以下是我們定義的CNN模型,包含3個卷積塊和1個全連接輸出層:
class CNN(nn.Module):def __init__(self):super().__init__() # 繼承PyTorch模塊基類# 卷積塊1:輸入1通道(灰度圖)→ 輸出8通道特征圖self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, # 輸入通道數(灰度圖)out_channels=8, # 輸出8個特征圖(8個卷積核)kernel_size=5, # 卷積核尺寸5x5(覆蓋局部區域)stride=1, # 滑動步長1(不跳躍)padding=2 # 邊緣填充2圈0(保持輸出尺寸不變)),nn.ReLU(), # 非線性激活(引入復雜模式)nn.MaxPool2d(kernel_size=2) # 最大池化(2x2窗口,尺寸減半))# 卷積塊2:特征抽象(8→16→32通道)self.conv2 = nn.Sequential(nn.Conv2d(8, 16, 5, 1, 2), # 8→16通道,5x5卷積,填充2(尺寸不變)nn.ReLU(),nn.Conv2d(16, 32, 5, 1, 2), # 16→32通道,5x5卷積,填充2(尺寸不變)nn.ReLU(),nn.MaxPool2d(kernel_size=2) # 尺寸減半(14→7))# 卷積塊3:特征精煉(32→256通道,保留空間信息)self.conv3 = nn.Sequential(nn.Conv2d(32, 256, 5, 1, 2), # 32→256通道,5x5卷積,填充2(尺寸不變)nn.ReLU())# 全連接輸出層:256*7*7維特征→10類概率self.out = nn.Linear(256 * 7 * 7, 10) # 10對應0-9數字類別def forward(self, x):"""前向傳播:定義數據流動路徑"""x = self.conv1(x) # 輸入:[64,1,28,28] → 輸出:[64,8,14,14](池化后尺寸減半)x = self.conv2(x) # 輸入:[64,8,14,14] → 輸出:[64,32,7,7](兩次卷積+池化)x = self.conv3(x) # 輸入:[64,32,7,7] → 輸出:[64,256,7,7](僅卷積)x = x.view(x.size(0), -1) # 展平:[64,256,7,7] → [64,256*7*7](全連接需要一維輸入)output = self.out(x) # 輸出:[64,10](每個樣本對應10類的得分)return output
3. 關鍵參數計算(以輸入28x28為例)
- conv1后:卷積核5x5,填充2,輸出尺寸
(28-5+2*2)/1 +1=28
;池化后尺寸28/2=14
→ 輸出[64,8,14,14]
- conv2后:兩次卷積保持14x14,池化后
14/2=7
→ 輸出[64,32,7,7]
- conv3后:卷積保持7x7 → 輸出
[64,256,7,7]
- 展平后:
256*7*7=12544
維向量 → 全連接到10類
四、訓練配置:損失函數與優化器
1. 損失函數:交叉熵損失(CrossEntropyLoss)
手寫數字識別是多分類任務,交叉熵損失函數直接衡量模型輸出概率與真實標簽的差異。PyTorch的nn.CrossEntropyLoss
已集成Softmax操作(無需手動添加)。
2. 優化器:隨機梯度下降(SGD)
優化器負責根據損失值更新模型參數。這里選擇SGD(學習率lr=0.1
),簡單且對小數據集友好(也可嘗試Adam等更復雜的優化器)。
model = CNN().to(device) # 模型加載到GPU/CPU
loss_fn = nn.CrossEntropyLoss() # 交叉熵損失
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) # SGD優化器
五、訓練循環:讓模型“學習”特征
1. 訓練邏輯概述
訓練過程的核心是“前向傳播→計算損失→反向傳播→更新參數”,重復直到模型收斂。具體步驟:
- 模型設為訓練模式(
model.train()
); - 遍歷訓練數據,按批輸入模型;
- 計算預測值與真實標簽的損失;
- 反向傳播計算梯度(
loss.backward()
); - 優化器更新參數(
optimizer.step()
); - 清空梯度(
optimizer.zero_grad()
)避免累積。
2. 代碼實現:訓練函數
def train(dataloader, model, loss_fn, optimizer):model.train() # 開啟訓練模式(影響Dropout/BatchNorm等層)total_loss = 0 # 記錄總損失for batch_idx, (x, y) in enumerate(dataloader):x, y = x.to(device), y.to(device) # 數據加載到GPU/CPU# 1. 前向傳播:模型預測pred = model(x)# 2. 計算損失:預測值 vs 真實標簽loss = loss_fn(pred, y)total_loss += loss.item() # 累加批次損失# 3. 反向傳播:計算梯度optimizer.zero_grad() # 清空歷史梯度loss.backward() # 反向傳播計算當前梯度# 4. 更新參數:根據梯度調整模型權重optimizer.step()# 每100個批次打印一次損失(監控訓練進度)if (batch_idx + 1) % 100 == 0:print(f"批次 {batch_idx+1}/{len(dataloader)}, 當前損失: {loss.item():.4f}")avg_loss = total_loss / len(dataloader)print(f"訓練完成,平均損失: {avg_loss:.4f}")
六、測試評估:驗證模型泛化能力
1. 測試邏輯概述
測試階段需關閉模型的隨機操作(如Dropout),用測試集評估模型的泛化能力。核心指標是準確率(正確預測的樣本比例)。
2. 代碼實現:測試函數
def test(dataloader, model):model.eval() # 開啟評估模式(關閉Dropout等隨機層)correct = 0 # 記錄正確預測數total = 0 # 記錄總樣本數with torch.no_grad(): # 關閉梯度計算(節省內存)for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model(x) # 模型預測# 統計正確數:pred.argmax(1)取預測概率最大的類別correct += (pred.argmax(1) == y).sum().item()total += y.size(0) # 累加批次樣本數accuracy = correct / totalprint(f"測試準確率: {accuracy * 100:.2f}%")return accuracy
七、完整訓練與結果
1. 運行訓練循環
我們訓練10個epoch(遍歷整個訓練集10次):
# 訓練10輪
for epoch in range(10):print(f"
===== 第 {epoch+1} 輪訓練 =====")train(train_dataloader, model, loss_fn, optimizer)# 測試最終效果
print("
===== 最終測試 =====")
test_acc = test(test_dataloader, model)
2. 典型輸出結果
假設訓練10輪后,測試準確率可能達到98.5%+(具體取決于超參數和硬件):
===== 第 1 輪訓練 =====
批次 100/938, 當前損失: 0.2145
...
訓練完成,平均損失: 0.1234===== 第 10 輪訓練 =====
批次 100/938, 當前損失: 0.0321
...
訓練完成,平均損失: 0.0189===== 最終測試 =====
測試準確率: 98.76%
八、改進方向:讓模型更強大
當前模型已能較好識別手寫數字,但仍有優化空間:
1. 調整超參數
- 學習率:若損失下降緩慢,降低
lr
(如0.01);若波動大,增大lr
。 - 批量大小:增大
batch_size
(如128)可加速訓練(需更大顯存)。 - 訓練輪次:增加
epoch
(如20輪),但需防止過擬合(訓練損失持續下降,測試損失上升)。
2. 添加正則化
- Batch Normalization:在卷積層后添加
nn.BatchNorm2d(out_channels)
,加速收斂并穩定訓練。self.conv1 = nn.Sequential(nn.Conv2d(1,8,5,1,2),nn.BatchNorm2d(8), # 新增nn.ReLU(),nn.MaxPool2d(2) )
- Dropout:在全連接層前添加
nn.Dropout(p=0.5)
,隨機斷開神經元,防止過擬合。self.out = nn.Sequential(nn.Dropout(0.5), # 新增nn.Linear(256*7*7, 10) )
3. 使用更深的網絡
當前模型僅3個卷積塊,對于復雜任務(如ImageNet),可使用ResNet等殘差網絡,通過跳躍連接(Skip Connection)解決深層網絡的梯度消失問題。
九、總結
通過本文,你已完成從數據加載到模型訓練的全流程,掌握了:
- 數據預處理:使用
torchvision
加載標準數據集,DataLoader
批量管理數據; - 模型構建:設計CNN的核心組件(卷積層、激活函數、池化層);
- 訓練與評估:理解損失函數、優化器的作用,掌握訓練循環和測試邏輯。
手寫數字識別是深度學習的起點,你可以嘗試修改模型結構(如增加卷積層)、更換數據集(如Fashion-MNIST)或調整超參數,進一步探索深度學習的魅力!
動手建議:運行代碼時,嘗試將device
改為cpu
(無GPU時),觀察訓練速度變化;或修改kernel_size
(如3x3),對比模型性能差異。