LEARNING DYNAMICS OF LLM FINETUNING
一句話總結
作者將LLM的學習動力機制拆解成AKG三項,并分別觀察了SFT和DPO訓練過程中??正梯度信號??和??負梯度信號??的變化及其帶來的影響,并得到以下結論:
- ??SFT通過梯度相似性間接提升無關回答置信度??;
- ??DPO負梯度回傳會引發confidence(softmax score) 向模型訓練前confidence最高的token聚集:DPO中正樣本與 負樣本之間的margin雖然在增大,但正樣本的confidence的上漲幅度卻沒有訓練前模型最中意的答案的confidence漲得快。
吐槽一下,這篇非常不適合LLM伴讀和粗讀,因為符號多,符號的描述寫的很松散,圖表的講解非常碎(一張圖的解釋可能會橫跨兩個subsection),我開始用LLM伴讀給我弄的非常混亂,自己粗讀發現跳過一兩段可能一個符號什么意思就不知道了。最后老老實實一行一行讀。雖然作者的工作非常詳實,但是……只能說這種寫法不坑老實人吧。
↓↓↓這里先補充一下Learning Dynamic的分析邏輯↓↓
學習動力機制(Learning Dynamic)指什么?
Learning Dynamic,簡單來說,就是研究每次模型經過一個batch的數據訓練后,參數發生了變化,這個變化如何影響模型的表現。具體來說,哪些方面變好了?哪些方面變差了?
用來描述哪些方面變好,哪些方面變差的工具就是先做一個觀察數據集,這個數據集是靜態的。
觀察一個batch的訓練后,觀察數據集里的<哪些樣本>的預測結果有<什么樣>的變化,就是研究Learning Dynamic。
為了銜接論文里的公式,這里把上面學習動力機制研究的公式寫出來
第一步:定義,參數更新對特定樣本的輸出的影響
f ( x ; θ ) = π ( z ( x ; θ ) ) f(x;θ)=π(z(x;θ)) f(x;θ)=π(z(x;θ)) π π π是最后一層的softmax函數, x o x_o xo?指的是觀察集的樣本。
—>參數更新 Δ θ Δθ Δθ 會導致輸入 x x x的輸出變化是
Δ f ( x o ; θ ) ≈ ? θ f ( x o ; θ ) ? Δ θ Δf(x_o;θ)≈?_θf(x_o;θ)?Δθ Δf(xo?;θ)≈?θ?f(xo?;θ)?Δθ
第二步:展開 第一項
? θ f ( x o ; θ ) = ? π ( z ( x o ; θ ) ) ? z ? softmax雅可比? J π ( z ) ? ? θ z ( x o ; θ ) ? logits的梯度 \nabla_\theta f(x_o;\theta) = \underbrace{\frac{\partial \pi(z(x_o;\theta))}{\partial z}}_{\text{softmax雅可比 } J_\pi(z)} \cdot \underbrace{\nabla_\theta z(x_o;\theta)}_{\text{logits的梯度}} ?θ?f(xo?;θ)=softmax雅可比?Jπ?(z) ?z?π(z(xo?;θ))????logits的梯度 ?θ?z(xo?;θ)??
第三步: Δ θ Δθ Δθ 把梯度替換成 當批次樣本 { x 1 , . . . , x i , . . . x n } \{x_1,...,x_i,...x_n\} {x1?,...,xi?,...xn?} 帶來的梯度
把 Δ θ Δθ Δθ 替換成 Δ θ = ? η ? θ L ( θ ) Δθ=?η? _{θ}L(θ) Δθ=?η?θ?L(θ) 即學習率乘以梯度
? θ L ( θ ) = 1 n ∑ i = 1 n ( π ( z ( x i ; θ ) ) ? y i ) ? ? ? θ z ( x i ; θ ) ? 通過鏈式法則: ? L ? z ? ? z ? θ \nabla_\theta L(\theta) = \frac{1}{n} \sum_{i=1}^n \underbrace{\left( \pi(z(x_i;\theta)) - y_i \right)^\top \cdot \nabla_\theta z(x_i;\theta)}_{\text{通過鏈式法則:} \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial \theta}} ?θ?L(θ)=n1?i=1∑n?通過鏈式法則:?z?L???θ?z? (π(z(xi?;θ))?yi?)???θ?z(xi?;θ)??
第四步:兩個展開式帶回源式
Δ f ( x ; θ ) ≈ ? η n ∑ i = 1 n [ J π ( z ( x ; θ ) ) ? ? θ z ( x ; θ ) ] ? [ ( π ( z ( x i ; θ ) ) ? y i ) ? ? ? θ z ( x i ; θ ) ] \Delta f(x;\theta) \approx -\frac{\eta}{n} \sum_{i=1}^n \left[ J_\pi(z(x;\theta)) \cdot \nabla_\theta z(x;\theta) \right] \cdot \left[ \left( \pi(z(x_i;\theta)) - y_i \right)^\top \cdot \nabla_\theta z(x_i;\theta) \right] Δf(x;θ)≈?nη?i=1∑n?[Jπ?(z(x;θ))??θ?z(x;θ)]?[(π(z(xi?;θ))?yi?)???θ?z(xi?;θ)]
第五步:把gradient乘積寫成NTK項
Δ f ( x ; θ ) ≈ ? η n J π ( z ( x ; θ ) ) ∑ i = 1 n ? ? θ z ( x ; θ ) , ? θ z ( x i ; θ ) ? ? NTK項? K ( x , x i ) ? ( π ( z ( x i ; θ ) ) ? y i ) Δf(x;θ) ≈ -\frac{η}{n} J_π(z(x;θ)) \sum_{i=1}^n \underbrace{\left\langle \nabla_θ z(x;θ), \nabla_θ z(x_i;θ) \right\rangle}_{\text{NTK項 } \mathbf{K}(x, x_i)} \cdot \left( π(z(x_i;θ)) - y_i \right) Δf(x;θ)≈?nη?Jπ?(z(x;θ))i=1∑n?NTK項?K(x,xi?) ??θ?z(x;θ),?θ?z(xi?;θ)????(π(z(xi?;θ))?yi?)
最終
Δ f ( x o ; θ ) ≈ ? η n ? J π ( z ( x o ; θ ) ) ? softmax非線性效應 ? ∑ i = 1 n K ( x o , x i ) ? 樣本相關性 ? ( π ( z ( x i ; θ ) ) ? y i ) ? 訓練樣本誤差 \Delta f(x_o;\theta) \approx -\frac{\eta}{n} \cdot \underbrace{J_\pi(z(x_o;\theta))}_{\text{softmax非線性效應}} \cdot \sum_{i=1}^n \underbrace{\mathbf{K}(x_o, x_i)}_{\text{樣本相關性}} \cdot \underbrace{\left( \pi(z(x_i;\theta)) - y_i \right)}_{\text{訓練樣本誤差}} Δf(xo?;θ)≈?nη??softmax非線性效應 Jπ?(z(xo?;θ))???i=1∑n?樣本相關性 K(xo?,xi?)???訓練樣本誤差 (π(z(xi?;θ))?yi?)??
最終這個式子里面的第一項、第二項和第三項分別對應了文章公式3的 AKG。
其中K的部分,代表了在函數空間中的樣本相關性(不是兩個樣本乍看之下像不像,而是通過梯度核(NTK)映射之后像不像。文章在分析SFT時會一直用這個角度
G是訓練樣本誤差帶來的影響,這個在DPO分析中還會被拆成正負兩項(因為DPO用的是margin)
關鍵細節與結論
1. G項給SFT帶來的是訓練Token的confidence上升
這里的confidence指的是Token的輸出概率(softmax之后的score)
看上圖中左側第一張圖,灰色是本步驟之前的這個樣本的詞表上的所有Score的曲線。藍色就是這一步訓練之后,詞表上所有詞的Score的變化。
這是很符合直覺的,也是論文的前菜.(但是這張圖特別不好,把后面要說的東西的分析也放這兒了,看著很亂。大概是因為投稿篇幅的關系,作者壓進來了)
2. SFT中 出現在訓練集里的y,會在掛羊頭賣狗肉的情況下也很自信
文中figure3的這張圖,展示的是
π ( y u + ∣ x u ) π(y_{u^+}|x_u) π(yu+?∣xu?) -虛線, π ( y j ∣ x u ) π(y_{j}|x_u) π(yj?∣xu?)-藍色線, 和 π ( y t e s t + ∣ x u ) π(y_{test^+}|x_u) π(ytest+?∣xu?)-橙色線 在訓練過程中的變化。( π π π是softmax函數,這三條線都對應的是模型分)
其中 π ( y u + ∣ x u ) π(y_{u^+}|x_u) π(yu+?∣xu?) 就是訓練集中樣本A 對應的答案A的概率;
π ( y j ∣ x u ) π(y_{j}|x_u) π(yj?∣xu?)是訓練集中樣本A作為context(或者說prompt),后面接樣本B的正確答案y_j的時候的score
π ( y t e s t + ∣ x u ) π(y_{test^+}|x_u) π(ytest+?∣xu?) 是訓練集中樣本A作為context,后面接測試集中某個樣本的正確答案的時候的score。
如果按照我們的希望🔽
給定樣本A的prompt,樣本B的正確答案作為output的時候,這個答案是錯的,他的confidence不應該上漲。
但在作者的實驗記錄下,這一數值上漲了。這就是作者認為SFT中一部分幻覺的來源
3. K項給SFT帶來的是意思相近和格式相近的答案的confidence的上升
作者在分析的時候使用的是DPO數據集(Antropic-HH 和UltraFeedback),盡管他在訓練SFT的時候僅適用了 DPO數據集中的prefer樣本 y u + y_{u^+} yu+?,但在分析中他同時分析了less prefer(負樣本) y u ? y_{u^-} yu??。
下圖中, y u + y_{u^+} yu+? 是DPO樣本集中<正樣本>
y u ? y_{u^-} yu?? 是DPO樣本集中<負樣本>
y g p t s + y_{gpts^+} ygpts+? 是用GPT針對DPO樣本集中<正樣本>做的同意改寫樣本
y g p t s ? y_{gpts^-} ygpts?? 是用GPT針對DPO樣本集中<負樣本>做的同意改寫樣本
y h u m y_{hum} yhum? 是隨機的一個句子(跟訓練樣本無關)
y j ≠ u + y_{{j≠u}^+} yj=u+? 是訓練集中另一個樣本的<正樣本>
y t e s t + y_{{test}^+} ytest+? 是測試集中某個樣本的<正樣本>
y r n d y_{{rnd}} yrnd? 是隨機英文詞組成的字符串(連句子都不是)
基于以上符號意義,觀察上圖,藍色部分confidence上漲,橙色就是下降。上圖中左側多個列都顯示,每訓練25步 [ x u ; y u ] [x_u;y_u] [xu?;yu?]就觀察一次,發現一些語義或者格式相似( y g p t s + y_{gpts^+} ygpts+?)的樣本的confidence也上漲了。
同時亂序的字符串( y h u m y_{hum} yhum?和 y r n d y_{rnd} yrnd?)訓練的時候,其他有正常語義的答案的confidence都會降低。因為,用K來衡量的距離跟這些亂序的答案的距離是很大的,Learning Dynamic的角度上看這些confidence也確實應該下降。
》最炸裂的其實是這個圖的最后一列 DPO這列,這列引入的是下面一個結論
4. DPO中的負梯度會給所有未被訓練的答案打負分
作者實驗中發現,如果下圖左1所示,在用DPO訓練很多個epoch的過程中,根據訓練樣本正確答案改寫的正樣本 y g p t s + y_{gpts^+} ygpts+? (語義相同)和 y g p t f + y_{gptf^+} ygptf+? (格式相同)的confidence都在持續下降。
這是與我們的直覺相反的。
同時,上圖左2顯示,不僅DPO中正樣本的改寫樣本的confidence在持續下滑,負樣本的confidence也在持續下滑(仔細看數軸,這組斜率更大)
看上圖的左4,在訓練中的正樣本和負樣本的confidence的變化是:正樣本的confidence先上漲后下滑,到第四個epoch,連正樣本的confidence都比訓練前要低了;而負樣本的confidence全程咔咔下跌。
那正負樣本的confidence都跌了,誰漲了呢?(畢竟是用softmax轉換過的score,有人跌就有人漲)
答:看第左5,圖中的黑線是模型在DPO之前(如果按照RL的說法,可以說是reference model)最prefer的答案的概率。
整個訓練過程,reference model中概率最高的token的概率漲的最猛,比訓練的正樣本y還猛。
5. 這種錯位的虹吸效應的源頭是DPO的負梯度影響
文中的大部分拆解都沿用了經典的拆分方案。DPO上做了一個變化,把G分成了正負兩部分。
這里有個比較討厭的東西 χ u ? χ^-_u χu?? 這個符號其實是附錄中下面高亮部分的意思,就是 x u x_u xu?是一樣的,y不一樣。這么表示其實有點討厭,一來是符號不在公式附近(在附錄B),二來這么寫其實挺容易讓人誤解的。
前面這種奇怪的現象主要來自于負梯度的回傳,也就是 G ? G^- G?的部分。
在附錄中,作者還展示了reference model對 y u + y_{u^+} yu+?分布形狀不同的時候 ,DPO的負梯度帶來的影響差別。
上面這個分析有個需要注意的點,就是這個DPO是off-policy的(雖然我們說DPO的時候通常就是off-policy的),即完全是靜態的樣本集,正樣本和負樣本都不是通過play逐步變化的。
那如果是on-policy的,也就是隨著模型訓練,概率分布有變化的時候(或者 y u + = y u ? y_{u^+}=y_{u^*} yu+?=yu?? 或 y u ? = y u ? y_{u^-}=y_{u^*} yu??=yu?? , u ? u^* u?指的就是reference model最中意的答案。)負梯度帶來的影響是什么樣的?
灰線是訓練前,藍線是訓練后
看圖的第三行,在原先分布是有幾個特殊token的概率較高的情況下, y u ? y_{u^-} yu??直接作用到reference模型概率最高的token上時,其他概率相對較高的token的概率被直接抬起來很多。
看圖的第四行,在原先分布是有幾個特殊token的概率較高的情況下, y u ? y_{u^-} yu??直接作用到reference模型概率很低的token時,原先概率最高的token的概率被大大提升了。而不是所有概率高的token的概率都提升了,而是效用集中到一個token上。
評價
多于1個epoch問題就會逐漸凸顯
在作者原文的大部分實驗圖中,都能看到,盡管DPO帶來的擠壓效應在多個epoch之后就顯得非常病態。在實際訓練過程中,其實也是這樣的,不管是SFT還是DPO,在已知–>“全參數微調可以讓模型在一個epoch之內完成記憶”<–這個前提下,訓練超過1個epoch都要非常小心。而且在準備樣本的過程中也要考慮,樣本實際重復的情況來算實際epoch數。
定點制作DPO數據效果會很好
我這里提供的是論文作者論述以外,一個自己之前實踐上的經驗:制作兩個只有關鍵知識點不同,其他描述完全相同的樣本,比如“上海迪士尼上午9點開園”和“上海迪士尼上午10點開園”,這個會很大程度上提高模型的知識矯正(更新)效率。比unlearning還快。
完全靜態的DPO樣本集,會很快打亂SFT中已經訓練好的模式
同樣是跟作者的認知比較像的,如果先訓練SFT把想要的模式注入到LLM里,然后用DPO來訓,前面SFT的勝利成果可能在很小的樣本量下就不行了。原因就在于SFT注入的本身就是一個模型原來并不prefer的格式。而且注入的點可能覆蓋面并不夠大(小樣本的時候)