在本文中,我們將深入探討Deepseek采用的策略優化方法GRPO,并順帶介紹一些強化學習(Reinforcement Learning, RL)的基礎知識,包括PPO等關鍵概念。
策略函數(policy)
在強化學習中, a t ∣ s t a_t \mid s_t at?∣st? 表示在狀態 s t s_t st? 下采取動作 a t a_t at? 的條件概率。具體來說,它是由策略函數 π \pi π 決定的。
詳細說明
s t s_t st?
- s t s_t st? 表示在時間步 t t t 時的狀態(state)。
- 狀態是環境對智能體的當前描述,例如在游戲中可能是角色的位置、速度等信息。
a t a_t at?
- a t a_t at? 表示在時間步 t t t 時智能體采取的動作(action)。
- 動作是智能體在給定狀態下可以執行的操作,例如在游戲中可能是“向左移動”或“跳躍”。
π ( a t ∣ s t ) \pi(a_t \mid s_t) π(at?∣st?)
- π ( a t ∣ s t ) \pi(a_t \mid s_t) π(at?∣st?) 是策略函數(policy),表示在狀態 s t s_t st? 下選擇動作 a t a_t at? 的概率。
- 如果是確定性策略, π ( a t ∣ s t ) \pi(a_t \mid s_t) π(at?∣st?) 會直接輸出一個確定的動作;如果是隨機策略,它會輸出一個動作的概率分布。
r t ( θ ) = π θ ( a t ∣ s t ) π θ old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t \mid s_t)}{\pi_{\theta_{\text{old}}}(a_t \mid s_t)} rt?(θ)=πθold??(at?∣st?)πθ?(at?∣st?)?
- 在 PPO 中, r t ( θ ) r_t(\theta) rt?(θ) 是新策略 π θ \pi_\theta πθ? 和舊策略 π θ old \pi_{\theta_{\text{old}}} πθold?? 在狀態 s t s_t st? 下選擇動作 a t a_t at? 的概率比。
- 這個比值用于衡量策略更新的幅度,并通過裁剪機制限制其變化范圍,確保訓練的穩定性。
舉例說明
假設我們有一個簡單的游戲環境:
-
狀態 s t s_t st?:角色的位置。
-
動作 a t a_t at?:可以執行的動作是“向左”或“向右”。
-
策略 π ( a t ∣ s t ) \pi(a_t \mid s_t) π(at?∣st?):在某個位置 s t s_t st? 下,策略可能以 70% 的概率選擇“向左”,以 30% 的概率選擇“向右”。
在 PPO 中,我們會比較新舊策略在相同狀態 s t s_t st? 下選擇相同動作 a t a_t at? 的概率,從而計算概率比 r t ( θ ) r_t(\theta) rt?(θ),并用于優化目標函數。
總結
a t ∣ s t a_t \mid s_t at?∣st? 表示在狀態 s t s_t st? 下選擇動作 a t a_t at? 的條件概率,由策略函數 π \pi π 決定。在 PPO 中,這一概率用于計算新舊策略的比值,從而控制策略更新的幅度。
近端策略優化(PPO)
PPO(Proximal Policy Optimization) 是一種用于強化學習的策略優化算法,由 OpenAI 提出。它通過限制策略更新的幅度,確保訓練過程的穩定性。
核心思想
PPO 的核心在于限制策略更新的幅度,避免因更新過大導致性能下降。它通過引入“裁剪”機制,控制新舊策略之間的差異。
公式
PPO 的替代目標函數 J P P O ( θ ) \mathcal{J}_{PPO}(\theta) JPPO?(θ) 用于優化策略 π θ \pi_\theta πθ?,公式如下:
J P P O ( θ ) = E [ q ~ P ( Q ) , o ~ π θ o l d ( O ∣ q ) ] 1 ∣ o ∣ ∑ t = 1 ∣ o ∣ min ? [ π θ ( o t ∣ q , o < t ) π θ o l d ( o t ∣ q , o < t ) A t , clip ( π θ ( o t ∣ q , o < t ) π θ o l d ( o t ∣ q , o < t ) , 1 ? ε , 1 + ε ) A t ] \mathcal{J}_{PPO}(\theta) = \mathbb{E}_{[q \sim P(Q), o \sim \pi_{\theta_{old}}(O|q)]} \frac{1}{|o|} \sum_{t=1}^{|o|} \min \left[ \frac{\pi_\theta(o_{t} | q, o_{<t})}{\pi_{\theta_{old}}(o_{t} | q, o_{<t})} A_{t}, \text{clip} \left( \frac{\pi_\theta(o_{t} | q, o_{<t})}{\pi_{\theta_{old}}(o_{t} | q, o_{<t})}, 1 - \varepsilon, 1 + \varepsilon\right) A_{t} \right] JPPO?(θ)=E[q~P(Q),o~πθold??(O∣q)]?∣o∣1?t=1∑∣o∣?min[πθold??(ot?∣q,o<t?)πθ?(ot?∣q,o<t?)?At?,clip(πθold??(ot?∣q,o<t?)πθ?(ot?∣q,o<t?)?,1?ε,1+ε)At?]
其中:
期望符號 E \mathbb{E} E 表示對查詢 q q q 和輸出 o o o 的期望:
-
q ~ P ( Q ) q \sim P(Q) q~P(Q): 查詢 q q q 從分布 P ( Q ) P(Q) P(Q) 中采樣。
-
o ~ π θ o l d ( O ∣ q ) o \sim \pi_{\theta_{old}}(O|q) o~πθold??(O∣q): 輸出 o o o 由舊策略 π θ o l d \pi_{\theta_{old}} πθold?? 生成。
1 ∣ o ∣ ∑ t = 1 ∣ o ∣ \frac{1}{|o|} \sum_{t=1}^{|o|} ∣o∣1?∑t=1∣o∣? 對輸出 o o o 的每個時間步 t t t 求平均:
- ∣ o ∣ |o| ∣o∣ 是輸出序列的長度。
其核心目標函數為:
L C L I P ( θ ) = E t [ min ? ( r t ( θ ) A ^ t , clip ( r t ( θ ) , 1 ? ? , 1 + ? ) A ^ t ) ] L^{CLIP}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \hat{A}_t \right) \right] LCLIP(θ)=Et?[min(rt?(θ)A^t?,clip(rt?(θ),1??,1+?)A^t?)]
其中:
-
r t ( θ ) = π θ ( a t ∣ s t ) π θ old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} rt?(θ)=πθold??(at?∣st?)πθ?(at?∣st?)? 是新舊策略的概率比。
-
A ^ t \hat{A}_t A^t? 是優勢函數,衡量動作的相對好壞。
-
? \epsilon ? 是裁剪參數,通常為 0.1 或 0.2。
步驟
- 采樣:使用當前策略與環境交互,收集數據,在語言模型中,可以類比為生成補全(generating completions)。
- 計算優勢值:基于收集的數據計算優勢值函數 A ^ t \hat{A}_t A^t?。
- 優化目標函數:通過梯度上升優化目標函數 L C L I P ( θ ) L^{CLIP}(\theta) LCLIP(θ)。
- 更新策略:重復上述步驟,直到策略收斂。
優點
- 穩定性:通過裁剪機制,避免策略更新過大。
- 高效性:相比 TRPO,PPO 實現更簡單,計算效率更高。
補充
在強化學習中,策略的目標是最大化期望回報,而不是最小化損失。所以,在PPO中使用的是梯度上升,原因在于它的優化目標是最大化目標函數(如強化學習中的期望回報),而不是最小化損失函數(如分類或回歸問題)。
Advantage(優勢函數)
定義
Advantage函數用于衡量在某個狀態(State)下,采取某個動作(Action)相對于平均表現的優劣程度。它的數學定義為:
A ( s , a ) = Q ( s , a ) ? V ( s ) A(s, a) = Q(s, a) - V(s) A(s,a)=Q(s,a)?V(s), 其中:
-
Q ( s , a ) Q(s, a) Q(s,a) 是動作值函數,表示在狀態 s s s 下采取動作 a a a 后,未來累積回報的期望。
-
V ( s ) V(s) V(s) 是狀態值函數,表示在狀態 s s s 下,按照當前策略采取動作后,未來累積回報的期望。
-
A ( s , a ) A(s, a) A(s,a) 是優勢函數,表示在狀態 s s s 下采取動作 a a a 比平均表現好多少(或差多少)。
作用
-
Advantage函數用于指導策略更新:
- 如果 A ( s , a ) > 0 A(s, a) > 0 A(s,a)>0,說明動作 a a a 比平均表現更好,策略應該更傾向于選擇這個動作;
- 如果 A ( s , a ) < 0 A(s, a) < 0 A(s,a)<0,說明動作 a a a 比平均表現更差,策略應該減少選擇這個動作的概率。
-
在PPO等算法中,Advantage函數通常通過**GAE(Generalized Advantage Estimation)**來估計。
直觀理解
Advantage函數就像一個“評分”,告訴模型某個動作在當前狀態下是好還是壞,以及好(或壞)的程度。
KL Penalty(KL散度懲罰)
定義
KL Penalty是基于**KL散度(Kullback-Leibler Divergence)**的一種正則化手段。KL散度用于衡量兩個概率分布之間的差異。在強化學習中,KL Penalty通常用于限制當前策略 π θ \pi_{\theta} πθ? 和參考策略 π ref \pi_{\text{ref}} πref? 之間的差異。其數學定義為:
KL?Penalty = D KL ( π ref ∥ π θ ) \text{KL Penalty} = D_{\text{KL}}(\pi_{\text{ref}} \| \pi_{\theta}) KL?Penalty=DKL?(πref?∥πθ?)
其中:
-
π θ \pi_{\theta} πθ? 是當前策略(由模型參數 θ \theta θ 決定)。
-
π ref \pi_{\text{ref}} πref? 是參考策略(通常是更新前的策略或某個基線策略)。
-
D KL D_{\text{KL}} DKL? 是KL散度,用于衡量兩個策略之間的差異。
作用
- KL Penalty用于防止策略更新過大,確保當前策略不會偏離參考策略太遠。這樣可以避免訓練過程中的不穩定現象(如策略崩潰)。
- 在PPO等算法中,KL Penalty通常被添加到目標函數中,作為正則化項。
直觀理解
KL Penalty就像一個“約束”,告訴模型在更新策略時不要“步子邁得太大”,以免失去穩定性。
Advantage和KL Penalty的關系
-
Advantage 用于指導策略更新,告訴模型哪些動作更好。
-
KL Penalty 用于約束策略更新,防止策略變化過大。
-
在PPO等算法中,Advantage和KL Penalty共同作用,既鼓勵模型選擇更好的動作,又確保更新過程穩定可靠。
舉例說明
假設我們訓練一個機器人走迷宮:
-
Advantage:機器人發現“向右轉”比“向左轉”更容易找到出口,于是Advantage函數會給“向右轉”一個正的值,鼓勵策略更傾向于選擇“向右轉”。
-
KL Penalty:為了防止策略突然變得只選擇“向右轉”而忽略其他可能性,KL Penalty會限制策略的變化幅度,確保策略更新是平滑的。
總結
-
Advantage(優勢函數):衡量某個動作比平均表現好多少,用于指導策略更新。
-
KL Penalty(KL散度懲罰):限制策略更新的幅度,確保訓練過程的穩定性。
群體相對策略優化(GRPO)
GRPO 是一種在線學習算法(online learning algorithm),這意味著它通過使用訓練過程中由訓練模型自身生成的數據來迭代改進。GRPO 的目標直覺是最大化生成補全(completions)的優勢函數(advantage),同時確保模型保持在參考策略(reference policy)附近。
其目標函數為:
J GRPO ( θ ) = E q ~ P ( Q ) , { o i } i = 1 G ~ π old ( O ∣ q ) [ 1 G ∑ i = 1 G 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ ( r i , t ( θ ) A ^ i , t ? β D KL ( π θ ∣ ∣ π ref ) ) ] J_{\text{GRPO}}(\theta) = \mathbb{E}_{q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_{\text{old}}(O|q)} \left[ \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left( r_{i,t}(\theta) \hat{A}_{i,t} - \beta D_{\text{KL}}(\pi_\theta || \pi_{\text{ref}}) \right) \right] JGRPO?(θ)=Eq~P(Q),{oi?}i=1G?~πold?(O∣q)? ?G1?i=1∑G?∣oi?∣1?t=1∑∣oi?∣?(ri,t?(θ)A^i,t??βDKL?(πθ?∣∣πref?)) ?
為了理解 GRPO 的工作原理,可以將其分解為四個主要步驟:
-
生成補全(Generating completions)
-
計算優勢值(Computing the advantage)
-
估計KL散度(Estimating the KL divergence)
-
計算損失(Computing the loss)
1. 生成補全(Generating completions)
在每一個訓練步驟中,我們從提示(prompts)中采樣一個批次(batch),并為每個提示生成一組 G G G 個補全(completions)(記為 o i o_i oi?)。
2. 計算優勢值(Computing the advantage)
對于每一個 G G G 序列,使用獎勵模型(reward model)計算其獎勵(reward)。為了與獎勵模型的比較性質保持一致——通常獎勵模型是基于同一問題的輸出之間的比較數據集進行訓練的——優勢的計算反映了這些相對比較。其歸一化公式如下:
A ^ i , t = r i ? mean ( r ) std ( r ) \hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})} A^i,t?=std(r)ri??mean(r)?
這種方法賦予了該方法其名稱:群體相對策略優化(Group Relative Policy Optimization, GRPO)
GRPO通過優化PPO算法,解決了計算優勢值時需要同時依賴獎勵模型(reward model)和價值模型(value model)的問題,成功移除了value model(價值模型),顯著降低了推理時的內存占用和時間開銷。**Advantage(優勢值)**的核心價值在于為模型輸出提供更精準的評估,不僅衡量答案的絕對質量,還通過相對比較(與其他回答的對比)來更全面地定位其優劣。
3. 估計KL散度(Estimating the KL divergence)
在實際算法實現中,直接計算KL散度可能會面臨一些挑戰:
- 計算復雜度高:KL散度的定義涉及對兩個概率分布的對數比值的期望計算。對于復雜的策略分布,直接計算KL散度可能需要大量的計算資源;
- 數值穩定性:在實際計算中,直接計算KL散度可能會遇到數值不穩定的問題,尤其是當兩個策略的概率分布非常接近時,對數比值可能會趨近于零或無窮大。近似器可以通過引入一些數值穩定性的技巧(如截斷或平滑)來避免這些問題;
- 在線學習:在強化學習中,策略通常需要在每一步或每幾步更新一次。如果每次更新都需要精確計算KL散度,可能會導致訓練過程變得非常緩慢。近似器可以快速估計KL散度,從而支持在線學習和實時更新。
Schulman et al. (2020) 提出的近似器可以根據當前策略和參考策略的差異動態調整估計的精度,從而在保證計算效率的同時,盡可能減少估計誤差,其定義如下:
D KL [ π θ ∥ π ref ] = π ref ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) ? log ? π ref ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) ? 1 \mathbb{D}_{\text{KL}}\left[\pi_\theta \|\pi_{\text{ref}}\right] = \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - \log \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - 1 DKL?[πθ?∥πref?]=πθ?(oi,t?∣q,oi,<t?)πref?(oi,t?∣q,oi,<t?)??logπθ?(oi,t?∣q,oi,<t?)πref?(oi,t?∣q,oi,<t?)??1
這個近似器的核心思想是通過對當前策略和參考策略的概率比值的簡單變換來估計KL散度。具體來說:
- 第一項: π ref ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} πθ?(oi,t?∣q,oi,<t?)πref?(oi,t?∣q,oi,<t?)? 是參考策略與當前策略的概率比值。
- 第二項: log ? π ref ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) \log \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} logπθ?(oi,t?∣q,oi,<t?)πref?(oi,t?∣q,oi,<t?)? 是對數概率比值。
- 第三項: ? 1 -1 ?1 是一個常數項,用于調整近似器的偏差。
這個近似器的優勢在于它只需要計算當前策略和參考策略的概率比值,而不需要直接計算KL散度的積分或期望。因此,它可以在保證一定精度的同時,顯著降低計算復雜度。
近似器的直觀理解
這個近似器的設計靈感來自于泰勒展開。KL散度可以看作是兩個分布之間的某種“距離”,而這個近似器通過一階或二階近似來估計這個距離。具體來說:
- 當 π θ \pi_\theta πθ? 和 π ref \pi_{\text{ref}} πref? 非常接近時, π ref π θ ≈ 1 \frac{\pi_{\text{ref}}}{\pi_\theta} \approx 1 πθ?πref??≈1,此時 log ? π ref π θ ≈ 0 \log \frac{\pi_{\text{ref}}}{\pi_\theta} \approx 0 logπθ?πref??≈0,近似器的值趨近于零,符合KL散度的性質。
- 當 π θ \pi_\theta πθ? 和 π ref \pi_{\text{ref}} πref? 差異較大時,近似器會給出一個較大的正值,反映出兩個分布之間的差異。
4. 計算損失(Computing the loss)
這一步的目標是最大化優勢,同時確保模型保持在參考策略附近。因此,損失定義如下:
L GRPO ( θ ) = ? 1 G ∑ i = 1 G 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ [ π θ ( o i , t ∣ q , o i , < t ) [ π θ ( o i , t ∣ q , o i , < t ) ] no?grad A ^ i , t ? β D KL [ π θ ∥ π ref ] ] \mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right] LGRPO?(θ)=?G1?i=1∑G?∣oi?∣1?t=1∑∣oi?∣?[[πθ?(oi,t?∣q,oi,<t?)]no?grad?πθ?(oi,t?∣q,oi,<t?)?A^i,t??βDKL?[πθ?∥πref?]]
其中第一項表示縮放后的優勢,第二項通過KL散度懲罰與參考策略的偏離。
在原始論文中,該公式被推廣為在每次生成后通過利用**裁剪替代目標(clipped surrogate objective)**進行多次更新:
L GRPO ( θ ) = ? 1 G ∑ i = 1 G 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ [ min ? ( π θ ( o i , t ∣ q , o i , < t ) π θ old ( o i , t ∣ q , o i , < t ) A ^ i , t , clip ( π θ ( o i , t ∣ q , o i , < t ) π θ old ( o i , t ∣ q , o i , < t ) , 1 ? ? , 1 + ? ) A ^ i , t ) ? β D KL [ π θ ∥ π ref ] ] \mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right] LGRPO?(θ)=?G1?i=1∑G?∣oi?∣1?t=1∑∣oi?∣?[min(πθold??(oi,t?∣q,oi,<t?)πθ?(oi,t?∣q,oi,<t?)?A^i,t?,clip(πθold??(oi,t?∣q,oi,<t?)πθ?(oi,t?∣q,oi,<t?)?,1??,1+?)A^i,t?)?βDKL?[πθ?∥πref?]]
其中 clip ( ? , 1 ? ? , 1 + ? ) \text{clip}(\cdot, 1 - \epsilon, 1 + \epsilon) clip(?,1??,1+?) 通過將策略比率限制在 1 ? ? 1 - \epsilon 1?? 和 1 + ? 1 + \epsilon 1+? 之間,確保更新不會過度偏離參考策略。
在很多代碼實現,比如Huggingface的TRL中,與原始論文一樣每次生成只進行一次更新,因此可以將損失簡化為第一種形式。
總結
GRPO通過優化PPO算法,移除了價值模型,降低了計算開銷,同時利用群體相對優勢函數和KL散度懲罰,確保策略更新既高效又穩定。
想象一下,你是個銷售員,這個月業績10萬塊,PPO算法就像個精明的老會計,拿著算盤噼里啪啦一頓算,考慮市場行情、產品類型,最后得出結論:“嗯,這10萬還算靠譜,但GAE一算,發現你的優勢值還不夠高,還得再加把勁啊”
而GRPO呢,就像老板直接搞了個“內卷大賽”,把所有銷售員拉到一個群里,每天曬業績:“你10萬,他15萬,她20萬……”老板還時不時發個紅包,刺激大家繼續卷。你的10萬塊在群里瞬間被淹沒,老板搖搖頭:“你這水平,還得加把勁啊!”
GRPO這招絕了,它把PPO的“算盤”扔了,省了不少計算功夫,直接搞“內卷PK”,用KL散度懲罰來確保大家別躺平。這樣一來,策略更新既快又穩,老板再也不用擔心有人摸魚了,畢竟大家都在拼命卷,誰敢松懈?
總結一下:PPO是“單打獨斗看實力”,GRPO是“內卷大賽拼到死”,最后GRPO還省了算盤錢,老板笑得合不攏嘴,而我們只能默默加班,繼續卷。