1st author: Eric Zelikman
paper: STaR: Bootstrapping Reasoning With Reasoning | OpenReview NeurIPS 2022
code: ezelikman/STaR: Code for STaR: Bootstrapping Reasoning With Reasoning (NeurIPS 2022)
1. 當語言模型學會自我進化
Zelikman 等人提出的 STaR (Self-Taught Reasoner),旨在解決當前大型語言模型在復雜推理任務中,高質量“思維鏈”(Chain-of-Thought, CoT)數據獲取的困境。目前主流方法要么依賴昂貴的人工標注,要么犧牲準確性采用少樣本(few-shot)推理。STaR 獨辟蹊徑,提出一種迭代式自學習框架,讓模型能夠利用少量帶推理過程的樣本和大量無推理過程的問答對,自我生成并篩選推理過程,從而引導自身逐步提升推理能力。
1.1 問題根源與解法思路
當前,我們深知 CoT 能顯著提升 LLM 在數學、常識問答等任務上的表現。然而,構建大規模、高質量的 CoT 數據集是瓶頸:
- 人工標注 (Manual Annotation): 成本高昂,難以覆蓋所有領域。
- 模板化生成 (Template-based): 適用范圍窄,依賴預先設計的啟發式規則或已知解法。
- 少樣本提示 (Few-shot Prompting with CoT): 雖然靈活,但通常性能遠不如在完整 CoT 數據集上微調的模型。
STaR 的核心思想可以概括為一個 “生成 - 篩選 - 學習” 的迭代循環:
-
生成 (Generate): 利用現有的少量 CoT 樣本作為提示 (prompt),引導模型為大量無 CoT 的問題生成推理過程和答案。
- 形式化地,給定問題 x i x_i xi?,模型 M M M 生成推理 r i r_i ri? 和答案 y ^ i \hat{y}_i y^?i?: ( r i , y ^ i ) ~ M ( x i ∣ P ) (r_i, \hat{y}_i) \sim M(x_i | P) (ri?,y^?i?)~M(xi?∣P), 其中 P P P 是少量 CoT 樣本集合。
-
篩選 (Filter): 只保留那些最終能導出正確答案 y i y_i yi? 的推理過程。
- 構建訓練集 D c o r r e c t = { ( x i , r i , y i ) ∣ y ^ i = y i } D_{correct} = \{ (x_i, r_i, y_i) | \hat{y}_i = y_i \} Dcorrect?={(xi?,ri?,yi?)∣y^?i?=yi?}。
-
學習 (Learn): 在篩選后的高質量 ( x i , r i , y i ) (x_i, r_i, y_i) (xi?,ri?,yi?) 數據上微調 (fine-tune) 基礎 LLM。
-
重復 (Repeat): 使用微調后的新模型,重復上述過程,期望能解決更復雜的問題,生成更高質量的推理。
1.2. 關鍵:“反思”與“合理化” (Rationalization)
上述循環存在一個問題:如果模型在某些問題上持續失敗,它將無法從這些失敗中獲得新的學習信號。為了解決這個問題,STaR 引入了一個巧妙的機制—— “合理化” (Rationalization):
- 對于模型未能正確解答的問題 ( x j , y j ) (x_j, y_j) (xj?,yj?),STaR 會將正確答案 y j y_j yj? 作為提示的一部分,再次引導模型生成一個能夠“解釋”或“推導出”這個正確答案的推理過程 r j r a t r_j^{rat} rjrat?。
- ( r j r a t , y ^ j r a t ) ~ M ( x j , hint = y j ∣ P ′ ) (r_j^{rat}, \hat{y}_j^{rat}) \sim M(x_j, \text{hint}=y_j | P') (rjrat?,y^?jrat?)~M(xj?,hint=yj?∣P′),其中 P ′ P' P′ 可能是包含答案提示格式的樣本。
- 如果這個“事后諸葛亮”式的推理 y ^ j r a t \hat{y}_j^{rat} y^?jrat? 確實能導出 y j y_j yj?,則將 ( x j , r j r a t , y j ) (x_j, r_j^{rat}, y_j) (xj?,rjrat?,yj?) 加入訓練集。
- 注意: 在微調時,這個“提示”信息(即正確答案 y j y_j yj?)并不會包含在輸入中,模型被訓練得仿佛是它自己獨立思考出 r j r a t r_j^{rat} rjrat? 的。
這個“合理化”步驟,本質上是讓模型學會 “向答案學習推理” ,從而攻克原本難以解決的難題,并擴大了有效訓練數據的規模。
1.3. 數學視角的 STaR
論文指出,STaR 的過程可以視為一種對強化學習(RL)策略梯度目標的近似。
將模型 M M M 視為一個生成 ( r , y ^ ) (r, \hat{y}) (r,y^?) 的策略。給定問題 x x x 和真實答案 y y y,我們可以定義一個獎勵函數,例如指示函數 I ( y ^ = y ) \mathbb{I}(\hat{y} = y) I(y^?=y),當模型生成的答案 y ^ \hat{y} y^? 與真實答案 y y y 相同時,獎勵為 1,否則為 0。
目標是最大化期望獎勵:
J ( M , X , Y ) = ∑ i E r i , y ^ i ~ p M ( x i ) [ I ( y ^ i = y i ) ] J(M, X, Y) = \sum_i \mathbb{E}_{r_i, \hat{y}_i \sim p_M(x_i)} [\mathbb{I}(\hat{y}_i = y_i)] J(M,X,Y)=i∑?Eri?,y^?i?~pM?(xi?)?[I(y^?i?=yi?)]
其梯度可以寫作:
? J ( M , X , Y ) = ∑ i E r i , y ^ i ~ p M ( x i ) [ I ( y ^ i = y i ) ? ? log ? p M ( y ^ i , r i ∣ x i ) ] \nabla J(M, X, Y) = \sum_i \mathbb{E}_{r_i, \hat{y}_i \sim p_M(x_i)} [\mathbb{I}(\hat{y}_i = y_i) \cdot \nabla \log p_M(\hat{y}_i, r_i | x_i)] ?J(M,X,Y)=i∑?Eri?,y^?i?~pM?(xi?)?[I(y^?i?=yi?)??logpM?(y^?i?,ri?∣xi?)]
STaR 的做法可以理解為:
- Greedy Decoding: 通過貪心解碼(或低 temperature 采樣)來近似采樣 ( r i , y ^ i ) (r_i, \hat{y}_i) (ri?,y^?i?),以減少方差(但可能引入偏差)。
- Filtering as Reward: I ( y ^ i = y i ) \mathbb{I}(\hat{y}_i = y_i) I(y^?i?=yi?) 項使得只有導出正確答案的 ( r i , y ^ i ) (r_i, \hat{y}_i) (ri?,y^?i?) 對梯度有貢獻,這正是 STaR 中“篩選”步驟的體現。
- Supervised Fine-tuning: 對篩選出的樣本進行微調,可以看作是在這個近似的策略梯度上進行多步優化。
“合理化”步驟則可以看作是從一個不同的、加入了提示(hint)的“教師”分布 p M ( r ∣ x , y ) p_M(r | x, y) pM?(r∣x,y) 中采樣高質量的軌跡,用于豐富訓練數據,幫助模型探索更優的策略空間。
2. 算法流程
STaR 算法偽代碼:
論文中給出了清晰的算法流程圖(Figure 1)和偽代碼(Algorithm 1)。我們可以將其邏輯概括如下:
Algorithm 1: STaR
輸入:
- M pretrained M_{\text{pretrained}} Mpretrained?: 預訓練的大語言模型
- D = { ( x i , y i ) } i = 1 N D = \{(x_i, y_i)\}_{i=1}^N D={(xi?,yi?)}i=1N?: 問題-答案數據集
- P few_shot P_{\text{few\_shot}} Pfew_shot?: 少量帶推理過程的示例
初始化:
M 0 ← M pretrained M_0 \leftarrow M_{\text{pretrained}} M0?←Mpretrained? // 復制原始模型
循環迭代 n = 1 n = 1 n=1 到 Max_Iterations \text{Max\_Iterations} Max_Iterations:
-
生成步驟 (Rationale Generation):
- 初始化生成的推理集合: G e n e r a t e d _ R a t i o n a l e s ← { } Generated\_Rationales \leftarrow \{\} Generated_Rationales←{}
- 對每個樣本 ( x i , y i ) ∈ D (x_i, y_i) \in D (xi?,yi?)∈D:
- ( r gen i , y gen ^ i ) ← M n ? 1 . generate ( prompt ( P few_shot , x i ) ) (r_{\text{gen}_i}, y_{\hat{\text{gen}}_i}) \leftarrow M_{n-1}.\text{generate}(\text{prompt}(P_{\text{few\_shot}}, x_i)) (rgeni??,ygen^?i??)←Mn?1?.generate(prompt(Pfew_shot?,xi?))
- 添加 ( x i , y i , r gen i , y gen ^ i ) (x_i, y_i, r_{\text{gen}_i}, y_{\hat{\text{gen}}_i}) (xi?,yi?,rgeni??,ygen^?i??) 到 G e n e r a t e d _ R a t i o n a l e s Generated\_Rationales Generated_Rationales
-
生成推理過濾步驟:
- 初始化正確推理集合: D n correct ← { } D_n^{\text{correct}} \leftarrow \{\} Dncorrect?←{}
- 對每個 ( x i , y i , r gen i , y gen ^ i ) ∈ G e n e r a t e d _ R a t i o n a l e s (x_i, y_i, r_{\text{gen}_i}, y_{\hat{\text{gen}}_i}) \in Generated\_Rationales (xi?,yi?,rgeni??,ygen^?i??)∈Generated_Rationales:
- 若 y gen ^ i = y i y_{\hat{\text{gen}}_i} = y_i ygen^?i??=yi?:
- 添加 ( x i , r gen i , y i ) (x_i, r_{\text{gen}_i}, y_i) (xi?,rgeni??,yi?) 到 D n correct D_n^{\text{correct}} Dncorrect?
- 若 y gen ^ i = y i y_{\hat{\text{gen}}_i} = y_i ygen^?i??=yi?:
-
錯誤答案合理化步驟:
- 初始化合理化推理集合: R a t i o n a l i z e d _ R a t i o n a l e s ← { } Rationalized\_Rationales \leftarrow \{\} Rationalized_Rationales←{}
- 對每個 ( x i , y i , r gen i , y gen ^ i ) ∈ G e n e r a t e d _ R a t i o n a l e s (x_i, y_i, r_{\text{gen}_i}, y_{\hat{\text{gen}}_i}) \in Generated\_Rationales (xi?,yi?,rgeni??,ygen^?i??)∈Generated_Rationales:
- 若 y gen ^ i ≠ y i y_{\hat{\text{gen}}_i} \neq y_i ygen^?i??=yi?:
- ( r rat i , y rat ^ i ) ← M n ? 1 . generate ( prompt ( P few_shot_with_hint , x i , hint = y i ) ) (r_{\text{rat}_i}, y_{\hat{\text{rat}}_i}) \leftarrow M_{n-1}.\text{generate}(\text{prompt}(P_{\text{few\_shot\_with\_hint}}, x_i, \text{hint}=y_i)) (rrati??,yrat^i??)←Mn?1?.generate(prompt(Pfew_shot_with_hint?,xi?,hint=yi?))
- 添加 ( x i , y i , r rat i , y rat ^ i ) (x_i, y_i, r_{\text{rat}_i}, y_{\hat{\text{rat}}_i}) (xi?,yi?,rrati??,yrat^i??) 到 R a t i o n a l i z e d _ R a t i o n a l e s Rationalized\_Rationales Rationalized_Rationales
- 若 y gen ^ i ≠ y i y_{\hat{\text{gen}}_i} \neq y_i ygen^?i??=yi?:
-
合理化推理過濾步驟:
- 初始化合理化訓練集: D n rationalized ← { } D_n^{\text{rationalized}} \leftarrow \{\} Dnrationalized?←{}
- 對每個 ( x i , y i , r rat i , y rat ^ i ) ∈ R a t i o n a l i z e d _ R a t i o n a l e s (x_i, y_i, r_{\text{rat}_i}, y_{\hat{\text{rat}}_i}) \in Rationalized\_Rationales (xi?,yi?,rrati??,yrat^i??)∈Rationalized_Rationales:
- 若 y rat ^ i = y i y_{\hat{\text{rat}}_i} = y_i yrat^i??=yi?:
- 添加 ( x i , r rat i , y i ) (x_i, r_{\text{rat}_i}, y_i) (xi?,rrati??,yi?) 到 D n rationalized D_n^{\text{rationalized}} Dnrationalized?
- 若 y rat ^ i = y i y_{\hat{\text{rat}}_i} = y_i yrat^i??=yi?:
-
合并數據與微調:
- 合并訓練集: D n train ← D n correct ∪ D n rationalized D_n^{\text{train}} \leftarrow D_n^{\text{correct}} \cup D_n^{\text{rationalized}} Dntrain?←Dncorrect?∪Dnrationalized?
- 若 D n train = { } D_n^{\text{train}} = \{\} Dntrain?={} 或 性能達到平臺期:
- 終止迭代
- 微調模型: M n ← fine_tune ( M pretrained , D n train ) M_n \leftarrow \text{fine\_tune}(M_{\text{pretrained}}, D_n^{\text{train}}) Mn?←fine_tune(Mpretrained?,Dntrain?)
- 關鍵: 每次從預訓練模型微調,而非上一輪模型,避免過擬合
- 更新模型: M n ? 1 ← M n M_{n-1} \leftarrow M_n Mn?1?←Mn?
輸出: 基于驗證性能選擇最優模型 M n M_n Mn?
3. 實驗剖析
3.1. 參數與設置細節
- 基礎模型 (Base Model): GPT-J (6B 參數 )。選擇 GPT-J 是因為其開源且具備一定的推理能力基礎。
- 迭代次數 (Iterations): 實驗中通常運行到性能飽和為止。
- 訓練步數 (Training Steps per Iteration): 初始迭代訓練步數較少(如 40 步),后續迭代中逐步增加(如每輪增加 20%)。這種漸進式增加訓練強度的方法,有助于模型在早期穩定學習,后期充分利用數據。
- Few-shot Prompts:
- Rationale Generation: 使用少量 ( 如 10 個 ) 固定的、高質量的 CoT 示例。
- Rationalization: 使用類似的 CoT 示例,但格式上會明確包含“正確答案提示”。例如,在 CommonsenseQA (CQA) 的例子中(Figure 2),提示是
(b) grocery cart (CORRECT)
。
- 數據集 (Datasets):
- Arithmetic ( 算術 ): n 位數加法。評估模型對符號操作和步驟記憶的能力。
- CommonsenseQA (CQA): 常識問答選擇題。評估自然語言理解和常識推理。
- GSM8K (Grade School Math): 小學數學應用題。評估結合算術和文本理解的復雜推理。
- 關鍵trick:從預訓練模型重新微調: 每次迭代收集到新的訓練數據后,STaR 從原始的預訓練模型 M p r e t r a i n e d M_pretrained Mp?retrained 開始微調,而不是在上一次迭代的模型 M n ? 1 M_{n-1} Mn?1? 基礎上繼續微調。這可以有效防止災難性遺忘和過擬合到特定迭代生成的數據噪聲。
3.2. 實驗結果亮點:
-
顯著優于基線:
- Arithmetic: STaR 能夠從幾乎為零的 few-shot 準確率(2 位數加法 < 1%)通過迭代學習達到很高的準確率(16 次迭代后整體 89.5%)。對比直接在無推理過程的 10,000 個樣本上微調(76.3%),優勢明顯。
- CommonsenseQA (CQA):
- STaR (72.5% 準確率 ) 顯著超過 few-shot CoT GPT-J (36.6%) 和直接微調 GPT-J (60.0%)。
- STaR (GPT-J 6B) 的性能逼近了在完整數據集上微調的 GPT-3 (175B,論文中提到的是 30x 更大的模型,應指 PaLM 或類似規模,CQA 上的 GPT-3 Finetuned 結果為 73.0%)。這表明 STaR 能夠有效地從小模型中“榨取”出強大的推理能力。
- Rationalization 的作用: 在 CQA 上,加入 Rationalization 后,準確率從 68.8% (STaR without rationalization) 提升到 72.5%,證明了其在解決難題和提升上限方面的價值。
- GSM8K: STaR (10.7%) 同樣遠超 few-shot CoT GPT-J (3.1%) 和直接微調 GPT-J (5.8%)。
-
Rationalization 的加速與提升效果:
- Arithmetic ( Figure 4): 有 Rationalization 的 STaR ( Figure 4b) 相比無 Rationalization ( Figure 4a),在早期迭代中對多位數加法的學習速度更快,性能提升更平滑。無 Rationalization 的版本則呈現階梯式提升,即模型通常在掌握 (n-1) 位數加法后才能較好地學習 n 位數加法。
- CommonsenseQA: Rationalization 帶來了約 3.7% 的絕對提升。
-
數據效率: STaR 通常只使用了訓練集的一部分(例如 CQA 上約 86.7% 的數據,其中 8.5% 來自 Rationalization),但取得了比使用完整數據集直接微調更好的性能。這說明 STaR 生成的 CoT 數據質量較高。
-
推理質量的提升 (Case Study & Human Evaluation on CQA):
- Case Study: 展示了 STaR 能夠生成比 few-shot CoT 更合理、更連貫的推理過程,即使原始 few-shot 也能答對問題。(論文提到 Figure 7 展示, 但是作者可能忘了放這張圖)
- Human Evaluation: 眾包評估者認為 STaR 生成的推理過程比 few-shot CoT 生成的推理過程更具說服力(30% more likely to rank STaR higher, p=0.039)。甚至,STaR 生成的推理比一些人工標注的推理更受青睞(74% more likely, p < 0.001),這可能反映了眾包標注本身的質量波動,但也側面印證了 STaR 生成推理的潛力。
3.4. 對實驗結果的初步思考
-
STaR 的成功,很大程度上依賴于基礎 LLM 已具備一定的“潛在”推理能力。Few-shot CoT 能夠激活這種能力,而 STaR 通過迭代微調,將這種“潛在”能力強化并泛化。
-
Rationalization 機制非常精妙,它相當于給模型提供了一個 “目標導向的逆向工程” 機會,讓模型思考“為了得到這個答案,我應該如何推理?”
-
“從頭微調”策略是控制訓練穩定性和避免過擬合的關鍵。
4. 挑戰、局限性與未來展望
4.1. STaR 面臨的挑戰與局限性
-
對初始 Few-shot 樣本的依賴與敏感性:
- STaR 的啟動依賴于少量高質量的 CoT 樣本。這些樣本的質量和風格可能會顯著影響后續生成的推理過程的質量和多樣性。
- 如果初始樣本包含偏見或不完善的推理模式,STaR 可能會放大這些問題。
-
“正確答案但錯誤推理”的問題 (Filtering Imperfection):
- STaR 的核心篩選機制是基于最終答案的正確性。這意味著,如果模型通過一個錯誤的、不相關的或者僅僅是“碰巧”正確的推理過程得到了正確答案,這樣的樣本依然會被用于微調。
- 在多項選擇題(如 CQA)中,隨機猜對的概率不低 (20%),這使得該問題尤為突出。雖然論文提到一些簡單啟發式方法(如語義相似度)可以將隨機猜測提升到約 30%,但 STaR 的目標是學習真正的推理。
- 這種“噪聲”數據可能會污染訓練集,限制模型學習到真正魯棒和泛化的推理能力。
-
Rationalization 的提示工程:
- 如何設計“合理化”步驟中的提示(即如何將正確答案作為 hint 融入問題)可能并非易事,尤其對于更復雜的任務結構。
- 論文中對算術和 CQA 的提示方式相對直接,但其普適性有待驗證。
-
計算成本:
- 迭代生成、篩選和微調的過程計算成本較高。盡管比標注大規模 CoT 數據集便宜,但對于資源受限的研究者而言仍是一個考量。
- 每次都從預訓練模型重新微調,雖然能避免過擬合,但也增加了訓練時間。
-
溫度參數 (Temperature) 的影響:
- 論文提到,嘗試使用更高的溫度進行采樣以增加數據多樣性,結果適得其反,導致模型性能下降,尤其是在結構化任務(如算術)中,生成的“思維鏈”會變得無意義。
- 這表明 STaR 依賴于模型在低溫度下生成相對“自信”且連貫的推理。如何平衡探索(高溫度)和利用(低溫度)仍然是一個開放問題。
-
可解釋性與忠實度 (Faithfulness):
- 雖然 STaR 生成的推理看起來更合理,但我們無法保證這些推理過程真正反映了模型內部的“思考”過程。模型可能只是學會了生成“看起來像那么回事”的文本。
- 這是所有基于生成 CoT 方法的共同挑戰。
-
偏見放大 (Bias Amplification):
- 如果原始數據集或 few-shot 樣本中存在偏見(例如,CQA 中的性別偏見),STaR 的迭代學習過程可能會放大這些偏見,因為它傾向于強化那些能“成功”解決訓練集問題的模式,即使這些模式是基于偏見的。
- 論文提到了一些初步的積極跡象(如模型在性別無關問題中似乎忽略了性別信息),但這需要更深入的研究。
-
對小模型和簡單任務的適用性:
- STaR 的成功依賴于基礎模型具備一定的 few-shot 推理能力。對于非常小的模型或無法通過 few-shot 激活推理能力的簡單任務,STaR 可能難以啟動。論文提到 GPT-2 在算術任務上無法通過 STaR 自舉。
- 對于正確率本身就很高(例如二元決策)的任務,錯誤答案的樣本過少,Rationalization 的作用會減弱。
4.2. STaR 的深遠意義與未來展望
-
邁向模型自我改進的重要一步: STaR 最核心的貢獻在于展示了一種讓 LLM 通過自身的生成和推理來學習和改進自身推理能力的有效途徑。這為實現更自主、更少依賴人工監督的 AI 系統提供了新的思路。
-
數據高效的推理能力獲取: STaR 證明了可以利用大量無標注數據和少量有標注數據,以一種自舉的方式生成高質量的推理訓練數據,這對于解決許多領域標注數據稀缺的問題具有重要價值。
-
對“思維鏈”研究的推動: STaR 強調了推理過程本身作為學習信號的重要性。未來的研究可以探索更精細化的推理過程評估方法(超越最終答案的正確性),例如引入 token 級驗證器(如論文 [9] Cobbe et al. 中用于數學問題的驗證器)。
-
結合強化學習的潛力: 論文中已將 STaR 與 RL 的策略梯度聯系起來。未來可以探索更直接的 RL 方法,例如使用模型自身生成的推理作為軌跡,并設計更復雜的獎勵函數來指導學習,或者結合基于模型的 RL 來規劃推理步驟。
-
探索更優的 Rationalization 機制:
- 如何更有效地從“失敗”中學習?除了提供正確答案,是否可以提供更細致的反饋或引導?
- 研究不同類型的“提示”對 Rationalization 效果的影響。
-
處理“正確答案但錯誤推理”:
- 開發自動檢測或過濾不合理推理的方法。
- 引入“負面樣本”學習,即明確告訴模型哪些推理是錯誤的。
-
跨任務和跨領域泛化: STaR 在特定任務上學習到的推理能力,能否更好地泛化到新的、未見過的任務或領域?
-
與人類反饋的結合: STaR 的迭代過程可以與人類反饋回路(Human-in-the-loop)相結合,讓人類專家在關鍵的篩選或 Rationalization 步驟提供指導,進一步提升學習效率和推理質量。
4.3. 總結
STaR 是一項具有開創性的工作,它巧妙地利用了大型語言模型自身的生成能力,通過迭代式的“生成 - 篩選(含合理化)- 學習”循環,實現了在復雜推理任務上顯著的性能提升,并且在某些任務上逼近了遠大于自身規模的模型。它不僅為解決 CoT 數據獲取難題提供了有效方案,更為重要的是,它展示了語言模型 “自我教育”和“自我進化” 的巨大潛力。盡管存在一些挑戰和局限,STaR 無疑為 LLM 的能力邊界探索和未來發展開辟了新的道路。