用PyTorch搭建卷積神經網絡實現MNIST手寫數字識別
在深度學習領域,卷積神經網絡(Convolutional Neural Network,簡稱CNN)是處理圖像數據的強大工具。它通過卷積層、池化層和全連接層等組件,自動提取圖像特征,在圖像分類、目標檢測等任務中表現卓越。本文將使用PyTorch框架,搭建一個CNN模型來實現MNIST手寫數字識別,并詳細解析每一步代碼。
一、MNIST數據集介紹
MNIST數據集是深度學習領域經典的入門數據集,包含70,000張手寫數字圖像,其中60,000張用于訓練,10,000張用于測試。這些圖像均為灰度圖,尺寸是28x28像素,并且已經做了居中處理,這在一定程度上減少了預處理的工作量,能夠加快模型的訓練和運行速度。
二、環境準備與數據加載
2.1 導入必要的庫
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
上述代碼導入了PyTorch的核心庫、神經網絡模塊、數據加載工具以及用于圖像數據處理和數據集管理的庫。
2.2 下載并加載數據集
training_data = datasets.MNIST(root='data',train=True,download=True,transform=ToTensor()
)test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor()
)
通過datasets.MNIST
函數分別下載訓練集和測試集。root
參數指定數據下載的路徑;train=True
表示下載訓練集數據,train=False
則表示下載測試集數據;download=True
確保如果數據尚未下載,會自動進行下載;transform=ToTensor()
將圖像數據轉換為PyTorch能夠處理的張量格式。
2.3 數據可視化
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img, label = training_data[i + 59000]figure.add_subplot(3, 3, i + 1)plt.title(label)plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()
這段代碼使用matplotlib
庫展示了訓練數據集中的部分手寫數字圖像,通過plt.imshow
函數將張量格式的圖像數據可視化,直觀感受MNIST數據集的內容。
2.4 創建數據加載器
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
DataLoader
用于將數據集打包成批次,batch_size
參數指定每個批次包含的數據樣本數量。將數據集分成批次進行訓練,能夠有效減少內存使用,并提高訓練速度。
三、設備配置
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
這段代碼檢測當前設備是否支持GPU(CUDA)或蘋果M系列芯片的GPU(MPS),如果都不支持,則使用CPU進行計算。后續模型和數據都會被移動到選定的設備上運行,以充分利用硬件資源加速訓練。
四、定義卷積神經網絡模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=16,kernel_size=5,stride=1,padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2))self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU())self.out = nn.Linear(64 * 7 * 7, 10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)output = self.out(x)return output
在這個自定義的CNN
類中,繼承自nn.Module
。__init__
方法中定義了網絡的結構:
- 卷積層(
nn.Conv2d
):用于提取圖像特征,通過設置in_channels
(輸入通道數)、out_channels
(輸出通道數,即卷積核個數)、kernel_size
(卷積核大小)、stride
(步長)和padding
(填充)等參數,控制卷積操作。 - 激活函數層(
nn.ReLU
):引入非線性,增強網絡的表達能力。 - 池化層(
nn.MaxPool2d
):對特征圖進行下采樣,減少數據量和計算量,同時保留主要特征。 - 全連接層(
nn.Linear
):將卷積層和池化層提取的特征映射到輸出類別(MNIST數據集中有10個數字類別)。
forward
方法定義了數據在網絡中的前向傳播路徑,確保數據按照網絡結構依次經過各層處理,最終輸出預測結果。
五、訓練與測試模型
5.1 定義損失函數和優化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
nn.CrossEntropyLoss
是適用于多分類任務的交叉熵損失函數,用于計算模型預測結果與真實標簽之間的差距。torch.optim.Adam
是一種常用的優化器,通過調整模型的參數(model.parameters()
)來最小化損失函數,lr
參數設置學習率,控制參數更新的步長。
5.2 訓練函數
def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num % 100 == 0:print(f'loss:{loss_value:>7f} [number:{batch_size_num}]')batch_size_num += 1
在訓練函數中:
model.train()
將模型設置為訓練模式,此時模型中的一些層(如Dropout層)會按照訓練規則工作。- 遍歷數據加載器中的每一個批次數據,將數據和標簽移動到指定設備上。
- 通過模型進行預測,計算損失值。
- 使用
optimizer.zero_grad()
清零梯度,loss.backward()
進行反向傳播計算梯度,optimizer.step()
根據梯度更新模型參數。 - 每隔100個批次,打印當前的損失值,以便觀察訓練過程中的損失變化。
5.3 測試函數
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f'Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}')
測試函數中:
model.eval()
將模型設置為測試模式,關閉一些在訓練過程中起作用但在測試時不需要的操作(如Dropout)。- 使用
with torch.no_grad()
上下文管理器,關閉梯度計算,因為在測試階段不需要更新模型參數,這樣可以節省計算資源。 - 遍歷測試數據,計算每個批次的損失值并累加,同時統計預測正確的樣本數量。
- 最后計算并打印測試集上的平均損失和準確率,評估模型的性能。
5.4 執行訓練和測試
epoch = 9
for i in range(epoch):print(i + 1)train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
通過設置訓練輪數(epoch
),循環調用訓練函數進行模型訓練,每一輪訓練結束后,調用測試函數評估模型在測試集上的性能。
六、總結
本文通過詳細的代碼解析,展示了如何使用PyTorch搭建一個簡單的卷積神經網絡來實現MNIST手寫數字識別任務。從數據加載、模型定義,到訓練和測試,每一個步驟都體現了CNN在圖像分類任務中的核心思想和實現方法。通過不斷調整模型結構、超參數等,還可以進一步提升模型的性能。卷積神經網絡在圖像領域的應用遠不止于此,它在更復雜的圖像任務和其他領域也有著廣泛的應用前景,希望本文能為大家深入學習深度學習提供一個良好的開端。