原文鏈接
2105.14103 (arxiv.org)
原文翻譯
Abstract
我們介紹了 Attention Free Transformer (AFT),這是 Transformer [1] 的有效變體,它消除了點積自注意力的需要。在 AFT 層,鍵key和值value首先與一組學習的位置偏差position biases相結合,其結果以元素方式與查詢相乘。這種新操作的內存復雜度為線性 w.r.t。上下文大小和特征維度,使其與大輸入和模型大小兼容。我們還引入了 AFT-local 和 AFT-conv,這是兩個模型變體,它利用了局部性和空間權重共享的思想,同時保持全局連通性。我們在兩個自回歸建模任務(CIFAR10 和 Enwik8)以及圖像識別任務(ImageNet-1K 分類)上進行了廣泛的實驗。我們表明 AFT 在所有基準測試中都表現出具有競爭力的性能,同時提供了出色的效率。
1 Introduction
以Transformers[1]為代表的自注意機制推動了各種機器學習問題的發展,包括語言理解[2,3]和計算機視覺應用[4 - 6]。與卷積神經網絡(cnn)或循環神經網絡(rnn)等經典模型架構不同,變形金剛可以在序列中的每對元素之間進行直接交互,這使得它們在捕獲長期依賴關系方面特別強大。
然而,變壓器需要很高的計算成本。這一挑戰的原因是需要執行具有二次時間和空間復雜性的注意力操作,這涉及上下文大小。這使得transformer難以擴展到具有大上下文大小的輸入。最近的許多工作都致力于解決transformer的可伸縮性問題[7 -13]。這里的共同思想是近似全注意力操作,使用的技術包括稀疏性、局域敏感散列、低秩分解、核近似等。
在本文中,我們提出了一個不使用或近似標準點積注意力的計算模塊。因此,我們將我們的模型命名為不使用注意力的Transformer?(AFT)。與點積注意力類似,AFT 由查詢、鍵和值 (Q, K, V) 三個量的交互組成。不同之處在于,在 AFT 中,鍵和值(上下文)首先與一組可學習的位置偏執相結合然后使用元素乘法將查詢與縮減的上下文相結合。有關說明,請參見圖 2。
AFT 保留了在上下文中任意兩個點之間的直接交互,這是點積注意力的主要優勢。事實上,AFT 可以解釋為執行注意力,其中注意力頭的數量與模型特征維度相同,而注意力圖不需要顯式計算(詳見第 3.1 節)。這導致內存復雜度線性 w.r.t。輸入和模型大小。
Q、K、V 的重新排列計算排序在最近的“線性化注意力”工作中也被發現 [11, 13 –15]。不同之處在于 AFT 以元素方式組合 k 和 v,而所有線性注意力論文都依賴于矩陣點積。后一種方法導致復雜度與模型特征維度的二次方,這對大型模型大小不友好。有關 AFT 與其他變體相比的復雜性分析,請參見表 1。
根據經驗,我們觀察到經過訓練的 Transformer 往往表現出廣泛的局部模式(見圖 1)。這促使我們提出了兩種 AFT 變體:AFT-local 和 AFT-conv。在 AFT-local 中,學習到的位置偏差被限制在局部區域,同時保持全局連接。AFT-conv 通過施加空間權重共享進一步擴展了這種設計,有效地使其成為具有全局感受野的 CNN 變體。我們表明,局部性約束不僅提供了更好的參數和計算效率,而且大大提高了模型在所有任務中的表現。
我們在圖像自回歸建模、字符級語言建模和圖像分類任務上使用 AFT 進行了實驗。我們表明,AFT 提供了具有競爭力的性能,通常匹配或擊敗標準 Transformer 和其他變體(的準確度),同時提供了出色的效率。我們還對 AFT 的幾種設計選擇進行了廣泛的消融研究,并討論了它的獨特屬性,例如與 Transformer的兼容性、稀疏性和輸入大小的可變性。
2 Multi-Head Attention
Transformers 的核心是多頭注意力 (MHA) 操作。在自注意模式下,給定一個輸入序列 X ∈ R^T ×d 和頭部的數量 h,MHA 對每個頭部 i 執行縮放的點積注意力,定義為:
其中 W Q i ∈ R^d×dk , W K i ∈ R^d×dk , W V i ∈ R^d×dv 是頭部 i 的線性變換,σ 是默認設置為 sof tmax 函數的非線性(應用于矩陣的每一行)。dk, dv 分別是鍵和值的維度。MHA 將 h 個注意力頭的輸出沿通道維度拼接起來,得到特征維度 hdv。除非另有說明,我們假設dk=dv和h=d/dk。這意味著查詢、鍵和值在每個頭內都是相同的維度,輸出維度與輸入的維度匹配。
3 Methodology
3.1 Attention Free Transformer
我們現在定義 Attention free Transformer (AFT),它是 MHA 的插件替換,而不需要更改 Transformer 的其他架構方面。給定輸入 X,AFT 首先將它們線性變換為 Q = XW^Q, K=XW^K,V =XW^V ,然后進行以下操作 2:
其中 是元素乘積; σq 是應用于query的非線性,默認為 sigmoid; w ∈ RT ×T 是學習的成對位置偏差(參見圖 2 的說明)。
簡而言之,對于每個目標位置t, AFT執行value的加權平均值,其結果與query進行元素間乘法相結合。具體來說,相結合的權重只是由鍵和一組學習得到的成對位置偏差組成。這提供了不需要計算和存儲昂貴的注意力矩陣的直接優勢,同時像MHA那樣維護查詢和值之間的全局交互。為了進一步了解AFT與MHA的關系,我們可以將方程2改寫為:
這里我們使用上標 i 來索引矩陣的特征維度; <·, · >; 表示向量的點積。在這個重新排列的形式中,我們能夠再次根據注意力來表達 AFT。具體來說,對于每個位置,我們對每個維度都有一個注意力向量 ai t ∈ RT,由 Q、K、w 組成。換句話說,AFT 可以解釋為執行隱式注意力,頭部數量與特征維度一樣多,其中注意力矩陣采用分解形式。
下略