Masked Attention 在 LLM 訓練中的作用與原理

大語言模型(LLM)訓練過程中,Masked Attention(掩碼注意力) 是一個關鍵機制,它決定了 模型如何在訓練時只利用過去的信息,而不會看到未來的 token。這篇文章將幫助你理解 Masked Attention 的作用、實現方式,以及為什么它能確保當前 token 只依賴于過去的 token,而不會泄露未來的信息。

1. Masked Attention 在 LLM 訓練中的作用

在 LLM 訓練時,我們通常使用 自回歸(Autoregressive) 方式來讓模型學習文本的生成。例如,給定輸入序列:

"The cat is very"

模型需要預測下一個 token:

"cute"

但是,為了保證模型的生成方式符合自然語言流向,每個 token 只能看到它之前的 token,不能看到未來的 token

Masked Attention 的作用就是:

  • 屏蔽未來的 token,使當前 token 只能關注之前的 token
  • 保證訓練階段的注意力機制符合推理時的因果(causal)生成方式
  • 防止信息泄露,讓模型學會自回歸生成文本

如果沒有 Masked Attention,模型在訓練時可以“偷看”未來的 token,導致它學到的規律無法泛化到推理階段,從而影響文本生成的效果。

舉例說明

假設輸入是 "The cat is cute",模型按 token 級別計算注意力:

(1) 沒有 Mask(BERT 方式)
TokenThecatiscute
The????
cat????
is????
cute????

每個 token 都能看到整個句子,適用于 BERT 這種雙向模型。

(2) 有 Mask(GPT 方式)
TokenThecatiscute
The????
cat????
is????
cute????

每個 token 只能看到它自己及之前的 token,保證訓練和推理時的生成順序一致。

2. Masked Attention 的工作原理

?在標準的 自注意力(Self-Attention) 機制中,注意力分數是這樣計算的:

A = \text{softmax} \left( \frac{Q K^T}{\sqrt{d_k}} \right)

其中:

  • Q, K, V ?是 Query(查詢)、Key(鍵)和 Value(值)矩陣

  • Q K^T 計算所有 token 之間的相似度

  • 如果不做 Masking,每個 token 都能看到所有的 token

而在 Masked Attention 中,我們會使用一個 上三角掩碼(Upper Triangular Mask),使得未來的 token 不能影響當前 token:

S' = \frac{Q K^T}{\sqrt{d_k}} + \text{mask}

Mask 是一個 上三角矩陣,其中:

  • 未來 token 的位置填充 -\infty,確保 softmax 之后它們的注意力權重為 0

  • 只允許關注當前 token 及之前的 token

例如,假設有 4 個 token:

\begin{bmatrix} s_{1,1} & -\infty & -\infty & -\infty \\ s_{2,1} & s_{2,2} & -\infty & -\infty \\ s_{3,1} & s_{3,2} & s_{3,3} & -\infty \\ s_{4,1} & s_{4,2} & s_{4,3} & s_{4,4} \end{bmatrix}

經過 softmax 之后:

A = \begin{bmatrix} 1 & 0 & 0 & 0 \\ \text{non-zero} & \text{non-zero} & 0 & 0 \\ \text{non-zero} & \text{non-zero} & \text{non-zero} & 0 \\ \text{non-zero} & \text{non-zero} & \text{non-zero} & \text{non-zero} \end{bmatrix}

最終,每個 token 只會關注它自己和它之前的 token,完全忽略未來的 token!

3. Masked Attention 計算下三角部分的值時,如何保證未來信息不會泄露?

換句話說,我們需要證明 Masked Attention 計算出的下三角部分的值(即歷史 token 之間的注意力分數)不會受到未來 token 的影響

1. 問題重述

Masked Attention 的核心計算是:

\text{Attention}(Q, K, V) = \text{softmax}(\frac{Q K^T}{\sqrt{d_k}} + \text{mask}) V

其中:

  • Q, K, V 是整個序列的矩陣。

  • QK^T計算的是所有 token 之間的注意力分數。

  • Mask 確保 softmax 后未來 token 的注意力分數變為 0。

這個問題可以分解成兩個關鍵點:

  1. 未來 token 是否影響了下三角部分的 Q 或 K?

  2. 即使未來 token 參與了 Q, K 計算,為什么它們不會影響下三角的注意力分數?

2. 未來 token 是否影響了 Q 或 K?

我們先看 Transformer 計算 Q, K, V 的方式:

