計算機視覺語義分割——Attention U-Net(Learning Where to Look for the Pancreas)
文章目錄
- 計算機視覺語義分割——Attention U-Net(Learning Where to Look for the Pancreas)
- 摘要
- Abstract
- 一、Attention U-Net
- 1. 基本思想
- 2. Attention Gate模塊
- 3. 軟注意力與硬注意力
- 4. 實驗
- 5. 代碼實踐
- 總結
摘要
本周學習了Attention U-Net模型,這是一種在U-Net基礎上改進的語義分割模型,主要應用于醫學影像分割任務。Attention U-Net通過引入注意力門(Attention Gate, AG)模塊,自動學習目標的形狀和尺寸,同時抑制無關區域,聚焦于顯著特征。相比標準U-Net模型,Attention U-Net在分割性能、模型敏感度和準確率方面均有顯著提升。AG模塊采用加性注意力機制,能夠有效突出感興趣區域的特征響應,減少冗余信息,并且易于與不同的CNN模型集成。此外,學習了軟注意力與硬注意力的原理及其在模型中的具體應用。最后,結合代碼實踐,進一步加深了對Attention U-Net網絡架構與實現的理解。
Abstract
This week focused on the Attention U-Net model, an improved semantic segmentation architecture based on U-Net, primarily designed for medical image segmentation tasks. By introducing Attention Gate (AG) modules, the model automatically learns the shape and size of targets while suppressing irrelevant regions and focusing on salient features. Compared to the standard U-Net, the Attention U-Net achieves significant improvements in segmentation performance, sensitivity, and accuracy. The AG module employs an additive attention mechanism, effectively highlighting features of interest and reducing redundant information, while being easily integrable into various CNN models. Additionally, the principles of soft attention and hard attention were studied along with their specific applications in the model. Practical implementation through code further enhanced the understanding of the Attention U-Net architecture and its functionality.
一、Attention U-Net
Attention Unet是Unet的改進版本,主要應用于醫學影像領域中的分割任務。Attention Unet是一種基于注意門(Attention Gate, AG)的模型,它會自動學習區分目標的外形和尺寸。這種有Attention Gate的模型在訓練時會學會抑制不相關的區域,注重于有用的顯著特征。Attention Gate很容易被整合進標準的CNN模型中,極少的額外計算量卻能帶來顯著的模型敏感度和準確率的提高。作者利用Attention U-Net模型,在兩個大型CT腹部數據集上進行了多類別的圖像分割。實驗結果表明,Attention Gate可以在保持計算效率的同時,持續提高U-Net在不同數據集和訓練規模下的預測性能。
1. 基本思想
既然Attention U-Net是U-Net的改進,那么需要先簡單回顧一下U-Net,來更好地對比。下面是U-Net的網絡結構圖:
從U-Net的結構圖中看出,為了避免在Decoder解碼器時丟失大量的細節,使用了Skip Connection跳躍連接,將Encoder編碼器中提取的信息直接連接到Decoder對應的層上。但是,Encoder提取的low-level feature有很多的冗余信息,也就是提取的特征不太好,可能對后面并沒有幫助,這就是U-Net網絡存在的問題。
該問題可以通過在U-Net上加入注意力機制,來減少冗余的Skip Connection,這也是設計Attention Gate的動機。
對比基礎的U-Net網絡結構,我們可以發現在跳躍連接中都加入了Attention Gate模塊。而原始U-Net只是單純的把同層的下采樣層的特征直接連接到上采樣層中。Attention U-Net的主要貢獻分為三點:
- 提出基于網格的AG,這會使注意力系數更能凸顯局部區域的特征。
- 首次在醫學圖像的CNN中使用Soft Attention,該模塊可以替代分類任務中的Hard Attention和器官定位任務中的定位模塊。
- 將U-Net改進為Attention U-Net,增加了模型對前景像素的敏感度,并設計實驗證明了這種改進是通用的。
2. Attention Gate模塊
U-Net(FCN類的網絡模型)中的卷積層會根據一層一層的局部信息來提取得到高維圖像的表示 x l x^l xl,最終在高維空間的離散像素會具有語義信息。通過U-Net這種順序結構的處理,由巨大的感受野所提取的信息會影響著模型的預測結果。因此,特征圖 x l x^l xl是由 l l l層的輸出依次通過一個線性變換(卷積操作)和一個非線性激活函數得到的。這個非線性激活函數一般選擇為ReLU函數,即: σ 1 ( x i , c l ) = max ? ( 0 , x i , c l ) {\sigma _1}(x_{i,c}^l) = \max (0,x_{i,c}^l) σ1?(xi,cl?)=max(0,xi,cl?),其中 i i i和 c c c分別代表著空間信息維度和通道信息維度。特征的激活可以寫為: x c l = σ 1 ( ∑ c ′ ∈ F l x c ′ l ? 1 ? k c ′ , c ) x_c^l = {\sigma _1}\left( {\sum\nolimits_{c' \in {F_l}} {x_{c'}^{l - 1} * {k_{c',c}}} } \right) xcl?=σ1?(∑c′∈Fl??xc′l?1??kc′,c?),這是通道維度上的寫法,其中 ? * ?代表卷積操作,空間信息維度的下標 i i i為了簡便而省略。因此,每個卷積層的操作過程可以寫為函數: f ( x l ; Φ l ) = x ( l + 1 ) f({x^l};{\Phi ^l}) = {x^{(l + 1)}} f(xl;Φl)=x(l+1), Φ l {\Phi ^l} Φl是可以訓練的卷積核參數。參數的學習是通過最小化訓練目標(如交叉熵損失),并使用SGD隨機梯度下降法進行學習的。本論文在基本思想的框架圖中,基于U-Net體系構建了Attention U-Net,粗粒度的特征圖會捕獲上下文信息,并突出顯示前景對象的類別和位置。隨后,通過skip connections合并以多個比例提取的特征圖,以合并粗粒度和細粒度的密集預測。
注意系數(Attention coefficients): α i ∈ [ 0 , 1 ] {\alpha _i} \in \left[ {0,1} \right] αi?∈[0,1],是為了突出顯著的圖像區域和抑制任務無關的特征響應。Attention Gate的輸出是輸入的特征圖與 α i {\alpha _i} αi?做Element-wise乘法(對應元素逐個相乘),即: x ^ i , c l = x i , c l ? α i l \hat x_{i,c}^l = x_{i,c}^l \cdot \alpha _i^l x^i,cl?=xi,cl??αil? 。在默認設置中,會為每個像素向量 x i l ∈ R F l x_i^l \in {R^{{F_l}}} xil?∈RFl?計算一個標量的注意值,其中 F l {{F_l}} Fl?對應于第 l l l層中特征圖的數量。如果有多個語義類別,則應當學習多維的注意系數。
在上圖Attention Gate的具體框圖中,門控向量 g i ∈ R F g {g_i} \in {R^{{F_g}}} gi?∈RFg?為每個像素 i i i確定焦點區域。門控向量包含上下文信息,以修剪低級特征響應。論文中選擇加性注意力來獲得門控系數,盡管這在計算上更昂貴,但從實驗上看,它的性能比乘法注意力要高。加性注意力的計算公式如下所示:
注意這里要結合上面的結構來一起分析,其中 σ 1 {\sigma _1} σ1?是ReLU激活函數, σ 2 {\sigma _2} σ2?是Sigmoid激活函數。 W x ∈ R F l × F i n t {W_x} \in {R^{{F_l} \times {F_{{\mathop{\rm int}} }}}} Wx?∈RFl?×Fint?, W g ∈ R F g × F i n t {W_g} \in {R^{{F_g} \times {F_{{\mathop{\rm int}} }}}} Wg?∈RFg?×Fint?, ψ ∈ R F i n t × 1 \psi \in {R^{{F_{{\mathop{\rm int}} }} \times 1}} ψ∈RFint?×1都是卷積操作, b ψ ∈ R {b_\psi } \in R bψ?∈R, b g ∈ R F i n t {b_g} \in {R^{{F_{{\mathop{\rm int}} }}}} bg?∈RFint?為偏置項,分別對應 ψ \psi ψ, W g {W_g} Wg?的卷積操作,而 W x {W_x} Wx?無偏置項。 F i n t {{F_{{\mathop{\rm int}} }}} Fint?一般比 F g {{F_g}} Fg?和 F l {{F_l}} Fl?要小。在圖像標注和分類任務中一般都用Softmax函數,之所以 σ2使用Sigmoid函數,是因為順序使用Softmax函數會輸出較稀疏的激活響應,而用Sigmoid函數能夠使訓練更好的收斂。門控信號不是全局圖像的表示矢量,而是在一定條件下部分圖像空間信息的網格信號,每個skip connection的門控信號都會匯總來自多個成像比例的信息。
上圖直觀地表明了 x x x和 g g g的輸入,注意系數(Attention coefficients): α i ∈ [ 0 , 1 ] {\alpha _i} \in \left[ {0,1} \right] αi?∈[0,1]是通過 σ 2 {\sigma _2} σ2?這個Sigmoid激活函數得到的。
舉一個具體的例子來理解。
上圖顯示了不同訓練時期(epochs:3、6、10、60、150)時的注意力系數,表明注意力越來越集中在感興趣的部分。
3. 軟注意力與硬注意力
軟注意力(Soft Attention):軟(確定性)注意力機制使用所有鍵的加權平均值來構建上下文向量。對于軟注意力,注意力模塊相對于輸入是可微的,因此整個系統仍然可以通過標準的反向傳播方法進行訓練。軟注意力數學描述如下:
其中 f ( q , k ) f\left( {q,k} \right) f(q,k)的有很多種計算方法,如下表所示:
硬注意力(Hard Attention):硬(隨機)注意力中的上下文向量是根據隨機采樣的鍵計算的。硬注意力可以實現如下:
注:多項式分布是二項式分布的推廣。二項式做n次伯努利實驗,規定了每次試驗的結果只有兩個。如果現在還是做n次試驗,只不過每次試驗的結果可以有m個,且m個結果發生的概率互斥且和為1,則發生其中一個結果X次的概率就是多項分布。概率密度函數是:
兩者的對比與改進方案:與軟注意力模型相比,硬注意力模型的計算成本更低,因為它不需要每次都計算所有元素的注意力權重。 然而,在輸入特征的每個位置做出艱難的決定會使模塊不可微且難以優化,因此可以通過最大化近似變分下限或等效地通過 REINFORCE 來訓練整個系統。 在此基礎上,Luong 等人提出了機器翻譯的全局注意力和局部注意力機制。 全局注意力類似于軟注意力。 局部注意力可以看作是硬注意力和軟注意力之間的有趣混合,其中一次只考慮源詞的一個子集。局部注意力這種方法在計算上比全局注意力或軟注意力更便宜。 同時,與硬注意力不同,這種方法幾乎在任何地方都是可微的,從而更容易實現和訓練。
4. 實驗
AGs(Attention Gates)是模塊化的,與應用類型無關; 因此,它可以很容易地適應分類和回歸任務。為了證明其對圖像分割的適用性,論文在具有挑戰性的腹部CT多標簽分割問題上評估Attention U-Net模型。特別是,由于形狀變化和組織對比度差,胰腺邊界描繪是一項艱巨的任務。Attention U-Net模型在分割性能,模型容量,計算時間和內存要求方面與標準3D U-Net進行了比較。
- 評估數據集:NIH-TCIA 和 這篇論文中的。
- 實施細節:有一個3D的模型,Adam,BN,deep-supervision和標準數據增強技術(仿射變換,軸向翻轉,隨機裁剪)。
- 注意力圖分析:我們通常觀察到AG最初具有均勻分布并且在所有位置,然后逐步更新和定位到目標器官邊界。在較粗糙的尺度上,AG提供了粗略的器官輪廓,這些器官在更精細的分辨率下逐漸細化。 此外,通過在每個圖像尺度上訓練多個AG,我們觀察到每個AG學習專注于器官的特定子集。
- 分割實驗: 性能比U-Net高 2?3% 。
5. 代碼實踐
# 導入TensorFlow及其子模塊
import tensorflow as tf
from tensorflow.keras import models, layers, regularizers
from tensorflow.keras import backend as K # Keras后端,用于底層操作# 定義Dice系數指標函數(常用于圖像分割任務評估)
def dice_coef(y_true, y_pred):# 將真實標簽和預測結果展平為一維向量y_true_f = K.flatten(y_true)y_pred_f = K.flatten(y_pred)# 計算兩個向量的交集(對應元素相乘后求和)intersection = K.sum(y_true_f * y_pred_f)# Dice系數公式:(2*交集 + 平滑項)/(真實標簽總和 + 預測結果總和 + 平滑項)# 平滑項(1.0)用于防止分母為零return (2.0 * intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1.0)# 定義Jaccard系數(交并比)指標函數
def jacard_coef(y_true, y_pred):y_true_f = K.flatten(y_true)y_pred_f = K.flatten(y_pred)intersection = K.sum(y_true_f * y_pred_f)# Jaccard系數公式:(交集 + 平滑項)/(并集 + 平滑項)# 并集 = 真實標簽總和 + 預測結果總和 - 交集return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)# 定義基于Jaccard系數的損失函數
def jacard_coef_loss(y_true, y_pred):# 返回負的Jaccard系數,因為損失函數需要最小化,而Jaccard系數需要最大化return -jacard_coef(y_true, y_pred)# 定義基于Dice系數的損失函數
def dice_coef_loss(y_true, y_pred):# 返回負的Dice系數,原理同上return -dice_coef(y_true, y_pred)"""
代碼特性說明:
1. 適用于圖像分割任務,Dice和Jaccard系數能有效評估像素級預測精度
2. 使用K.flatten處理確保適用于不同形狀的輸出(如batch_size, height, width, channels)
3. 平滑項(+1.0)的作用:- 防止分母為零的數學錯誤- 當預測和標簽均為全黑時仍能給出合理評估值- 起到正則化作用,使指標對極端情況更魯棒
4. 損失函數通過取負數將指標最大化問題轉化為最小化問題,符合優化器的工作方式
"""
# 定義構建U-Net架構的各個模塊def conv_block(x, filter_size, size, dropout, batch_norm=False):"""卷積塊:包含兩個卷積層,可選批歸一化和Dropout參數:x: 輸入張量filter_size: 卷積核尺寸(整數,如3表示3x3卷積)size: 卷積核數量(輸出通道數)dropout: Dropout比率(0表示不使用)batch_norm: 是否使用批歸一化返回:經過卷積處理后的張量"""# 第一個卷積層conv = layers.Conv2D(size, (filter_size, filter_size), padding="same")(x)if batch_norm:conv = layers.BatchNormalization(axis=3)(conv) # 沿通道軸做批歸一化conv = layers.Activation("relu")(conv) # ReLU激活# 第二個卷積層conv = layers.Conv2D(size, (filter_size, filter_size), padding="same")(conv)if batch_norm:conv = layers.BatchNormalization(axis=3)(conv)conv = layers.Activation("relu")(conv)# 應用Dropout(當比率大于0時)if dropout > 0:conv = layers.Dropout(dropout)(conv)return convdef repeat_elem(tensor, rep):"""張量元素重復:沿著最后一個軸重復張量元素示例: 輸入形狀(None, 256,256,3),指定axis=3和rep=2時輸出形狀(None, 256,256,6)"""return layers.Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=3), # 使用Keras后端函數arguments={'repnum': rep} # 傳入重復次數參數)(tensor)def gating_signal(input, out_size, batch_norm=False):"""生成門控信號:使用1x1卷積調整特征圖維度,匹配上采樣層尺寸返回:與上層特征圖維度相同的門控特征圖"""x = layers.Conv2D(out_size, (1, 1), padding='same')(input) # 1x1卷積調整通道數if batch_norm:x = layers.BatchNormalization()(x)x = layers.Activation('relu')(x) # ReLU激活return xdef attention_block(x, gating, inter_shape):"""注意力機制塊:通過門控信號學習空間注意力權重參數:x: 跳躍連接的特征圖gating: 門控信號(來自深層網絡)inter_shape: 中間特征通道數返回:應用注意力權重后的特征圖"""# 獲取輸入特征圖尺寸shape_x = K.int_shape(x)shape_g = K.int_shape(gating)# 將x的特征圖下采樣到門控信號尺寸theta_x = layers.Conv2D(inter_shape, (2, 2), strides=(2, 2), padding='same')(x)shape_theta_x = K.int_shape(theta_x)# 調整門控信號通道數phi_g = layers.Conv2D(inter_shape, (1, 1), padding='same')(gating)# 上采樣門控信號到theta_x的尺寸upsample_g = layers.Conv2DTranspose(inter_shape, (3, 3),strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]),padding='same')(phi_g)# 合并特征concat_xg = layers.add([upsample_g, theta_x])act_xg = layers.Activation('relu')(concat_xg)# 生成注意力權重psi = layers.Conv2D(1, (1, 1), padding='same')(act_xg)sigmoid_xg = layers.Activation('sigmoid')(psi)# 上采樣注意力權重到原始x的尺寸shape_sigmoid = K.int_shape(sigmoid_xg)upsample_psi = layers.UpSampling2D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]))(sigmoid_xg)# 擴展注意力權重的通道數upsample_psi = repeat_elem(upsample_psi, shape_x[3])# 應用注意力權重y = layers.multiply([upsample_psi, x])# 最終卷積和批歸一化result = layers.Conv2D(shape_x[3], (1, 1), padding='same')(y)result_bn = layers.BatchNormalization()(result)return result_bn
# 定義Attention U-Net模型
def Attention_UNet(input_shape, NUM_CLASSES=1, dropout_rate=0.0, batch_norm=True):'''Attention UNet網絡實現參數:input_shape: 輸入張量的形狀(高度, 寬度, 通道數)NUM_CLASSES: 輸出類別數(默認為1,二分類問題)dropout_rate: Dropout層的丟棄率(0.0表示不使用)batch_norm: 是否使用批量歸一化(默認為True)'''# 網絡結構超參數配置FILTER_NUM = 64 # 第一層的基礎卷積核數量FILTER_SIZE = 3 # 卷積核尺寸UP_SAMP_SIZE = 2 # 上采樣比例# 輸入層(指定數據類型為float32)inputs = layers.Input(input_shape, dtype=tf.float32)# 下采樣路徑(編碼器部分)# 第一層:卷積塊 + 最大池化conv_128 = conv_block(inputs, FILTER_SIZE, FILTER_NUM, dropout_rate, batch_norm) # 128x128 特征圖pool_64 = layers.MaxPooling2D(pool_size=(2,2))(conv_128) # 下采樣到64x64# 第二層:卷積塊 + 最大池化conv_64 = conv_block(pool_64, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, batch_norm) # 64x64 特征圖pool_32 = layers.MaxPooling2D(pool_size=(2,2))(conv_64) # 下采樣到32x32# 第三層:卷積塊 + 最大池化conv_32 = conv_block(pool_32, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, batch_norm) # 32x32 特征圖pool_16 = layers.MaxPooling2D(pool_size=(2,2))(conv_32) # 下采樣到16x16# 第四層:卷積塊 + 最大池化conv_16 = conv_block(pool_16, FILTER_SIZE, 8*FILTER_NUM, dropout_rate, batch_norm) # 16x16 特征圖pool_8 = layers.MaxPooling2D(pool_size=(2,2))(conv_16) # 下采樣到8x8# 第五層:底層卷積塊(無池化)conv_8 = conv_block(pool_8, FILTER_SIZE, 16*FILTER_NUM, dropout_rate, batch_norm) # 8x8 特征圖# 上采樣路徑(解碼器部分,包含注意力門控)# 第六層:生成門控信號 -> 注意力機制 -> 上采樣 -> 特征拼接 -> 卷積塊gating_16 = gating_signal(conv_8, 8*FILTER_NUM, batch_norm) # 為16x16層生成門控信號att_16 = attention_block(conv_16, gating_16, 8*FILTER_NUM) # 計算注意力權重up_16 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE))(conv_8) # 上采樣到16x16up_16 = layers.concatenate([up_16, att_16], axis=3) # 拼接跳躍連接和注意力加權的特征up_conv_16 = conv_block(up_16, FILTER_SIZE, 8*FILTER_NUM, dropout_rate, batch_norm)# 第七層(結構同上)gating_32 = gating_signal(up_conv_16, 4*FILTER_NUM, batch_norm)att_32 = attention_block(conv_32, gating_32, 4*FILTER_NUM)up_32 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE))(up_conv_16)up_32 = layers.concatenate([up_32, att_32], axis=3)up_conv_32 = conv_block(up_32, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, batch_norm)# 第八層(結構同上)gating_64 = gating_signal(up_conv_32, 2*FILTER_NUM, batch_norm)att_64 = attention_block(conv_64, gating_64, 2*FILTER_NUM)up_64 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE))(up_conv_32)up_64 = layers.concatenate([up_64, att_64], axis=3)up_conv_64 = conv_block(up_64, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, batch_norm)# 第九層(結構同上)gating_128 = gating_signal(up_conv_64, FILTER_NUM, batch_norm)att_128 = attention_block(conv_128, gating_128, FILTER_NUM)up_128 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE))(up_conv_64)up_128 = layers.concatenate([up_128, att_128], axis=3)up_conv_128 = conv_block(up_128, FILTER_SIZE, FILTER_NUM, dropout_rate, batch_norm)# 輸出層:1x1卷積 + 批量歸一化 + 激活函數conv_final = layers.Conv2D(NUM_CLASSES, kernel_size=(1,1))(up_conv_128) # 調整通道數為類別數conv_final = layers.BatchNormalization(axis=3)(conv_final)conv_final = layers.Activation('sigmoid')(conv_final) # 二分類使用sigmoid,多分類需改為softmax# 構建并返回模型model = models.Model(inputs, conv_final, name="Attention_UNet")return model
總結
Attention U-Net通過在U-Net模型中引入注意力門模塊,解決了傳統跳躍連接中冗余信息的問題,顯著提高了語義分割任務中的模型性能和敏感度。通過加性注意力機制,AG模塊能夠突出感興趣區域的特征響應,同時抑制無關信息,使分割更加精準。此外,對軟注意力和硬注意力的概念及其對模型優化的作用也得到了深入理解。結合代碼實踐,進一步掌握了Attention U-Net的網絡架構及其實現方法。