基于DeepSpeed Chat詳解 PPO 算法中的actor_loss_fn及其核心參數

詳解 PPO 算法中的 actor_loss_fn 及其核心參數


1. 引言

在強化學習中,PPO(Proximal Policy Optimization,近端策略優化)算法是一種經典且高效的策略優化方法。它通過重要性采樣(Importance Sampling)和策略剪切(Clipping)來平衡策略改進的幅度模型訓練的穩定性

這篇博客將重點解析 DeepSpeed Chat 的 actor_loss_fn 函數,代碼鏈接: https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py#L44,結合代碼詳細說明其作用,并通過一個具體的例子,說明模型生成的 response 是如何對應到 logprobsold_logprobsadvantagesmask 這幾個參數的。


2. 代碼解析

下面是 actor_loss_fn 函數的實現:

def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):# 計算 log 概率比log_ratio = (logprobs - old_logprobs) * maskratio = torch.exp(log_ratio)# 策略梯度損失(兩種形式)pg_loss1 = -advantages * ratiopg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)# 選擇損失中的最大值,并加權求和(masked)pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()return pg_loss
2.1 參數說明
  • logprobs:當前策略(即 Actor 網絡)在生成的 response 上計算得到的 log 概率。這是動態更新的,隨著 Actor 策略的更新而變化。
  • old_logprobs:舊策略(之前保存的 Actor 策略)在相同 response 上計算得到的 log 概率。這是固定的,不隨當前策略變化。
  • advantages:舊策略計算出的優勢函數值,表示每個 token 的相對價值。這也是固定的。
  • mask:一個布爾型張量,用于標記有效的 token(即非填充部分)。例如,response 的部分 token 有效,其余部分被填充為無效。
2.2 核心邏輯
  1. 計算 Log 概率比
    log_ratio = ( logprobs ? old_logprobs ) ? mask \text{log\_ratio} = (\text{logprobs} - \text{old\_logprobs}) \cdot \text{mask} log_ratio=(logprobs?old_logprobs)?mask
    通過當前策略和舊策略的 log 概率差,計算重要性比率的對數。

  2. 將 Log 概率比轉換為比例
    ratio = exp ? ( log_ratio ) \text{ratio} = \exp(\text{log\_ratio}) ratio=exp(log_ratio)
    轉換為實際的概率比。

  3. 定義兩種策略梯度損失

    • 未剪切的損失
      pg_loss1 = ? advantages ? ratio \text{pg\_loss1} = -\text{advantages} \cdot \text{ratio} pg_loss1=?advantages?ratio
    • 剪切后的損失
      pg_loss2 = ? advantages ? clip ( ratio , 1 ? ? , 1 + ? ) \text{pg\_loss2} = -\text{advantages} \cdot \text{clip}(\text{ratio}, 1-\epsilon, 1+\epsilon) pg_loss2=?advantages?clip(ratio,1??,1+?)
      通過限制 ( ratio \text{ratio} ratio) 在 ( [ 1 ? ? , 1 + ? ] [1-\epsilon, 1+\epsilon] [1??,1+?]) 范圍內,避免策略更新幅度過大,防止訓練不穩定。
  4. 選取最大損失
    對每個 token,取兩種損失中的較大值,以確保更新過程不會偏離策略優化目標。

  5. 加權平均損失
    使用 mask 對損失加權,避免填充值對計算的干擾。


3. 參數與 Response 的對應關系

為了更好地理解 logprobsold_logprobsadvantagesmask,我們通過一個具體例子來說明:

假設輸入
  • Prompt"Translate the following text into French: 'Hello, how are you?'"
  • Response(生成的句子):"Bonjour, comment ?a va?"
具體參數計算過程
  1. logprobs
    當前策略對 response 每個 token 的生成概率:

    • logprobs 可能為:
      [-1.2, -0.8, -1.5, -0.6, -1.0]  # 對應 "Bonjour,", "comment", "?a", "va", "?"
      
    • 它是通過當前 Actor 模型的 softmax 輸出取對數后得到的。
  2. old_logprobs
    舊策略對相同 response 的生成概率:

    • old_logprobs 可能為:
      [-1.0, -0.9, -1.4, -0.7, -0.9]  # 對應 "Bonjour,", "comment", "?a", "va", "?"
      
  3. advantages
    每個 token 的優勢值,表示生成該 token 的相對價值:

    • 例如:
      [0.5, 0.8, 0.3, 0.6, 0.4]  # 對應 "Bonjour,", "comment", "?a", "va", "?"
      
    • 這些值通過舊策略與環境交互(或獎勵模型)計算得到,表示每個 token 對整體策略的貢獻。
  4. mask
    一個布爾型標記,有效部分為 1,填充部分為 0:

    • 例如,假設 response 中沒有填充值:
      [1, 1, 1, 1, 1]
      