Q = X W_Q, \quad K = X W_K, \quad V = X W_V

這里:

  • X 是整個輸入序列的表示。

  • W_Q, W_K, W_V是相同的投影矩陣,作用于所有 token。

由于 每個 token 的 Q, K, V 只取決于它自己,并不會在計算時使用未來 token 的信息,所以:

  • 計算第 i?個 token 的 Q_i, K_i, V_i時,并沒有用到 X_{i+1}, X_{i+2}, \dots,所以未來 token 并不會影響當前 token 的 Q, K, V

結論 1未來 token 不會影響當前 token 的 Q 和 K。

3. Masked Attention 如何確保下三角部分不包含未來信息?

即使 Q, K 沒有未來信息,我們仍然要證明 計算出的注意力分數不會受到未來信息影響

我們來看注意力計算:

\frac{Q K^T}{\sqrt{d_k}}

這是一個 所有 token 之間的相似度矩陣,即:

S = \begin{bmatrix} Q_1 \cdot K_1^T & Q_1 \cdot K_2^T & Q_1 \cdot K_3^T & Q_1 \cdot K_4^T \\ Q_2 \cdot K_1^T & Q_2 \cdot K_2^T & Q_2 \cdot K_3^T & Q_2 \cdot K_4^T \\ Q_3 \cdot K_1^T & Q_3 \cdot K_2^T & Q_3 \cdot K_3^T & Q_3 \cdot K_4^T \\ Q_4 \cdot K_1^T & Q_4 \cdot K_2^T & Q_4 \cdot K_3^T & Q_4 \cdot K_4^T \end{bmatrix}

然后,我們應用 因果 Mask(Causal Mask)

S' = S + \text{mask}

Mask 讓右上角(未來 token 相關的部分)變成 -\infty

\begin{bmatrix} S_{1,1} & -\infty & -\infty & -\infty \\ S_{2,1} & S_{2,2} & -\infty & -\infty \\ S_{3,1} & S_{3,2} & S_{3,3} & -\infty \\ S_{4,1} & S_{4,2} & S_{4,3} & S_{4,4} \end{bmatrix}

然后計算 softmax:

