論文信息
-
標題: CSAM: A 2.5D Cross-Slice Attention Module for Anisotropic Volumetric Medical Image Segmentation
-
論文鏈接: https://arxiv.org/pdf/2311.04942
-
GitHub鏈接: https://github.com/aL3x-O-o-Hung/CSAM
創新點
CSAM(跨切片注意力模塊)旨在解決傳統3D和2D醫學圖像分割方法在處理各向異性體積數據時的不足。其主要創新包括:
-
跨切片注意力機制: 通過在不同尺度的深度特征圖上應用語義、位置和切片注意力,CSAM能夠有效捕捉體積數據中不同切片之間的關系。
-
參數優化: CSAM設計了最小可訓練參數的結構,減少了模型的復雜性,同時保持了良好的性能。
-
2.5D方法的應用: 該模塊結合了2D卷積與體積信息,填補了3D和2D方法之間的空白,特別適用于MRI等各向異性數據。
方法
CSAM的實現方法包括以下幾個步驟:
-
特征提取: 使用卷積神經網絡(CNN)提取輸入的體積數據特征。
-
注意力機制: 在提取的特征圖上應用跨切片注意力機制,分別關注語義信息、位置關系和切片信息,以增強特征的表達能力。
-
模型訓練: 通過最小化損失函數來訓練模型,確保模型能夠有效學習到各向異性體積數據的特征。
效果
實驗結果表明,CSAM在多個醫學圖像分割任務中表現出色,尤其是在處理各向異性數據時,其性能優于傳統的3D和2D方法。具體效果包括:
-
分割精度: CSAM在分割精度上達到了新的狀態,能夠更好地識別和分割復雜的醫學圖像結構。
-
訓練效率: 由于參數較少,CSAM的訓練時間顯著低于其他復雜模型。
實驗結果
研究者進行了廣泛的實驗,以驗證CSAM的有效性和通用性。實驗包括:
-
數據集: 使用多個公開的醫學圖像數據集進行測試,涵蓋不同的醫學成像技術(如MRI)。
-
對比實驗: 將CSAM與現有的3D和2D分割模型進行比較,結果顯示CSAM在多個指標上均優于對比模型。
-
泛化能力: CSAM在不同任務和數據集上的表現一致,證明了其良好的泛化能力。
總結
CSAM作為一種新穎的2.5D跨切片注意力模塊,為各向異性體積醫學圖像分割提供了有效的解決方案。通過引入跨切片注意力機制,CSAM不僅提高了分割精度,還減少了模型的復雜性和訓練時間。實驗結果驗證了其在醫學圖像處理中的廣泛適用性和優越性能,為未來的研究提供了新的思路和方法。
代碼
import torch
import torch.nn.functional
from torch import nn
import torch.distributions as td
def custom_max(x,dim,keepdim=True):temp_x=xfor i in dim:temp_x=torch.max(temp_x,dim=i,keepdim=True)[0]if not keepdim:temp_x=temp_x.squeeze()return temp_xclass PositionalAttentionModule(nn.Module):def __init__(self):super(PositionalAttentionModule,self).__init__()self.conv=nn.Conv2d(in_channels=2,out_channels=1,kernel_size=(7,7),padding=3)def forward(self,x):max_x=custom_max(x,dim=(0,1),keepdim=True)avg_x=torch.mean(x,dim=(0,1),keepdim=True)att=torch.cat((max_x,avg_x),dim=1)att=self.conv(att)att=torch.sigmoid(att)return x*attclass SemanticAttentionModule(nn.Module):def __init__(self,in_features,reduction_rate=16):super(SemanticAttentionModule,self).__init__()self.linear=[]self.linear.append(nn.Linear(in_features=in_features,out_features=in_features//reduction_rate))self.linear.append(nn.ReLU())self.linear.append(nn.Linear(in_features=in_features//reduction_rate,out_features=in_features))self.linear=nn.Sequential(*self.linear)def forward(self,x):max_x=custom_max(x,dim=(0,2,3),keepdim=False).unsqueeze(0)avg_x=torch.mean(x,dim=(0,2,3),keepdim=False).unsqueeze(0)max_x=self.linear(max_x)avg_x=self.linear(avg_x)att=max_x+avg_xatt=torch.sigmoid(att).unsqueeze(-1).unsqueeze(-1)return x*attclass SliceAttentionModule(nn.Module):def __init__(self,in_features,rate=4,uncertainty=True,rank=5):super(SliceAttentionModule,self).__init__()self.uncertainty=uncertaintyself.rank=rankself.linear=[]self.linear.append(nn.Linear(in_features=in_features,out_features=int(in_features*rate)))self.linear.append(nn.ReLU())self.linear.append(nn.Linear(in_features=int(in_features*rate),out_features=in_features))self.linear=nn.Sequential(*self.linear)if uncertainty:self.non_linear=nn.ReLU()self.mean=nn.Linear(in_features=in_features,out_features=in_features)self.log_diag=nn.Linear(in_features=in_features,out_features=in_features)self.factor=nn.Linear(in_features=in_features,out_features=in_features*rank)def forward(self,x):max_x=custom_max(x,dim=(1,2,3),keepdim=False).unsqueeze(0)avg_x=torch.mean(x,dim=(1,2,3),keepdim=False).unsqueeze(0)max_x=self.linear(max_x)avg_x=self.linear(avg_x)att=max_x+avg_xif self.uncertainty:temp=self.non_linear(att)mean=self.mean(temp)diag=self.log_diag(temp).exp()factor=self.factor(temp)factor=factor.view(1,-1,self.rank)dist=td.LowRankMultivariateNormal(loc=mean,cov_factor=factor,cov_diag=diag)att=dist.sample()att=torch.sigmoid(att).squeeze().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)return x*attclass CSAM(nn.Module):def __init__(self,num_slices,num_channels,semantic=True,positional=True,slice=True,uncertainty=True,rank=5):super(CSAM,self).__init__()self.semantic=semanticself.positional=positionalself.slice=sliceif semantic:self.semantic_att=SemanticAttentionModule(num_channels)if positional:self.positional_att=PositionalAttentionModule()if slice:self.slice_att=SliceAttentionModule(num_slices,uncertainty=uncertainty,rank=rank)def forward(self,x):if self.semantic:x=self.semantic_att(x)if self.positional:x=self.positional_att(x)if self.slice:x=self.slice_att(x)return xif __name__ == "__main__":dim=64# 如果GPU可用,將模塊移動到 GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 輸入張量 (batch_size, channels,height, width)x = torch.randn(2,dim,40,40).to(device)# 初始化 FullyAttentionalBlock 模塊block = CSAM(2,dim,) # kernel_size為height或者widthprint(block)block = block.to(device)# 前向傳播output = block(x)print("輸入:", x.shape)print("輸出:", output.shape)