4. 為什么引入 logprobs 和策略剪切?

在 PPO 中,策略更新會參考舊策略,但如果更新幅度太大(即 ( ratio \text{ratio} ratio) 偏離 1 太遠),可能會導致模型“訓歪”:

  • 過度更新:策略過快改變,導致新策略偏離舊策略太遠,破壞學習穩定性。
  • 欠更新:策略改變過小,無法顯著優化目標函數。

通過引入 logprobs 和剪切機制(( clip \text{clip} clip)),我們可以控制每次更新的幅度,確保:

  • 更新不偏離舊策略太遠。
  • 改進的同時保留策略的多樣性和探索能力。

5. 實例代碼

以下是基于 PyTorch 的實現,模擬計算 actor_loss

import torch# 模擬輸入
logprobs = torch.tensor([-1.2, -0.8, -1.5, -0.6, -1.0])  # 當前策略的 log 概率
old_logprobs = torch.tensor([-1.0, -0.9, -1.4, -0.7, -0.9])  # 舊策略的 log 概率
advantages = torch.tensor([0.5, 0.8, 0.3, 0.6, 0.4])  # 每個 token 的優勢值
mask = torch.tensor([1, 1, 1, 1, 1])  # 有效部分的掩碼
cliprange = 0.2  # 剪切范圍# 計算 log 概率比和比率
log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)# 策略梯度損失
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()# 輸出結果
print("Log Ratios:", log_ratio)
print("Ratios:", ratio)
print("Policy Gradient Loss:", pg_loss.item())

6. 總結

通過這篇博客,我們詳細介紹了 actor_loss_fn 的實現邏輯及其關鍵參數在 PPO 中的作用:

  • logprobsold_logprobs 是新舊策略的橋梁,通過重要性采樣連接兩者。
  • advantages 衡量生成 token 的相對價值。
  • mask 避免無效 token 對損失的影響。
  • 引入剪切機制確保訓練穩定。

通過具體例子和代碼演示,我們更直觀地理解了這些參數與 PPO 的關系,以及它們在策略優化中的關鍵作用。

【1】pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum() 解釋

要理解這句代碼 pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum(),我們可以通過一個具體的例子,模擬實際計算過程,特別關注 mask 的作用和 mask.sum() 的意義。


代碼結構解析

  1. pg_loss1pg_loss2

    • 兩種策略梯度損失的計算:
      • pg_loss1 是未剪切的策略梯度損失。
      • pg_loss2 是剪切后的策略梯度損失(通過 torch.clamp 限制比率的范圍)。
    • torch.max(pg_loss1, pg_loss2):對每個 token,取兩種損失中的較大值。
  2. 乘以 mask

    • mask 是一個布爾型張量,用于標記哪些 token 是有效的,哪些是無效的。
    • 乘以 mask 的作用是:將無效的 token 對應的損失置為 0,避免它們對整體損失的計算產生干擾。
  3. torch.sum(... * mask)

    • 計算有效 token 的總損失,忽略無效 token。
  4. / mask.sum()

    • 歸一化:用有效 token 的數量(mask.sum())對總損失進行歸一化,確保損失是平均的,而不是直接被無效 token 的數量拉低。

實際數值模擬

假設我們有以下數據:

輸入數據
  • pg_loss1:未剪切的損失

    pg_loss1 = [-0.5, -0.8, -0.3, -1.2, -0.6]  # 每個 token 的損失
    
  • pg_loss2:剪切后的損失

    pg_loss2 = [-0.4, -0.9, -0.5, -0.8, -0.7]
    
  • mask:有效 token 標記

    mask = [1, 1, 1, 0, 0]  # 前 3 個 token 有效,后兩個 token 無效
    
計算步驟
  1. torch.max(pg_loss1, pg_loss2)
    對每個 token,取 pg_loss1pg_loss2 中的較大值(注意:負數值較小的絕對值是較大值)。

    max_loss = torch.max(pg_loss1, pg_loss2)= [-0.4, -0.8, -0.3, -0.8, -0.6]
    
  2. 乘以 mask
    通過與 mask 相乘,無效 token(mask = 0)對應的損失被置為 0:

    masked_loss = max_loss * mask= [-0.4, -0.8, -0.3, 0.0, 0.0]
    
  3. torch.sum(masked_loss)
    求有效 token 的總損失:

    total_loss = torch.sum(masked_loss)= -0.4 + (-0.8) + (-0.3) + 0.0 + 0.0= -1.5
    
  4. mask.sum()
    統計有效 token 的數量:

    mask_sum = mask.sum()= 1 + 1 + 1 + 0 + 0= 3
    
  5. 歸一化損失
    將總損失除以有效 token 數量,計算平均損失:

    pg_loss = total_loss / mask_sum= -1.5 / 3= -0.5
    

乘以 mask 的作用

