1.論文介紹
ASPS: Augmented Segment Anything Model for Polyp Segmentation
ASPS:用于息肉分割的擴展SAM模型
2024年 arxiv
Paper Code
2.摘要
息肉分割在結直腸癌診斷中起著至關重要的作用。最近,Segment Anything Model(SAM)的出現利用其在大規模數據集上的強大預訓練能力,為息肉分割帶來了前所未有的潛力。然而,由于自然圖像和內窺鏡圖像之間的區域差距,SAM在實現有效的息肉分割方面遇到了兩個限制。首先,它的基于變壓器的結構優先考慮全局和低頻信息,可能會忽略局部細節,并在學習的特征中引入偏差。其次,當應用于內窺鏡圖像時,其較差的分布(OOD)性能導致不符合標準的預測和有偏差的置信度輸出。為了應對這些挑戰,我們提出了一種新的用于息肉分割的擴展SAM(ASPS)方法,該方法包括兩個模塊:跨分支特征增強(CFA)和不確定引導預測正則化(UPR)。CFA將可訓練的CNN編碼器分支與凍結的VIT編碼器集成在一起,在增強局部特征和高頻細節的同時,實現了特定領域知識的整合。此外,UPR巧妙地利用SAM的IOU分數來減少訓練過程中的不確定性,從而提高了面向對象設計的性能和領域泛化。大量的實驗結果證明了該方法在提高SAM在息肉分割中的性能方面的有效性和實用性。
Keywords: SAM,CNN+ViT,跨分支特征增強,不確定引導預測正則化
3.Introduction
自動息肉分割是診斷結直腸癌的關鍵工具,有助于有效的干預和及時的治療策略。最近SAM被引入,由于其巨大的模型大小和數據量,這種創新的方法為息肉分割領域引入了新的視角。它還具有增強的表示和特征提取能力,超過了現有的方法。然而,由于訓練數據和內窺鏡圖像之間的域差距,SAM在息肉分割任務中的性能并不令人滿意。這導致了兩個主要問題:第一,SAM不能充分捕捉息肉圖像的顯著特征,導致其學習表示的偏差。其次,它對分布外數據產生了錯誤的預測和不準確的置信度估計。此外,由于SAM依賴于提示,大大阻礙了其在臨床應用中的便利性。盡管有幾種方法改進了SAM,但這些方法要么依賴于提示,要么直接微調實體模型。Samus有效地集成了CNN和VIT,但其設計相當復雜,特別適合處理小圖像。因此,這些方法的有效性在某種程度上受到了限制。人們已經提出了各種方法來解決語義切分中無監督領域自適應的挑戰。MIC提出了一種用于目標領域情境學習的掩蔽圖像一致性模型;情境感知領域自適應通過交叉注意改進了情境遷移。然而,特定領域的信息集成和不確定性減少仍然沒有被探索。為了解決這些問題,本文從領域自適應的角度提出了一種新的基于SAM的方法,旨在增強特征提取能力和泛化能力,而不依賴于提示。本文提出了交叉分支特征增強模塊(CFA)和不確定性引導預測正則化模塊(UPR)。CFA加入了一個額外的可訓練卷積神經網絡(CNN)編碼器分支,該分支補充了凍結視覺轉換器(VIT)編碼器,以捕獲多尺度和多層次的特征。UPR調整歸一化層以促進內窺鏡領域的自適應,并利用提示確保準確的置信度估計,從而提高SAM的面向對象設計性能。
4.模型結構詳解
CFA模塊集成了CNN編碼器特征和全局VIT信息,導致了廣義特征表示學習。這種集成通過將深層信息聚合到淺層并結合來自淺層的位置信息來促進精細化分割輸出。同時,UPR的設計是為了最大限度地減少訓練期間的不確定性和校準置信度。UPR利用基于不確定性的訓練策略,利用gt作為指導性“提示”。提出的網絡遵循端到端訓練,在沒有提示的情況下,聯合優化兩個模塊以實現最佳性能。
Cross-branch Feature Augmentation Module 跨分支特征增強模塊:
SAM在息肉分割任務中仍存在一定的局限性,其中一個主要原因是SAM的圖像編碼器不能有效地從不可見的內窺鏡圖像中捕獲足夠的特征。為了解決這個問題,CFA模塊被設計成學習多尺度特征和多層次表示,從而增強編碼器的特征提取能力。
首先,為了實現自動分割,對SAM的體系結構進行了改進,去掉了它的提示輸入和提示編碼器部分,保留了它的圖像編碼和掩碼解碼部分。最近的研究表明,VIT更專注于低頻信號,而CNN更擅長處理高頻信號。因此,本文集成了一個基于CNN的并行分支來彌補高頻和局部特征的缺失。此外,通過提出一個額外的多頭交叉分支注意塊來增強SAM的掩碼解碼器,以促進從VIT編碼器和CNN編碼器中提取的特征的整合。對于VIT分支機構的 F V F_V FV?特征和CNN的 F C F_C FC?特征,跨分支特征注意力可以表達如下:
其中 Q = F v W Q Q=F_vW^Q Q=Fv?WQ, K = V = F c W K K=V=F_cW^K K=V=Fc?WK,并且d是Fv的每個頭部的通道數。考慮到CNN特征提供了更精確的位置信息,用CNN編碼器的最終輸出特征代替了SAM在掩碼解碼器中的原始位置嵌入。此外,將跨分支注意機制集成到掩碼解碼器的注意塊中,重復此過程兩次,以確保來自VIT和CNN編碼器的多尺度特征的集成。
其次,為了獲得更準確的分割結果,將來自編碼器的高層上下文和低層邊界信息與SAM的解碼器特征相結合,以增強輸出信息。具體地說,將從中間嵌入的VIT編碼器獲得的淺局部特征,從VIT編碼器獲得的最終全局特征與CNN的最終特征相結合,像上圖所示的,這種方法充分利用了豐富的邊緣信息、廣泛的全局上下文信息和每個編碼器分支的局部位置細節。因此,可以有效地集成VIT和CNN的多層次特征。
ViT,即原SAM的image encoder的最終特征和中間層特征,與CNN的最終特征相結合。并在原SAM的mask decoder中增加多頭注意力,以融合ViT與CNN的特征。
首先在原本的image_encoder中提取第4、8、12層特征作為interm embedding(代碼中這么設置),無prompt_encoder,在mask_decoder中,首先把來自CNN的特征,與interm embedding和image embedding融合,用CNN的特征代替ViT的位置嵌入,并在transformer塊中融合CNN與image embedding特征,然后在最后把fusion的特征再加上去。
除此之外,增加一個置信度計算。
Uncertainty-guided Prediction Regularization Module 不確定性引導的預測正則化模型:
為了增強SAM的泛化能力,本文提出了一種新的訓練策略,該策略包括在編碼器內選擇性激活LayerNorm。我們還以gt為線索,通過糾正置信度來進一步指導訓練過程。由于SAM是在自然數據上訓練的,其在息肉圖像中的性能可能會因域轉移而惡化。盡管LayerNorm的引入可能會減少訓練時間,但它從根本上改變了輸入數據的分布。當將SAM從自然圖像轉移到內窺鏡圖像時,數據分布和相應的特征空間分布都發生了偏移。這些分布差異可能會導致內部協變量的變化,從而影響模型的性能。為了提高SAM在內窺鏡領域的泛化能力,本文對編碼器的歸一化層進行了微調。在這個過程中,該模型有效地適應了目標領域的數據分布,并緩解了內部協變量變化的影響。
具體地說,SAM的VIT編碼器的層范數LayerNorm分為(1)transformer塊norm和(2)neck層norm,如上圖(A)所示。考慮到Neck層的特征更接近編碼器的輸出特征,最終決定訓練Neck層歸一化,這相當于對預先訓練的VIT編碼器的特征進行重新歸一化。
這里提到訓練策略在編碼器的neck層增加LayerNorm,原代碼中就是可選擇的,不理解這里怎么重新處理了。
此外,以前的研究已經證明,不確定性較低的預測往往表現出優越的分布外(OOD)性能,這也有利于領域適應。SAM會生成一個IOU分數輸出,這本身就代表著不確定性(或者置信度)。然而,在預測過程中,SAM可能會頻繁地對看不見的數據產生高置信度的錯誤預測,這是不可取的。為了緩解這個問題,我們努力減少模型在訓練過程中的不確定性(即增加置信度)。在[4]的啟發下,我們利用gt作為提示來指導模型的學習。首先,我們將SAM的IoU得分表示為圖像級別的置信度Ci。然后,我們計算像素級置信度cp,以使用公式來細化每個像素的不確定性。其中Up∈Rb×1×H×W.
Up表示像素不確定性,定義為 U p = 1 ? σ ( ∣ P ∣ ) Up=1?σ(|P|) Up=1?σ(∣P∣)。這里,σ表示Sigmoid函數,而P表示輸出預測。最終置信度被計算為圖像級置信度和像素級置信度之和,表示為 c = 1 / 2 ( C i + c p ) c=1/2 (Ci+cp) c=1/2(Ci+cp)。這種可信度是由伯努利分布決定的,它決定了是否將gt作為提示。換句話說,如果置信度足夠低,我們認為該模型需要特定的答案提示來學習正確的掩碼預測。因此,答案是必需的,作為提示,否則就沒有必要。提示的權重由置信度c確定,其表示如下:
然而,通過最小化損失函數,模型將傾向于使c=0,從而 P ′ P' P′將始終是GT。這意味著該模型實際上并不學習。因此,引入置信度損失來監督c,當c→為0時,置信度損失會增加,并且置信度損失的定義如下:
最終損失函數是分段損失Ls和置信度損失Lc之和,如公式中所定義的。5.這里,λ表示一個超參數。具體地,所采用的分段損失是CE損失、Dice損失和MSE損失的組合,如ls=Lce+0.5Ldice+Lmse。
這里說增加一個新損失:置信度損失。它由圖像級和像素級兩個組成。
5.實驗結果
創新點:
- image_encoder中用交叉注意力引入CNN特征;
- mask_decoder中使用CNN特征作為位置嵌入,融合CNN特征、image encoder的中間層特征與最終特征,再最后加上;
- 設置置信度損失,解決置信度錯誤地過高的問題。