PPO和GRPO算法

????????verl 是現在非常火的 rl 框架,而且已經支持了多個 rl 算法(ppo、grpo 等等)。

????????過去對 rl 的理解很粗淺(只知道有好多個角色,有的更新權重,有的不更新),也曾硬著頭皮看了一些論文和知乎,依然有很多細節不理解,現在準備跟著 verl 的代碼梳理一遍兩個著名的 rl 算法,畢竟代碼不會隱藏任何細節!

????????雖然 GRPO 算法是基于 PPO 算法改進來的,但是畢竟更簡單,所以我先從 GRPO 的流程開始學習,然后再看 PPO。

GRPO 論文中的展示的總體流程:

論文中這張圖主要展示了 GRPO 和 PPO 的區別,隱藏了其他的細節。

圖中只能注意到以下幾個關鍵點:

  • 沒有 Value Model 和輸出 v(value)

  • 同一個 q 得出了一組的 o(從 1 到 G)

  • 計算 A(Advantage) 的算法從 GAE 變成了 Group Computation

  • KL 散度計算不作用于 Reward Model,而是直接作用于 Policy Model

????????其他細節看不懂,結合論文也依然比較抽象,因為我完全沒有 RL 的知識基礎,下文中我們結合代碼會再一次嘗試理解。

????????下面是我根據 verl 代碼自己 DIY 的流程圖(幫助理解):

01?第一步:Rollout

????????第一步是 rollout,rollout 是一個強化學習專用詞匯,指的是從一個特定的狀態按照某個策略進行一些列動作和狀態轉移。

????????在 LLM 語境下,“某個策略”就是 actor model 的初始狀態,“進行一些列動作”指的就是推理,即輸入 prompt 輸出 response 的過程。

verl/trainer/ppo/ray_trainer.py:

gen_batch_output?= self.actor_rollout_wg.generate_sequences(gen_batch)

????????其背后的實現一般就是是 vllm 或 sglang 這些常見推理框架的離線推理功能,這部分功能相對獨立我們先不展開。

權重同步

????????一個值得注意的細節是代碼里面的?rollout_sharding_manager?實現,它負責每一個大 step 結束后把剛剛訓練好的 actor model 參數更新到 vllm 或 sglang。

????????這樣下一個大 step 的 rollout 采用的就是最新的模型權重(最新的策略)了。

????????這是每一個大 step 里面真正要做的第一件事,在真正執行 rollout 之前。

????????verl/workers/fsdp_workers.py:

class?ActorRolloutRefWorker(Worker):? ?# ...? ??@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)? ?? def?generate_sequences(self,?prompts:?DataProto):? ? ? ?# ...? ? ? ? with?self.rollout_sharding_manager:? ? ? ? ? ??# ...? ? ? ? ? ? prompts =?self.rollout_sharding_manager.preprocess_data(prompts)? ? ? ? ? ?output =?self.rollout.generate_sequences(prompts=prompts)? ? ? ? ? ? output =?self.rollout_sharding_manager.postprocess_data(output)

rollout_sharding_manager?的基類是?BaseShardingManager。

verl/workers/sharding_manager/base.py:

class?BaseShardingManager:? ?def?__enter__(self):? ? ? ? pass? ??def?__exit__(self, exc_type, exc_value, traceback):? ? ? ? pass? ??def?preprocess_data(self,?data:?DataProto) ->?DataProto:? ? ? ??return?data? ??def?postprocess_data(self,?data:?DataProto) ->?DataProto:? ? ? ??return?data

??BaseShardingManager?的派生類在各自的?__enter__?方法中實現了把 Actor Model 的權重 Sync 到 Rollout 實例的邏輯,以保證被?with self.rollout_sharding_manager?包裹的預處理和推理邏輯都是用的最新 Actor Model 權重。

推理 N 次

????????此外,GRPO 算法要求對每一個 prompt 都生成多個 response,后續才能根據組間對比得出相對于平均的優勢(Advantage)。

verl/trainer/config/ppo_trainer.yaml:

actor_rollout_ref:??rollout:? ??# number of responses (i.e. num sample times)? ?n:?1?# >?1?for grpo

????????在?_build_rollout?的時候?actor_rollout_ref.rollout.n?被傳給了?vLLMRollout?或其他的 Rollout 實現中,從而推理出?n?組 response。

verl/workers/fsdp_workers.py:

