在計算機視覺領域,UNet因其優異的性能在圖像分割任務中廣受歡迎。本文將介紹一種改進的UNet架構——UNetWithCrossAttention,它通過引入交叉注意力機制來增強模型的特征融合能力。
1. 交叉注意力機制
交叉注意力(Cross Attention)是一種讓模型能夠動態地從輔助特征中提取相關信息來增強主特征的機制。在我們的實現中,CrossAttention
類實現了這一功能:
class CrossAttention(nn.Module):def __init__(self, channels):super(CrossAttention, self).__init__()self.query_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.key_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.value_conv = nn.Conv2d(channels, channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x1, x2):batch_size, C, height, width = x1.size()# 投影到query, key, value空間proj_query = self.query_conv(x1).view(batch_size, -1, height * width).permute(0, 2, 1)proj_key = self.key_conv(x2).view(batch_size, -1, height * width)proj_value = self.value_conv(x2).view(batch_size, -1, height * width)# 計算注意力圖energy = torch.bmm(proj_query, proj_key)attention = torch.softmax(energy / math.sqrt(proj_key.size(-1)), dim=-1)# 應用注意力out = torch.bmm(proj_value, attention.permute(0, 2, 1))out = out.view(batch_size, C, height, width)# 殘差連接out = self.gamma * out + x1return out
該模塊的工作原理是:
-
將主特征x1投影為query,輔助特征x2投影為key和value
-
計算query和key的相似度得到注意力權重
-
使用注意力權重對value進行加權求和
-
通過殘差連接將結果與原始主特征融合
2. 雙卷積模塊
DoubleConv
是UNet中的基礎構建塊,包含兩個連續的卷積層,并可選擇性地加入交叉注意力:
class DoubleConv(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(DoubleConv, self).__init__()self.use_cross_attention = use_cross_attentionself.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))self.conv2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))if use_cross_attention:self.cross_attention = CrossAttention(out_channels)def forward(self, x, aux_feature=None):x = self.conv1(x)x = self.conv2(x)if self.use_cross_attention and aux_feature is not None:x = self.cross_attention(x, aux_feature)return x
3. 下采樣和上采樣模塊
下采樣模塊Down
結合了最大池化和雙卷積:
class Down(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Down, self).__init__()self.downsampling = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2),DoubleConv(in_channels, out_channels, use_cross_attention))def forward(self, x, aux_feature=None):return self.downsampling[1](self.downsampling[0](x), aux_feature)
上采樣模塊Up
使用轉置卷積進行上采樣并拼接特征:
class Up(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Up, self).__init__()self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels, use_cross_attention)def forward(self, x1, x2, aux_feature=None):x1 = self.upsampling(x1)x = torch.cat([x2, x1], dim=1)x = self.conv(x, aux_feature)return x
4. 完整的UNetWithCrossAttention架構
將上述模塊組合起來,我們得到了完整的UNetWithCrossAttention:
class UNetWithCrossAttention(nn.Module):def __init__(self, in_channels=1, num_classes=1, use_cross_attention=False):super(UNetWithCrossAttention, self).__init__()self.in_channels = in_channelsself.num_classes = num_classesself.use_cross_attention = use_cross_attention# 編碼器self.in_conv = DoubleConv(in_channels, 64, use_cross_attention)self.down1 = Down(64, 128, use_cross_attention)self.down2 = Down(128, 256, use_cross_attention)self.down3 = Down(256, 512, use_cross_attention)self.down4 = Down(512, 1024, use_cross_attention)# 解碼器self.up1 = Up(1024, 512, use_cross_attention)self.up2 = Up(512, 256, use_cross_attention)self.up3 = Up(256, 128, use_cross_attention)self.up4 = Up(128, 64, use_cross_attention)self.out_conv = OutConv(64, num_classes)def forward(self, x, aux_feature=None):# 編碼過程x1 = self.in_conv(x, aux_feature)x2 = self.down1(x1, aux_feature)x3 = self.down2(x2, aux_feature)x4 = self.down3(x3, aux_feature)x5 = self.down4(x4, aux_feature)# 解碼過程x = self.up1(x5, x4, aux_feature)x = self.up2(x, x3, aux_feature)x = self.up3(x, x2, aux_feature)x = self.up4(x, x1, aux_feature)x = self.out_conv(x)return x
5. 應用場景與優勢
這種帶有交叉注意力的UNet架構特別適合以下場景:
-
多模態圖像分割:當有來自不同成像模態的輔助信息時,交叉注意力可以幫助模型有效地融合這些信息
-
時序圖像分析:對于視頻序列,前一幀的特征可以作為輔助特征來增強當前幀的分割
-
弱監督學習:當有額外的弱監督信號時,可以通過交叉注意力將其融入主網絡
相比于傳統UNet,這種架構的優勢在于:
-
能夠動態地關注輔助特征中最相關的部分
-
通過注意力機制實現更精細的特征融合
-
保留了UNet原有的多尺度特征提取能力
-
通過殘差連接避免了信息丟失
6. 總結
本文介紹了一種增強版的UNet架構,通過引入交叉注意力機制,使模型能夠更有效地利用輔助特征。這種設計既保留了UNet原有的優勢,又增加了靈活的特征融合能力,特別適合需要整合多源信息的復雜視覺任務。
在實際應用中,可以根據具體任務需求選擇在哪些層級啟用交叉注意力,也可以調整注意力模塊的復雜度來平衡模型性能和計算開銷。
希望這篇文章能幫助你理解交叉注意力在UNet中的應用。如果你有任何問題或建議,歡迎在評論區留言討論!
完整代碼
如下:
import torch.nn as nn
import torch
import mathclass CrossAttention(nn.Module):def __init__(self, channels):super(CrossAttention, self).__init__()self.query_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.key_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.value_conv = nn.Conv2d(channels, channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x1, x2):"""x1: 主特征 (batch, channels, height, width)x2: 輔助特征 (batch, channels, height, width)"""batch_size, C, height, width = x1.size()# 投影到query, key, value空間proj_query = self.query_conv(x1).view(batch_size, -1, height * width).permute(0, 2, 1) # (B, N, C')proj_key = self.key_conv(x2).view(batch_size, -1, height * width) # (B, C', N)proj_value = self.value_conv(x2).view(batch_size, -1, height * width) # (B, C, N)# 計算注意力圖energy = torch.bmm(proj_query, proj_key) # (B, N, N)attention = torch.softmax(energy / math.sqrt(proj_key.size(-1)), dim=-1)# 應用注意力out = torch.bmm(proj_value, attention.permute(0, 2, 1)) # (B, C, N)out = out.view(batch_size, C, height, width)# 殘差連接out = self.gamma * out + x1return outclass DoubleConv(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(DoubleConv, self).__init__()self.use_cross_attention = use_cross_attentionself.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True)self.conv2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))if use_cross_attention:self.cross_attention = CrossAttention(out_channels)def forward(self, x, aux_feature=None):x = self.conv1(x)x = self.conv2(x)if self.use_cross_attention and aux_feature is not None:x = self.cross_attention(x, aux_feature)return xclass Down(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Down, self).__init__()self.downsampling = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2),DoubleConv(in_channels, out_channels, use_cross_attention))def forward(self, x, aux_feature=None):return self.downsampling[1](self.downsampling[0](x), aux_feature)class Up(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Up, self).__init__()self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels, use_cross_attention)def forward(self, x1, x2, aux_feature=None):x1 = self.upsampling(x1)x = torch.cat([x2, x1], dim=1)x = self.conv(x, aux_feature)return xclass UNetWithCrossAttention(nn.Module):def __init__(self, in_channels=1, num_classes=1, use_cross_attention=False):super(UNetWithCrossAttention, self).__init__()self.in_channels = in_channelsself.num_classes = num_classesself.use_cross_attention = use_cross_attention# 編碼器self.in_conv = DoubleConv(in_channels, 64, use_cross_attention)self.down1 = Down(64, 128, use_cross_attention)self.down2 = Down(128, 256, use_cross_attention)self.down3 = Down(256, 512, use_cross_attention)self.down4 = Down(512, 1024, use_cross_attention)# 解碼器self.up1 = Up(1024, 512, use_cross_attention)self.up2 = Up(512, 256, use_cross_attention)self.up3 = Up(256, 128, use_cross_attention)self.up4 = Up(128, 64, use_cross_attention)self.out_conv = OutConv(64, num_classes)def forward(self, x, aux_feature=None):# 編碼過程x1 = self.in_conv(x, aux_feature)x2 = self.down1(x1, aux_feature)x3 = self.down2(x2, aux_feature)x4 = self.down3(x3, aux_feature)x5 = self.down4(x4, aux_feature)# 解碼過程x = self.up1(x5, x4, aux_feature)x = self.up2(x, x3, aux_feature)x = self.up3(x, x2, aux_feature)x = self.up4(x, x1, aux_feature)x = self.out_conv(x)return x