隨著深度學習技術的不斷發展,神經網絡架構變得越來越復雜,而這些復雜網絡在訓練時常常遇到梯度消失、梯度爆炸以及計算效率低等問題。為了克服這些問題,研究者們提出了多種網絡架構,包括 殘差網絡(ResNet)、加權殘差連接(WRC) 和 跨階段部分連接(CSP)。
本文將詳細介紹這三種網絡架構的基本概念、工作原理以及如何在 PyTorch 中實現它們。我們會通過代碼示例來展示每個技術的實現方式,并重點講解其中的核心部分。
目錄
一、殘差網絡(ResNet)
1.1 殘差網絡的背景與原理
1.2 殘差塊的實現
重點
二、加權殘差連接(WRC)
2.1 WRC的提出背景
2.2 WRC的實現
重點
三、跨階段部分連接(CSP)
3.1 CSP的提出背景
3.2 CSP的實現
重點
四、總結
一、殘差網絡(ResNet)
1.1 殘差網絡的背景與原理
有關于殘差網絡,詳情可以查閱以下博客,更為詳細與新手向:
YOLO系列基礎(三)從ResNet殘差網絡到C3層-CSDN博客
深層神經網絡的訓練常常遭遇梯度消失或梯度爆炸的問題,導致訓練效果不好。為了解決這一問題,微軟的何凱明等人提出了 殘差網絡(ResNet),引入了“跳躍連接(skip connections)”的概念,使得信息可以直接繞過某些層傳播,從而避免了深度網絡訓練中的問題。
在傳統的神經網絡中,每一層都試圖學習輸入到輸出的映射。但在 ResNet 中,網絡不再直接學習從輸入到輸出的映射,而是學習輸入與輸出之間的“殘差”,即
其中 是網絡學到的殘差部分,
是輸入。
這種方式顯著提升了網絡的訓練效果,并且讓深層網絡的訓練變得更加穩定。
1.2 殘差塊的實現
下面是一個簡單的殘差塊實現,它包括了兩層卷積和一個跳躍連接。跳躍連接幫助保持梯度的流動,避免深層網絡中的梯度消失問題。
圖例如下:
代碼示例如下:
import torch
import torch.nn as nn
import torch.nn.functional as F# 定義殘差塊
class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels):super(ResidualBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.bn2 = nn.BatchNorm2d(out_channels)# 如果輸入和輸出的通道數不同,則使用1x1卷積調整尺寸if in_channels != out_channels:self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)else:self.shortcut = nn.Identity()def forward(self, x):out = F.relu(self.bn1(self.conv1(x))) # 第一層卷積后激活out = self.bn2(self.conv2(out)) # 第二層卷積out += self.shortcut(x) # 殘差連接return F.relu(out) # ReLU激活# 構建ResNet
class ResNet(nn.Module):def __init__(self, num_classes=10):super(ResNet, self).__init__()self.layer1 = ResidualBlock(3, 64)self.layer2 = ResidualBlock(64, 128)self.layer3 = ResidualBlock(128, 256)self.fc = nn.Linear(256, num_classes)def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = F.adaptive_avg_pool2d(x, (1, 1)) # 全局平均池化x = torch.flatten(x, 1) # 展平x = self.fc(x) # 全連接層return x# 示例:構建一個簡單的 ResNet
model = ResNet(num_classes=10)
print(model)
重點
- 殘差連接的實現:在
ResidualBlock
類中,out += self.shortcut(x)
實現了輸入與輸出的加法操作,這是殘差學習的核心。 - 處理輸入和輸出通道數不一致的情況:如果輸入和輸出的通道數不同,通過使用 1x1 卷積調整輸入的維度,確保加法操作能夠進行。
二、加權殘差連接(WRC)
2.1 WRC的提出背景
傳統的殘差網絡通過簡單的跳躍連接將輸入和輸出相加,但在某些情況下,不同層的輸出對最終結果的貢獻是不同的。為了讓網絡更靈活地調整各層貢獻,加權殘差連接(WRC) 引入了可學習的權重。公式如下
其中 是網絡學到的殘差部分,
是輸入,
?和?
是權重。
WRC通過為每個殘差連接引入可學習的權重 和
,使得網絡能夠根據任務需求自適應地調整每個連接的貢獻。
2.2 WRC的實現
以下是 WRC 的實現代碼,我們為每個殘差連接引入了權重參數 alpha
和 beta
,這些參數通過訓練進行優化。
圖例如下:
可以看到,加權殘差快其實就是給殘差網絡的兩條分支加個權而已?
代碼示例如下:?
class WeightedResidualBlock(nn.Module):def __init__(self, in_channels, out_channels):super(WeightedResidualBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.bn2 = nn.BatchNorm2d(out_channels)# 權重初始化self.alpha = nn.Parameter(torch.ones(1)) # 可學習的權重self.beta = nn.Parameter(torch.ones(1)) # 可學習的權重# 如果輸入和輸出的通道數不同,則使用1x1卷積調整尺寸if in_channels != out_channels:self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)else:self.shortcut = nn.Identity()def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))# 加權殘差連接:使用可學習的權重 alpha 和 betaout = self.alpha * out + self.beta * self.shortcut(x)return F.relu(out)# 示例:構建一個加權殘差塊
model_wrc = WeightedResidualBlock(3, 64)
print(model_wrc)
重點
-
可學習的權重
alpha
和beta
:我們為殘差塊中的兩個加法項(即殘差部分和輸入部分)引入了可學習的權重。通過訓練,這些權重可以自動調整,使網絡能夠根據任務需求更好地融合輸入和輸出。 -
加權殘差連接的實現:在
forward
方法中,out = self.alpha * out + self.beta * self.shortcut(x)
表示加權殘差連接,其中alpha
和beta
是可學習的參數。
三、跨階段部分連接(CSP)
3.1 CSP的提出背景
雖然 ResNet 和 WRC 提供了有效的殘差學習和信息融合機制,但在一些更復雜的網絡中,信息的傳遞依然面臨冗余和計算開銷較大的問題。為了解決這一問題,跨階段部分連接(CSP) 提出了更加高效的信息傳遞方式。CSP通過選擇性地傳遞部分信息而不是所有信息,減少了計算量并保持了模型的表達能力。
3.2 CSP的實現
CSP通過分割輸入特征,并在不同階段進行不同的處理,從而減少冗余的信息傳遞。下面是 CSP 的實現代碼。
CSP思想圖例如下:
特征分割(Feature Splitting):CSP通過分割輸入特征圖,并將分割后的特征圖分別送入不同的子網絡進行處理。一般來說,一條分支的子網絡會比較簡單,一條分支的自網絡則是原來主干網絡的一部分。
重點
- 部分特征選擇性連接:將輸入特征分為兩部分。每部分特征單獨經過卷積處理后,通過
torch.cat()
進行拼接,形成最終的輸出。 - 跨階段部分連接:CSP塊通過分割輸入特征并在不同階段處理,有效地減少了計算開銷,并且保持了網絡的表達能力。
四、總結
本文介紹了 殘差網絡(ResNet)、加權殘差連接(WRC) 和 跨階段部分連接(CSP) 這三種網絡架構。
finally,求贊求贊求贊~