PyTorch 是目前最流行的深度學習框架之一,以其靈活性和易用性受到開發者的喜愛。本文將帶你從零開始,用 PyTorch 實現一個簡單的神經網絡,用于解決經典的 MNIST 手寫數字分類問題。我們將涵蓋數據準備、模型構建、訓練和預測的完整流程,并提供可運行的代碼示例。
1. 環境準備
首先,確保你已安裝 PyTorch 和相關依賴。本例使用 Python 3.8+ 和 PyTorch。你可以通過以下命令安裝:
pip install torch torchvision
我們將使用 MNIST 數據集,它包含 28x28 像素的手寫數字圖像(0-9),目標是訓練一個神經網絡來識別這些數字。
2. 數據準備
MNIST 數據集可以通過 PyTorch 的?torchvision?模塊直接加載。我們需要將數據加載為張量,并進行歸一化處理以加速訓練。
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 定義數據預處理:將圖像轉換為張量并歸一化
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,)) # MNIST 的均值和標準差
])# 加載 MNIST 數據集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 創建數據加載器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
代碼說明:
transforms.ToTensor()?將圖像轉換為 PyTorch 張量,并將像素值從 [0, 255] 縮放到 [0, 1]。
transforms.Normalize?標準化數據,加速梯度下降收斂。
DataLoader?用于批量加載數據,batch_size=64?表示每次處理 64 張圖像。
3. 構建神經網絡
我們將定義一個簡單的全連接神經網絡,包含兩個隱藏層,適合處理 MNIST 的分類任務。
import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.flatten = nn.Flatten() # 將 28x28 圖像展平為 784 維向量self.fc1 = nn.Linear(28 * 28, 128) # 第一個全連接層self.relu = nn.ReLU() # 激活函數self.fc2 = nn.Linear(128, 64) # 第二個全連接層self.fc3 = nn.Linear(64, 10) # 輸出層,10 個類別(0-9)def forward(self, x):x = self.flatten(x)x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.fc3(x)return x# 實例化模型
model = SimpleNN()
代碼說明:
nn.Module?是 PyTorch 模型的基類,自定義模型需要繼承它。
forward?方法定義了前向傳播的計算流程。
網絡結構:輸入層 (784) → 隱藏層1 (128) → ReLU → 隱藏層2 (64) → ReLU → 輸出層 (10)。
4. 定義損失函數和優化器
我們使用交叉熵損失(適合分類任務)和 Adam 優化器來訓練模型。
import torch.optim as optim# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
代碼說明:
nn.CrossEntropyLoss?結合了 softmax 和負對數似然損失,適合多分類任務。
Adam?優化器以 0.001 的學習率更新模型參數。
5. 訓練模型
接下來,我們訓練模型 5 個 epoch,觀察損失變化。
def train(model, train_loader, criterion, optimizer, epochs=5):model.train() # 切換到訓練模式for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:optimizer.zero_grad() # 清零梯度outputs = model(images) # 前向傳播loss = criterion(outputs, labels) # 計算損失loss.backward() # 反向傳播optimizer.step() # 更新參數running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")# 開始訓練
train(model, train_loader, criterion, optimizer)
代碼說明:
model.train()?啟用訓練模式(影響 dropout 和 batch norm 等層)。
每次迭代清零梯度、計算損失、反向傳播并更新參數。
每輪 epoch 打印平均損失。
6. 測試模型
訓練完成后,我們在測試集上評估模型的準確率。
def test(model, test_loader, criterion):model. # 切換到評估模式correct = 0total = 0test_loss = 0.0with torch.no_grad(): # 禁用梯度計算for images, labels in test_loader:outputs = model(images)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = torch.max(outputs.data, 1) # 獲取預測類別total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f"Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {accuracy:.2f}%")# 測試模型
test(model, test_loader, criterion)
代碼說明:
model.?切換到評估模式,禁用 dropout 等。
使用?torch.no_grad()?減少內存消耗。
計算測試集的損失和準確率。
7. 進行預測
最后,我們用訓練好的模型對單張圖像進行預測。
import matplotlib.pyplot as plt# 獲取一張測試圖像
dataiter = iter(test_loader)
images, labels = next(dataiter)
image, label = images[0], labels[0]# 預測
model.
with torch.no_grad():output = model(image.unsqueeze(0)) # 增加 batch 維度_, predicted = torch.max(output, 1)# 顯示圖像和預測結果
plt.imshow(image.squeeze(), cmap='gray')
plt.title(f"Predicted: {predicted.item()}, Actual: {label.item()}")
plt.savefig('prediction.png') # 保存圖像
代碼說明:
從測試集取一張圖像,調用模型進行預測。
使用 Matplotlib 顯示圖像及其預測結果,保存為 PNG 文件。
8. 完整代碼
以下是完整的可運行代碼,整合了上述所有步驟:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 數據準備
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 定義模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(28 * 28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 64)self.fc3 = nn.Linear(64, 10)def forward(self, x):x = self.flatten(x)x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.fc3(x)return x# 實例化模型、損失函數和優化器
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 訓練函數
def train(model, train_loader, criterion, optimizer, epochs=5):model.train()for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")# 測試函數
def test(model, test_loader, criterion):model.correct = 0total = 0test_loss = 0.0with torch.no_grad():for images, labels in test_loader:outputs = model(images)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f"Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {accuracy:.2f}%")# 訓練和測試
train(model, train_loader, criterion, optimizer)
test(model, test_loader, criterion)# 預測單張圖像
dataiter = iter(test_loader)
images, labels = next(dataiter)
image, label = images[0], labels[0]
model.
with torch.no_grad():output = model(image.unsqueeze(0))_, predicted = torch.max(output, 1)
plt.imshow(image.squeeze(), cmap='gray')
plt.title(f"Predicted: {predicted.item()}, Actual: {label.item()}")
plt.savefig('prediction.png')
9. 總結
通過本文,可以了解到如何用 PyTorch 實現一個簡單的神經網絡,包括:
加載和預處理 MNIST 數據集。
構建一個全連接神經網絡。
使用交叉熵損失和 Adam 優化器進行訓練。
在測試集上評估模型性能。
對單張圖像進行預測并可視化結果。
這個模型雖然簡單,但在 MNIST 數據集上通常能達到 95% 以上的準確率。可以進一步嘗試調整網絡結構(如增加層數)、優化超參數(如學習率)或使用卷積神經網絡(CNN)來提升性能。希望這篇文章對你理解 PyTorch 和深度學習有所幫助!