Think Only When You Need with Large Hybrid-Reasoning Models
- 2 Large Hybrid-Reasoning Models
- 2.1 Problem Formulation
- 關鍵定義與目標
- 核心挑戰與解決方案
- 2.2 第一階段:混合微調(Hybrid Fine-Tuning, HFT)
- 核心設計
- 數據構建
- 數據集統計
- 優化目標(Optimize Objective)
- 關鍵技術點
- 階段輸出
- 2.3 第二階段:混合組策略優化(Hybrid Group Policy Optimization, HGPO)
- 無Critic模型架構
- 計算優化特性
- 算法框架
- 采樣策略(Sampling Strategy)
- 獎勵計算與分配(Reward Scoring and Assignment)
- 優勢估計(Advantage Estimation)
- 優化目標(Optimization Objective)
- 算法特性
- 2.4 混合推理能力評估
- 評估流程
Think Only When You Need with Large Hybrid-Reasoning Models一文指出,當前的大型推理模型(LRMs)通過生成冗長的思維過程(如標記為 <think> 的中間步驟)顯著提升了推理能力,但這種方式在處理簡單查詢時會帶來不必要的計算開銷和延遲。為解決這一問題,作者提出了大型混合推理模型(LHRMs),這是第一種能夠根據用戶查詢的上下文信息自適應決定是否進行深入思考的模型。
為實現這一目標,作者設計了一個兩階段的訓練流程:
-
混合微調(HFT):作為冷啟動階段,通過結合推理密集型(Thinking)和直接回答(No-Thinking)數據,使模型初步支持兩種推理模式。
-
混合組策略優化(HGPO):一種在線強化學習方法,通過隱式學習選擇適當的思考模式,同時生成更有用且無害的響應。
此外,作者提出了“混合準確率”(Hybrid Accuracy)這一新指標,用于定量評估模型的混合推理能力。實驗結果表明,LHRMs能夠根據查詢的難度和類型自適應地選擇思考模式,在推理和通用任務上均優于現有的LRMs和LLMs,同時顯著提升了效率。
本篇博客聚焦文章的方法部分。
2 Large Hybrid-Reasoning Models
2.1 Problem Formulation
本節正式定義了大型混合推理模型(LHRMs)的核心問題,即如何根據輸入查詢動態選擇最優推理模式(Thinking或No-Thinking)以最大化任務特定效用。
關鍵定義與目標
-
輸入與模式:
- 輸入查詢記為 qqq
- 提供兩種推理模式:
- 思考模式(?\vdash?):生成顯式推理步驟(如中間計算或邏輯鏈)
- 非思考模式(?\nprec?):直接生成最終答案無需中間步驟
-
條件分布:
- 每種模式對應一個答案空間 A\mathcal{A}A 上的條件概率分布:
P(a∣q,m),m∈M={?,?}(1)\mathcal{P}(a \mid q, m), \quad m \in \mathcal{M} = \{\vdash, \nprec\} \quad (1) P(a∣q,m),m∈M={?,?}(1)
- 每種模式對應一個答案空間 A\mathcal{A}A 上的條件概率分布:
-
最優模式選擇:
- 對每個查詢 qqq,選擇能最大化期望效用 U(q,a)\mathcal{U}(q,a)U(q,a) 的模式 m?(q)m^*(q)m?(q):
m?(q)=arg?max?m∈MEa~P(a∣q,m)[U(q,a)](2)m^*(q) = \arg\max_{m\in\mathcal{M}} \mathbb{E}_{a\sim\mathcal{P}(a|q,m)}\Big[\mathcal{U}(q,a)\Big] \quad (2) m?(q)=argm∈Mmax?Ea~P(a∣q,m)?[U(q,a)](2)
- 對每個查詢 qqq,選擇能最大化期望效用 U(q,a)\mathcal{U}(q,a)U(q,a) 的模式 m?(q)m^*(q)m?(q):
-
全局優化目標:
- 學習策略 π:Q→M\pi: \mathcal{Q}\rightarrow\mathcal{M}π:Q→M 以最大化跨任務分布的期望效用:
max?π1N∑i=1NEDi~Θ,Di?Ui[Ea~P(a∣q,π(q)),q~Di[Ui(q,a)]](3)\max_{\pi} \frac{1}{N}\sum_{i=1}^N \mathbb{E}_{\mathcal{D}_i\sim\Theta, \mathcal{D}_i\Leftrightarrow\mathcal{U}_i}\Bigg[\mathbb{E}_{a\sim\mathcal{P}(a|q,\pi(q)), q\sim\mathcal{D}_i}\Big[\mathcal{U}_i(q,a)\Big]\Bigg] \quad (3) πmax?N1?i=1∑N?EDi?~Θ,Di??Ui??[Ea~P(a∣q,π(q)),q~Di??[Ui?(q,a)]](3)
其中 Θ={(Di,Ui)}i=1N\Theta = \{(\mathcal{D}_i,\mathcal{U}_i)\}_{i=1}^NΘ={(Di?,Ui?)}i=1N? 表示不同任務的數據分布和效用函數對。
- 學習策略 π:Q→M\pi: \mathcal{Q}\rightarrow\mathcal{M}π:Q→M 以最大化跨任務分布的期望效用:
核心挑戰與解決方案
-
策略學習(C1):
- 通過兩階段訓練實現:
- 階段I:混合微調(HFT)冷啟動
- 階段II:混合組策略優化(HGPO)強化學習
- 通過兩階段訓練實現:
-
評估指標(C2):
- 提出混合準確率 Hacc\mathcal{H}_{\text{acc}}Hacc? 量化模式選擇能力
2.2 第一階段:混合微調(Hybrid Fine-Tuning, HFT)
本節詳細介紹了LHRMs訓練流程的第一階段——混合微調(HFT),這是模型冷啟動的關鍵步驟。
核心設計
數據構建
HFT使用混合格式的監督微調數據集,包含兩類數據:
-
思考模式數據:
- 來源:數學(MATH)、編程(Code)和科學領域的高質量數據集
- 處理方式:
- 使用DeepSeek-R1生成答案
- 人工驗證正確性
- 添加
<think>
和</think>
標簽標記推理步驟 - 示例:
<think> 首先分析約束條件...然后推導可能的解... </think> 最終答案是$\boxed{17}$
-
非思考模式數據:
- 來源:WildChat-1M中的簡單查詢
- 處理方式:
- 使用FastText分類器過濾復雜推理任務
- 添加
<no_think>
和</no_think>
標簽 - 示例:
<no_think> 當然,請問您需要什么幫助? </no_think>
數據集統計
類別 | 數據量 | 平均token長度 | 主要來源 |
---|---|---|---|
思考模式 | 631,325 | 575 | SYNTHETIC-1, OpenMath |
非思考模式 | 674,908 | 4,897 | WildChat-1M, OASST2 |
總計 | 1,694,586 | - | - |
優化目標(Optimize Objective)
HFT階段通過標準的語言建模目標訓練模型,使其能夠基于上文預測下一個token。對于構建的數據集DHFT={(xi,yi)}i=1N\mathcal{D}_{\text{HFT}} = \{(x^i, y^i)\}_{i=1}^NDHFT?={(xi,yi)}i=1N?,其優化目標定義為:
LHFT(θ)=?E(x,y)~DHFT[∑t=1∣y∣log?πθ(yt∣x,y1:t?1)](4)\mathcal{L}_{\text{HFT}}(\theta) = -\mathbb{E}_{(x,y)\sim\mathcal{D}_{\text{HFT}}} \left[ \sum_{t=1}^{|y|} \log \pi_\theta(y_t \mid x, y_{1:t-1}) \right] \quad (4) LHFT?(θ)=?E(x,y)~DHFT???t=1∑∣y∣?logπθ?(yt?∣x,y1:t?1?)?(4)
其中:
- θ\thetaθ:模型參數
- (x,y)(x,y)(x,y):輸入-輸出對
- πθ\pi_\thetaπθ?:模型參數化的概率分布
關鍵技術點
-
防模式崩潰設計:
- 對同一查詢同時提供兩種格式的答案
- 示例:
# 思考模式 "計算2+2": "<think>2加2等于4</think>"# 非思考模式 "計算2+2": "<no_think>4</no_think>"
-
數據平衡策略:
- 思考模式與非思考模式樣本比例 ≈ 1:1
- 每個batch內兩種模式均勻混合
-
訓練配置:
- 優化器:AdamW(lr=1e-4)
- 批次大小:128
- 序列長度:32k tokens
- 訓練時長:7B模型約2.5天(4×NVIDIA H100節點)
階段輸出
HFT階段產出的模型πθHFT\pi_{\theta_{\text{HFT}}}πθHFT??具備:
- 同時支持兩種推理模式的能力
- 穩定的模式切換基礎
- 為第二階段RL訓練提供優質初始化
2.3 第二階段:混合組策略優化(Hybrid Group Policy Optimization, HGPO)
本節詳細介紹訓練流程的第二階段——混合組策略優化(HGPO),這是一種創新的強化學習算法,用于優化模型的自適應推理能力。
HGPO的完整流程如圖2和算法1所示,通過以下創新設計降低計算成本:
無Critic模型架構
-
核心設計:
- 摒棄傳統強化學習中的critic(價值函數)模型
- 采用多樣本估計替代價值函數計算
-
采樣機制:
- 對提示集P\mathcal{P}P中的每個問題qqq
- 從舊策略πθHFT\pi_{\theta_{\text{HFT}}}πθHFT??中采樣兩組輸出:
- 思考模式組:N/2N/2N/2個含推理過程的響應
- 非思考模式組:N/2N/2N/2個直接答案
計算優化特性
設計選擇 | 傳統RL | HGPO | 優勢 |
---|---|---|---|
價值估計 | Critic模型預測 | 多樣本直接統計 | 減少40%訓練內存 |
梯度計算 | 依賴價值函數導數 | 零階策略梯度 | 避免梯度沖突問題 |
模式切換成本 | 需要重訓練critic | 動態樣本重加權 | 支持在線模式切換 |
算法框架
采樣策略(Sampling Strategy)
對于每個查詢q∈Pq \in \mathcal{P}q∈P,從初始策略πθHFT\pi_{\theta_{\text{HFT}}}πθHFT??中按兩種模式分別采樣N/2N/2N/2個候選響應:
{oi?}i=1N/2~πθHFT(?∣q,m=?),{oi?}i=1N/2~πθHFT(?∣q,m=?)(5)\{o_i^\vdash\}_{i=1}^{N/2} \sim \pi_{\theta_{\text{HFT}}}(\cdot \mid q, m=\vdash), \quad \{o_i^\nprec\}_{i=1}^{N/2} \sim \pi_{\theta_{\text{HFT}}}(\cdot \mid q, m=\nprec) \quad (5) {oi??}i=1N/2?~πθHFT??(?∣q,m=?),{oi??}i=1N/2?~πθHFT??(?∣q,m=?)(5)
完整候選集定義為:
O(q)={oi?}i=1N/2∪{oi?}i=1N/2(6)\mathcal{O}(q) = \{o_i^\vdash\}_{i=1}^{N/2} \cup \{o_i^\nprec\}_{i=1}^{N/2} \quad (6) O(q)={oi??}i=1N/2?∪{oi??}i=1N/2?(6)
實現細節:
- 默認N=4N=4N=4(每種模式2個樣本)
- 溫度系數τ=0.7\tau=0.7τ=0.7控制多樣性
- 禁止重復采樣機制
獎勵計算與分配(Reward Scoring and Assignment)
使用獎勵函數R?R_\phiR??對候選輸出評分,生成兩組獎勵值:
R?={r(oi?)}i=1N/2,R?={r(oi?)}i=1N/2(7)\mathcal{R}^\vdash = \{r(o_i^\vdash)\}_{i=1}^{N/2}, \quad \mathcal{R}^\nprec = \{r(o_i^\nprec)\}_{i=1}^{N/2} \quad (7) R?={r(oi??)}i=1N/2?,R?={r(oi??)}i=1N/2?(7)
計算各模式平均獎勵:
Rˉ?=2N∑i=1N/2r(oi?),Rˉ?=2N∑i=1N/2r(oi?)(8)\bar{\mathcal{R}}^\vdash = \frac{2}{N}\sum_{i=1}^{N/2} r(o_i^\vdash), \quad \bar{\mathcal{R}}^\nprec = \frac{2}{N}\sum_{i=1}^{N/2} r(o_i^\nprec) \quad (8) Rˉ?=N2?i=1∑N/2?r(oi??),Rˉ?=N2?i=1∑N/2?r(oi??)(8)
定義兩種獎勵類型:
- 組間獎勵(Inter-group):
rinter(oim)={1,if?m=arg?max?m′∈{?,?}{Rˉ?,Rˉ?+δ}0,otherwise(9a)r_{\text{inter}}(o_i^m) = \begin{cases} 1, & \text{if } m = \arg\max_{m'\in\{\vdash,\nprec\}} \{\bar{\mathcal{R}}^\vdash, \bar{\mathcal{R}}^\nprec + \delta\} \\ 0, & \text{otherwise} \end{cases} \quad (9a) rinter?(oim?)={1,0,?if?m=argmaxm′∈{?,?}?{Rˉ?,Rˉ?+δ}otherwise?(9a) - 組內獎勵(Intra-group):
rintra(oim)={1,if?i=arg?max?j∈{1,...,N/2}rjm0,otherwise(9b)r_{\text{intra}}(o_i^m) = \begin{cases} 1, & \text{if } i = \arg\max_{j\in\{1,...,N/2\}} r_j^m \\ 0, & \text{otherwise} \end{cases} \quad (9b) rintra?(oim?)={1,0,?if?i=argmaxj∈{1,...,N/2}?rjm?otherwise?(9b)
關鍵參數:
- δ\deltaδ:模式偏好邊際(默認0.2)
- 規則型獎勵用于數學/編程等確定性任務
- 參數化獎勵模型用于開放域任務
δ\deltaδ這個參數的出現提供了一種可以控制模型思考偏好的方法,在具體工程實現中,可以基于任務種類設置不同的δ\deltaδ達到控制長短的目的
優勢估計(Advantage Estimation)
采用GRPO優勢估計器:
Ait=[rintra(oi)?mean(rintra(oj))std(rintra(oj))]?Intra-group+1{oit∈Φ}?α[rinter(oi)?mean(rinter(oj))std(rinter(oj))]?Inter-group(10)A_i^t = \underbrace{\left[\frac{r_{\text{intra}}(o_i) - \text{mean}(r_{\text{intra}}(o_j))}{\text{std}(r_{\text{intra}}(o_j))}\right]}_{\text{Intra-group}} + \underbrace{\mathbb{1}\{o_i^t \in \Phi\} \cdot \alpha \left[\frac{r_{\text{inter}}(o_i) - \text{mean}(r_{\text{inter}}(o_j))}{\text{std}(r_{\text{inter}}(o_j))}\right]}_{\text{Inter-group}} \quad (10) Ait?=Intra-group[std(rintra?(oj?))rintra?(oi?)?mean(rintra?(oj?))?]??+Inter-group1{oit?∈Φ}?α[std(rinter?(oj?))rinter?(oi?)?mean(rinter?(oj?))?]??(10)
其中:
- Φ={<think>,<no_think>}\Phi = \{\text{<think>}, \text{<no\_think>}\}Φ={<think>,<no_think>}為模式標記集合
- α=1.0\alpha=1.0α=1.0為平衡系數
優化目標(Optimization Objective)
最大化以下目標函數:
JHGPO(θ)=Eq~P,{oim}~πθHFT[1N∑i=1N∑t=1∣o∣[min?(πθ(oim,t∣q,oim,<t)πθHFT(oim,t∣q,oim,<t)Ait,clip(πθ(oim,t∣q,oim,<t)πθHFT(oim,t∣q,oim,<t),1??,1+?)Ait)?βDKL(πθ∣∣πref)]](11)\mathcal{J}_{\text{HGPO}}(\theta) = \mathbb{E}_{q\sim\mathcal{P}, \{o_i^m\}\sim\pi_{\theta_{\text{HFT}}}}\Bigg[ \frac{1}{N}\sum_{i=1}^N \sum_{t=1}^{|o|} \bigg[ \min\Bigg( \frac{\pi_\theta(o_i^{m,t}|q,o_i^{m,<t})}{\pi_{\theta_{\text{HFT}}}(o_i^{m,t}|q,o_i^{m,<t})} A_i^t, \\ \text{clip}\Bigg(\frac{\pi_\theta(o_i^{m,t}|q,o_i^{m,<t})}{\pi_{\theta_{\text{HFT}}}(o_i^{m,t}|q,o_i^{m,<t})}, 1-\epsilon, 1+\epsilon\Bigg) A_i^t \bigg) - \beta \mathbb{D}_{\text{KL}}(\pi_\theta || \pi_{\text{ref}}) \bigg] \Bigg] \quad (11) JHGPO?(θ)=Eq~P,{oim?}~πθHFT???[N1?i=1∑N?t=1∑∣o∣?[min(πθHFT??(oim,t?∣q,oim,<t?)πθ?(oim,t?∣q,oim,<t?)?Ait?,clip(πθHFT??(oim,t?∣q,oim,<t?)πθ?(oim,t?∣q,oim,<t?)?,1??,1+?)Ait?)?βDKL?(πθ?∣∣πref?)]](11)
KL散度項展開為:
DKL(πθ∣∣πref)=πref(oim∣q)πθ(oim∣q)?log?πref(oim∣q)πθ(oim∣q)?1(12)\mathbb{D}_{\text{KL}}(\pi_\theta || \pi_{\text{ref}}) = \frac{\pi_{\text{ref}}(o_i^m|q)}{\pi_\theta(o_i^m|q)} - \log \frac{\pi_{\text{ref}}(o_i^m|q)}{\pi_\theta(o_i^m|q)} - 1 \quad (12) DKL?(πθ?∣∣πref?)=πθ?(oim?∣q)πref?(oim?∣q)??logπθ?(oim?∣q)πref?(oim?∣q)??1(12)
訓練配置:
- 學習率:1×10?61\times10^{-6}1×10?6(恒定)
- 批次大小:256(微批次8)
- KL系數β=0.001\beta=0.001β=0.001
- 裁剪閾值?=0.5\epsilon=0.5?=0.5
- 訓練時長:2天(4×H100)
算法特性
-
雙重獎勵機制:
- 組間獎勵引導模式選擇
- 組內獎勵優化內容質量
-
策略約束:
- KL懲罰項防止過度偏離初始策略
- 重要性采樣裁剪保證穩定性
-
零階優化:
無需價值函數模型,直接基于樣本獎勵優化
2.4 混合推理能力評估
為更全面地評估LHRMs的性能(超越傳統下游任務指標),文章提出新指標混合準確率(Hybrid Accuracy, Hacc\mathcal{H}_{acc}Hacc?),用于量化模型選擇正確推理模式的能力。
評估流程
給定任務提示集P={pi}i=1K\mathcal{P} = \{p_i\}_{i=1}^KP={pi?}i=1K?:
- 對每個pip_ipi?,模型在?\vdash?和?\nprec?模式下各生成NNN個響應
- 使用獎勵模型R?R_\phiR??對響應評分,計算各模式平均得分Rˉ?\bar{\mathcal{R}}^\vdashRˉ?和Rˉ?\bar{\mathcal{R}}^\nprecRˉ?
- 確定基準模式mgtm_{gt}mgt?:
- 若∣Rˉ??Rˉ?∣>?|\bar{\mathcal{R}}^\vdash - \bar{\mathcal{R}}^\nprec| > \epsilon∣Rˉ??Rˉ?∣>?,選擇高分模式
- 否則選擇響應更短的模式
- 模型自主選擇模式mpm_pmp?,計算匹配比例:
Hacc=1K∑i=1K1[Equal(mgt,mp)]s.t.mgt,mp∈{?,?}(13)\mathcal{H}_{acc} = \frac{1}{K}\sum_{i=1}^K \mathbb{1}\left[\text{Equal}(m_{gt}, m_p)\right] \quad \text{s.t.} \quad m_{gt}, m_p \in \{\vdash, \nprec\} \quad (13) Hacc?=K1?i=1∑K?1[Equal(mgt?,mp?)]s.t.mgt?,mp?∈{?,?}(13)
關鍵參數:
- ?\epsilon?:模式得分差異閾值(默認0.05)
- NNN:每種模式采樣數(默認4)