U-Net是醫學圖像分割領域最成功的架構之一,其對稱的編碼器-解碼器結構和跳躍連接使其能夠有效捕捉多尺度特征。本文將解析一個改進版的U-Net實現,該版本通過引入Squeeze-and-Excitation(SE)模塊進一步提升了模型性能。
一、架構概覽
這個改進的U-Net保持了經典U-Net的核心結構,但在每個卷積塊后添加了SE模塊,主要包含以下幾個關鍵組件:
-
SE注意力模塊:增強重要通道的特征響應
-
雙卷積塊:基礎特征提取單元
-
編碼器-解碼器結構:逐步下采樣和上采樣
-
跳躍連接:結合低層和高層特征
二、核心組件詳解
1. SE注意力模塊 (SELayer)
class SELayer(nn.Module):def __init__(self, in_channels, reduction=16):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(in_channels // reduction, in_channels, bias=False),nn.Sigmoid())
SE模塊通過以下步驟工作:
-
使用全局平均池化將空間信息壓縮為一個通道描述符
-
通過兩個全連接層學習通道間的依賴關系
-
使用Sigmoid激活生成通道權重
-
將權重應用于原始特征圖
這種機制讓模型能夠自適應地強調重要特征通道,抑制不重要的通道。
2. 改進的雙卷積塊 (DoubleConv)
class DoubleConv(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),SELayer(out_channels) # 添加 SE 模塊)
每個雙卷積塊包含:
-
兩個3×3卷積層,保持空間分辨率(padding=1)
-
每個卷積后接批量歸一化和ReLU激活
-
最后添加SE模塊進行通道注意力加權
3. 完整的改進U-Net (ImprovedUNet)
編碼器部分通過最大池化逐步下采樣,解碼器部分通過轉置卷積上采樣,并結合跳躍連接:
class ImprovedUNet(nn.Module):def __init__(self, n_channels, n_classes):# 初始化各層...def forward(self, x):# 編碼過程x1 = self.inc(x) # 初始卷積x2 = self.down1(x1) # 下采樣1x3 = self.down2(x2) # 下采樣2x4 = self.down3(x3) # 下采樣3x5 = self.down4(x4) # 下采樣4# 解碼過程x = self.up1(x5)x = self.double_conv_up1(torch.cat([x, F.interpolate(x4, size=x.shape[2:])], dim=1))# ...類似處理其他上采樣層return self.outc(x)
三、創新點與優勢
-
SE模塊集成:在每個雙卷積塊后添加SE模塊,使模型能夠自適應地重新校準通道特征響應
-
改進的特征融合:使用雙線性插值調整跳躍連接特征圖尺寸,確保精確對齊
-
參數效率:通過factor參數控制解碼器通道數,平衡模型容量和計算成本
四、性能分析
這個改進版U-Net相比原始U-Net有以下潛在優勢:
-
更好的特征選擇能力,通過SE模塊突出重要特征
-
更穩定的訓練,得益于批量歸一化的廣泛使用
-
更精確的邊界預測,得益于改進的特征融合方式
五、使用示例
# 創建模型實例
model = ImprovedUNet(n_channels=3, n_classes=1)# 隨機輸入測試
input_tensor = torch.randn(2, 3, 256, 256) # 2張256x256的RGB圖像
output = model(input_tensor) # 輸出形狀為[2, 1, 256, 256]
六、完整代碼
import torch
import torch.nn as nn
import torch.nn.functional as F# SE 模塊
class SELayer(nn.Module):def __init__(self, in_channels, reduction=16):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(in_channels // reduction, in_channels, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)# 改進的卷積塊
class DoubleConv(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),SELayer(out_channels) # 添加 SE 模塊)def forward(self, x):return self.double_conv(x)# 改進的 U-Net 模型
class ImprovedUNet(nn.Module):def __init__(self, n_channels, n_classes):super().__init__()self.n_channels = n_channelsself.n_classes = n_classesself.inc = DoubleConv(n_channels, 64)self.down1 = nn.Sequential(nn.MaxPool2d(2),DoubleConv(64, 128))self.down2 = nn.Sequential(nn.MaxPool2d(2),DoubleConv(128, 256))self.down3 = nn.Sequential(nn.MaxPool2d(2),DoubleConv(256, 512))factor = 2self.down4 = nn.Sequential(nn.MaxPool2d(2),DoubleConv(512, 1024 // factor))self.up1 = nn.Sequential(nn.ConvTranspose2d(1024 // factor, 512 // factor, kernel_size=2, stride=2))self.double_conv_up1 = DoubleConv(512 // factor + 512, 512 // factor)self.up2 = nn.Sequential(nn.ConvTranspose2d(512 // factor, 256 // factor, kernel_size=2, stride=2))self.double_conv_up2 = DoubleConv(256 // factor + 256, 256 // factor)self.up3 = nn.Sequential(nn.ConvTranspose2d(256 // factor, 128 // factor, kernel_size=2, stride=2))self.double_conv_up3 = DoubleConv(128 // factor + 128, 128 // factor)self.up4 = nn.Sequential(nn.ConvTranspose2d(128 // factor, 64, kernel_size=2, stride=2))self.double_conv_up4 = DoubleConv(64 + 64, 64)self.outc = nn.Conv2d(64, n_classes, kernel_size=1)def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5)x = self.double_conv_up1(torch.cat([x, F.interpolate(x4, size=x.shape[2:], mode='bilinear', align_corners=True)], dim=1))x = self.up2(x)x = self.double_conv_up2(torch.cat([x, F.interpolate(x3, size=x.shape[2:], mode='bilinear', align_corners=True)], dim=1))x = self.up3(x)x = self.double_conv_up3(torch.cat([x, F.interpolate(x2, size=x.shape[2:], mode='bilinear', align_corners=True)], dim=1))x = self.up4(x)x = self.double_conv_up4(torch.cat([x, F.interpolate(x1, size=x.shape[2:], mode='bilinear', align_corners=True)], dim=1))logits = self.outc(x)return logits# 創建改進的 U-Net 模型實例
model = ImprovedUNet(n_channels=3, n_classes=1)
print(model)# 生成一個隨機輸入
input_tensor = torch.randn(2, 3, 256, 256)# 前向傳播
output = model(input_tensor)
print(output.shape)
七、適用場景
這種改進的U-Net特別適合以下任務:
-
醫學圖像分割(CT/MRI)
-
遙感圖像解析
-
任何需要精確邊界預測的密集預測任務
八、總結
通過在U-Net中集成SE模塊,我們獲得了能夠自適應關注重要特征的改進架構。這種設計在不顯著增加計算成本的情況下,提高了模型的特征選擇能力,使其在各種圖像分割任務中表現更加出色。