在這個例子中,mask 的作用是:

  1. 將無效 token(mask = 0)對應的損失設置為 0,防止這些無效 token 對總損失產生影響。
    • 如果沒有乘以 mask,則 pg_loss 會錯誤地包含無效 token 的損失。
  2. 確保只有有效 token 的損失被納入計算。

歸一化 (/ mask.sum()) 的作用

在訓練過程中,不同 batch 的有效 token 數量可能不一樣。例如:

  • 一個 batch 的 mask 有效部分是 3 個 token。
  • 另一個 batch 的 mask 有效部分是 5 個 token。

如果不對總損失歸一化,那么:

  • 有效 token 數量多的 batch 會對損失計算產生更大的權重。
  • 歸一化后,每個 token 的貢獻是平均的,確保訓練的穩定性。

總結

通過這個例子,可以清楚地看到:

  1. torch.max(pg_loss1, pg_loss2):用于選擇更大的損失,平衡未剪切和剪切更新之間的效果。
  2. 乘以 mask:屏蔽無效 token 的損失,專注于有效部分。
  3. 歸一化 (/ mask.sum()):確保損失與有效 token 的數量無關,保持不同 batch 的訓練一致性。

最終,這段代碼實現了 PPO 策略更新時的核心邏輯:通過限制每次更新幅度,穩定地優化 Actor 策略,同時屏蔽無效 token 對損失的干擾。

【2】 為什么在 torch.max(pg_loss1, pg_loss2) 中使用 max 而不是 min

PPO(Proximal Policy Optimization)算法中的損失函數設計目標是 限制策略更新幅度,確保策略的穩定性和性能提升。這里的 max 并非簡單地取大值,而是通過裁剪(clipping)機制實現 選擇對策略有利的損失,同時約束更新幅度。


PPO 的核心損失函數回顧

PPO的目標函數為:
L CLIP ( θ ) = E [ min ? ( r t ( θ ) A ^ t , clip ( r t ( θ ) , 1 ? ? , 1 + ? ) A ^ t ) ] L^{\text{CLIP}}(\theta) = \mathbb{E} \left[ \min \left( r_t(\theta) \hat{A}_t, \, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \hat{A}_t \right) \right] LCLIP(θ)=E[min(rt?(θ)A^t?,clip(rt?(θ),1??,1+?)A^t?)]

其中:

  • ( r t ( θ ) = π θ ( a t ∣ s t ) π old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\text{old}}(a_t | s_t)} rt?(θ)=πold?(at?st?)πθ?(at?st?)? ):新舊策略的比例。
  • ( A ^ t \hat{A}_t A^t? ):優勢值(Advantage)。
  • ( ? \epsilon ? ):裁剪范圍(比如 ( 0.2 0.2 0.2 ))。
  • clip:將比例 ( r t r_t rt? ) 裁剪到 ( [ 1 ? ? , 1 + ? ] [1 - \epsilon, 1 + \epsilon] [1??,1+?] )。

PPO 的損失函數通過裁剪機制,限制策略更新幅度,防止模型發生 過大更新,從而訓練不穩定。

在實現中,pg_loss1pg_loss2 分別為:

  • pg_loss1 = - advantages * ratio(未裁剪的損失)
  • pg_loss2 = - advantages * clipped_ratio(裁剪后的損失)

為什么使用 max 而不是 min

  1. 損失的符號
    PPO 使用的是 負的優勢值(-Advantage) 來計算損失,因為優化目標是 最小化損失,而優勢值的正負決定了動作的好壞。

    • 當 ( A ^ t > 0 \hat{A}_t > 0 A^t?>0 ):表示動作優于基準,應該增加概率。
    • 當 ( A ^ t < 0 \hat{A}_t < 0 A^t?<0 ):表示動作劣于基準,應該減少概率。
  2. max 的意義
    在兩種情況下,PPO 的目標是:

    • 對于好的動作(優勢 ( > 0 > 0 >0 )),我們允許更新比例 ( r t r_t rt? ) 稍微增大,但不能超出裁剪范圍。
    • 對于壞的動作(優勢 ( < 0 < 0 <0 )),我們允許更新比例 ( r t r_t rt? ) 稍微減小,但不能超出裁剪范圍。

    通過 max,我們實現了以下效果:

    • 當策略偏離較少時,使用未裁剪的損失(pg_loss1)。
    • 當策略偏離較多時,使用裁剪后的損失(pg_loss2),限制更新幅度。
  3. 為什么不是 min
    如果使用 min,會選擇較小的損失值,反而可能導致:

    • 對好的動作,更新不足(不增大概率)。
    • 對壞的動作,更新過度(減少概率過大)。
      這將違背 PPO 的設計目標,無法有效地限制策略更新幅度,導致訓練不穩定。

數值示例分析

