在深度學習領域,MNIST 手寫數字識別是經典的入門級項目,就像編程世界里的 “Hello, World”。卷積神經網絡(Convolutional Neural Network,CNN)作為處理圖像數據的強大工具,在該任務中展現出卓越的性能。本文將結合具體的 PyTorch 代碼,詳細解析如何利用 CNN 實現 MNIST 手寫數字識別,帶大家從代碼實踐深入理解背后的技術原理。
一、數據準備:加載與預處理 MNIST 數據集
MNIST 數據集包含 6 萬張訓練圖像和 1 萬張測試圖像,涵蓋 0 - 9 這十個數字的手寫體。我們借助torchvision
庫中的datasets.MNIST
函數來加載數據,具體代碼如下:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensortraining_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),
)
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)
上述代碼中,root="data"
指定數據集的存儲路徑;train=True
表示加載訓練集,train=False
用于加載測試集;download=True
確保本地無數據集時自動下載;transform=ToTensor()
將圖像數據轉換為 PyTorch 張量格式,并把像素值從 0 - 255 歸一化到 0 - 1 區間,便于后續處理。
為直觀感受數據,我們用matplotlib
庫繪制 9 張訓練圖像及其標簽:
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img, label = training_data[i + 59000]figure.add_subplot(3, 3, i + 1)plt.title(label)plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")a = img.squeeze()
plt.show()
完成數據加載后,使用DataLoader
將數據封裝成批次,方便模型訓練和測試:
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
batch_size=64
意味著每次訓練或測試,模型會同時處理 64 個樣本,能提高計算效率和訓練穩定性。
二、模型構建:搭建卷積神經網絡架構
我們定義一個名為CNN
的類,繼承自nn.Module
,用于構建卷積神經網絡:
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=16,kernel_size=3,stride=1,padding=1,),nn.ReLU(),nn.MaxPool2d(2))self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 3, 1, 1),nn.ReLU(),nn.MaxPool2d(2),)self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, 1),nn.ReLU(),)self.out = nn.Linear(64 * 7 * 7, 10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)output = self.out(x)return output
- 卷積層(
nn.Conv2d
):在conv1
、conv2
和conv3
中,通過卷積層提取圖像特征。例如conv1
中的nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
,in_channels=1
表示輸入圖像為單通道灰度圖,out_channels=16
表示輸出 16 個特征圖,kernel_size=3
指定 3×3 的卷積核,stride=1
是步長,padding=1
用于保持圖像尺寸不變。 - 激活函數(
nn.ReLU
):緊跟在卷積層之后,為模型引入非線性,幫助模型學習復雜的模式。 - 池化層(
nn.MaxPool2d
):通過下采樣操作,如nn.MaxPool2d(2)
將圖像尺寸減半,減少數據量和模型參數,同時保留重要特征,防止過擬合。 - 全連接層(
nn.Linear
):self.out = nn.Linear(64 * 7 * 7, 10)
將卷積層輸出的特征圖展平后連接到全連接層,輸出 10 個神經元對應 0 - 9 十個數字類別,完成最終分類。
最后,將模型移動到合適的計算設備(GPU、MPS 或 CPU)上:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
model = CNN().to(device)
print(model)
三、模型訓練與測試:優化與評估
3.1 訓練函數
def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num % 100 == 0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1
在訓練函數中,model.train()
將模型設為訓練模式。遍歷數據加載器,將每一批數據和標簽移至指定設備,前向傳播計算預測值,通過交叉熵損失函數nn.CrossEntropyLoss()
計算損失,optimizer.zero_grad()
清空梯度,loss.backward()
反向傳播計算梯度,optimizer.step()
更新模型參數,每 100 個批次打印一次損失值。
3.2 測試函數
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")return test_loss, correct
測試函數中,model.eval()
將模型設為評估模式,關閉如 Dropout 等訓練時的操作。在with torch.no_grad()
下遍歷測試數據,計算測試損失和正確預測的樣本數,最后計算平均損失和準確率并輸出。
3.3 執行訓練與測試
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
epochs = 10
for t in range(epochs):print(f"Epoch {t + 1}\n--------------------")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)
我們選用交叉熵損失函數和 Adam 優化器,學習率設為 0.01,通過 10 個訓練周期不斷優化模型,訓練完成后在測試集上評估模型性能,得到最終的準確率和平均損失。
四、總結與展望
通過上述代碼實踐,我們成功利用卷積神經網絡實現了 MNIST 手寫數字識別。從數據加載、模型構建到訓練測試,每個環節都緊密相連,展示了 CNN 在圖像識別任務中的強大能力。