KnowPAT
KnowPAT(Knowledgeable Preference AlignmenT) 出自2023年11月的論文《Knowledgeable Preference Alignment for LLMs in Domain-specific Question Answering》,主要針對領域問答來進行知識對齊。
在領域問答有兩個挑戰:希望輸出滿足用戶的要求、輸出充分利用領域知識庫。為了解決這些挑戰,提出了如下圖的三階段的KnowPAT框架。
假設有一個QA數據集 D = ( q i , a i ) ∣ i = 1 , 2 , … , N \mathcal{D} = {(q_i, a_i) | i=1,2,\ldots,N} D=(qi?,ai?)∣i=1,2,…,N, q i q_i qi?和 a i a_i ai?是問答對,在論文中是對應的云端產品使用相關問答對,是由人工收集和標注的。
如果直接在數據集 D \mathcal{D} D上微調LLM M \mathcal{M} M(即通常所說的SFT),設prompt 模板為 I \mathcal{I} I,則優化目標如下(式中的 a i , j a_{i, j} ai,j?是 a i a_i ai?的第j個token, P M P_{\mathcal{M}} PM?是模型 M \mathcal{M} M預測的token概率)。
L f t = ? 1 ∣ a i ∣ ∑ j = 1 ∣ a i ∣ log ? P M ( a i , j ∣ I , q i , a i , < j ) \mathcal{L}_{f t}=-\frac{1}{\left|a_i\right|} \sum_{j=1}^{\left|a_i\right|} \log P_{\mathcal{M}}\left(a_{i, j} \mid \mathcal{I}, q_i, a_{i,<j}\right) Lft?=?∣ai?∣1?j=1∑∣ai?∣?logPM?(ai,j?∣I,qi?,ai,<j?)
對于領域相關任務,一般會有一個領域知識庫(domain KB) B \mathcal{B} B,現在流行的RAG就是領域領域知識庫來讓LLM在領域相關問題上回答更準確的一種解決方法。而KnowPAT采用的是如下三部分的框架來利用領域知識。
無監督知識檢索
設有語義相似度檢索器 H \mathcal{H} H,對于每個問題 q i q_i qi?從KB B \mathcal{B} B中檢索出top-k條最相似的知識并記為 K \mathcal{K} K?, 相似性以檢索器編碼后向量間的余弦相似度來衡量。
偏好數據集構建
偏好數據集分為風格偏好數據集(style preference set, SPS) P s \mathcal{P}_s Ps?和知識偏好數據(knowledge preference set, KPS) P k \mathcal{P}_k Pk?。
風格偏好數據集 P s \mathcal{P}_s Ps?構建過程:
- 選擇l-1個不同的LLM記為 M 1 , M 2 , … , M l ? 1 \mathcal{M}_1,\mathcal{M}_2,\ldots,\mathcal{M}_{l-1} M1?,M2?,…,Ml?1?,不同LLM的文本理解和表達能力不一樣,所以可以生成不同風格的回答。
- 將上一步LLM生成的l-1個回答和金標準回答構成長度為l的風格偏好數據集 P s = { b 1 , b 2 , … , b l } \mathcal{P}_s = \{b_1, b_2,\ldots,b_l \} Ps?={b1?,b2?,…,bl?}。
- 為了與知識偏好數據集的長度一致,論文中取l為4,選了3個模型:ChatGPT、ChatGLM-6B、Vicuna-7B。
- 設金標準回答為 b 1 b_1 b1?,ChatGPT生成的回答為 b 2 b_2 b2?、ChatGLM-6B生成的回答為 b 3 b_3 b3?、Vicuna-7B生成的回答為 b 4 b_4 b4?,作者使用規則來確定這四個回答的偏好分數,認為三個模型的能力ChatGPT>ChatGLM>Vicuna,所以這四個回答的偏好分數順序為 r 1 > r 2 > r 3 > r 4 r_1 > r_2 > r_3 > r_4 r1?>r2?>r3?>r4?。
知識偏好數據集 P k \mathcal{P}_k Pk?構建過程:
- 對于問題a從知識庫KB中檢索出3個知識組合 K 1 \mathcal{K_1} K1?、 K 2 \mathcal{K_2} K2?、 K 3 \mathcal{K_3} K3?, K 1 \mathcal{K_1} K1?是top-k最相似的知識, KaTeX parse error: Undefined control sequence: \O at position 16: \mathcal{K_2}= \?O?是空集表示不包括任何檢索知識, K 3 \mathcal{K_3} K3??表示top-k+1至top 2k相似的知識。
- 將不同的知識組合與prompt模板 I \mathcal{I} I一起輸入到LLM M \mathcal{M} M生成答案,生成的三個答案與金標準一起組成知識偏好數據 P k = { c 1 , c 2 , c 3 , c 4 } \mathcal{P}_k = \{c_1, c_2, c_3,c_4 \} Pk?={c1?,c2?,c3?,c4?}。
- 設金標準回答為 c 1 c_1 c1?,使用 K 1 \mathcal{K_1} K1?生成的回答為 c 2 c_2 c2?、使用 K 2 \mathcal{K_2} K2?生成的回答為 c 3 c_3 c3?、使用 K 3 \mathcal{K_3} K3?生成的回答為 c 4 c_4 c4?,作者發現與問題不那么相似的知識很容易誤導LLM,所以這四個回答的偏好分數順序為 r 1 > r 2 > r 3 > r 4 r_1 > r_2 > r_3 > r_4 r1?>r2?>r3?>r4?。
微調和偏好對齊
前面構建的偏好數據集里偏好分數 r i r_i ri?代表了偏好度,希望模型 M \mathcal{M} M能夠對齊偏好。模型在給定prompt模板和問題 q i q_i qi?后對每個回答token的平均對數似然如下式 S i S_i Si?表示,分數越高表示模型認為回答有更高的概率:
S i = ? 1 ∣ a i ∣ ∑ j = 1 ∣ a i ∣ log ? P M ( a i , j ∣ I , q i , a i , < j ) \mathcal{S}_{i}=-\frac{1}{\left|a_i\right|} \sum_{j=1}^{\left|a_i\right|} \log P_{\mathcal{M}}\left(a_{i, j} \mid \mathcal{I}, q_i, a_{i,<j}\right) Si?=?∣ai?∣1?j=1∑∣ai?∣?logPM?(ai,j?∣I,qi?,ai,<j?)
KnowPAT先設計了如下的對齊目標,目的是為了對比偏好答案和非偏好答案,偏好分數只用來決定不同答案的順序。式中的 σ \sigma σ是sigmoid函數。
L a l i g n = ? ∑ i = 1 ∣ P ∣ ? 1 ( log ? σ ( S i ) + log ? ∑ r j < r i σ ( ? S j ) ) \mathcal{L}_{align}=- \sum_{i=1}^{|\mathcal{P}|-1} \left( \log \sigma (\mathcal{S}_i) + \log \sum_{r_j < r_i}\sigma (-\mathcal{S}_j) \right ) Lalign?=?i=1∑∣P∣?1? ?logσ(Si?)+logrj?<ri?∑?σ(?Sj?) ?
考慮到不同的回答的文本質量和偏好等級不一樣,作者設計了如下式的自適應權重來控制每個偏好回答的影響,式中的 S m a x S_{max} Smax?和 S m i n S_{min} Smin?是偏好數據集里的最大和最小偏好分數。
μ i = S i ? S m i n S m a x ? S m i n \mu_i = \frac {S_i - S_{min}}{S_{max} - S_{min}} μi?=Smax??Smin?Si??Smin??
使用自適應權重后,不同偏好分數的回答的影響可以動態調整,對齊目標相應地變為下式:
L a l i g n = ∑ i = 1 ∣ P ∣ ? 1 μ i ( log ? ( 1 + e ? S i ) + log ? ∑ r j < r i log ? ( 1 + e S j ) ) \mathcal{L}_{align}= \sum_{i=1}^{|\mathcal{P}|-1} \mu_i \left( \log (1 + e^{-\mathcal{S}_i} )+ \log \sum_{r_j < r_i}\log ( 1 + e^{ \mathcal{S}_j}) \right ) Lalign?=i=1∑∣P∣?1?μi? ?log(1+e?Si?)+logrj?<ri?∑?log(1+eSj?) ?
KnowPAT的訓練目標為對齊損失和微調目標之和,超參數 λ \lambda λ作為對齊損失的系數, P ? 1 \mathcal{P}-1 P?1用來歸一化對齊損失。
L = L f t + λ ∣ P ∣ ? 1 L a l i g n \mathcal{L} = \mathcal{L}_{ft} + \frac{\lambda} {|\mathcal{P}| -1} \mathcal{L}_{align} L=Lft?+∣P∣?1λ?Lalign?
注:1. 有一點疑問是前面構建了兩個偏好數據集,微調里沒有詳細說明是一起訓練還是分別訓練,只寫了一句看起來像是分別訓練的話:For each preference set constructed in the previous section, the model is trained and optimized with such an objective. 2. 風格偏好數據集與RRHF的數據構建思路是一樣的,論文代碼也是基于RRHF的,不過對齊目標函數有所區別
參考資料
- KnowPAT: arxiv, github