論文:https://arxiv.org/abs/2212.04497
代碼:GitHub - Amshaker/unetr_plus_plus: UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation
機構:Mohamed Bin Zayed University of Artificial Intelligence1, University of California Merced2, Google Research3, Linkoping University4
UNETR++作者和UNETR居然完全不沾邊來著,繼續找思路所以主要寫寫方法部分,別的部分簡略一點.....!感覺挺有收獲的!
摘要
由于Transformer模型的成功,最近的工作研究了它們在三維醫學分割任務中的適用性。在Transformer模型中,自注意力機制是努力獲取遠程依賴關系的主要構建塊之一。然而,自注意運算具有二次復雜度,這被證明是一個計算瓶頸,特別是在體積醫學成像中,其中輸入是三維的,有許多切片。在本文中,我們提出了一種名為unetr++的三維醫學圖像分割方法,該方法既提供了高質量的分割mask,又在參數、計算成本和推理速度方面具有效率。我們設計的核心是引入一種新的高效成對注意(efficient paired attention, EPA)塊,該塊使用基于空間和通道注意的一對相互依賴的分支有效地學習空間和通道方面的判別特征。我們的空間注意公式是有效的,具有相對于輸入序列長度的線性復雜性(linear complexity)。為了實現空間分支和以通道為中心的分支之間的通信,我們共享查詢(query)和鍵映射(key mapping)功能的權重,這些功能提供了互補的好處(配對關注),同時還減少了整體網絡參數。我們對Synapse、BTCV、ACDC、BRaTs和Decathlon-Lung這五個基準進行了廣泛的評估,揭示了我們在效率和準確性方面的貢獻的有效性。在Synapse上,我們的UNETR++設置了一個新的最先進的骰子得分為87.2%,同時與文獻中最好的方法相比,在參數和FLOPs方面都降低了71%以上,效率顯著。
背景
早期基于CNN的網絡受限于他們的感受野,但是基于transformer的方法計算成本高
后面也冒出了一些混合方法,一些用基于transformer的encoder和卷積的decoder,另外一些設計編碼器和解碼器子網的混合塊。但是這些網絡主要關注于提高分割進度,這反過來又大大增加了模型在參數和FLOPs的大小,導致魯棒性不理想。我們認為這是由于他們低效的self-attention的設計,在體數據分割中顯露出更大的問題。此外,這些現有的方法沒有捕捉到空間和通道特征之間的顯式依賴關系,這可以提高分割質量。在這項工作中,我們的目標是在一個統一的框架中同時提高分割精度和模型效率。
貢獻
1)我們提出了一種高效的混合分層結構用于三維醫學圖像分割,命名為unetr++,力求在參數、FLOPs和推理速度方面實現更好的分割精度和效率。基于最近的UNETR框架[13],我們提出的UNETR++分層方法引入了一種新的高效的對注意力(EPA)塊,該塊通過在兩個分支中應用空間和通道注意力有效地捕獲豐富的相互依賴的空間和通道特征。我們在EPA中的空間注意將鍵和值投射到一個固定的低維空間,使自注意計算相對于輸入令牌的數量呈線性。另一方面,我們的通道注意通過在通道維度中執行查詢和鍵之間的點積操作來強調通道特征映射之間的依賴關系。此外,為了捕獲空間和通道特征之間的強相關性,查詢和鍵的權重在分支之間共享,這也有助于控制網絡參數的數量。相反,值的權重保持獨立,以強制學習兩個分支中的互補特征。
2)我們通過在五個基準上進行全面實驗來驗證我們的UNETR++方法:
Synapse[19]、BTCV[19]、ACDC[1]、BRaTs[24] Decathlon-Lungs[30]。定性和定量結果都證明了UNETR++的有效性,與文獻中已有的方法相比,在分割精度和模型效率方面都有更好的表現。
相關工作
CNN-based Segmentation Methods?
unet,多尺度三維全卷積,nnunet
金字塔[35]、大核[26]、擴展卷積[6]和可變形卷積[20]等方法,在基于cnn的框架內編碼整體上下文信息
Transformers-based Segmentation Methods
ViT, 1d-embedding,shifted windows for 2D
Hybrid Segmentation Methods
TransFuse[34]提出了一種帶有BiFusion模塊的并行cnn-Transformer架構,用于融合編碼器中的多級特征。
MedT[31]在自注意中引入了門控的位置敏感軸向注意機制來控制編碼器中的位置嵌入信息,而解碼器中的ConvNet模塊產生分割模型。
TransUNet[5]結合了Transformer和U-Net架構,其中Transformer對來自卷積特征的嵌入圖像補丁進行編碼,解碼器將上采樣編碼特征與高分辨率CNN特征相結合進行定位。
Ds-transunet[21]采用雙尺度編碼器 Swin-transformer[22]處理多尺度輸入,并通過自注意編碼來自不同語義尺度的局部和全局特征表示。
UNETR,三維混合模型, 該模型將變壓器的遠程空間依賴關系與CNN的感應偏置結合成“u形”編碼器結構。其參數量是nnunet的2.5倍,但是如果nnFormer要在UNETR的基礎上獲得了更好的性能,需要進一步增加了1.6X參數和2.8Xflop。?UNETR:用于三維醫學圖像分割的Transformer-CSDN博客
nnFormer,?該方法適應swing - unet[3]架構。在這里,卷積層將輸入掃描轉換成三維patches,并引入基于體積的自關注模塊來構建分層特征金字塔。nnFormer在取得良好性能的同時,其計算復雜度明顯高于UNETR和其他混合方法。
我們認為上述混合方法難以有效捕獲特征通道之間的相互依賴關系,以獲得豐富的特征表示,既編碼空間信息,也編碼通道間的特征依賴關系。
方法
我們首先確定了我們要設計混合框架的兩個理想屬性:
1)Efficient Global Attention 高效的全局注意力:
在體積醫學分割的情況下,計算上是昂貴的,并且在混合設計中交織窗口關注和卷積組件時變得更加成問題。與這些方法不同的是,我們認為跨特征通道計算自關注而不是計算體積維度,有望將相對于體積維度的復雜性從二次型降低到線性型。此外,通過將鍵和值的空間矩陣投影到較低維空間中,可以有效地學習空間注意信息。
2) Enriched Spatial-channel Feature Representation豐富的空間通道特征表示:
現有的混合體醫學圖像分割方法大多是通過注意力計算來捕獲空間特征,而忽略了以編碼不同通道特征映射之間相互依賴關系的形式來獲取通道信息。
整體框架
我們的UNETR++框架基于最近推出的UNETR[13],在編碼器和解碼器之間使用跳過連接,然后是卷積塊(ConvBlocks)來生成預測掩碼。
我們的unetr++采用分層設計,而不是在整個編碼器中使用固定的特征分辨率,其中特征的分辨率在每個階段逐漸降低兩倍。在我們的UNETR++框架中,編碼器有四個階段,其中第一階段包括Patch embedding,將體積輸入劃分為3D補丁,然后是我們新穎的高效成對注意(EPA)塊。
Patch embedding
UNETR++的這個部分和 UNETR挺像的呢,但是有點好奇的是為什么UNETR里面用的直接是P,而沒有分為P1,P2,P3這樣,到時候看看代碼其中P1,P2,P3是否不同好了
把3D輸入?x∈R? HxWxD?變成不重疊的補丁?xu∈R Nx(P1,P2,P3),其中P1,P2,P3是每個patch的分辨率, N=H/P1 x W/P2 xD/P3,是序列長度。
然后,將這些補丁投影到C通道維度,得到的特征圖尺寸為 H/P1 x W/P2 xD/P3 x C
對于每個剩余的編碼器階段,我們使用非重疊卷積的下采樣層將分辨率降低兩倍,然后是EPA塊。
在我們提出的unetr++框架中,每個EPA塊包括兩個注意模塊,通過使用共享關鍵字查詢方案對空間和通道維度的信息進行編碼,有效地學習豐富的空間通道特征表示。
在我們提出的unetr++框架中,每個EPA塊包括兩個注意模塊,通過使用共享keys-queries方案對空間和通道維度的信息進行編碼,有效地學習豐富的空間通道特征表示。編碼器級通過skip-connection 與解碼器級連接以合并不同分辨率的輸出。這可以恢復下采樣操作期間丟失的空間信息,從而預測更精確的輸出。與編碼器類似,解碼器也包括四個階段,其中每個解碼器階段包括一個上采樣層,使用反卷積將特征圖的分辨率提高兩倍,然后是EPA塊(最后一個解碼器除外)。Channel 的數量在每兩個解碼器階段之間減少2倍。因此,最后一個解碼器的輸出與卷積特征映射融合,以恢復空間信息并增強特征表示。然后將結果輸出饋送到3x3x3和1x1x1個卷積塊中以生成voxel-wise的最終掩碼預測。
Efficient PairedAttention Block
空間注意模塊將自注意的復雜度從二次型降低到線性型。另一方面,通道注意模塊有效地學習了通道特征映射之間的相互依賴關系。EPA塊基于兩個注意模塊之間的共享keys-queries查詢方案,以相互通知,以產生更好和更有效的特征表示。這可能是由于通過共享keys-queries來學習互補特性,但使用不同的value layer。
如圖所示,輸入特征映射x被饋送到EPA塊的通道和空間注意模塊。
Q和K線性層的權值是在兩個注意模塊之間共享的,每個注意模塊使用不同的V層。兩個注意模塊計算為:
其中,^X s和^X c分別表示空間和通道注意圖。SA為空間注意模塊,CA為通道注意模塊。Qshared、Kshared、Vspatial和Vchannel分別是共享查詢、共享鍵、空間值層和通道值層的矩陣。!就是這里的QK都是共享的但是值做單獨注意
Spatial attention
我們用這個模塊把獲取空間信息的復雜度從O(n^2)降低到O(np) (所以到底和原先的相比怎么降的呢🤔),其中n為記號的個數,p為投影向量的維數,其中p << n。
給定shape為 HW DXC的歸一化張量x,我們使用三個線性層計算Qshared, Kshared和vspace投1影,收益率Qshared = WQX, Kshared=WKX, vspace =WVX,其中,WQ、WK、WV分別為Qshared、Kshared、Vspatial的投影權值。
1)Kshared和Vspatial層從HWD XC投影到形狀為p C的低維矩陣中。(壞了我怎么記得是把channel壓癟,我再回去看看先)
2)其次,通過將Qshared層乘以投影Kshared的轉置來計算空間注意圖,然后使用softmax來度量每個特征與其他空間特征之間的相似性。
3)這些相似度乘以投影的vspace層,生成shapeHWDxC的最終空間注意圖。空間注意的定義如下:
(我記憶中的空間注意力是CBAM的這個↓
)
Channel attention?
該模塊通過在通道值層和通道注意圖之間的通道維度中應用點積運算來捕獲特征通道之間的相互依賴關系。
利用空間注意模塊相同的Qshared和Kshared,計算通道的值層,利用線性層學習互補特征,得到Vchannel = WVX,維數為 HWDxC,其中wv為Vchannel的投影權值。
定義如下
式中,Vchannel、Qshared、Kshared分別表示通道值層、共享查詢、共享鍵,d為每個向量的大小。
最后,我們對兩個關注模塊的輸出進行和融合,并通過卷積塊對其進行變換,以獲得豐富的特征表示。EPA塊的最終輸出^X為:
其中,^X s和^X c表示空間和通道注意圖,Conv1和Conv3分別為1x1x1和3x3x3卷積塊。
dbq我到時候在琢磨一下CBAM的通道和空間注意力和這個什么關系好了,感覺不太一樣
損失函數
soft dice loss + cross-entropy loss
式中,I為類數;V為體素數;Yv;i和Pv;i分別表示類i在體素v處的真實情況和輸出概率。
實驗
數據集
Synapse 多器官CT分割
BTCV?多器官CT分割
ACDC 心臟自動診斷
BraTS 腦腫瘤分割
Decathlon-Lung
實現細節
Pytorch v1.10.1, MONAI庫(可惡這個也用的是那個庫,我有空直接進行一個學!)
硬件:A100 40GB GPU
1k epochs
learning rate :0.01 , weight decay :3e^5.
評估指標
Dice Similarity Coefficient (DSC
95% Hausdorff Distance (HD95
結果
Synapse
BTCV
ACDC
BRATS
Lungs
展望
為了觀察UNETR++的潛在局限性,我們分析了Synapse的不同異常情況。雖然我們的預測比現有的方法更好,更接近真實情況,但我們發現,在一些情況下,我們的模型和現有的方法一樣,難以分割某些器官。當一些切片中器官的幾何形狀異常(由細邊界描繪)時,我們的模型和現有的模型很難準確地分割它們。原因可能是與正常樣本相比,具有這種異常形狀的訓練樣本的可用性有限。我們計劃在預處理階段應用幾何數據增強技術來解決這個問題。