在使用 Proximal Policy Optimization(PPO)對語言模型進行強化學習微調(如 RLHF)時,大家經常會問:
- 策略網絡的動作概率是怎么來的?
- 價值網絡的得分是如何計算的?
- 獎勵從哪里來?損失函數怎么構建?
- 微調后的舊軌跡還能用嗎?
這篇文章將以語言模型強化學習微調為例,結合實際實現和數學公式,深入解析 PPO 的關鍵計算流程。
1?? 策略網絡:如何計算動作概率?
策略網絡 πθ(a∣s)\pi_\theta(a|s)πθ?(a∣s) 用于給出狀態 sss 下采取動作 aaa 的概率。
對于語言模型(如 GPT)來說:
- 狀態 sss:Prompt(如“請介紹量子計算”)
- 動作 aaa:生成的回答(如“量子計算是一種…”)
策略網絡的輸出是 token 級別的 logits,經 softmax 后得到概率:
outputs = model(input_ids)
logits = outputs.logits # [batch_size, seq_len, vocab_size]
probs = torch.softmax(logits, dim=-1) # 得到 token 概率
對于一個完整回答,其概率為:
πθ(a1:T∣s)=∏t=1Tπθ(at∣s,a<t) \pi_\theta(a_{1:T} | s) = \prod_{t=1}^T \pi_\theta(a_t | s, a_{<t}) πθ?(a1:T?∣s)=t=1∏T?πθ?(at?∣s,a<t?)
該概率在 PPO 中用于計算策略概率比:
rt=πθ(at∣st)πθold(at∣st) r_t = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} rt?=πθold??(at?∣st?)πθ?(at?∣st?)?
2?? 價值網絡:如何計算狀態得分?
價值網絡 V?(s)V_\phi(s)V??(s) 預測的是狀態 sss 的期望累計獎勵,即該 prompt + 回復的“好壞”。
實現方式通常是共享模型底座 + 線性輸出層:
hidden_states = outputs.hidden_states # [batch_size, seq_len, hidden_dim]
value = value_head(hidden_states).squeeze(-1) # 每個 token 對應一個值
通常使用最后一個 token 的 value 作為整段文本的狀態值:
V?(s)=value(last_token)
V_\phi(s) = \text{value}(\text{last\_token})
V??(s)=value(last_token)
也可以做 mean pooling 等方式。
3?? 獎勵函數:怎么定義?
PPO 是一個基于獎勵優化的強化學習算法。對于語言模型,一般使用人工偏好、打分器、獎勵模型(RM)來計算獎勵 RRR。
示例方式:
- 高質量回答獎勵高,例如 R=+4R = +4R=+4
- 差的回答獎勵低,例如 R=+1R = +1R=+1
- 或者使用兩個回復的相對排序值差距(ranking loss)
PPO 使用獎勵和預測值來計算優勢函數(Advantage):
A^t=Rt?V?(st) \hat{A}_t = R_t - V_\phi(s_t) A^t?=Rt??V??(st?)
也可以用 GAE(廣義優勢估計)進一步平滑優勢值。
4?? PPO 策略損失函數:如何構建?
核心損失函數如下(Clipped Surrogate Objective):
Lpolicy=?Et[min?(rtA^t,clip(rt,1??,1+?)A^t)] L^{\text{policy}} = -\mathbb{E}_t \left[ \min \left( r_t \hat{A}_t, \text{clip}(r_t, 1 - \epsilon, 1 + \epsilon) \hat{A}_t \right) \right] Lpolicy=?Et?[min(rt?A^t?,clip(rt?,1??,1+?)A^t?)]
解釋:
- rtr_trt? 是策略概率比
- A^t\hat{A}_tA^t? 是優勢函數
- ?\epsilon? 是截斷系數(常用 0.2)
這個損失保證了策略更新不能偏離舊策略太遠,防止訓練不穩定。
🔍 第一次微調時,rt=1r_t = 1rt?=1:
由于初始時當前策略與舊策略相同,有:
rt=πθ(at∣st)πθold(at∣st)=1 r_t = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} = 1 rt?=πθold??(at?∣st?)πθ?(at?∣st?)?=1
所以第一次策略更新實際變成:
Lpolicy=?A^t L^{\text{policy}} = -\hat{A}_t Lpolicy=?A^t?
相當于標準的策略梯度算法。
5?? PPO 價值損失函數:如何構建?
價值網絡使用均方誤差損失來擬合獎勵:
Lvalue=Et[(V?(st)?Rt)2] L^{\text{value}} = \mathbb{E}_t \left[ \left( V_\phi(s_t) - R_t \right)^2 \right] Lvalue=Et?[(V??(st?)?Rt?)2]
也可以加上 value clipping:
Lvalue-clipped=max?((V?(st)?Rt)2,(clip(V?(st),Vold??,Vold+?)?Rt)2) L^{\text{value-clipped}} = \max\left( (V_\phi(s_t) - R_t)^2, (\text{clip}(V_\phi(s_t), V_{\text{old}} - \epsilon, V_{\text{old}} + \epsilon) - R_t)^2 \right) Lvalue-clipped=max((V??(st?)?Rt?)2,(clip(V??(st?),Vold???,Vold?+?)?Rt?)2)
6?? 總損失函數:包含 entropy 獎勵
完整的 PPO 損失函數通常為:
L=Lpolicy+c1?Lvalue?c2?H(πθ) L = L^{\text{policy}} + c_1 \cdot L^{\text{value}} - c_2 \cdot H(\pi_\theta) L=Lpolicy+c1??Lvalue?c2??H(πθ?)
- H(πθ)H(\pi_\theta)H(πθ?):策略的熵,用于鼓勵探索(entropy bonus)
- c1,c2c_1, c_2c1?,c2?:超參數,通常取 0.5 和 0.01
熵越高表示策略更隨機,防止策略過早收斂到確定動作。
7?? 微調后,舊軌跡還能繼續用嗎?
不能。PPO 是 on-policy 算法。
每輪策略更新后,舊軌跡(state, action, reward, old prob)就過時了,必須重新采樣:
- 舊策略生成的樣本反映不了當前策略的行為
- 若繼續使用,會引入策略偏移(policy mismatch)
因此,PPO 的標準訓練循環是:
- 用當前策略生成軌跡
- 固定軌跡,訓練 N 個 epoch
- 更新策略后丟棄舊軌跡
- 重復采樣新數據
? 總結回顧
項目 | 內容說明 |
---|---|
策略概率 | 模型輸出 logits → softmax 得到 token 概率 |
策略損失 | PPO clipped objective,基于概率比和優勢函數 |
價值得分 | Value head 輸出一個實數,預測狀態期望獎勵 |
獎勵函數 | 來自人工打分或獎勵模型,指導優勢函數計算 |
是否復用軌跡 | ? 不能復用舊軌跡,策略更新后必須重新采樣 |
🔚 寫在最后
理解 PPO 中策略概率、價值得分、損失函數之間的關系,是成功實現 RLHF、SFT + RL 微調語言模型的基礎。
這些原理不只是公式,更影響著你訓練是否穩定、樣本是否有效、微調是否收斂。