知識蒸餾 Knowledge Distillation 序列的聯合概率 分解成 基于歷史的條件概率的連乘序列
flyfish
代碼實踐
論文 Generalized Knowledge Distillation (GKD)
On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes
自回歸分解 將 “序列的聯合概率” 分解成 “基于歷史的條件概率的連乘序列”
自回歸
以句子“你現在讀的這句話”為例
首先明確句子的token時序順序(從左到右,依次生成):
第1個token:你 → 第2個token:現 → 第3個token:在 → 第4個token:讀 → 第5個token:的 → 第6個token:這 → 第7個token:句 → 第8個token:話
自回歸的核心邏輯是:每個“當前token”只依賴“它前面所有已生成的歷史token”(即左側的、早于它的token),與右側未生成的token無關。具體依賴關系如下:
當前token | 生成順序 | 依賴的“歷史token”(僅左側已生成的) | 不依賴的內容(右側未生成的) |
---|---|---|---|
你 | 第1步 | 無(第一個生成的token,僅依賴輸入prompt,如“請說一句話:”) | 現、在、讀、的、這、句、話 |
現 | 第2步 | 僅依賴第1個token“你” | 在、讀、的、這、句、話 |
在 | 第3步 | 依賴第1-2個token“你、現” | 讀、的、這、句、話 |
讀 | 第4步 | 依賴第1-3個token“你、現、在” | 的、這、句、話 |
的 | 第5步 | 依賴第1-4個token“你、現、在、讀” | 這、句、話 |
這 | 第6步 | 依賴第1-5個token“你、現、在、讀、的” | 句、話 |
句 | 第7步 | 依賴第1-6個token“你、現、在、讀、的、這” | 話 |
話 | 第8步 | 依賴第1-7個token“你、現、在、讀、的、這、句” | 無(最后一個token) |
生成過程
語言模型生成“你現在讀的這句話”的過程,就像人寫字“一筆接一筆、從左到右”:
- 第一步:看到prompt(比如“請寫一個日常場景的短句”),先寫出第一個字“你”——此時只需要考慮“prompt要求”,不需要想后面要寫什么;
- 第二步:寫完“你”后,思考“‘你’后面接什么字合理”,選擇“現”——只看已經寫好的“你”,不看還沒寫的“在、讀、的…”;
- 第三步:寫完“你、現”后,思考“‘你現’后面接什么字合理”,選擇“在”——只看已經寫好的“你、現”,不看后面的“讀、的…”;
- 以此類推,直到寫完最后一個字“話”——每一步的選擇,都只基于“左邊已經寫好的內容”,右側的內容是“下一步才會生成的”,當前步驟完全無法預知,自然也無法依賴。
這里為了簡單介紹一個字是一個token,實際會有多種token算法
自回歸的“自”(自身歷史),特指**“自身左側已生成的時序歷史”,而非“整個句子的所有內容”。原則是:
文本生成是“單向時序過程”,后面的token還沒被模型生成,當前token只能“回頭看左邊的歷史”,不能“超前看右邊的未來”自回歸語言生成是嚴格按“從左到右”的時序展開**,當前token只能依賴“它左邊已經生成的歷史token”,絕不可能依賴“右邊還沒生成的token”。
條件鏈式法則的序列形式
自回歸分解的公式
P(y1,…,yL∣x)=∏n=1LP?(yn∣y<n,x)P(y_1,\dots,y_L\mid x)=\prod_{n=1}^{L} P\!\left(y_n \mid y_{<n},x\right) P(y1?,…,yL?∣x)=n=1∏L?P(yn?∣y<n?,x)
本質是條件概率鏈式法則的直接推廣。
回顧基礎鏈式法則:對于隨機變量序列 Y1,Y2,…,YLY_1,Y_2,\dots,Y_LY1?,Y2?,…,YL? 和給定的條件變量 XXX,它們的條件聯合概率可以分解為一系列條件概率的乘積。這里的關鍵是:
- 每一步的條件概率 P(yn∣y<n,x)P(y_n \mid y_{<n},x)P(yn?∣y<n?,x) 只依賴于“之前所有已出現的變量 y<n=(y1,…,yn?1)y_{<n} = (y_1,\dots,y_{n-1})y<n?=(y1?,…,yn?1?)”和“外部條件 xxx”;
- 整個序列的聯合概率被拆解為 L個“局部條件概率”的連乘,將高維聯合概率的計算轉化為低維條件概率的乘積,大幅降低了建模難度。
為什么自回歸分解是“最優解”?
語言模型的核心任務是建模文本序列的概率分布(如“一句話是否合理”“下一個詞是什么”),而自回歸分解完美適配語言的“時序特性”,主要原因有三點:
1. 符合語言的“順序生成特性”
人類語言天然是時序序列:說話/寫作時,我們總是“先說出第一個詞,再根據第一個詞說第二個詞,以此類推”。自回歸分解恰好模擬了這一過程——每個詞 yny_nyn? 的生成只依賴于“已經說過的詞 y<ny_{<n}y<n?”和“上下文提示 xxx”,與人類語言生成的直覺一致。
2. 降低建模復雜度
直接建模整個序列的聯合概率 P(y1,…,yL∣x)P(y_1,\dots,y_L \mid x)P(y1?,…,yL?∣x) 幾乎不可能:對于詞匯表大小為 VVV 的語言,序列長度為 LLL 的可能組合有 VLV^LVL 種,無法直接枚舉或存儲。
而自回歸分解將問題轉化為逐一生成每個位置的詞,只需建模 P(yn∣y<n,x)P(y_n \mid y_{<n},x)P(yn?∣y<n?,x)——這個條件概率的“輸入維度”是“已生成序列 y<ny_{<n}y<n? + 提示 xxx”,輸出是“下一個詞 yny_nyn? 的概率分布”,可以用神經網絡(如Transformer)高效建模。
3. 支持“增量生成”
語言模型的核心功能之一是生成文本(如續寫、翻譯),自回歸分解天然支持“增量生成”:
- 第一步:根據提示 xxx 生成第一個詞 y1y_1y1?,即采樣自 P(y1∣x)P(y_1 \mid x)P(y1?∣x);
- 第二步:根據 xxx 和 y1y_1y1? 生成 y2y_2y2?,采樣自 P(y2∣y1,x)P(y_2 \mid y_1, x)P(y2?∣y1?,x);
- …
- 第L步:根據 xxx 和 y1,…,yL?1y_1,\dots,y_{L-1}y1?,…,yL?1? 生成 yLy_LyL?,最終得到完整序列。
自回歸語言模型的生成邏輯——每一步只需要“看前面的詞”,就能繼續生成下一個詞。
語言模型如何“學習”這個分解?
語言模型的訓練目標,本質上是通過海量文本數據,學習自回歸分解中的每個條件概率 P(yn∣y<n,x)P(y_n \mid y_{<n},x)P(yn?∣y<n?,x)。
具體來說:
- 訓練數據是大量文本序列(如句子、段落),每個序列可視為 (x,y1,…,yL)(x, y_1,\dots,y_L)(x,y1?,…,yL?)(其中 xxx 可能是序列的前綴,或空提示);
- 模型通過“預測下一個詞”的任務學習條件概率:給定 xxx 和 y<ny_{<n}y<n?,模型輸出對 yny_nyn? 的概率分布,通過交叉熵損失優化,讓模型的預測盡可能接近真實文本中 yny_nyn? 的出現概率;
- 訓練完成后,模型就能對“任意前綴序列”輸出下一個詞的合理概率分布,從而支持生成連貫的文本。
一句話的自回歸分解
以生成句子“我愛吃蘋果”(假設 xxx 為空提示,即生成獨立句子)為例,其概率可分解為:
P(我,愛,吃,蘋果)=P(我)×P(愛∣我)×P(吃∣我,愛)×P(蘋果∣我,愛,吃)P(\text{我},\text{愛},\text{吃},\text{蘋果}) = P(\text{我}) \times P(\text{愛} \mid \text{我}) \times P(\text{吃} \mid \text{我},\text{愛}) \times P(\text{蘋果} \mid \text{我},\text{愛},\text{吃}) P(我,愛,吃,蘋果)=P(我)×P(愛∣我)×P(吃∣我,愛)×P(蘋果∣我,愛,吃)
- P(我)P(\text{我})P(我):第一個詞是“我”的概率(語言模型中,通常會給句首詞一個先驗分布);
- P(愛∣我)P(\text{愛} \mid \text{我})P(愛∣我):在“我”之后接“愛”的概率(比如“我”后面更可能接“愛”“是”“想”等詞,而不是“蘋果”“跑步”);
- 以此類推,每個步驟的概率都依賴于前面的詞,最終乘積就是整個句子的聯合概率。
自回歸分解是將“序列的聯合概率”分解成“基于歷史的條件概率的連乘序列”,而這個分解過程是條件概率鏈式法則——但它不是“任意的鏈式法則”,而是加了“自回歸約束”的鏈式法則應用。
第一步:先明確“自回歸分解分解的是什么?”
自回歸分解的核心目標,是把“一個token序列(比如句子)的整體概率”拆解開,變成“每一步生成token的局部概率”,方便語言模型計算和預測。
比如,對于一個長度為n
的token序列 X = [x?, x?, x?, ..., x?]
(x?
是第一個token,x?
是最后一個token),我們要分解的是這個序列的聯合概率 P(X) = P(x?, x?, x?, ..., x?)
。
為什么要分解?因為直接計算“整個序列同時出現”的聯合概率非常困難(可能性太多),但分解成“每一步基于前面內容的條件概率”后,模型只需要預測“當前token在歷史token下的概率”,難度會大幅降低。
第二步:自回歸分解如何用“條件概率鏈式法則”?
條件概率鏈式法則是概率論的基礎規則:多個隨機變量的聯合概率,可分解為第一個變量的邊緣概率,乘以第二個變量在第一個變量條件下的概率,再乘以第三個變量在第一、二個變量條件下的概率……以此類推。
對序列 X = [x?, x?, ..., x?]
,通用的條件概率鏈式法則是這樣的:
P(x1,x2,...,xn)=P(x1)×P(x2∣x1)×P(x3∣x1,x2)×P(x4∣x1,x2,x3)×...×P(xn∣x1,x2,...,xn?1)P(x?, x?, ..., x?) = P(x?) × P(x? | x?) × P(x? | x?, x?) × P(x? | x?, x?, x?) × ... × P(x? | x?, x?, ..., x???) P(x1?,x2?,...,xn?)=P(x1?)×P(x2?∣x1?)×P(x3?∣x1?,x2?)×P(x4?∣x1?,x2?,x3?)×...×P(xn?∣x1?,x2?,...,xn?1?)
而自回歸分解,就是完全遵循這個鏈式法則,但額外加了一個“自回歸約束”:
約束:對于第k
個token x?
,它的條件依賴只能是“它前面所有已生成的歷史token(x?到x???)”,絕對不能依賴“它后面未生成的token(x???到x?)”。
自回歸分解沒有“發明新的分解規則”,而是把“條件概率鏈式法則”直接用在了“時序序列”上——因為語言生成是“從左到右、先有歷史再有當前”的單向過程,后面的token還沒生成,自然無法成為當前token的依賴,這就和鏈式法則中“x?只依賴x?到x???”的形式完美匹配。
第三步:用具體例子看“自回歸分解的結果”
還是用之前的句子“你現在讀的這句話”,對應的token序列 X = [x?=你, x?=現, x?=在, x?=讀, x?=的, x?=這, x?=句, x?=話]
。
根據自回歸分解(基于條件鏈式法則),這個序列的聯合概率會被分解成以下條件概率的連乘序列:
P(你,現,在,讀,的,這,句,話)=P(你)(第一個token,無歷史,只看邊緣概率)×P(現∣你)(第二個token,依賴前1個歷史token“你”)×P(在∣你,現)(第三個token,依賴前2個歷史token“你、現”)×P(讀∣你,現,在)(第四個token,依賴前3個歷史token)×P(的∣你,現,在,讀)(第五個token,依賴前4個)×P(這∣你,現,在,讀,的)(第六個token,依賴前5個)×P(句∣你,現,在,讀,的,這)(第七個token,依賴前6個)×P(話∣你,現,在,讀,的,這,句)(第八個token,依賴前7個)\begin{align*} P(你,現,在,讀,的,這,句,話) &= P(你) \quad \text{(第一個token,無歷史,只看邊緣概率)} \\ &\times P(現 \mid 你) \quad \text{(第二個token,依賴前1個歷史token“你”)} \\ &\times P(在 \mid 你,現) \quad \text{(第三個token,依賴前2個歷史token“你、現”)} \\ &\times P(讀 \mid 你,現,在) \quad \text{(第四個token,依賴前3個歷史token)} \\ &\times P(的 \mid 你,現,在,讀) \quad \text{(第五個token,依賴前4個)} \\ &\times P(這 \mid 你,現,在,讀,的) \quad \text{(第六個token,依賴前5個)} \\ &\times P(句 \mid 你,現,在,讀,的,這) \quad \text{(第七個token,依賴前6個)} \\ &\times P(話 \mid 你,現,在,讀,的,這,句) \quad \text{(第八個token,依賴前7個)} \\ \end{align*} P(你,現,在,讀,的,這,句,話)?=P(你)(第一個token,無歷史,只看邊緣概率)×P(現∣你)(第二個token,依賴前1個歷史token“你”)×P(在∣你,現)(第三個token,依賴前2個歷史token“你、現”)×P(讀∣你,現,在)(第四個token,依賴前3個歷史token)×P(的∣你,現,在,讀)(第五個token,依賴前4個)×P(這∣你,現,在,讀,的)(第六個token,依賴前5個)×P(句∣你,現,在,讀,的,這)(第七個token,依賴前6個)×P(話∣你,現,在,讀,的,這,句)(第八個token,依賴前7個)?
這個連乘序列,就是“自回歸分解”的最終結果——它完全是條件概率鏈式法則在“語言生成時序”下的直接體現,每一項都是“當前token基于歷史的條件概率”,沒有任何超出歷史的依賴。
自回歸分解與條件鏈式法則的關系
結論 | 具體解釋 |
---|---|
分解的對象 | 序列的聯合概率(如P(x?,x?,...,x?) ) |
分解的工具 | 條件概率鏈式法則(概率論基礎規則) |
分解的約束 | 自回歸約束:x? 僅依賴x?~x??? (左側歷史,無未來依賴) |
分解的結果 | 一個條件概率的連乘序列(每一項對應一步生成的概率) |