ResNet(殘差網絡)是深度學習中的經典模型,通過引入殘差連接解決了深層網絡訓練中的梯度消失問題。本文將從殘差塊的定義開始,逐步實現一個ResNet模型,并在Fashion MNIST數據集上進行訓練和測試。
1. 殘差塊(Residual Block)實現
殘差塊通過跳躍連接(Shortcut Connection)將輸入直接傳遞到輸出,緩解了深層網絡的訓練難題。以下是殘差塊的PyTorch實現:
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lclass Residual(nn.Module):def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)else:self.conv3 = Noneself.relu = nn.ReLU(inplace=True)def forward(self, x):y = F.relu(self.bn1(self.conv1(x)))y = self.bn2(self.conv2(y))if self.conv3:x = self.conv3(x)y += xreturn F.relu(y)
代碼解析:
-
use_1x1conv
:當輸入和輸出通道數不一致時,使用1x1卷積調整通道數。 -
strides
:控制特征圖下采樣的步長。 -
殘差相加后再次使用ReLU激活,增強非線性表達能力。
2. 構建ResNet模型
ResNet由多個殘差塊堆疊而成,以下代碼構建了一個簡化版ResNet-18:
# 初始卷積層
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)def resnet_block(input_channels, num_channels, num_residuals, first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block: # 第一個塊需下采樣blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blk# 堆疊殘差塊
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))# 完整網絡結構
net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1, 1)),nn.Flatten(),nn.Linear(512, 10)
)
模型結構說明:
-
AdaptiveAvgPool2d
:自適應平均池化,將特征圖尺寸統一為1x1。 -
Flatten
:展平特征用于全連接層分類。
3. 數據加載與預處理
使用Fashion MNIST數據集,批量大小為256:
train_data, test_data = d2l.load_data_fashion_mnist(batch_size=256)
4. 模型訓練與測試
設置訓練參數:10個epoch,學習率0.05,并使用GPU加速:
d2l.train_ch6(net, train_data, test_data, num_epochs=10, lr=0.05, device=d2l.try_gpu())
訓練結果:
loss 0.124, train acc 0.952, test acc 0.860
4921.4 examples/sec on cuda:0
5. 結果可視化
訓練過程中損失和準確率變化如下圖所示:
分析:
-
訓練準確率(紫色虛線)迅速上升并穩定在95%以上。
-
測試準確率(綠色點線)達到86%,表明模型具有良好的泛化能力。
-
損失值(藍色實線)持續下降,未出現過擬合。
6. 完整代碼
整合所有代碼片段(需安裝d2l
庫):
# 殘差塊定義、模型構建、訓練代碼見上文
7. 總結
本文實現了ResNet的核心組件——殘差塊,并構建了一個簡化版ResNet模型。通過實驗驗證,模型在Fashion MNIST數據集上表現良好。讀者可嘗試調整網絡深度或超參數以進一步提升性能。
改進方向:
-
增加殘差塊數量構建更深的ResNet(如ResNet-34/50)。
-
使用數據增強策略提升泛化能力。
-
嘗試不同的優化器和學習率調度策略。
注意事項:
-
確保已安裝PyTorch和
d2l
庫。 -
GPU環境可顯著加速訓練,若使用CPU需調整批量大小。
希望本文能幫助您理解ResNet的實現細節!如有疑問,歡迎在評論區留言討論。