假設我們有以下輸入數據:

  • 優勢值(Advantage):[2.0, -1.0]
  • 比例(ratio)[1.1, 0.7]
  • 裁剪范圍[0.8, 1.2]

根據 PPO 公式:

  1. 未裁剪的損失pg_loss1):
    pg_loss1 = ? advantages × ratio \text{pg\_loss1} = - \text{advantages} \times \text{ratio} pg_loss1=?advantages×ratio

  2. 裁剪后的損失pg_loss2):
    pg_loss2 = ? advantages × clip ( ratio , 1 ? ? , 1 + ? ) \text{pg\_loss2} = - \text{advantages} \times \text{clip}(\text{ratio}, 1 - \epsilon, 1 + \epsilon) pg_loss2=?advantages×clip(ratio,1??,1+?)


代碼實現

import torch# 輸入數據
advantages = torch.tensor([2.0, -1.0])  # 優勢值
ratio = torch.tensor([1.1, 0.7])  # 更新比例
cliprange = 0.2  # 裁剪范圍# 計算未裁剪的損失
pg_loss1 = -advantages * ratio# 計算裁剪后的損失
clipped_ratio = torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
pg_loss2 = -advantages * clipped_ratio# 取 max
final_loss = torch.max(pg_loss1, pg_loss2)print("pg_loss1 (未裁剪):", pg_loss1)
print("pg_loss2 (裁剪后):", pg_loss2)
print("Final Loss (max):", final_loss)

輸出結果

pg_loss1 (未裁剪): tensor([-2.2000,  0.7000])
pg_loss2 (裁剪后): tensor([-2.2000,  0.8000])
Final Loss (max): tensor([-2.2000,  0.8000])

解釋

  1. 第一個樣本(優勢 > 0)

    • pg_loss1 = -2.2(未裁剪)
    • pg_loss2 = -2.2(裁剪后的損失與未裁剪相等)
    • max,最終損失為 -2.2
  2. 第二個樣本(優勢 < 0)

    • pg_loss1 = 0.7(未裁剪)
    • pg_loss2 = 0.8(裁剪后的比例被限制為 0.8)
    • max,最終損失為 0.8,表示更新受到了裁剪限制。

結論

  • max(pg_loss1, pg_loss2) 的作用:確保在未裁剪損失和裁剪損失之間取較大值,限制策略的更新幅度,穩定訓練過程。
  • 損失函數關系
    • 對于 優勢 ( > 0 ),我們防止概率增大過多。
    • 對于 優勢 ( < 0 ),我們防止概率減小過多。
  • 使用 min 會導致訓練不穩定,無法有效地限制更新幅度,違反 PPO 的核心設計理念。

PPO 損失函數前面是否有負號取決于損失函數的最優化目標,即最大化策略性能還是最小化損失。


【3】為什么有的資料ppo的損失函數前面有負號,-min(),而有的材料前面沒有負號?

1. 從優化目標的角度解釋

  • 強化學習的目標:最大化策略的累積回報 ( R t R_t Rt? )。
    由于優化器(例如 Adam)默認是最小化損失,我們需要對最大化目標取負數,從而轉換為最小化損失問題。

因此:

  • 如果目標是最大化策略性能,則公式為:
    maximize? L CLIP ( θ ) \text{maximize } L^{\text{CLIP}}(\theta) maximize?LCLIP(θ)
  • 為了與優化器配合(最小化),我們取負號:
    minimize? ? L CLIP ( θ ) \text{minimize } -L^{\text{CLIP}}(\theta) minimize??LCLIP(θ)

帶負號的形式
actor_loss = ? min ? ( r t ? Adv t , clip ( r t , 1 ? ? , 1 + ? ) ? Adv t ) \text{actor\_loss} = - \min \left( r_t \cdot \text{Adv}_t, \, \text{clip}(r_t, 1 - \epsilon, 1 + \epsilon) \cdot \text{Adv}_t \right) actor_loss=?min(rt??Advt?,clip(rt?,1??,1+?)?Advt?)


2. 從損失函數的數值角度

  • 優勢值 ( Adv t \text{Adv}_t Advt? )
    • 當 ( Adv t > 0 \text{Adv}_t > 0 Advt?>0 ):表示當前動作優于基準,應該增大概率。
    • 當 ( Adv t < 0 \text{Adv}_t < 0 Advt?<0 ):表示當前動作劣于基準,應該減小概率。

如果前面有負號:

  • 最大化優勢值:對于 ( Adv t > 0 \text{Adv}_t > 0 Advt?>0 ),損失會減小(概率增大);對于 ( Adv t < 0 \text{Adv}_t < 0 Advt?<0 ),損失會增大(概率減小)。

如果前面沒有負號:

  • 直接是獎勵優化形式,適合展示最大化回報的目標,但在實際代碼中優化器不直接適配這種形式。

