UNet 由Ronneberger等人于2015年提出,專門針對醫學圖像分割任務,解決了早期卷積網絡在小樣本數據下的效率問題和細節丟失難題。
一 核心創新
1.1對稱編碼器-解碼器結構
實現上下文信息與高分辨率細節的雙向融合
如圖所示:編碼器進行了4步(紅框)到達了瓶頸層(紫框),每一步包含兩次3x3卷積+ReLU并通過通過2x2最大池化下采樣,到達瓶頸層后,解碼器也進行了4步(綠框),使用了轉置卷積上采樣后與編碼器對應層特征拼接(跳躍連接(灰色箭頭))后再進行兩次卷積。
可以看出解碼器和編碼器非常的對稱,呈現一個U型,所以叫UNet。
其中:
編碼器:通過池化逐漸擴大感受野。
解碼器:逐步恢復空間分辨率,精確定位目標邊界。
跳躍連接:將編碼器特征與解碼器特征拼接,融合多級信息解決深層網絡定位精度下降的問題
1.2跳躍連接(Skip Connections)
解決深層卷積神經網絡中空間信息丟失和細節模糊的核心問題。
因為編碼器下采樣會丟失細節,而解碼器上采樣又難以完全恢復位置信息,所以使用跳躍鏈接來補償細節。
1.2.1數學形式表達
設編碼器第? 層輸出為?
?,?解碼器第?
?層輸入為?
,?則跳躍連接操作:
: 沿通道維度拼接(Channel-wise Concatenation)?
? 轉置卷積/雙線性插值將解碼器輸出的分辨率提升至與編碼器相同
1.2.2特征融合方法
編碼器每層的輸出須與解碼器對應層上采樣后的尺寸匹配,拼接后總通道數為兩者之和。
(黑色圓圈)
# PyTorch代碼示例:拼接編碼器和解碼器特征
def forward(self, decoder_feat, encoder_feat):# decoder_feat: [B, C1, H, W] # encoder_feat: [B, C2, H, W]merged = torch.cat([decoder_feat, encoder_feat], dim=1) # 沿通道拼接return merged # 結果維度:[B, C1+C2, H, W]
?1.3端到端精細分割(End-to-End Fine Segmentation)
在少量標注數據下仍能輸出像素級預測
直接從原始輸入圖像生成像素級預測的模型設計范式,無需手動設計特征提取器或多階段后處理。
1.3.1核心
全流程自動映射:輸入 → 特征學習 → 高精度分割結果,中間過程由網絡自動優化
細節敏感機制:通過多層次特征融合、邊界增強模塊等手段保證細粒度分割
無后處理輸出:輸出可直接使用,無需形態學后處理
1.3.2技術實現
編碼器:通過卷積與池化逐層提取高層語義(形狀、位置)
# 編碼器層示例:每次下采樣通道數翻倍
class Encoder(nn.Module):def __init__(self, in_ch, out_ch):super().__init__()self.block = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1),#卷積nn.BatchNorm2d(out_ch),#標準化(歸一+線性變換)nn.ReLU(),#非線性激活nn.Conv2d(out_ch, out_ch, 3, padding=1),nn.BatchNorm2d(out_ch),nn.ReLU(),nn.MaxPool2d(2)#最大值池化)def forward(self, x):return self.block(x)
解碼器:上采樣恢復分辨率 + 跳躍連接補充細節
# 解碼器層示例:特征拼接后卷積
class Decoder(nn.Module):def __init__(self, in_ch, out_ch):super().__init__()self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)self.conv = nn.Sequential(nn.Conv2d(out_ch*2, out_ch, 3, padding=1), # 拼接后通道數翻倍nn.BatchNorm2d(out_ch),nn.ReLU())def forward(self, x, skip):x = self.up(x)x = torch.cat([x, skip], dim=1) # 與編碼器特征拼接return self.conv(x)
改良1: 注意力引導跳躍連接:通過空間注意力強化邊緣區域(在跳躍連接前應用空間注意力,突出邊緣信息)
class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super().__init__()self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2)self.sigmoid = nn.Sigmoid()def forward(self, x):avg = torch.mean(x, dim=1, keepdim=True)max_pool, _ = torch.max(x, dim=1, keepdim=True)concat = torch.cat([avg, max_pool], dim=1) # 沿通道維度拼接均值和最大值mask = self.sigmoid(self.conv(concat)) # 生成空間注意力掩碼return x * mask # 加權增強關鍵區域
改良2: 多尺度損失監督:在不同解碼層注入輔助損失。
class MultiScaleLoss(nn.Module):def __init__(self, losses):super().__init__()self.losses = losses # 各層對應的損失函數列表def forward(self, preds, target):total_loss = 0for pred, loss_fn in zip(preds, self.losses):# 將目標下采樣至與當前預測同尺寸_, _, H, W = pred.shaperesized_target = F.interpolate(target, size=(H,W), mode='nearest')total_loss += loss_fn(pred, resized_target)return total_loss
適用性擴展:該范式可遷移至其他密集預測任務,如衛星影像分析、自動駕駛場景理解等。
二 與傳統分割模型對比
模型 | 優勢 | 局限性 |
---|---|---|
FCN | 全卷積保留空間信息 | 輸出分辨率粗糙,跳躍連接簡單 |
SegNet | 使用池化索引提升精度 | 特征復用效率低 |
DeepLab | 空洞卷積擴大感受野 | 小目標分割邊緣模糊 |
UNet | 對稱結構+密集跳躍連接,細節恢復 | 原版對大尺度變化敏感 |
三 UNet的改良方法?
3.1跨尺度空洞卷積替換編碼器的普通卷積層
在底層使用擴張率=1捕捉細節,高層使用d=3或5擴大感受野。
# 原編碼器卷積塊
self.encoder_conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1),nn.ReLU(),nn.Conv2d(out_ch, out_ch, 3, padding=1),nn.ReLU()
)# 改進:跨尺度空洞卷積模塊
self.encoder_conv = CrossScaleDilatedConv(in_ch, out_ch)
3.2融入密集塊融合增強跳躍連接的特征傳遞
在編碼器和解碼器拼接前加入密集塊
class ImprovedSkipConnection(nn.Module):def __init__(self, in_ch):super().__init__()self.dense_block = DenseBlock(num_layers=4, in_channels=in_ch)def forward(self, enc_feat, dec_feat):enc_processed = self.dense_block(enc_feat) # 特征增強merged = torch.cat([enc_processed, dec_feat], dim=1)return merged# 在UNet解碼器中應用
def forward(self, x):# ... 編碼過程d4 = self.upconv4(d5)d4 = self.skip_conn4(e4, d4) # 使用改進的跳躍連接d4 = self.decoder_conv4(d4)# ...
四 核心代碼(未改良)
class UNet(nn.Module):def __init__(self, n_class=1):super().__init__()# 編碼器self.enc1 = EncoderBlock(3, 64)self.enc2 = EncoderBlock(64, 128)self.enc3 = EncoderBlock(128, 256)self.enc4 = EncoderBlock(256, 512)self.bottleneck = EncoderBlock(512, 1024)# 解碼器self.upconv4 = UpConv(1024, 512)self.dec4 = DecoderBlock(1024, 512) # 輸入1024因拼接self.upconv3 = UpConv(512, 256)self.dec3 = DecoderBlock(512, 256)self.upconv2 = UpConv(256, 128)self.dec2 = DecoderBlock(256, 128)self.upconv1 = UpConv(128, 64)self.dec1 = DecoderBlock(128, 64)self.final = nn.Conv2d(64, n_class, kernel_size=1)def forward(self, x):# 編碼e1 = self.enc1(x)e2 = self.enc2(F.max_pool2d(e1, 2))e3 = self.enc3(F.max_pool2d(e2, 2))e4 = self.enc4(F.max_pool2d(e3, 2))bn = self.bottleneck(F.max_pool2d(e4, 2))# 解碼d4 = self.dec4(self.upconv4(bn), e4)d3 = self.dec3(self.upconv3(d4), e3)d2 = self.dec2(self.upconv2(d3), e2)d1 = self.dec1(self.upconv1(d2), e1)return torch.sigmoid(self.final(d1))class EncoderBlock(nn.Module):def __init__(self, in_ch, out_ch):super().__init__()self.conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1),nn.BatchNorm2d(out_ch),nn.ReLU(),nn.Conv2d(out_ch, out_ch, 3, padding=1),nn.BatchNorm2d(out_ch),nn.ReLU())def forward(self, x):return self.conv(x)class UpConv(nn.Module):def __init__(self, in_ch, out_ch):super().__init__()self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)def forward(self, x):return self.up(x)class DecoderBlock(nn.Module):def __init__(self, in_ch, out_ch):super().__init__()self.conv = EncoderBlock(in_ch, out_ch)def forward(self, x, skip):x = torch.cat([x, skip], dim=1) # 通道拼接return self.conv(x)
UNet憑借其優雅的對稱結構和密集跳躍連接,成為醫學圖像分割的基準模型。通過集成跨尺度空洞卷積與密集塊融合等模塊,可顯著提升其對多尺度目標的適應性。