計算病理學(computational pathology)下的深度學習方法需要手動注釋大型 WSI 數據集,并且通常存在領域適應性和可解釋性較差的問題。作者報告了一種可解釋的弱監督深度學習方法,只需要WSI級標簽。將該方法命名為聚類約束注意力多實例學習 (CLAM,clustering-constrained-attention multiple-instance learning),它使用注意力來識別具有高診斷價值的子區域,以準確對整個WSI進行分類,并在已識別的代表性區域上進行實例級聚類以約束和細化特征空間。通過將 CLAM 應用于腎細胞癌和非小細胞肺癌的亞型分類以及淋巴結轉移的檢測,表明它可用于定位 WSI 上的形態特征,其性能優于標準弱監督分類算法。
來自:Data-efficient and weakly supervised computational pathology on whole-slide images, Nature Biomedical Engineering, 2021
工程地址:https://github.com/mahmoodlab/CLAM
目錄
- CLAM概述
- 方法
- Instance-level clustering
- Smooth SVM loss
- 訓練細節
CLAM概述
- 圖1a:分割后,我們可以從WSI中提取patches。
- 圖1b:patches被預訓練的CNN編碼成特征表示,在訓練和推理過程中,每個WSI中提取的patch作為特征向量傳遞給CLAM。使用注意力網絡將patch信息聚合為WSI級表示,用于最終的診斷預測。
- 圖1c:對于每個類,注意力網絡對WSI中的每個patch進行排名,根據其對WSI診斷的重要性分配注意力分數(左)。注意力pooling根據每個patch的注意力得分對其進行加權,并將patch級別的特征總結為WSI級別的表示(右下)。在訓練過程中,給定GT標簽,強參與(紅色)和弱參與(藍色)patch可以額外用作代表性樣本以監督聚類層,聚類層學習豐富的patch級特征空間,可在不同類別的正實例和負實例之間分離(右上)。
- 圖1d:注意力得分可以可視化為熱圖,以識別ROI(解釋用于診斷的重要形態學)。
方法
CLAM是一個高通量的深度學習工具箱,旨在解決計算病理學中的弱監督分類任務,其中訓練集中的每個WSI是具有已知WSI級別的單個數據點,但對于WSI中的任何像素或patch都沒有類別特定的信息或注釋。CLAM建立在MIL框架之上,該框架將每個WSI(稱為bag)視為由許多(多達數十萬)較小的區域或patch(稱為instance)組成的集合。MIL框架通常將其范圍限制在一個正類和一個負類的二元分類問題上,并基于這樣的假設:如果至少有一個patch屬于正類,那么整個WSI應該被分類為正類(陽性),而如果所有patch都屬于負類,則WSI應該被分類為負類(陰性)。這一假設體現在max-pooling聚合函數上,它簡單地使用正類預測概率最高的patch進行WSI級預測,這也使得MIL不適合多類分類問題。
除了Max-pooling之外,雖然可以使用其他聚合函數,但它們依然不能提供簡單、直觀的模型可解釋性機制。相比之下,CLAM通常適用于多類別分類,它是圍繞可訓練和可解釋的基于注意力的pooling函數構建的,從patch級表示中聚合每個類別的WSI級表示。在多分類注意力pooling設計中,注意力網絡預測了一個多類分類問題中對應于 N N N個類別的 N N N個不同的注意力分數集。這使得網絡能夠明確地了解哪些形態學特征應該被視為每個類的積極證據(類相關的特征)和消極證據(非信息性的,缺乏類定義的特征),并總結WSI級表示。
具體來說,對于表示為 K K K個實例(patch)的WSI,我們將對應于第 k k k個patch的實例級嵌入表示為 z k z_{k} zk?。在CLAM中,第一個全連接層 W 1 ∈ R 512 × 1024 W_{1}\in\R^{512\times 1024} W1?∈R512×1024進一步將每個固定的patch級表示 z k ∈ R 1024 z_{k}\in\R^{1024} zk?∈R1024壓縮為 h k ∈ R 512 h_{k}\in\R^{512} hk?∈R512。注意力網絡由幾個堆疊的全連接層組成;如果將注意力網絡的前兩層 U a ∈ R 256 × 512 U_{a}\in\R^{256\times 512} Ua?∈R256×512+ V a ∈ R 256 × 512 V_{a}\in\R^{256\times 512} Va?∈R256×512和 W 1 W_{1} W1?共同視為所有類共享的注意力主干的一部分,注意力網絡將分為 N N N個平行分支: W a , 1 , . . . , W a , N ∈ R 1 × 256 W_{a,1},...,W_{a,N}\in\R^{1\times 256} Wa,1?,...,Wa,N?∈R1×256。同樣, N N N個并行獨立分類器 W c , 1 , . . . , W c , N W_{c,1},...,W_{c,N} Wc,1?,...,Wc,N?對每個特定類的WSI表示進行評分。
因此,第 i i i類的第 k k k個patch的注意力分數記為 a i , k a_{i,k} ai,k?,并且根據第 i i i類注意力分數聚合WSI表示記為 h s l i d e , i ∈ R 512 h_{slide,i}\in\R^{512} hslide,i?∈R512:
分類層 W c , i W_{c,i} Wc,i?給出相應的非歸一化WSI級分數 s s l i d e , i s_{slide,i} sslide,i?: s s l i d e , i = W c , i h s l i d e , i s_{slide,i}=W_{c,i}h_{slide,i} sslide,i?=Wc,i?hslide,i?。我們在模型的注意力主干的每一層后使用dropout( P = 0.25 P=0.25 P=0.25)進行正則化。
對于推理,通過對WSI級預測分數應用softmax函數來計算每個類的預測概率分布。
Instance-level clustering
為了進一步鼓勵學習特定于類的特征,我們在訓練期間加入一個額外的二值聚類目標。對于 N N N個類中的每一個,在第一個層 W 1 W_{1} W1?之后加一個全連接層。將第 i i i個類對應的聚類層權重記為 W i n s t , i ∈ R 2 × 512 W_{inst,i}\in\R^{2\times 512} Winst,i?∈R2×512,則第 k k k個patch預測的聚類分數為 p i , k p_{i,k} pi,k?: p i , k = W i n s t , i h k p_{i,k}=W_{inst,i}h_{k} pi,k?=Winst,i?hk?。
鑒于我們無法訪問patch級標簽,我們使用注意力網絡的輸出在每次訓練迭代中為每張WSI生成偽標簽,以監督聚類。聚類中只優化最強參與和最弱參與的區域。為了避免混淆,對于給定的WSI,對于GT標簽 Y ∈ { 1 , . . . , N } Y\in\left\{1,...,N\right\} Y∈{1,...,N},我們將GT類別對應的注意力分支 W a , Y W_{a,Y} Wa,Y?稱為"in-the-class",其余的 N ? 1 N-1 N?1個注意力分支稱為"out-the-class"。如果將in-the-class的注意力分數的排序列表(升序)表示為 a ~ Y , 1 , . . . , a ~ Y , K \widetilde{a}_{Y,1},...,\widetilde{a}_{Y,K} a Y,1?,...,a Y,K?,我們將注意力得分最低的 B B B個patch分配給負類標簽 ( y Y , b = 0 ) (y_{Y,b}=0) (yY,b?=0),其中, 1 ≤ b ≤ B 1\leq b\leq B 1≤b≤B。注意力得分最高的 B B B個patch分配給正類標簽 ( y Y , b = 1 ) (y_{Y,b}=1) (yY,b?=1),其中, B + 1 ≤ b ≤ 2 B B+1\leq b\leq 2B B+1≤b≤2B。直觀地說,由于在訓練過程中每個注意力分支都受到WSI級別標簽的監督,因此高注意力分數的 B B B個patch被期望成為 Y Y Y類別的強參與陽性證據,而低注意分數的 B B B個patch被期望成為 Y Y Y類別的強參與陰性證據。聚類任務可以直觀地解釋為約束patch級特征空間 h k h_k hk?,使每個類別的強參與特征證據與其陰性證據線性可分。
對于癌癥亞型問題,所有類別通常被認為是互斥的(也就是說,它們不能出現在同一張WSI中),因為將in-the-class注意力分支中最受關注和最不受關注的片段分別聚類為正證據和負證據,因此對N?1個out-the-class注意力分支施加額外的監督是有意義的。也就是說,給定GT的WSI標簽 Y Y Y,任取類別 i i i不屬于 Y Y Y,如果我們假設WSI上的patch都不屬于 i i i類,那么注意力得分最高的 B B B個patch就不能成為 i i i類的正證據(由于互斥性)。
因此,除了對從in-the-class注意力分支中選擇的 2 B 2B 2B個patch進行聚類外,還將所有out-the-class注意力分支中最受關注的前 B B B個patch分配為負聚類標簽,因為它們被認為是假陽性證據。另一方面,如果互斥性假設不成立(例如,癌癥與非癌癥問題,其中一張WSI可以包含來自腫瘤組織和正常組織的patch),那么就不會監督來自out-the-class分支的高注意力的patch的聚類,因為我們不知道它們是否為假陽性。
實例級聚類算法如下:
Smooth SVM loss
對于實例級聚類任務,我們使用平滑的top-1 SVM loss,它是基于多分類SVM loss的,神經網絡模型輸出一個預測分數向量 s s s,其中 s s s中的每個條目對應于模型對單個類的預測。給定所有GT標簽 y ∈ { 1 , . . . , N } y\in\left\{1,...,N\right\} y∈{1,...,N},多類別SVM loss對分類器進行線性懲罰,僅當該差值大于指定的裕度 α α α時,對GT類的預測分數與其余類的最高預測分數之間的差值進行懲罰。Smooth變體(公式5)在多分類SVM損失中加入了溫度標度 τ τ τ,它已被證明具有非稀疏梯度的無限可微性,并且在有效實現算法時適用于深度神經網絡的優化。平滑支持向量機損失可以看作是廣泛使用的交叉熵分類損失的一種推廣,適用于不同的邊界有限值選擇和不同的溫度尺度。
經驗表明,當數據標簽有噪聲或數據有限時,向損失函數引入margin可以減少過擬合。在訓練過程中,創建的用于監督實例級聚類任務的偽標簽必然是有噪聲的。也就是說,強參與的patch可能不一定對應于GT類,同樣,弱參與的patch也不能保證是該類的負證據。因此,代替廣泛使用的交叉熵損失,將二進制top-1平滑SVM損失應用于網絡聚類層的輸出。在所有的實驗中, α α α和 τ τ τ都被設置為1.0。
訓練細節
在訓練過程中,WSI被隨機采樣。每張WSI的多項采樣概率與GT類的頻率成反比(來自代表性不足的類的WSI相對于其他類更有可能被采樣),以減輕訓練集中的類不平衡。注意力模塊的權重參數隨機初始化,并使用WSI標簽和模型其余部分端到端訓練,總的損失是WSI級損失 L s l i d e L_{slide} Lslide?和instance-level損失 L p a t c h L_{patch} Lpatch?之和。
為了計算 L s l i d e L_{slide} Lslide?,使用標準交叉熵損失將 s s l i d e s_{slide} sslide?與真實的WSI級標簽進行比較,為了計算 L p a t c h L_{patch} Lpatch?,使用二元Smooth SVM損失將每個采樣patch的實例級聚類預測分數 p k p_k pk?與相應的偽聚類標簽進行比較(回想一下,對于非亞型問題,從in-the-class分支中總共采樣了 2 B 2B 2B個patch。而對于亞型問題,從in-the-class分支中采樣 2 B 2B 2B個patch,從 N ? 1 N?1 N?1個out-the-class注意力分支各采樣 B B B個patch)。
數據集摘要見補充表8: