源碼
import torch
from torch import nn
from torchsummary import summary"""
DenseNet的核心組件:稠密層(DenseLayer)
實現特征復用機制,每個層的輸出會與所有前序層的輸出在通道維度拼接
"""class DenseLayer(nn.Module):def __init__(self, input_channels, growth_rate):super().__init__()# 批歸一化 + ReLU + 1x1卷積 (瓶頸層,減少計算量)self.bn1 = nn.BatchNorm2d(input_channels)self.conv1 = nn.Conv2d(input_channels, 4 * growth_rate, kernel_size=1)# 批歸一化 + ReLU + 3x3卷積 (特征提取層)self.bn2 = nn.BatchNorm2d(4 * growth_rate)self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1)self.relu = nn.ReLU()def forward(self, x):# 前向傳播:BN->ReLU->Conv(1x1)->BN->ReLU->Conv(3x3)out = self.conv1(self.relu(self.bn1(x)))out = self.conv2(self.relu(self.bn2(out)))# 將新特征與輸入特征在通道維度拼接(實現特征復用)return torch.cat([x, out], 1)"""
稠密塊(DenseBlock):由多個稠密層組成
每個稠密層的輸入包含前面所有層的特征圖
"""class DenseBlock(nn.Module):def __init__(self, num_layers, input_channels, growth_rate):super().__init__()layers = []# 構建num_layers個稠密層for i in range(num_layers):# 每層的輸入通道數 = 初始通道數 + 已添加的特征圖數layers.append(DenseLayer(input_channels + i * growth_rate, growth_rate))self.block = nn.Sequential(*layers)def forward(self, x):return self.block(x)"""
過渡層(TransitionLayer):用于壓縮特征圖尺寸和通道數
包含1x1卷積和平均池化
"""class TransitionLayer(nn.Module):def __init__(self, input_channels, output_channels):super().__init__()# 壓縮通道數的1x1卷積self.bn = nn.BatchNorm2d(input_channels)self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=1)# 下采樣用的平均池化self.pool = nn.AvgPool2d(2, stride=2)self.relu = nn.ReLU()def forward(self, x):# 前向傳播:BN->ReLU->Conv(1x1)->AvgPoolout = self.conv(self.relu(self.bn(x)))return self.pool(out)"""
完整的DenseNet網絡結構
包含初始卷積層、多個稠密塊+過渡層、分類層
"""class DynamicDenseNet(nn.Module):def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=5):super().__init__()# 初始卷積層(標準CNN開始結構)self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3), # 下采樣nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 進一步下采樣)# 構建稠密塊和過渡層num_channels = 64 # 初始通道數for i, num_layers in enumerate(block_config):# 添加稠密塊block = DenseBlock(num_layers, num_channels, growth_rate)self.features.add_module(f'denseblock{i + 1}', block)# 更新通道數(每個稠密層增加growth_rate個通道)num_channels += num_layers * growth_rate# 不是最后一個塊時添加過渡層if i != len(block_config) - 1:trans = TransitionLayer(num_channels, num_channels // 2)self.features.add_module(f'transition{i + 1}', trans)num_channels = num_channels // 2 # 過渡層壓縮通道數# 分類層self.classifier = nn.Sequential(nn.BatchNorm2d(num_channels),nn.ReLU(),nn.AdaptiveAvgPool2d((1, 1)), # 全局平均池化nn.Flatten(),nn.Linear(num_channels, num_classes) # 全連接輸出分類結果)def forward(self, x):features = self.features(x)return self.classifier(features)# 測試代碼
if __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = DynamicDenseNet().to(device)# 打印網絡結構和參數統計(輸入尺寸為3x224x224)print(summary(model, (3, 224, 224)))
流程圖
設計理念
密集連接機制
在DenseNet中,每個層都與其后續的所有層直接連接。這意味著:
- 第l層的輸入 = 所有前序層(0到l-1)的特征圖拼接
- 數學表示:x_l = H_l([x_0, x_1, ..., x_{l-1}])
- 與傳統架構相比,緩解了梯度消失問題,增強了特征傳播
特征復用機制
- 每個層都可以訪問所有前序層的特征圖
- 網絡自動學習在不同層級復用特征
- 減少了特征冗余,提高了參數效率
瓶頸層設計
每個DenseLayer包含:
- ?BN-ReLU-Conv(1×1)層?:作為瓶頸層,減少特征圖數量和計算量
- 將輸入通道壓縮到4×growth_rate
- ?BN-ReLU-Conv(3×3)層?:主特征提取層
- 輸出growth_rate個特征圖(通常growth_rate=12-48)
增長率(growth_rate)參數
- 控制每個層添加到特征圖的通道數
- 較小的growth_rate也能獲得優異性能(如k=12 vs ResNet k=64)
- 決定模型容量和參數效率的關鍵超參數
過渡層設計
- ?1×1卷積?:壓縮特征通道數(通常減少50%)
- ?2×2平均池化?:下采樣特征圖尺寸
- 公式:θ = 壓縮因子(通常0.5)
- output_channels = θ × input_channels
?充電:BatchNorm2d的用法
batchnorm2d是PyTorch中用于2D輸入的批歸一化(Batch Normalization)層。
參數 | 類型 | 默認值 | 說明 |
---|---|---|---|
num_features | int | - | 輸入通道數C |
eps | float | 1e-5 | 數值穩定項 |
momentum | float | 0.1 | 運行統計量更新系數 |
affine | bool | True | 是否啟用γ/β可學習參數 |
track_running_stats | bool | True | 是否記錄運行統計量 |
通常只需要設置輸入通道數即可。比如:
conv = nn.Conv2d(in_c, out_c, 3)
bn = nn.BatchNorm2d(out_c) # 注意與卷積輸出通道一致
relu = nn.ReLU()
output = relu(bn(conv(input)))
bn層可以做初始化設置,比如:
bn = nn.BatchNorm2d(64)
# 初始化縮放因子為1,偏移為0
nn.init.constant_(bn.weight, 1)
nn.init.constant_(bn.bias, 0)
?需要注意的是,當批次數量太小時,使用bn層可能表現不穩定。當batch<16時,建議使用GroupNorm方法做替代