3. 兩種形式的適用場景

  • 帶負號的形式:通常出現在代碼實現中,因為優化器默認執行最小化損失。
  • 不帶負號的形式:出現在理論公式或論文中,強調最大化目標。

結論

帶負號的形式是為了將“最大化回報”轉化為優化器的最小化損失問題,而不帶負號的形式僅僅表示理論上最大化優勢目標。實際實現中,大多數代碼會包含負號。


【4】把pg_loss1和2的負號去掉,最后pg_loss取min再加負號,這樣可以嗎?

下面是我問GPT的問題:

def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):## policy gradient losslog_ratio = (logprobs - old_logprobs) * maskratio = torch.exp(log_ratio)pg_loss1 = advantages * ratio pg_loss2 = advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)pg_loss = -torch.sum(torch.min(pg_loss1, pg_loss2) * mask) / mask.sum() return pg_loss

代碼修改成這樣可以嗎?把pg_loss1和2的負號去掉,最后pg_loss取min,再加負號,這樣可以嗎?或者不可以嗎?為什么

修改后的代碼是可以的!在 PPO 的優化目標中,加負號min 的組合恰好符合 PPO 損失函數的設計。以下會詳細解釋為什么這樣修改是合理的。


1. PPO 損失函數的原理

PPO 的目標是最大化期望回報,同時限制策略更新的幅度,以保證訓練的穩定性。其目標函數為:
L PPO ( θ ) = E t [ min ? ( r t ( θ ) ? Adv t , clip ( r t ( θ ) , 1 ? ? , 1 + ? ) ? Adv t ) ] L^{\text{PPO}}(\theta) = \mathbb{E}_t \left[ \min\left(r_t(\theta) \cdot \text{Adv}_t, \, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \cdot \text{Adv}_t\right) \right] LPPO(θ)=Et?[min(rt?(θ)?Advt?,clip(rt?(θ),1??,1+?)?Advt?)]
其中:

  • ( r t ( θ ) = π θ ( a t ∣ s t ) π old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\text{old}}(a_t|s_t)} rt?(θ)=πold?(at?st?)πθ?(at?st?)? ) 是新舊策略的概率比。
  • ( Adv t \text{Adv}_t Advt? ) 是優勢函數,表示當前動作相對于平均水平的好壞。
  • 裁剪操作:通過 ( clip ( r t ( θ ) , 1 ? ? , 1 + ? ) \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) clip(rt?(θ),1??,1+?) ) 限制比率的變化幅度。

為了最大化 ( L PPO L^{\text{PPO}} LPPO ),通常將其轉化為一個最小化問題,即在前面加上負號:
actor_loss = ? E t [ min ? ( r t ( θ ) ? Adv t , clip ( r t ( θ ) , 1 ? ? , 1 + ? ) ? Adv t ) ] \text{actor\_loss} = -\mathbb{E}_t \left[ \min\left(r_t(\theta) \cdot \text{Adv}_t, \, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \cdot \text{Adv}_t\right) \right] actor_loss=?Et?[min(rt?(θ)?Advt?,clip(rt?(θ),1??,1+?)?Advt?)]


2. 代碼分析

修改后的代碼:

log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)
pg_loss1 = advantages * ratio
pg_loss2 = advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)
pg_loss = -torch.sum(torch.min(pg_loss1, pg_loss2) * mask) / mask.sum()
return pg_loss
主要修改點:
  1. 去掉了 pg_loss1pg_loss2 的負號。
  2. 在最終計算 pg_loss 時,加了負號。
  3. 使用 torch.min(pg_loss1, pg_loss2)
分析:
  1. 去掉 pg_loss1pg_loss2 的負號
    這里 pg_loss1pg_loss2 都表示未裁剪和裁剪后的損失,原始 PPO 損失的優化目標是最大化優勢乘以比率。直接使用優勢和比率(不加負號)可以保留目標的最大化形式。

  2. 加負號到 pg_loss 前面
    PPO 的目標是最大化 ( L PPO L^{\text{PPO}} LPPO ),但優化器通常是最小化損失函數,因此需要加負號將最大化目標轉化為最小化問題。

  3. 使用 torch.min(pg_loss1, pg_loss2)

    • pg_loss1:未裁剪的情況 ( Adv t ? r t ( θ ) \text{Adv}_t \cdot r_t(\theta) Advt??rt?(θ) )。
    • pg_loss2:裁剪后的情況 ( Adv t ? clip ( r t ( θ ) , 1 ? ? , 1 + ? ) \text{Adv}_t \cdot \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) Advt??clip(rt?(θ),1??,1+?) )。
    • min:選取兩者中較小的值,目的是在不損害回報的情況下,限制策略更新幅度,從而保證訓練的穩定性。

3. 為什么這樣可以?