class?ActorRolloutRefWorker(Worker):? ??def?_build_rollout(self, trust_remote_code=False):? ? ? ??# ...? ? ? ??elif?rollout_name ==?"vllm":? ? ? ? ? ??# ...? ? ? ? ? ??if?vllm_mode ==?"customized":? ? ? ? ? ? ? ? rollout = vLLMRollout(? ? ? ? ? ? ? ? ? ?actor_module=self.actor_module_fsdp,? ? ? ? ? ? ? ? ? ?               config=self.config.rollout,? ? ? ? ? ? ? ? ? ?tokenizer=self.tokenizer,? ? ? ? ? ? ? ? ? ? 
model_hf_config=self.actor_model_config,? ? ? ? ? ? ? ?)

02?第二步:計算 log prob

????????log 是 logit,prob 是 probability,合起來就是對數概率,舉一個簡單的例子來說明什么是 log prob:

詞表僅有?5?個詞:? ? 
<pad> (ID?0)? ? 
你好 (ID?1)? ? 
世界 (ID?2)? ?
! (ID?3)? ? 
嗎 (ID?4)
prompt:你好
prompt?tokens: [1]
response:世界!
response?tokens: [2,3]
模型前向傳播得到完整的 logits 張量:
[? ? [-1.0,?0.5,?2.0, -0.5, -1.5], ? ?// 表示 “你好” 后接 “世界” 概率最高,數值為 2.0? ? [-2.0, -1.0,?0.1,?3.0,?0.2] ? ? ?// 表示 “你好世界” 后接 “!” 概率最高,數值為 3.0]
對每個 logit 計算 softmax 得到:
[? ? [-3.65, -2.15, -0.64, -3.15, -4.08],? ? [-4.34, -3.32, -2.20, -0.20, -2.10]]
提取實際 response 對應的數值:得到 log_probs:
[-0.64, -0.20]

總結下來:

  • 首先計算 prompt + response(來自 rollout)的完整 logits,即每一個 token 的概率分布

  • 截取 response 部分的 logits

  • 對每一個 logits 計算 log_sofmax(先 softmax,然后取對數),取出最終預測的 token 對應的 log_sofmax

  • 最終輸出 old_log_probs, size = [batchsize, seq_len]

????????此處你可能會有一個疑惑:在上一步 Rollout 的時候我們不是已經進行過完整 batch 的推理了么?

????????為什么現在還要重復進行一次 forward 來計算 log_prob,而不是在 generate 的過程中就把 log_prob 保存下來?

答:因為 generate_sequences 階段為了高效推理,不會保存每一個 token 的 log_prob,相反只關注整個序列的 log_prob。因此需要重新算一遍。

答:另外,vllm 官方 Q&A 中提到了 vllm 框架并不保證 log_probs 的穩定性。因為 pytorch 的 numerical instability 與 vllm 的并發批處理策略導致每一個 token 的 logits/log_probs 結果會略有不同,假如某一個 token 位采樣了不同 token id,那么這個誤差在后續還會被繼續累加。我們在訓練過程需要保證 log_probs 的穩定性,因此需要根據已經確定的 token id(即 response)再次 forward 一遍。

old log prob

verl/workers/fsdp_workers.py:

old_log_prob?= self.actor_rollout_wg.compute_log_prob(batch)

????????指 Actor Model 對整個 batch 的數據(prompt + response)進行 forward 得到的 log_prob

????????此處的 “old” 是相對于后續的 actor update 階段,因為現在 actor model 還沒有更新,所以依然采用的是舊策略 (ps:當前 step 的“舊策略”也是上一個大 step 的“新策略”)

ref log prob

verl/trainer/ppo/ray_trainer.py:

ref_log_prob?= self.ref_policy_wg.compute_ref_log_prob(batch)

????????指 Ref Model 對整個 batch 的數據(prompt + response)進行 forward 得到的 log_prob。

????????通常 Ref Model 就是整個強化學習開始之前 Actor Model 最初的模樣,換句話說第一個大 step 開始的時候 Actor Model == Ref Model,且 old_log_prob == ref_log_prob。

????????Ref Model 的作用是在后續計算 policy loss 之前,計算 KL 散度并作用于 policy loss,目的是讓 actor model 不要和最初的 ref model 相差太遠。

03第三步:advantage

????????advantage 是對一個策略的好壞最直接的評價,其背后就是 Reward Model,甚至也許不是一個 Model,而是一個粗暴的 function,甚至一個 sandbox 把 prompt+response 執行后得出的結果。

????????在 verl 中允許使用上述多種 Reward 方案中的一種或多種,并把得出的 score 做合。

verl/trainer/ppo/ray_trainer.py:

# compute reward model score
if?self.use_rm:? ? reward_tensor =?self.rm_wg.compute_rm_score(batch)? ? batch = batch.union(reward_tensor)
if?self.config.reward_model.launch_reward_fn_async:? ? future_reward = compute_reward_async.remote(batch,?self.config,?self.tokenizer)
else:? ?reward_tensor, reward_extra_infos_dict =?compute_reward(batch,?self.reward_fn)

然后用這個 score 計算最終的 advantage。

verl/trainer/ppo/ray_trainer.py:

# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get(? ??"norm_adv_by_std_in_grpo",?True) ?
# GRPO adv normalization factorbatch = compute_advantage(? ? batch,? ? 
adv_estimator=self.config.algorithm.adv_estimator,? ?gamma=self.config.algorithm.gamma,? ? 
lam=self.config.algorithm.lam,? ? 
num_repeat=self.config.actor_rollout_ref.rollout.n,? ? norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,)

04第四步:actor update(小循環)

????????在 PPOTrainer 中簡單地一行調用,背后可是整個 GRPO 算法中最關鍵的步驟:

actor_output?= self.actor_rollout_wg.update_actor(batch)

????????在這里,會把上面提到的整個 batch 的數據再根據?actor_rollout_ref.actor.ppo_mini_batch_size?配置的值拆分成很多個 mini batch。

????????然后對每一個 mini batch 數據進行一輪 forward + backward + optimize step,也就是小 step。

new log prob

????????每一個小 step 中首先會對 mini batch 的數據計算(new)log_prob,第一個小 step 得到的值還是和 old_log_prob 一模一樣的。

pg_loss

????????然后通過輸入所有 Group 的 Advantage 以新舊策略的概率比例(old_log_prob 和 log_prob),得出 pg_loss(Policy Gradient),這是最終用于 backward 的 policy loss 的基礎部分。

????????再次描述一下 pg_loss 的意義,即衡量當前策略(log_prob)相比于舊策略(old_log_prob),在當前優勢函數(advantage)指導下的改進程度。

verl/workers/actor/dp_actor.py:

pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(? ? old_log_prob=old_log_prob,? ? 
log_prob=log_prob,? ? 
advantages=advantages,? ? 
response_mask=response_mask,? ? 
cliprange=clip_ratio,? ? 
cliprange_low=clip_ratio_low,? ? 
cliprange_high=clip_ratio_high,? ? 
clip_ratio_c=clip_ratio_c,? ? 
loss_agg_mode=loss_agg_mode,)

entropy loss

????????entropy?指策略分布的熵 (Entropy):策略對選擇下一個動作(在這里是下一個 token)的不確定性程度。

????????熵越高,表示策略輸出的概率分布越均勻,選擇各個動作的概率越接近,策略的探索性越強;熵越低,表示策略越傾向于選擇少數幾個高概率的動作,確定性越強。

? entropy_loss?指 entropy 的 平均值,是一個標量,表示探索性高低。

verl/workers/actor/dp_actor.py:

if?entropy_coeff !=?0:? ?entropy_loss?= agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)? ?# compute policy loss? ??policy_loss?= pg_loss - entropy_loss * entropy_coeff
else:? ?policy_loss?= pg_loss

計算 KL 散度

????????這里用到了前面 Ref Model 推出的 ref_log_prob,用這個來計算 KL 并作用于最后的 policy_loss,保證模型距離 Ref Model(初始的模型)偏差不會太大。

verl/workers/actor/dp_actor.py:

if?self.config.use_kl_loss:? ? ref_log_prob = data["ref_log_prob"]? ?# compute kl loss? ? kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type? ? )? ? kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode? ? )? ? policy_loss = policy_loss + kl_loss *?self.config.kl_loss_coef? ? metrics["actor/kl_loss"] = kl_loss.detach().item()? ? metrics["actor/kl_coef"] =?self.config.kl_loss_coef

反向計算

verl/workers/actor/dp_actor.py:

loss.backward()

????????持續循環小 step,直到遍歷完所有的 mini batch,Actor Model 就完成了本輪的訓練,會在下一個大 step 前把權重 sync 到 Rollout實例當中,準備處理下一個大 batch 數據。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/84003.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/84003.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/84003.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

PyTorch——優化器(9)

優化器根據梯度調整參數&#xff0c;以達到降低誤差 import torch.optim import torchvision from torch import nn from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear from torch.utils.data import DataLoader# 加載CIFAR10測試數據集&#xff0c;設置tr…

c++學習-this指針

1.基本概念 非靜態成員函數都會默認傳遞this指針&#xff08;靜態成員函數屬于類本身&#xff0c;不屬于某個實例對象&#xff09;&#xff0c;方便訪問對象對類成員變量和 成員函數。 2.基本使用 編譯器實際處理類成員函數&#xff0c;this是第一個隱藏的參數&#xff0c;類…

【Oracle】數據倉庫

個人主頁&#xff1a;Guiat 歸屬專欄&#xff1a;Oracle 文章目錄 1. 數據倉庫概述1.1 為什么需要數據倉庫1.2 Oracle數據倉庫架構1.3 Oracle數據倉庫關鍵技術 2. 數據倉庫建模2.1 維度建模基礎2.2 星形模式設計2.3 雪花模式設計2.4 緩慢變化維度&#xff08;SCD&#xff09;處…

css-塞貝爾曲線

文章目錄 1、定義2、使用和解釋 1、定義 cubic-bezier() 函數定義了一個貝塞爾曲線(Cubic Bezier)語法&#xff1a;cubic-bezier(x1,y1,x2,y2) 2、使用和解釋 x1,y1,x2,y2&#xff0c;表示兩個點的坐標P1(x1,y1),P2(x2,y2)將以一條直線放在范圍只有 1 的坐標軸中&#xff0c;并…

函數式接口實現分頁查詢

你提供的 PageResult 類是一個非常完整、功能齊全的分頁結果封裝類&#xff0c;它包含了&#xff1a; 當前頁數據&#xff08;list&#xff09;總記錄數&#xff08;totalCount&#xff09;總頁數&#xff08;totalPage&#xff09;當前頁碼&#xff08;pageNo&#xff09;每頁…

Global Security Markets 第 10 章衍生品知識點總結?

一、衍生品的定義與本質 衍生品&#xff0c;作為一種金融工具&#xff0c;其價值并非獨立存在&#xff0c;而是緊密依賴于其他資產&#xff0c;如常見的股票、債券、商品&#xff0c;或者市場變量&#xff0c;像利率、匯率、股票指數等。這意味著衍生品的價格波動&#xff0c;…

DJango知識-模型類

一.項目創建 在想要將項目創鍵的目錄下,輸入cmd (進入命令提示符)在cmd中輸入:Django-admin startproject 項目名稱 (創建項目)cd 項目名稱 (進入項目)Django-admin startapp 程序名稱 (創建程序)python manage.py runserver 8080 (運行程序)將彈出的網址復制到瀏覽器中…

八股學習-JS的閉包

一.閉包的定義 閉包是指函數和其周圍的詞法環境的引用的組合。 簡單來說&#xff0c;就是函數可以記住并訪問其在定義時的作用域內的變量&#xff0c;即使該函數在其它作用域調用。 也就是說&#xff0c;閉包讓你可以在一個內層函數中訪問到其外層函數的作用域。 function …

qt使用筆記二:main.cpp詳解

Qt中main.cpp文件詳解 main.cpp是Qt應用程序的入口文件&#xff0c;包含程序的啟動邏輯。下面我將詳細解析其結構和功能。 基本結構 一個典型的Qt main.cpp 文件結構如下&#xff1a; #include <QApplication> // 或者 QGuiApplication/QCoreApplication #include &…

如何構建船舵舵角和船的航向之間的動力學方程?它是一個一階慣性環節嗎?

提問 船舵和船的航向之間的動力學方程是什么&#xff1f;是一個一階慣性環節嗎&#xff1f; 回答 船舵和船的航向&#xff08;航向角&#xff09;之間的動力學關系并不是一個簡單的一階慣性環節&#xff0c;雖然在某些簡化控制模型中可以近似為一階系統。實際上&#xff0c;…

抖去推--短視頻矩陣系統源碼開發

一、開發短視頻矩陣系統的源碼需要以下步驟&#xff1a; 確定系統需求&#xff1a; 根據客戶的具體業務目標&#xff0c;明確系統需實現的核心功能模塊&#xff0c;例如用戶注冊登錄、視頻內容上傳與管理、多維度視頻瀏覽與推薦、用戶互動&#xff08;評論、點贊、分享&#xf…

Windows 下搭建 Zephyr 開發環境

1. 系統要求 操作系統&#xff1a;Windows 10/11&#xff08;64位&#xff09;磁盤空間&#xff1a;至少 8GB 可用空間&#xff08;Zephyr 及其工具鏈較大&#xff09;權限&#xff1a;管理員權限&#xff08;部分工具需要&#xff09; 2. 安裝必要工具 winget安裝依賴工具&am…

三分算法與DeepSeek輔助證明是單峰函數

前置 單峰函數有唯一的最大值&#xff0c;最大值左側的數值嚴格單調遞增&#xff0c;最大值右側的數值嚴格單調遞減。 單谷函數有唯一的最小值&#xff0c;最小值左側的數值嚴格單調遞減&#xff0c;最小值右側的數值嚴格單調遞增。 三分的本質 三分和二分一樣都是通過不斷縮…

安全月報 | 傲盾DDoS攻擊防御2025年5月簡報

引言 在2025年5月&#xff0c;全球數字化進程高歌猛進&#xff0c;各行各業深度融入數字浪潮&#xff0c;人工智能、物聯網、大數據等前沿技術蓬勃發展&#xff0c;進一步夯實了數字經濟的基石。然而&#xff0c;在這看似繁榮的數字生態背后&#xff0c;網絡安全威脅正以驚人的…

【Spring】Spring哪些源碼解決了哪些問題P1

歡迎來到啾啾的博客&#x1f431;。 記錄學習點滴。分享工作思考和實用技巧&#xff0c;偶爾也分享一些雜談&#x1f4ac;。 有很多很多不足的地方&#xff0c;歡迎評論交流&#xff0c;感謝您的閱讀和評論&#x1f604;。 目錄 Spring是怎么處理請求的&#xff1f;Spring請求方…

堅持每日Codeforces三題挑戰:Day 4 - 題目詳解(2025-06-07,難度:1000, 1100, 1400)

前言&#xff1a; 此文章主要是記錄每天的codeforces刷題&#xff0c;還有就是給其他打算法競賽的人一點點點點小小的幫助&#xff08;畢竟現在實力比較菜&#xff0c;題目比較簡單&#xff0c;但我還是會認真寫題解&#xff09;。 之前忙學校事情&#xff0c;懈怠了一段時間…

6.7本日總結

一、英語 復習默寫list10list19&#xff0c;07年第3篇閱讀 二、數學 學習線代第一講&#xff0c;寫15講課后題 三、408 學習計組第二章&#xff0c;寫計組習題 四、總結 本周結束線代第一講和計組第二章&#xff0c;之后學習計網4.4&#xff0c;學完計網4.4之后開操作系…

PGSR : 基于平面的高斯濺射高保真表面重建【全流程分析與測試!】【2025最新版!!】

【PGSR】: 基于平面的高斯濺射高保真表面重建 前言 三維表面重建是計算機視覺和計算機圖形學領域的核心問題之一。隨著Neural Radiance Fields (NeRF)和3D Gaussian Splatting (3DGS)技術的發展&#xff0c;從多視角RGB圖像重建高質量三維表面成為了研究熱點。今天我們要深入…

從認識AI開始-----AutoEncoder:生成模型的起點

前言 從15年開始&#xff0c;在深度學習的重要模型中&#xff0c;AutoEncoder&#xff08;自編碼器&#xff09;可以說是打開生成模型世界的起點。它不僅是壓縮與重建的工具&#xff0c;更是VAE、GAN、DIffusion等復雜生成模型的思想起源。其實AutoEncoder并不復雜&#xff0c;…

解決MySQL8.4報錯ERROR 1524 (HY000): Plugin ‘mysql_native_password‘ is not loaded

最近使用了MySQL8.4 , 服務啟動成功,但是就是無法登陸,并且報錯: ERROR 1524 (HY000): Plugin mysql_native_password is not loaded 使用如下的命令也報錯 mysql -u root -p -P 3306 問題分析: 在MySQL 8.0版本中,默認的認證插件從mysql_native_password變更為cachi…