概念
殘差網絡(Residual Network,ResNet)是一種深度卷積神經網絡結構,旨在解決深層網絡訓練中的梯度消失和梯度爆炸問題,以及幫助訓練非常深的網絡。ResNet 在2015年被提出,其核心思想是引入了"殘差塊"(residual block)來克服訓練深層網絡時的優化問題。
傳統的神經網絡認為層與層之間是逐漸學習到更高級的特征表示的,但在實踐中,增加層數可能會導致性能下降,這是因為深層網絡在訓練過程中可能會難以優化。ResNet 通過引入"跳躍連接"或"殘差連接",使得網絡可以學習殘差(即原始特征)并將其添加到后續層的輸出中,從而解決了這個問題。
一個殘差塊的結構
Input|
Convolution|
Batch Normalization|
ReLU|
Convolution|
Batch Normalization|
Output|
Addition (Residual Connection)|
ReLU
代碼實現
import torch
import torch.nn as nnclass ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super(ResidualBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)# 如果輸入輸出通道數不匹配,使用 1x1 卷積調整維度self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels))def forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += self.shortcut(residual)out = self.relu(out)return out# 創建一個殘差塊實例
residual_block = ResidualBlock(in_channels=64, out_channels=128, stride=2)
print(residual_block)