PPO 的損失函數最終形式是:
actor_loss = ? E t [ min ? ( r t ( θ ) ? Adv t , clip ( r t ( θ ) , 1 ? ? , 1 + ? ) ? Adv t ) ] \text{actor\_loss} = -\mathbb{E}_t \left[ \min\left(r_t(\theta) \cdot \text{Adv}_t, \, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \cdot \text{Adv}_t\right) \right] actor_loss=?Et?[min(rt?(θ)?Advt?,clip(rt?(θ),1??,1+?)?Advt?)]

與修改后的代碼一一對應:

  • pg_loss1 表示未裁剪的部分 ( Adv t ? r t ( θ ) \text{Adv}_t \cdot r_t(\theta) Advt??rt?(θ) )。
  • pg_loss2 表示裁剪后的部分 ( Adv t ? clip ( r t ( θ ) , 1 ? ? , 1 + ? ) \text{Adv}_t \cdot \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) Advt??clip(rt?(θ),1??,1+?) )。
  • torch.min(pg_loss1, pg_loss2):實現了公式中的 min
  • -torch.sum(...) / mask.sum():加了負號,將最大化目標轉化為最小化損失,并進行歸一化處理。

這樣寫在數學意義和計算結果上都是正確的,符合 PPO 的原理。


4. 總結

  • 修改后的代碼是合理的,符合 PPO 損失函數的設計。
  • 去掉 pg_loss1pg_loss2 的負號,并在最后的 pg_loss 上加負號,可以正確地實現最大化期望回報的目標。
  • 使用 torch.min 選擇最保守的更新方向,保證 PPO 的穩定性。

這種寫法與標準的 PPO 損失函數完全一致,功能正確且符合原理。

DeepSpeed-Chat 中的 PPO 實現,和你提出的代碼雖然看起來形式有所不同,但兩者最終達到的效果是一樣的。從理論上講,它們都是在實現 PPO 損失函數,差別在于寫法的具體細節。

我們一步步來對比分析:


【5】那為什么DeepSpeed-Chat采用前者?在pg_loss1前面加負號?

1. DeepSpeed-Chat 中的代碼

在 DeepSpeed-Chat PPO Trainer 中,損失函數是這樣寫的:

log_ratio = (logprobs - old_logprobs)
ratio = torch.exp(log_ratio)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
損失函數分析
  • -advantages * ratio:表示未裁剪的策略梯度損失(帶負號是為了最小化損失)。
  • -advantages * torch.clamp(...):表示裁剪后的策略梯度損失(帶負號)。
  • torch.max(pg_loss1, pg_loss2):選取兩者中 最大 的值(注意是負數中的最大值),它實現了 PPO 的 min 損失 的效果。

2. 你修改的代碼

你的修改代碼為:

pg_loss1 = advantages * ratio
pg_loss2 = advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)
pg_loss = -torch.sum(torch.min(pg_loss1, pg_loss2) * mask) / mask.sum()
損失函數分析
  • advantages * ratioadvantages * torch.clamp(...):沒有負號,表示直接的策略梯度目標。
  • torch.min(pg_loss1, pg_loss2):取兩者中最小值。
  • 負號 - 加在最終損失上:將最大化目標轉化為最小化。

3. 兩者的數學等價性

從數學角度看:

  1. DeepSpeed 的寫法 使用 torch.max(pg_loss1, pg_loss2)負的優勢

    • ( pg_loss1 = ? Adv t ? r t ( θ ) \text{pg\_loss1} = - \text{Adv}_t \cdot r_t(\theta) pg_loss1=?Advt??rt?(θ) )
    • ( pg_loss2 = ? Adv t ? clip ( r t ( θ ) , 1 ? ? , 1 + ? ) \text{pg\_loss2} = - \text{Adv}_t \cdot \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) pg_loss2=?Advt??clip(rt?(θ),1??,1+?) )
    • 取最大值:因為兩者都是負數,取負數中的最大值,相當于取正數中的最小值(原始的 min)。
  2. 你的寫法 直接取 min,然后在結果前加負號:

    • ( pg_loss1 = Adv t ? r t ( θ ) \text{pg\_loss1} = \text{Adv}_t \cdot r_t(\theta) pg_loss1=Advt??rt?(θ) )
    • ( pg_loss2 = Adv t ? clip ( r t ( θ ) , 1 ? ? , 1 + ? ) \text{pg\_loss2} = \text{Adv}_t \cdot \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) pg_loss2=Advt??clip(rt?(θ),1??,1+?) )
    • 取最小值,再加負號,得到負的損失。

為什么 DeepSpeed 選擇前者?

雖然兩種寫法是等價的,但 DeepSpeed-Chat 選擇前者的原因主要有以下幾點:

  1. 代碼風格和一致性
    DeepSpeed 的代碼風格更偏向于將負號提前應用到 pg_loss1pg_loss2 上。這樣可以讓損失計算的邏輯和大部分框架保持一致,方便閱讀和維護。

  2. 梯度計算的數值穩定性
    帶負號的損失(pg_loss1pg_loss2)在早期階段就應用負號,可能減少后續計算過程中的誤差傳播,更加數值穩定。

  3. 易于對接優化器
    優化器通常最小化損失函數。如果損失的定義是負的優勢乘以比例,這樣的形式可以直接傳入優化器,避免在外部再引入負號。

  4. 便于調試和日志記錄
    在調試過程中,直接觀察 pg_loss1pg_loss2 的值(已經是負值)會更加直觀,可以更容易判斷未裁剪和裁剪部分的差異。


總結

  • 兩者數學上等價,因為取負號和取最小值/最大值的順序不同,但本質上實現了同樣的損失函數。
  • DeepSpeed 選擇前者是出于代碼風格、一致性、數值穩定性和易于調試的考量。
  • 如果你修改后的代碼符合 PPO 的原理,且沒有引入其他數值問題(如梯度異常),理論上是完全可行的。

對于實際框架而言,寫法的選擇更多是代碼規范和工程設計的考慮,而不是數學上的根本區別。

后記

2024年12月14日12點36分于上海,在GPT4o大模型輔助下完成。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/63145.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/63145.shtml
英文地址,請注明出處:http://en.pswp.cn/web/63145.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

D3 基礎1

D3 D3.js (Data-Driven Documents) 是一個基于 JavaScript 的庫&#xff0c;用于生成動態、交互式數據可視化。它通過操作文檔對象模型 (DOM) 來生成數據驅動的圖形。官方網站是 https://d3js.org/ <!DOCTYPE html> <html lang"en"><head><me…

基線檢查:Windows安全基線.【手動 || 自動】

基線定義 基線通常指配置和管理系統的詳細描述&#xff0c;或者說是最低的安全要求&#xff0c;它包括服務和應用程序設置、操作系統組件的配置、權限和權利分配、管理規則等。 基線檢查內容 主要包括賬號配置安全、口令配置安全、授權配置、日志配置、IP通信配置等方面內容&…

Python -- Linux中的Matplotlib圖中無法顯示中文 (中文為方框)

目的 用matplotlib生成的圖中文無法正常顯示 方法 主要原因: 沒找到字體 進入windows系統的C:\Windows\Fonts目錄, 復制自己想要的字體 粘貼到Linux服務器中對應python文件所處的文件夾內 設置字體: 設置好字體文件的路徑在需要對字體設置的地方設置字體 效果 中文正常顯…

快速理解類的加載過程

當程序主動使用某個類時&#xff0c;如果該類還未加載到內存中&#xff0c;則系統會通過如下三個步驟來對該類進行初始化&#xff1a; 1.加載&#xff1a;將class文件字節碼內容加載到內存中&#xff0c;并將這些靜態數據轉換成方法區的運行時數據結構&#xff0c;然后生成一個…

搭建 Elasticsearch 集群:完整教程

本文將詳細介紹如何在 Linux 環境下搭建一個 Elasticsearch 集群&#xff0c;涵蓋環境準備、配置優化、服務啟動等多個環節。 一、環境準備 創建安裝目錄 mkdir /es cd /es解壓 Elasticsearch 安裝包 tar -xzf elasticsearch-7.10.1-linux-x86_64.tar.gz -C /es配置環境變量 編…

寶塔-docker拉取寶塔鏡像,并運行寶塔鏡像

寶塔-拉取寶塔鏡像&#xff0c;并運行鏡像 第1步&#xff1a;查詢 docker search btpanel/baota此docker鏡像由堡塔安全官方發布&#xff0c;鏡像版本為寶塔面板9.2.0正式版和9.0.0_lts 穩定版&#xff0c;鏡像會隨著寶塔面板更新。 目前支持x86_64和arm架構可供下載使用 版本…

使用 Valgrind 檢測 C 程序中的內存問題 -基礎教程

內存泄漏是許多 C 語言程序中的常見問題&#xff0c;它不僅會導致程序性能下降&#xff0c;甚至可能讓系統崩潰。為了檢測和修復這些問題&#xff0c;Valgrind 是一個非常強大的工具&#xff0c;它可以幫助我們分析 C 程序中的內存使用情況&#xff0c;檢測內存泄漏、越界訪問、…

窮舉vs暴搜vs深搜vs回溯vs剪枝專題一>子集

題目&#xff1a; 兩個方法本質就是決策樹的畫法不同 方法一解析&#xff1a; 代碼&#xff1a; class Solution {private List<List<Integer>> ret;//返回結果private List<Integer> path;//記錄路徑&#xff0c;注意返回現場public List<List<Int…

leecode雙指針部分題目

leecode雙指針部分題目 1. 驗證回文串2. 判斷子序列3. 兩數之和 II - 輸入有序數組4. 盛最多水的容器5. 三數之和 1. 驗證回文串 如果在將所有大寫字符轉換為小寫字符、并移除所有非字母數字字符之后&#xff0c;短語正著讀和反著讀都一樣。則可以認為該短語是一個 回文串 。 …

Web 應用如何使用sqlite?使用 sql.js 實現前端 SQLite 數據庫操作

前言 在 Web 應用開發中&#xff0c;前端數據處理的重要性日益增加。為了實現更高效的前端數據管理&#xff0c;特別是在處理結構化數據時&#xff0c;sql.js 提供了一個出色的解決方案。sql.js 是將 SQLite 數據庫編譯為 JavaScript 的庫&#xff0c;允許開發者在瀏覽器環境中…

docker 安裝 mysql8.0容器外無法連接

文章目錄 概要問題描述解決方案其他命令 概要 主要是mysql5.7和mysql8.0的兼容性問題。 排查了很久 其實就是配置文件的一句話的事情 感覺mysql8.0更為嚴謹 這樣可能是考慮杜絕一些漏洞吧 問題描述 在容器內 netstat -an | grep 3306 都不行 在容器外 netstat -an | grep 2…

TCP協議簡單分析和握手揮手過程

TCP介紹 TCP是可靠的傳輸層協議&#xff0c;建立連接之前會經歷3次握手的階段。 確認機制&#xff1a;接受方 收到數據之后會向 發送方 回復ACK重傳機制&#xff1a;發送方 在一定時間內沒有收到 接收方的ACK就會重新發送 握手目的&#xff1a;與端口建立連接 TCP的三次握手 …

VisualStudio vsix插件自動加載

本文介紹如何在Visual Studio擴展中實現PackageRegistration&#xff0c;包括設置UseManagedResourcesOnly為true&#xff0c;允許背景加載&#xff0c;并針對C#、VB、F#項目提供自動裝載&#xff0c;附官方文檔鏈接。增加以下特性即可…… [PackageRegistration(UseManagedRe…

opencv所有常見函數

一、opencv圖像操作 二、opencv圖像的數值運算 三、opencv圖像的放射變換 四、opencv空間域圖像濾波 五、圖像灰度化與直方圖 六、形態學圖像處理 七、閾值處理與邊緣檢測 八、輪廓和模式匹配

【Excel】單元格分列

目錄 分列&#xff08;新手友好&#xff09; 1. 選中需要分列的單元格后&#xff0c;選擇 【數據】選項卡下的【分列】功能。 2. 按照分列向導提示選擇適合的分列方式。 3. 分好就是這個樣子 智能分列&#xff08;進階&#xff09; 高級分列 Tips&#xff1a; 新手推薦基…

【STM32練習】基于STM32的PM2.5環境監測系統

一.項目背景 最近為了完成老師交付的任務&#xff0c;遂重制了一下小項目用STM32做一個小型的環境監測系統。 項目整體示意框圖如下&#xff1a; 二.器件選擇 單片機&#xff08;STM32F103&#xff09;數字溫濕度模塊&#xff08;DHT11&#xff09;液晶顯示模塊&#xff08;0.8…

《開源數據:開啟信息共享與創新的寶藏之門》

《開源數據&#xff1a;開啟信息共享與創新的寶藏之門》 一、開源數據概述&#xff08;一&#xff09;開源數據的定義&#xff08;二&#xff09;開源數據的發展歷程 二、開源數據的優勢&#xff08;一&#xff09;成本效益優勢&#xff08;二&#xff09;靈活性與可定制性&…

ReactPress最佳實踐—搭建導航網站實戰

Github項目地址&#xff1a;https://github.com/fecommunity/easy-blog 歡迎Star。 近期&#xff0c;阮一峰在科技愛好者周刊第 325 期中推薦了一款開源工具——ReactPress&#xff0c;ReactPress一個基于 Next.js 的博客和 CMS 系統&#xff0c;可查看 demo站點。&#xff08;…

2024,大模型殺進“決賽圈”

Henry Chesbrough在著作《通過技術創新盈利勢在必行》中&#xff0c;曾提出過一個創新的“漏斗模型”。開放式創新一開始鼓勵百花齊放&#xff0c;但最終只有10%的技術能夠通過這個漏斗&#xff0c;成功抵達目標市場target market&#xff0c;進入到商業化與產業化的下一個階段…

STM8單片機學習筆記·GPIO的片上外設寄存器

目錄 前言 IC基本定義 三極管基礎知識 單片機引腳電路作用 STM8GPIO工作模式 GPIO外設寄存器 寄存器含義用法 CR1&#xff1a;Control Register 1 CR2&#xff1a;Control Register 2 ODR&#xff1a;Output Data Register IDR&#xff1a;Input Data Register 賦值…