A = \text{softmax}(S')

由于 e^{-\infty} = 0,所有未來 token 相關的注意力分數都變成 0

A = \begin{bmatrix} 1 & 0 & 0 & 0 \\ \text{non-zero} & \text{non-zero} & 0 & 0 \\ \text{non-zero} & \text{non-zero} & \text{non-zero} & 0 \\ \text{non-zero} & \text{non-zero} & \text{non-zero} & \text{non-zero} \end{bmatrix}

最后,我們計算:

\text{Output} = A V

由于未來 token 的注意力權重是 0,它們的 V 在計算中被忽略。因此,下三角部分(歷史 token 之間的注意力)完全不受未來 token 影響。

結論 2未來 token 的信息不會影響下三角部分的 Attention 計算。

4. 為什么 Masked Attention 能防止未來信息泄露?

你可能會問:

即使有 Mask,計算 Attention 之前,我們不是還是用到了整個序列的 Q, K, V 嗎?未來 token 的 Q, K, V 不是已經算出來了嗎?

的確,每個 token 的 Q, K, V 是獨立計算的,但 Masked Attention 確保了:

  1. 計算 Q, K, V 時,每個 token 只依賴于它自己的輸入

    • Q_i, K_i, V_i只來自 token i,不會用到未來的信息

    • 未來的 token 并不會影響當前 token 的 Q, K, V

  2. Masked Softmax 阻止了未來 token 的影響

    • 雖然 Q, K, V 都計算了,但 Masking 讓未來 token 的注意力分數變為 0,確保計算出的 Attention 結果不包含未來信息。

最終,當前 token 只能看到過去的信息,未來的信息被完全屏蔽!

5. 訓練時使用 Masked Attention 的必要性

Masked Attention 的一個關鍵作用是 讓訓練階段和推理階段保持一致

  • 訓練時:模型學習如何根據 歷史 token 預測 下一個 token,確保生成文本時符合自然語言流向。

  • 推理時:模型生成每個 token 后,仍然只能訪問過去的 token,而不會看到未來的 token。

如果 訓練時沒有 Masked Attention,模型會學習到“作弊”策略,直接利用未來信息進行預測。但在推理時,模型無法“偷看”未來的信息,導致生成質量急劇下降。

6. 結論

Masked Attention 是 LLM 訓練的核心機制之一,其作用在于:

  • 確保當前 token 只能訪問過去的 token,不會泄露未來信息
  • 讓訓練階段與推理階段保持一致,避免模型在推理時“失效”
  • 利用因果 Mask 讓 Transformer 具備自回歸能力,學會按序生成文本

Masked Attention 本質上是 Transformer 訓練過程中對信息流動的嚴格約束,它確保了 LLM 能夠正確學習自回歸生成任務,是大模型高質量文本生成的基礎。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/73957.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/73957.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/73957.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【自學筆記】PHP語言基礎知識點總覽-持續更新

提示:文章寫完后,目錄可以自動生成,如何生成可參考右邊的幫助文檔 文章目錄 1. PHP 簡介2. PHP 環境搭建3. 基本語法變量與常量數據類型運算符 4. 控制結構條件語句循環語句 5. 函數函數定義與調用作用域 6. 數組7. 字符串8. 表單處理9. 會話…

css選擇最后結尾的元素DOM

前言 選中最后一個元素&#xff0c;實際使用非常頻繁。 解決方案 使用 CSS 提供的選擇器&#xff0c;即可完成。 如下代碼示例&#xff0c;兩種選擇器均可實現。 <p>...</p>p:last-child{ background:#ff0000; }p:nth-last-child(1){background:#ff0000; }p&…

Axios 相關的面試題

在跟著視頻教程學習項目的時候使用了axios發送請求&#xff0c;但是只是跟著把代碼粘貼上去&#xff0c;一些語法規則根本不太清楚&#xff0c;但是根據之前的博客學習了fetch了之后&#xff0c;一看axios的介紹就明白了。所以就直接展示axios的面試題吧 本文主要內容&#xff…

瑞芯微RKRGA(librga)Buffer API 分析

一、Buffer API 簡介 在瑞芯微官方的 librga 庫的手冊中&#xff0c;有兩組配置 buffer 的API&#xff1a; importbuffer 方式&#xff1a; importbuffer_virtualaddr importbuffer_physicaladdr importbuffer_fd wrapbuffer 方式&#xff1a; wrapbuffer_virtualaddr wrapb…

C語言:多線程

多線程概述 定義 多線程是指在一個程序中可以同時運行多個不同的執行路徑&#xff08;線程&#xff09;&#xff0c;這些線程可以并發或并行執行。并發是指多個線程在宏觀上同時執行&#xff0c;但在微觀上可能是交替執行的&#xff1b;并行則是指多個線程真正地同時執行&…

Linux線程池實現

1.線程池實現 全部代碼&#xff1a;whb-helloworld/113 1.喚醒線程 一個是喚醒全部線程&#xff0c;一個是喚醒一個線程。 void WakeUpAllThread(){LockGuard lockguard(_mutex);if (_sleepernum)_cond.Broadcast();LOG(LogLevel::INFO) << "喚醒所有的休眠線程&q…

微信小程序逆向開發

一.wxapkg文件 如何查看微信小程序包文件&#xff1a; 回退一級 點擊進入這個目錄 這個就是我們小程序對應的文件 .wxapkg概述 .wxapkg是微信小程序的包文件格式&#xff0c;且其具有獨特的結構和加密方式。它不僅包含了小程序的源代碼&#xff0c;還包括了圖像和其他資源文…

多輸入多輸出 | Matlab實現CPO-LSTM冠豪豬算法優化長短期記憶神經網絡多輸入多輸出預測

多輸入多輸出 | Matlab實現CPO-LSTM冠豪豬算法優化長短期記憶神經網絡多輸入多輸出預測 目錄 多輸入多輸出 | Matlab實現CPO-LSTM冠豪豬算法優化長短期記憶神經網絡多輸入多輸出預測預測效果基本介紹程序設計參考資料 預測效果 基本介紹 Matlab實現CPO-LSTM冠豪豬算法優化長短期…

視頻編碼器的抉擇:x264、x265、libaom、vvenc 對比測試實驗

264、x265、libaom、vvenc 對比測試實驗 測試機器配置&#xff1a;Apple M1 Pro -16G編碼器版本&#xff08;選擇自己編譯&#xff09;&#xff1a;所有源碼都是當前最新更新的狀態&#xff0c;此外各類編碼具體的編譯過程可參考我的相關系列博客。 編碼器GitHubx264git clon…

【二刷代碼隨想錄】雙指針-數組相關題型、推薦習題

一、雙指針-數組 相關題型與常用思路 1、單個數組 &#xff08;1&#xff09;原地移除元素類 如推薦習題中的&#xff08;1&#xff09;、&#xff08;2&#xff09;、&#xff08;3&#xff09;&#xff0c;都屬于此類。引入雙指針 pre、last &#xff0c;用 pre 指針表明數…

Level DB --- TableCache

TableCache 是Level DB 中重要的類&#xff0c;Level DB 中多層&#xff08;multi level&#xff09;&#xff0c;且每一層&#xff08;level&#xff09;有多個 key-value file&#xff0c;TableCache正是用來緩存多層以及多層中的file數據&#xff0c;更快速地檢索。 table …

搜索-BFS

馬上藍橋杯了&#xff0c;最近刷了廣搜&#xff0c;感覺挺有意思的&#xff0c;廣搜題類型都差不多&#xff0c;模板也一樣&#xff0c;大家寫的時候可以直接套模板 這里給大家講一個比較經典的廣搜題-迷宮 題目問問能否走到 (n,m) 位置&#xff0c;假設最后一個點是我們的&…

智能預測維護:讓設備“未卜先知”,減少宕機煩惱

智能預測維護:讓設備“未卜先知”,減少宕機煩惱 1. 引言:設備維護的痛點與出路 在工業生產和自動化領域,設備故障一直是令人頭疼的問題。設備一旦故障,輕則影響生產效率,重則造成嚴重損失,甚至帶來安全隱患。傳統的設備維護方式主要有兩種: 被動維護(Reactive Maint…

安卓的布局方式

一、RelativeLayout 相對布局 特點&#xff1a;每個組件相對其他的某一個組件進行定位。 (一)主要屬性 1、設置和父組件的對齊&#xff1a; alignParentTop &#xff1a; 設置為true&#xff0c;代表和父布局頂部對齊。 其他對齊只需要改變后面的Top為 Left、Right 或者Bottom&…

SSM中藥分類管理系統

&#x1f345;點贊收藏關注 → 添加文檔最下方聯系方式咨詢本源代碼、數據庫&#x1f345; 本人在Java畢業設計領域有多年的經驗&#xff0c;陸續會更新更多優質的Java實戰項目希望你能有所收獲&#xff0c;少走一些彎路。&#x1f345;關注我不迷路&#x1f345; 項目視頻 SS…

epoch、batch、batch size、step、iteration深度學習名詞含義詳細介紹

卷積神經網絡訓練中的三個核心概念&#xff1a;Epoch、Batch Size 和迭代次數 在深度學習中&#xff0c;理解一些基本的術語非常重要&#xff0c;這些術語對模型的訓練過程、效率以及最終性能都有很大影響。以下是一些常見術語的含義介紹&#xff1a; 1. Epoch&#xff08;周…

React(七):Redux

Redux基本使用 純函數&#xff1a;1.函數內部不能依賴函數外部變量&#xff1b;2.不能產生副作用&#xff0c;在函數內部改變函數外部的變量 React只幫我們解決了DOM的渲染過程&#xff0c;State還是要由我們自己來管理——redux可幫助我們進行管理 Redux三大特點 1.單一數…

《Android低內存設備性能優化實戰:深度解析Dalvik虛擬機參數調優》

1. 痛點分析&#xff1a;低內存設備的性能困局 現象描述&#xff1a;大應用運行時頻繁GC導致卡頓 根本原因&#xff1a;Dalvik默認內存參數與硬件資源不匹配 解決方向&#xff1a;動態調整堆內存參數以平衡性能與資源消耗 2. 核心調優參數全景解析 關鍵參數矩陣&#xff1…

STC89C52單片機學習——第38節: [17-2] 紅外遙控紅外遙控電機

寫這個文章是用來學習的,記錄一下我的學習過程。希望我能一直堅持下去,我只是一個小白,只是想好好學習,我知道這會很難&#xff0c;但我還是想去做&#xff01; 本文寫于&#xff1a;2025.03.30 51單片機學習——第38節: [17-2] 紅外遙控&紅外遙控電機 前言開發板說明引用…

計算機組成原理————計算機運算方法精講<1>原碼表示法

第一部分:無符號數和有符號數的概念 1.無符號數 計算機中的數均存放在寄存器當中,通常稱寄存器的位數為機器字長,所謂無符號數,就是指沒有fu5號的數,在寄存器中的每一位均可用來存放數值,當存放有符號數時,需要留出位置存放符號,機器字長相同時,無符號數與有符號數所…