詳解 PPO 算法中的 actor_loss_fn
及其核心參數
1. 引言
在強化學習中,PPO(Proximal Policy Optimization,近端策略優化)算法是一種經典且高效的策略優化方法。它通過重要性采樣(Importance Sampling)和策略剪切(Clipping)來平衡策略改進的幅度與模型訓練的穩定性。
這篇博客將重點解析 DeepSpeed Chat 的 actor_loss_fn
函數,代碼鏈接: https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py#L44,結合代碼詳細說明其作用,并通過一個具體的例子,說明模型生成的 response 是如何對應到 logprobs
、old_logprobs
、advantages
和 mask
這幾個參數的。
2. 代碼解析
下面是 actor_loss_fn
函數的實現:
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):# 計算 log 概率比log_ratio = (logprobs - old_logprobs) * maskratio = torch.exp(log_ratio)# 策略梯度損失(兩種形式)pg_loss1 = -advantages * ratiopg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)# 選擇損失中的最大值,并加權求和(masked)pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()return pg_loss
2.1 參數說明
logprobs
:當前策略(即 Actor 網絡)在生成的 response 上計算得到的 log 概率。這是動態更新的,隨著 Actor 策略的更新而變化。old_logprobs
:舊策略(之前保存的 Actor 策略)在相同 response 上計算得到的 log 概率。這是固定的,不隨當前策略變化。advantages
:舊策略計算出的優勢函數值,表示每個 token 的相對價值。這也是固定的。mask
:一個布爾型張量,用于標記有效的 token(即非填充部分)。例如,response 的部分 token 有效,其余部分被填充為無效。
2.2 核心邏輯
-
計算 Log 概率比
log_ratio = ( logprobs ? old_logprobs ) ? mask \text{log\_ratio} = (\text{logprobs} - \text{old\_logprobs}) \cdot \text{mask} log_ratio=(logprobs?old_logprobs)?mask
通過當前策略和舊策略的 log 概率差,計算重要性比率的對數。 -
將 Log 概率比轉換為比例
ratio = exp ? ( log_ratio ) \text{ratio} = \exp(\text{log\_ratio}) ratio=exp(log_ratio)
轉換為實際的概率比。 -
定義兩種策略梯度損失
- 未剪切的損失:
pg_loss1 = ? advantages ? ratio \text{pg\_loss1} = -\text{advantages} \cdot \text{ratio} pg_loss1=?advantages?ratio - 剪切后的損失:
pg_loss2 = ? advantages ? clip ( ratio , 1 ? ? , 1 + ? ) \text{pg\_loss2} = -\text{advantages} \cdot \text{clip}(\text{ratio}, 1-\epsilon, 1+\epsilon) pg_loss2=?advantages?clip(ratio,1??,1+?)
通過限制 ( ratio \text{ratio} ratio) 在 ( [ 1 ? ? , 1 + ? ] [1-\epsilon, 1+\epsilon] [1??,1+?]) 范圍內,避免策略更新幅度過大,防止訓練不穩定。
- 未剪切的損失:
-
選取最大損失
對每個 token,取兩種損失中的較大值,以確保更新過程不會偏離策略優化目標。 -
加權平均損失
使用mask
對損失加權,避免填充值對計算的干擾。
3. 參數與 Response 的對應關系
為了更好地理解 logprobs
、old_logprobs
、advantages
和 mask
,我們通過一個具體例子來說明:
假設輸入
- Prompt:
"Translate the following text into French: 'Hello, how are you?'"
- Response(生成的句子):
"Bonjour, comment ?a va?"
具體參數計算過程
-
logprobs
當前策略對response
每個 token 的生成概率:logprobs
可能為:[-1.2, -0.8, -1.5, -0.6, -1.0] # 對應 "Bonjour,", "comment", "?a", "va", "?"
- 它是通過當前 Actor 模型的 softmax 輸出取對數后得到的。
-
old_logprobs
舊策略對相同response
的生成概率:old_logprobs
可能為:[-1.0, -0.9, -1.4, -0.7, -0.9] # 對應 "Bonjour,", "comment", "?a", "va", "?"
-
advantages
每個 token 的優勢值,表示生成該 token 的相對價值:- 例如:
[0.5, 0.8, 0.3, 0.6, 0.4] # 對應 "Bonjour,", "comment", "?a", "va", "?"
- 這些值通過舊策略與環境交互(或獎勵模型)計算得到,表示每個 token 對整體策略的貢獻。
- 例如:
-
mask
一個布爾型標記,有效部分為 1,填充部分為 0:- 例如,假設
response
中沒有填充值:[1, 1, 1, 1, 1]
- 例如,假設
4. 為什么引入 logprobs
和策略剪切?
在 PPO 中,策略更新會參考舊策略,但如果更新幅度太大(即 ( ratio \text{ratio} ratio) 偏離 1 太遠),可能會導致模型“訓歪”:
- 過度更新:策略過快改變,導致新策略偏離舊策略太遠,破壞學習穩定性。
- 欠更新:策略改變過小,無法顯著優化目標函數。
通過引入 logprobs
和剪切機制(( clip \text{clip} clip)),我們可以控制每次更新的幅度,確保:
- 更新不偏離舊策略太遠。
- 改進的同時保留策略的多樣性和探索能力。
5. 實例代碼
以下是基于 PyTorch 的實現,模擬計算 actor_loss
:
import torch# 模擬輸入
logprobs = torch.tensor([-1.2, -0.8, -1.5, -0.6, -1.0]) # 當前策略的 log 概率
old_logprobs = torch.tensor([-1.0, -0.9, -1.4, -0.7, -0.9]) # 舊策略的 log 概率
advantages = torch.tensor([0.5, 0.8, 0.3, 0.6, 0.4]) # 每個 token 的優勢值
mask = torch.tensor([1, 1, 1, 1, 1]) # 有效部分的掩碼
cliprange = 0.2 # 剪切范圍# 計算 log 概率比和比率
log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)# 策略梯度損失
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()# 輸出結果
print("Log Ratios:", log_ratio)
print("Ratios:", ratio)
print("Policy Gradient Loss:", pg_loss.item())
6. 總結
通過這篇博客,我們詳細介紹了 actor_loss_fn
的實現邏輯及其關鍵參數在 PPO 中的作用:
logprobs
和old_logprobs
是新舊策略的橋梁,通過重要性采樣連接兩者。advantages
衡量生成 token 的相對價值。mask
避免無效 token 對損失的影響。- 引入剪切機制確保訓練穩定。
通過具體例子和代碼演示,我們更直觀地理解了這些參數與 PPO 的關系,以及它們在策略優化中的關鍵作用。
【1】pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum() 解釋
要理解這句代碼 pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
,我們可以通過一個具體的例子,模擬實際計算過程,特別關注 mask
的作用和 mask.sum()
的意義。
代碼結構解析
-
pg_loss1
和pg_loss2
- 兩種策略梯度損失的計算:
pg_loss1
是未剪切的策略梯度損失。pg_loss2
是剪切后的策略梯度損失(通過torch.clamp
限制比率的范圍)。
torch.max(pg_loss1, pg_loss2)
:對每個 token,取兩種損失中的較大值。
- 兩種策略梯度損失的計算:
-
乘以
mask
mask
是一個布爾型張量,用于標記哪些 token 是有效的,哪些是無效的。- 乘以
mask
的作用是:將無效的 token 對應的損失置為 0,避免它們對整體損失的計算產生干擾。
-
torch.sum(... * mask)
- 計算有效 token 的總損失,忽略無效 token。
-
/ mask.sum()
- 歸一化:用有效 token 的數量(
mask.sum()
)對總損失進行歸一化,確保損失是平均的,而不是直接被無效 token 的數量拉低。
- 歸一化:用有效 token 的數量(
實際數值模擬
假設我們有以下數據:
輸入數據
-
pg_loss1
:未剪切的損失pg_loss1 = [-0.5, -0.8, -0.3, -1.2, -0.6] # 每個 token 的損失
-
pg_loss2
:剪切后的損失pg_loss2 = [-0.4, -0.9, -0.5, -0.8, -0.7]
-
mask
:有效 token 標記mask = [1, 1, 1, 0, 0] # 前 3 個 token 有效,后兩個 token 無效
計算步驟
-
torch.max(pg_loss1, pg_loss2)
對每個 token,取pg_loss1
和pg_loss2
中的較大值(注意:負數值較小的絕對值是較大值)。max_loss = torch.max(pg_loss1, pg_loss2)= [-0.4, -0.8, -0.3, -0.8, -0.6]
-
乘以
mask
通過與mask
相乘,無效 token(mask = 0
)對應的損失被置為 0:masked_loss = max_loss * mask= [-0.4, -0.8, -0.3, 0.0, 0.0]
-
torch.sum(masked_loss)
求有效 token 的總損失:total_loss = torch.sum(masked_loss)= -0.4 + (-0.8) + (-0.3) + 0.0 + 0.0= -1.5
-
mask.sum()
統計有效 token 的數量:mask_sum = mask.sum()= 1 + 1 + 1 + 0 + 0= 3
-
歸一化損失
將總損失除以有效 token 數量,計算平均損失:pg_loss = total_loss / mask_sum= -1.5 / 3= -0.5
乘以 mask
的作用
在這個例子中,mask
的作用是:
- 將無效 token(
mask = 0
)對應的損失設置為 0,防止這些無效 token 對總損失產生影響。- 如果沒有乘以
mask
,則pg_loss
會錯誤地包含無效 token 的損失。
- 如果沒有乘以
- 確保只有有效 token 的損失被納入計算。
歸一化 (/ mask.sum()
) 的作用
在訓練過程中,不同 batch 的有效 token 數量可能不一樣。例如:
- 一個 batch 的
mask
有效部分是 3 個 token。 - 另一個 batch 的
mask
有效部分是 5 個 token。
如果不對總損失歸一化,那么:
- 有效 token 數量多的 batch 會對損失計算產生更大的權重。
- 歸一化后,每個 token 的貢獻是平均的,確保訓練的穩定性。
總結
通過這個例子,可以清楚地看到:
torch.max(pg_loss1, pg_loss2)
:用于選擇更大的損失,平衡未剪切和剪切更新之間的效果。- 乘以
mask
:屏蔽無效 token 的損失,專注于有效部分。 - 歸一化 (
/ mask.sum()
):確保損失與有效 token 的數量無關,保持不同 batch 的訓練一致性。
最終,這段代碼實現了 PPO 策略更新時的核心邏輯:通過限制每次更新幅度,穩定地優化 Actor 策略,同時屏蔽無效 token 對損失的干擾。
【2】 為什么在 torch.max(pg_loss1, pg_loss2)
中使用 max
而不是 min
?
PPO(Proximal Policy Optimization)算法中的損失函數設計目標是 限制策略更新幅度,確保策略的穩定性和性能提升。這里的 max
并非簡單地取大值,而是通過裁剪(clipping)機制實現 選擇對策略有利的損失,同時約束更新幅度。
PPO 的核心損失函數回顧
PPO的目標函數為:
L CLIP ( θ ) = E [ min ? ( r t ( θ ) A ^ t , clip ( r t ( θ ) , 1 ? ? , 1 + ? ) A ^ t ) ] L^{\text{CLIP}}(\theta) = \mathbb{E} \left[ \min \left( r_t(\theta) \hat{A}_t, \, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \hat{A}_t \right) \right] LCLIP(θ)=E[min(rt?(θ)A^t?,clip(rt?(θ),1??,1+?)A^t?)]
其中:
- ( r t ( θ ) = π θ ( a t ∣ s t ) π old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\text{old}}(a_t | s_t)} rt?(θ)=πold?(at?∣st?)πθ?(at?∣st?)? ):新舊策略的比例。
- ( A ^ t \hat{A}_t A^t? ):優勢值(Advantage)。
- ( ? \epsilon ? ):裁剪范圍(比如 ( 0.2 0.2 0.2 ))。
- clip:將比例 ( r t r_t rt? ) 裁剪到 ( [ 1 ? ? , 1 + ? ] [1 - \epsilon, 1 + \epsilon] [1??,1+?] )。
PPO 的損失函數通過裁剪機制,限制策略更新幅度,防止模型發生 過大更新,從而訓練不穩定。
在實現中,pg_loss1
和 pg_loss2
分別為:
pg_loss1 = - advantages * ratio
(未裁剪的損失)pg_loss2 = - advantages * clipped_ratio
(裁剪后的損失)
為什么使用 max
而不是 min
?
-
損失的符號:
PPO 使用的是 負的優勢值(-Advantage) 來計算損失,因為優化目標是 最小化損失,而優勢值的正負決定了動作的好壞。- 當 ( A ^ t > 0 \hat{A}_t > 0 A^t?>0 ):表示動作優于基準,應該增加概率。
- 當 ( A ^ t < 0 \hat{A}_t < 0 A^t?<0 ):表示動作劣于基準,應該減少概率。
-
取
max
的意義:
在兩種情況下,PPO 的目標是:- 對于好的動作(優勢 ( > 0 > 0 >0 )),我們允許更新比例 ( r t r_t rt? ) 稍微增大,但不能超出裁剪范圍。
- 對于壞的動作(優勢 ( < 0 < 0 <0 )),我們允許更新比例 ( r t r_t rt? ) 稍微減小,但不能超出裁剪范圍。
通過
max
,我們實現了以下效果:- 當策略偏離較少時,使用未裁剪的損失(
pg_loss1
)。 - 當策略偏離較多時,使用裁剪后的損失(
pg_loss2
),限制更新幅度。
-
為什么不是
min
?
如果使用min
,會選擇較小的損失值,反而可能導致:- 對好的動作,更新不足(不增大概率)。
- 對壞的動作,更新過度(減少概率過大)。
這將違背 PPO 的設計目標,無法有效地限制策略更新幅度,導致訓練不穩定。
數值示例分析
假設我們有以下輸入數據:
- 優勢值(Advantage):
[2.0, -1.0]
- 比例(ratio):
[1.1, 0.7]
- 裁剪范圍:
[0.8, 1.2]
根據 PPO 公式:
-
未裁剪的損失(
pg_loss1
):
pg_loss1 = ? advantages × ratio \text{pg\_loss1} = - \text{advantages} \times \text{ratio} pg_loss1=?advantages×ratio -
裁剪后的損失(
pg_loss2
):
pg_loss2 = ? advantages × clip ( ratio , 1 ? ? , 1 + ? ) \text{pg\_loss2} = - \text{advantages} \times \text{clip}(\text{ratio}, 1 - \epsilon, 1 + \epsilon) pg_loss2=?advantages×clip(ratio,1??,1+?)
代碼實現
import torch# 輸入數據
advantages = torch.tensor([2.0, -1.0]) # 優勢值
ratio = torch.tensor([1.1, 0.7]) # 更新比例
cliprange = 0.2 # 裁剪范圍# 計算未裁剪的損失
pg_loss1 = -advantages * ratio# 計算裁剪后的損失
clipped_ratio = torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
pg_loss2 = -advantages * clipped_ratio# 取 max
final_loss = torch.max(pg_loss1, pg_loss2)print("pg_loss1 (未裁剪):", pg_loss1)
print("pg_loss2 (裁剪后):", pg_loss2)
print("Final Loss (max):", final_loss)
輸出結果
pg_loss1 (未裁剪): tensor([-2.2000, 0.7000])
pg_loss2 (裁剪后): tensor([-2.2000, 0.8000])
Final Loss (max): tensor([-2.2000, 0.8000])
解釋:
-
第一個樣本(優勢 > 0):
pg_loss1 = -2.2
(未裁剪)pg_loss2 = -2.2
(裁剪后的損失與未裁剪相等)- 取
max
,最終損失為-2.2
。
-
第二個樣本(優勢 < 0):
pg_loss1 = 0.7
(未裁剪)pg_loss2 = 0.8
(裁剪后的比例被限制為 0.8)- 取
max
,最終損失為0.8
,表示更新受到了裁剪限制。
結論
max(pg_loss1, pg_loss2)
的作用:確保在未裁剪損失和裁剪損失之間取較大值,限制策略的更新幅度,穩定訓練過程。- 損失函數關系:
- 對于 優勢 ( > 0 ),我們防止概率增大過多。
- 對于 優勢 ( < 0 ),我們防止概率減小過多。
- 使用
min
會導致訓練不穩定,無法有效地限制更新幅度,違反 PPO 的核心設計理念。
PPO 損失函數前面是否有負號取決于損失函數的最優化目標,即最大化策略性能還是最小化損失。
【3】為什么有的資料ppo的損失函數前面有負號,-min(),而有的材料前面沒有負號?
1. 從優化目標的角度解釋
- 強化學習的目標:最大化策略的累積回報 ( R t R_t Rt? )。
由于優化器(例如 Adam)默認是最小化損失,我們需要對最大化目標取負數,從而轉換為最小化損失問題。
因此:
- 如果目標是最大化策略性能,則公式為:
maximize? L CLIP ( θ ) \text{maximize } L^{\text{CLIP}}(\theta) maximize?LCLIP(θ) - 為了與優化器配合(最小化),我們取負號:
minimize? ? L CLIP ( θ ) \text{minimize } -L^{\text{CLIP}}(\theta) minimize??LCLIP(θ)
帶負號的形式:
actor_loss = ? min ? ( r t ? Adv t , clip ( r t , 1 ? ? , 1 + ? ) ? Adv t ) \text{actor\_loss} = - \min \left( r_t \cdot \text{Adv}_t, \, \text{clip}(r_t, 1 - \epsilon, 1 + \epsilon) \cdot \text{Adv}_t \right) actor_loss=?min(rt??Advt?,clip(rt?,1??,1+?)?Advt?)
2. 從損失函數的數值角度
- 優勢值 ( Adv t \text{Adv}_t Advt? ):
- 當 ( Adv t > 0 \text{Adv}_t > 0 Advt?>0 ):表示當前動作優于基準,應該增大概率。
- 當 ( Adv t < 0 \text{Adv}_t < 0 Advt?<0 ):表示當前動作劣于基準,應該減小概率。
如果前面有負號:
- 最大化優勢值:對于 ( Adv t > 0 \text{Adv}_t > 0 Advt?>0 ),損失會減小(概率增大);對于 ( Adv t < 0 \text{Adv}_t < 0 Advt?<0 ),損失會增大(概率減小)。
如果前面沒有負號:
- 直接是獎勵優化形式,適合展示最大化回報的目標,但在實際代碼中優化器不直接適配這種形式。
3. 兩種形式的適用場景
- 帶負號的形式:通常出現在代碼實現中,因為優化器默認執行最小化損失。
- 不帶負號的形式:出現在理論公式或論文中,強調最大化目標。
結論
帶負號的形式是為了將“最大化回報”轉化為優化器的最小化損失問題,而不帶負號的形式僅僅表示理論上最大化優勢目標。實際實現中,大多數代碼會包含負號。
【4】把pg_loss1和2的負號去掉,最后pg_loss取min再加負號,這樣可以嗎?
下面是我問GPT的問題:
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):## policy gradient losslog_ratio = (logprobs - old_logprobs) * maskratio = torch.exp(log_ratio)pg_loss1 = advantages * ratio pg_loss2 = advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)pg_loss = -torch.sum(torch.min(pg_loss1, pg_loss2) * mask) / mask.sum() return pg_loss
代碼修改成這樣可以嗎?把pg_loss1和2的負號去掉,最后pg_loss取min,再加負號,這樣可以嗎?或者不可以嗎?為什么
修改后的代碼是可以的!在 PPO 的優化目標中,加負號 和 取 min
的組合恰好符合 PPO 損失函數的設計。以下會詳細解釋為什么這樣修改是合理的。
1. PPO 損失函數的原理
PPO 的目標是最大化期望回報,同時限制策略更新的幅度,以保證訓練的穩定性。其目標函數為:
L PPO ( θ ) = E t [ min ? ( r t ( θ ) ? Adv t , clip ( r t ( θ ) , 1 ? ? , 1 + ? ) ? Adv t ) ] L^{\text{PPO}}(\theta) = \mathbb{E}_t \left[ \min\left(r_t(\theta) \cdot \text{Adv}_t, \, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \cdot \text{Adv}_t\right) \right] LPPO(θ)=Et?[min(rt?(θ)?Advt?,clip(rt?(θ),1??,1+?)?Advt?)]
其中:
- ( r t ( θ ) = π θ ( a t ∣ s t ) π old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\text{old}}(a_t|s_t)} rt?(θ)=πold?(at?∣st?)πθ?(at?∣st?)? ) 是新舊策略的概率比。
- ( Adv t \text{Adv}_t Advt? ) 是優勢函數,表示當前動作相對于平均水平的好壞。
- 裁剪操作:通過 ( clip ( r t ( θ ) , 1 ? ? , 1 + ? ) \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) clip(rt?(θ),1??,1+?) ) 限制比率的變化幅度。
為了最大化 ( L PPO L^{\text{PPO}} LPPO ),通常將其轉化為一個最小化問題,即在前面加上負號:
actor_loss = ? E t [ min ? ( r t ( θ ) ? Adv t , clip ( r t ( θ ) , 1 ? ? , 1 + ? ) ? Adv t ) ] \text{actor\_loss} = -\mathbb{E}_t \left[ \min\left(r_t(\theta) \cdot \text{Adv}_t, \, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \cdot \text{Adv}_t\right) \right] actor_loss=?Et?[min(rt?(θ)?Advt?,clip(rt?(θ),1??,1+?)?Advt?)]
2. 代碼分析
修改后的代碼:
log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)
pg_loss1 = advantages * ratio
pg_loss2 = advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)
pg_loss = -torch.sum(torch.min(pg_loss1, pg_loss2) * mask) / mask.sum()
return pg_loss
主要修改點:
- 去掉了
pg_loss1
和pg_loss2
的負號。 - 在最終計算
pg_loss
時,加了負號。 - 使用
torch.min(pg_loss1, pg_loss2)
。
分析:
-
去掉
pg_loss1
和pg_loss2
的負號:
這里pg_loss1
和pg_loss2
都表示未裁剪和裁剪后的損失,原始 PPO 損失的優化目標是最大化優勢乘以比率。直接使用優勢和比率(不加負號)可以保留目標的最大化形式。 -
加負號到
pg_loss
前面:
PPO 的目標是最大化 ( L PPO L^{\text{PPO}} LPPO ),但優化器通常是最小化損失函數,因此需要加負號將最大化目標轉化為最小化問題。 -
使用
torch.min(pg_loss1, pg_loss2)
:pg_loss1
:未裁剪的情況 ( Adv t ? r t ( θ ) \text{Adv}_t \cdot r_t(\theta) Advt??rt?(θ) )。pg_loss2
:裁剪后的情況 ( Adv t ? clip ( r t ( θ ) , 1 ? ? , 1 + ? ) \text{Adv}_t \cdot \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) Advt??clip(rt?(θ),1??,1+?) )。- 取
min
:選取兩者中較小的值,目的是在不損害回報的情況下,限制策略更新幅度,從而保證訓練的穩定性。
3. 為什么這樣可以?
PPO 的損失函數最終形式是:
actor_loss = ? E t [ min ? ( r t ( θ ) ? Adv t , clip ( r t ( θ ) , 1 ? ? , 1 + ? ) ? Adv t ) ] \text{actor\_loss} = -\mathbb{E}_t \left[ \min\left(r_t(\theta) \cdot \text{Adv}_t, \, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \cdot \text{Adv}_t\right) \right] actor_loss=?Et?[min(rt?(θ)?Advt?,clip(rt?(θ),1??,1+?)?Advt?)]
與修改后的代碼一一對應:
pg_loss1
表示未裁剪的部分 ( Adv t ? r t ( θ ) \text{Adv}_t \cdot r_t(\theta) Advt??rt?(θ) )。pg_loss2
表示裁剪后的部分 ( Adv t ? clip ( r t ( θ ) , 1 ? ? , 1 + ? ) \text{Adv}_t \cdot \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) Advt??clip(rt?(θ),1??,1+?) )。torch.min(pg_loss1, pg_loss2)
:實現了公式中的min
。-torch.sum(...) / mask.sum()
:加了負號,將最大化目標轉化為最小化損失,并進行歸一化處理。
這樣寫在數學意義和計算結果上都是正確的,符合 PPO 的原理。
4. 總結
- 修改后的代碼是合理的,符合 PPO 損失函數的設計。
- 去掉
pg_loss1
和pg_loss2
的負號,并在最后的pg_loss
上加負號,可以正確地實現最大化期望回報的目標。 - 使用
torch.min
選擇最保守的更新方向,保證 PPO 的穩定性。
這種寫法與標準的 PPO 損失函數完全一致,功能正確且符合原理。
DeepSpeed-Chat
中的 PPO 實現,和你提出的代碼雖然看起來形式有所不同,但兩者最終達到的效果是一樣的。從理論上講,它們都是在實現 PPO 損失函數,差別在于寫法的具體細節。
我們一步步來對比分析:
【5】那為什么DeepSpeed-Chat采用前者?在pg_loss1前面加負號?
1. DeepSpeed-Chat 中的代碼
在 DeepSpeed-Chat PPO Trainer 中,損失函數是這樣寫的:
log_ratio = (logprobs - old_logprobs)
ratio = torch.exp(log_ratio)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
損失函數分析
-advantages * ratio
:表示未裁剪的策略梯度損失(帶負號是為了最小化損失)。-advantages * torch.clamp(...)
:表示裁剪后的策略梯度損失(帶負號)。torch.max(pg_loss1, pg_loss2)
:選取兩者中 最大 的值(注意是負數中的最大值),它實現了 PPO 的min
損失 的效果。
2. 你修改的代碼
你的修改代碼為:
pg_loss1 = advantages * ratio
pg_loss2 = advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)
pg_loss = -torch.sum(torch.min(pg_loss1, pg_loss2) * mask) / mask.sum()
損失函數分析
advantages * ratio
和advantages * torch.clamp(...)
:沒有負號,表示直接的策略梯度目標。torch.min(pg_loss1, pg_loss2)
:取兩者中最小值。- 負號
-
加在最終損失上:將最大化目標轉化為最小化。
3. 兩者的數學等價性
從數學角度看:
-
DeepSpeed 的寫法 使用
torch.max(pg_loss1, pg_loss2)
和 負的優勢:- ( pg_loss1 = ? Adv t ? r t ( θ ) \text{pg\_loss1} = - \text{Adv}_t \cdot r_t(\theta) pg_loss1=?Advt??rt?(θ) )
- ( pg_loss2 = ? Adv t ? clip ( r t ( θ ) , 1 ? ? , 1 + ? ) \text{pg\_loss2} = - \text{Adv}_t \cdot \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) pg_loss2=?Advt??clip(rt?(θ),1??,1+?) )
- 取最大值:因為兩者都是負數,取負數中的最大值,相當于取正數中的最小值(原始的
min
)。
-
你的寫法 直接取
min
,然后在結果前加負號:- ( pg_loss1 = Adv t ? r t ( θ ) \text{pg\_loss1} = \text{Adv}_t \cdot r_t(\theta) pg_loss1=Advt??rt?(θ) )
- ( pg_loss2 = Adv t ? clip ( r t ( θ ) , 1 ? ? , 1 + ? ) \text{pg\_loss2} = \text{Adv}_t \cdot \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) pg_loss2=Advt??clip(rt?(θ),1??,1+?) )
- 取最小值,再加負號,得到負的損失。
為什么 DeepSpeed 選擇前者?
雖然兩種寫法是等價的,但 DeepSpeed-Chat 選擇前者的原因主要有以下幾點:
-
代碼風格和一致性
DeepSpeed 的代碼風格更偏向于將負號提前應用到pg_loss1
和pg_loss2
上。這樣可以讓損失計算的邏輯和大部分框架保持一致,方便閱讀和維護。 -
梯度計算的數值穩定性
帶負號的損失(pg_loss1
和pg_loss2
)在早期階段就應用負號,可能減少后續計算過程中的誤差傳播,更加數值穩定。 -
易于對接優化器
優化器通常最小化損失函數。如果損失的定義是負的優勢乘以比例,這樣的形式可以直接傳入優化器,避免在外部再引入負號。 -
便于調試和日志記錄
在調試過程中,直接觀察pg_loss1
和pg_loss2
的值(已經是負值)會更加直觀,可以更容易判斷未裁剪和裁剪部分的差異。
總結
- 兩者數學上等價,因為取負號和取最小值/最大值的順序不同,但本質上實現了同樣的損失函數。
- DeepSpeed 選擇前者是出于代碼風格、一致性、數值穩定性和易于調試的考量。
- 如果你修改后的代碼符合 PPO 的原理,且沒有引入其他數值問題(如梯度異常),理論上是完全可行的。
對于實際框架而言,寫法的選擇更多是代碼規范和工程設計的考慮,而不是數學上的根本區別。
后記
2024年12月14日12點36分于上海,在GPT4o大模型輔助下完成。