論文標題
SAS: Simulated Attention Score
論文地址
https://arxiv.org/pdf/2507.07694
代碼
見論文附錄
作者背景
摩根士丹利,斯坦福大學,微軟研究院,新加坡國立大學,得克薩斯大學奧斯汀分校,香港大學
動機
多頭注意力是 Transformer 的核心組件,它通過引入多組 QKV 投影來捕獲不同的特征子空間,從而在機器翻譯、問答等任務中取得巨大成功。研究表明,注意力頭的數量對 Transformer 性能至關重要:在保證每個頭的隱藏維度充分大的前提下,注意力頭數越多可以使模型效果越好。但問題在于,直接增加頭數或維度往往伴隨著模型參數量和計算開銷的劇增,這在訓練和部署中代價高昂
目前也有一些注意力架構旨在提高計算效率,例如共享部分 K 和 V 的 MQA、GQA;使用矩陣分解的 MLA、MFA、TPA 等。但這些方法主要關注降低內存/計算成本,而非提升注意力的表達能力
于是作者希望在不顯著增加參數的前提下,設計一種新的注意力架構,實現近似于使用了更多注意力頭和更高每頭維度的性能提升
本文方法
本文提出 SAS(Simulated Attention Score,模擬注意力分數),核心思想是在注意力計算中引入額外的映射層,將低維的頭表示投射到更高維空間,以此“虛擬地”增大注意力頭數和每頭的隱藏維度
一、擴展注意力頭
對于查詢Q,其特征維度為 [B, T, H, D],分別表示 batch_size,序列長度,頭數和隱藏維度。為了擴充 H,需要把其他維度拉平,得到張量 Q_0,維度為 [B * T * D, H] ;然后使用一個 H * H’ 的線性變換得到 Q_1,維度為 [B * T * D, H’],其中 H’ > H;Q_1 過一個 ReLU 引入非線性;最后再過一個 H’ * H’ 的線性層,并加上 Q_1 的殘差連接
于是我們獲得了更多的注意力頭,其中殘差連接的引入可以穩定訓練;值得注意的是,原始頭數 H 和擴展后的頭數 H’ 都遠小于每頭的特征維度 D,所以這個兩層 MLP 的參數開銷相對整模型來說可以忽略不計
除了使用 MLP 來擴展維度,作者還嘗試了卷積方案。具體地,將查詢 Q 的維度整理成 [B * T, H, D],類似于多通道特征圖,然后使用卷積變換將 H 擴展成 H’,同樣地,H’ > H,最后再過第二層卷積以及殘差連接
類似地,在 K、V 中都應用上述擴展流程
二、擴展注意力維度
直覺上,每個注意力頭內部特征維度 D 越大,其能夠捕獲的子空間信息越豐富。因此作者進一步在 Q 和 K 上也引入了類似的維度擴展映射。這里之所以不對 V 進行擴展,是因為 V
直接決定了注意力模塊的輸出張量隱藏維度,擴大 V 的每頭維度到 D 會導致后續前饋層的參數量大幅增加,違背了不顯著增加計算量的初衷
三、注意力聚合
在標準多頭注意力中,會將所有頭的輸出向量拼接,再通過一個輸出投影矩陣 O 映射回模型的隱藏維度。然而,由于 SAS 對注意力頭數進行了擴增,若仍按傳統方式拼接勢必導致輸出維度變大,進而導致 O 的參數量大大增加(H * hidden 變為 H’ * hidden)。為此,作者提出了參數高效注意力聚合機制,旨在不增加輸出層參數規模的情況下完成對多頭輸出的整合
實現過程非常簡單:假設注意力頭數擴展了 r 倍,即 r * H = H’,那么便把所有頭劃分成 r 組,每組都按照原本的計算流程與 O 相乘,得到 r 組輸出結果,最后取平均作為注意力模塊的最終輸出傳向前饋層
實驗結果
作者在多種基準任務和數據集上對SAS進行了驗證,包括語言模型預訓練及下游任務評估,全面展示了SAS在準確率和效率方面的優勢
一、預訓練效果
下圖對比了SAS與標準MHA、MQA、GQA、MLA、TPA等方法在ArXiv和Books3數據集上的表現。結果表明,無論是短序列訓練(長度512)還是長序列訓練(長度1024),SAS均取得了最低的驗證困惑度
除了取得更好的性能,SAS還加速了模型的收斂。作者報告,在 Books3 數據集、序列長度512的訓練中,MHA模型在5萬步時達到29.86的驗證困惑度,而SAS模型在3萬步時就達到了相近的30.49,即 SAS 可以節約 40% 左右的計算資源
此外,作者還在更大的訓練長度、更大的模型尺寸上做了驗證,結果表明相比于其他注意力機制 SAS 具備穩定的優勢
二、下游任務效果
作者評測了在多個下游任務基準(ARC、HellaSwag、PIQA、ScIQ、SocialIQA、WinoGrande)上 SAS 與其他注意力模型的效果,可見在多種參數量、訓練數據量的實驗設置下,SAS 大部分情況下都表現出了最優性能