在大語言模型(LLM)訓練過程中,Masked Attention(掩碼注意力) 是一個關鍵機制,它決定了 模型如何在訓練時只利用過去的信息,而不會看到未來的 token。這篇文章將幫助你理解 Masked Attention 的作用、實現方式,以及為什么它能確保當前 token 只依賴于過去的 token,而不會泄露未來的信息。
1. Masked Attention 在 LLM 訓練中的作用
在 LLM 訓練時,我們通常使用 自回歸(Autoregressive) 方式來讓模型學習文本的生成。例如,給定輸入序列:
"The cat is very"
模型需要預測下一個 token:
"cute"
但是,為了保證模型的生成方式符合自然語言流向,每個 token 只能看到它之前的 token,不能看到未來的 token。
Masked Attention 的作用就是:
- 屏蔽未來的 token,使當前 token 只能關注之前的 token
- 保證訓練階段的注意力機制符合推理時的因果(causal)生成方式
- 防止信息泄露,讓模型學會自回歸生成文本
如果沒有 Masked Attention,模型在訓練時可以“偷看”未來的 token,導致它學到的規律無法泛化到推理階段,從而影響文本生成的效果。
舉例說明
假設輸入是 "The cat is cute",模型按 token 級別計算注意力:
(1) 沒有 Mask(BERT 方式)
Token | The | cat | is | cute |
---|---|---|---|---|
The | ? | ? | ? | ? |
cat | ? | ? | ? | ? |
is | ? | ? | ? | ? |
cute | ? | ? | ? | ? |
每個 token 都能看到整個句子,適用于 BERT 這種雙向模型。
(2) 有 Mask(GPT 方式)
Token | The | cat | is | cute |
---|---|---|---|---|
The | ? | ? | ? | ? |
cat | ? | ? | ? | ? |
is | ? | ? | ? | ? |
cute | ? | ? | ? | ? |
每個 token 只能看到它自己及之前的 token,保證訓練和推理時的生成順序一致。
2. Masked Attention 的工作原理
?在標準的 自注意力(Self-Attention) 機制中,注意力分數是這樣計算的:
其中:
-
Q, K, V ?是 Query(查詢)、Key(鍵)和 Value(值)矩陣
-
計算所有 token 之間的相似度
-
如果不做 Masking,每個 token 都能看到所有的 token
而在 Masked Attention 中,我們會使用一個 上三角掩碼(Upper Triangular Mask),使得未來的 token 不能影響當前 token:
Mask 是一個 上三角矩陣,其中:
-
未來 token 的位置填充
,確保 softmax 之后它們的注意力權重為 0
-
只允許關注當前 token 及之前的 token
例如,假設有 4 個 token:
經過 softmax 之后:
最終,每個 token 只會關注它自己和它之前的 token,完全忽略未來的 token!
3. Masked Attention 計算下三角部分的值時,如何保證未來信息不會泄露?
換句話說,我們需要證明 Masked Attention 計算出的下三角部分的值(即歷史 token 之間的注意力分數)不會受到未來 token 的影響。
1. 問題重述
Masked Attention 的核心計算是:
其中:
-
Q, K, V 是整個序列的矩陣。
-
計算的是所有 token 之間的注意力分數。
-
Mask 確保 softmax 后未來 token 的注意力分數變為 0。
這個問題可以分解成兩個關鍵點:
-
未來 token 是否影響了下三角部分的 Q 或 K?
-
即使未來 token 參與了 Q, K 計算,為什么它們不會影響下三角的注意力分數?
2. 未來 token 是否影響了 Q 或 K?
我們先看 Transformer 計算 Q, K, V 的方式:
這里:
-
X 是整個輸入序列的表示。
-
是相同的投影矩陣,作用于所有 token。
由于 每個 token 的 Q, K, V 只取決于它自己,并不會在計算時使用未來 token 的信息,所以:
-
計算第 i?個 token 的
時,并沒有用到
,所以未來 token 并不會影響當前 token 的 Q, K, V。
結論 1:未來 token 不會影響當前 token 的 Q 和 K。
3. Masked Attention 如何確保下三角部分不包含未來信息?
即使 Q, K 沒有未來信息,我們仍然要證明 計算出的注意力分數不會受到未來信息影響。
我們來看注意力計算:
這是一個 所有 token 之間的相似度矩陣,即:
然后,我們應用 因果 Mask(Causal Mask):
Mask 讓右上角(未來 token 相關的部分)變成 :
然后計算 softmax:
由于 ,所有未來 token 相關的注意力分數都變成 0:
最后,我們計算:
由于未來 token 的注意力權重是 0,它們的 V 在計算中被忽略。因此,下三角部分(歷史 token 之間的注意力)完全不受未來 token 影響。
結論 2:未來 token 的信息不會影響下三角部分的 Attention 計算。
4. 為什么 Masked Attention 能防止未來信息泄露?
你可能會問:
即使有 Mask,計算 Attention 之前,我們不是還是用到了整個序列的 Q, K, V 嗎?未來 token 的 Q, K, V 不是已經算出來了嗎?
的確,每個 token 的 Q, K, V 是獨立計算的,但 Masked Attention 確保了:
-
計算 Q, K, V 時,每個 token 只依賴于它自己的輸入
-
只來自 token i,不會用到未來的信息
-
未來的 token 并不會影響當前 token 的 Q, K, V
-
-
Masked Softmax 阻止了未來 token 的影響
-
雖然 Q, K, V 都計算了,但 Masking 讓未來 token 的注意力分數變為 0,確保計算出的 Attention 結果不包含未來信息。
-
最終,當前 token 只能看到過去的信息,未來的信息被完全屏蔽!
5. 訓練時使用 Masked Attention 的必要性
Masked Attention 的一個關鍵作用是 讓訓練階段和推理階段保持一致。
-
訓練時:模型學習如何根據 歷史 token 預測 下一個 token,確保生成文本時符合自然語言流向。
-
推理時:模型生成每個 token 后,仍然只能訪問過去的 token,而不會看到未來的 token。
如果 訓練時沒有 Masked Attention,模型會學習到“作弊”策略,直接利用未來信息進行預測。但在推理時,模型無法“偷看”未來的信息,導致生成質量急劇下降。
6. 結論
Masked Attention 是 LLM 訓練的核心機制之一,其作用在于:
- 確保當前 token 只能訪問過去的 token,不會泄露未來信息
- 讓訓練階段與推理階段保持一致,避免模型在推理時“失效”
- 利用因果 Mask 讓 Transformer 具備自回歸能力,學會按序生成文本
Masked Attention 本質上是 Transformer 訓練過程中對信息流動的嚴格約束,它確保了 LLM 能夠正確學習自回歸生成任務,是大模型高質量文本生成的基礎。