【模型訓練篇】VeRL的使用 - RL(PPO)與源碼

繼續學習字節家的VeRL,今天來看看VeRL的RL,是VeRL系列的第三篇文章(話說近期好多大事兒,我司發布了Longcat、韓立結嬰、阿里周五發布了QWen-Next都是好東西啊,學不過來了damn)

  • 底層分布式能力基礎Ray(點擊查看):VeRL分布式能力的基礎,框架Ray
  • VeRL的原理(點擊查看):HybridFlow
  • VeRL的使用(點擊查看):普通RL(PPO)
  • VeRL的使用,Agentic RL(多輪RL)
  • VeRL的魔改

前兩篇文章分別介紹了VeRL的分布式基礎和其底層原理,下面就以RL的PPO為例,同時結合源碼,看看具體的使用。

安裝

  • 使用docker的話,verl提供了諸多版本可以使用,例如純凈的只包含Verl/CUDA/PyTorch等依賴的base鏡像,也有整合了vLLM/SGLang/FSDP/Megatron的application鏡像
  • 手動安裝的話,要從CUDA/cuDNN等基礎庫開始,一定會遇到沖突(嗯,一定…)

使用

  1. 首先在Ray的Head節點上執行 ray start --head --dashboard-host=0.0.0.0,之后會得到兩個address:
  • 一個是集群內head/worker之間通信用的 GCS address
  • 一個是提交與查看任務/資源監控/查看日志的dashboard地址(使用VSCode插件進行debug的地址也是它)
  1. 然后在每個Ray Worker節點上執行 ray start --address=gcs_address
  2. 最后提交job任務ray job submit --address=dashboard_address -- python3 -m verl.trainer.main_ppo trainer.n_gpus_per_node=8 ... 就可以在dashboard里看到各種信息了

啟動后,整體架構如圖,前兩篇文章介紹過了,就不贅述了:

  • 其中driver進行代表single-controller
  • 其他的 actor/critic/rollout/ref/reward 那些 workers 代表 multi-controller,均對應著各自的 resource group

在這里插入圖片描述

下面直接看源碼。

源碼

首先是入口函數,即main_ppo.py,主要做定義、初始化:在這里插入圖片描述

  1. 初始化 Ray cluster 環境
  2. 通過 @ray.remote 定義了一個 遠程執行的 class TaskRunner
  3. 定義 actor/rollout worker:通過配置指定使用 fsdpmegatron,并構建 mappingrole_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls)
  4. 定義 criticworker
  5. 將上述兩個worker映射到resourece資源上:mapping[Role.ActorRollout] = global_pool_idmapping[Role.Critic] = global_pool_id
  6. 定義 rewardworkerrole_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) 的同時映射資源 mapping[Role.RewardModel] = "global_pool"
  7. 定義 refworkerrole_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls) 的同時映射資源 mapping[Role.RefPolicy] = "global_pool"
  8. 執行PPO workflow:加載模型、準備dataset、構建RayPPOTrainer、執行 RayPPOTrainer.init_workers()、執行 RayPPOTrainer.fit()
# Initialize the PPO trainer.
trainer = RayPPOTrainer(config=config,tokenizer=tokenizer,processor=processor,role_worker_mapping=self.role_worker_mapping,resource_pool_manager=resource_pool_manager,ray_worker_group_cls=ray_worker_group_cls,reward_fn=reward_fn,val_reward_fn=val_reward_fn,train_dataset=train_dataset,val_dataset=val_dataset,collate_fn=collate_fn,train_sampler=train_sampler,
)
# Initialize the workers of the trainer.
trainer.init_workers()
# Start the training process.
trainer.fit()

然后執行的是核心的RayPPOTrainer,主要就是倆函數,一個是init_workers(),一個是fit() 在這里插入圖片描述

先看init_workers()

  1. 根據config配置的資源創建resource pool
  2. 創建hybrid_engine,這是actorrollout的 colocate的復合體
resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout],config=self.config.actor_rollout_ref,role="actor_rollout",
)
self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
  1. 創建critic
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
critic_cfg = omega_conf_to_dataclass(self.config.critic)
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)
self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
  1. 創建ref
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy],config=self.config.actor_rollout_ref,role="ref",
)
self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls
  1. 創建reward,下面設置用的是reward modelfunction
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls
  1. 創建各自的wroker groupWorkerGroup是一組Wroker的抽象集合,使得driver可以和底層的多個worker進行交互:
for resource_pool, class_dict in self.resource_pool_to_cls.items():worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool,ray_cls_with_init=worker_dict_cls,**wg_kwargs,)spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())all_wg.update(spawn_wg)if self.use_critic:self.critic_wg = all_wg["critic"]self.critic_wg.init_model()if self.use_reference_policy and not self.ref_in_actor: # 需要關注self.ref_policy_wg = all_wg["ref"]self.ref_policy_wg.init_model()if self.use_rm:self.rm_wg = all_wg["rm"]self.rm_wg.init_model()

這里需要注意的是:

  • actorrollout進行colocate的目的:是在rollout和train兩個階段間高效更新參數權重
  • 但是否也同樣也colocate ref,取決于是否用了LoRA,因為refactor它們的base基座模型一樣,只不過actor lora多了一層lora的適配層,也就是BA矩陣,所以如果用LoRA,可以把rollout/actor/ref同時colocate到一起,更省資源

之后再看fit(),其實就是標準的PPO實現了,下面提取出關鍵信息:

for prompt in dataloader:output = actor_rollout_ref_wg.generate_sequences(prompt) # old_log_prob = actor_rollout_ref_wg.compute_log_prob(output)ref_log_prob = actor_rollout_ref_wg.compute_ref_log_prob(output)values = critic_wg.compute_values(output)rewards = reward_wg.compute_scores(output)advantages = compute_advantages(values, rewards)output = output.union(old_log_prob).union(ref_log_prob).union(values).union(rewards).union(advantages)actor_rollout_ref_wg.update_actor(output)critic.update_critic(output)

另外,關于driverwroker的數據交互,大致可以分成3步:

  1. driver把數據按DP數量進行切分
  2. 把數據分發給每個worker
  3. 每個worker再將執行的結果進行整合,所以VeRL這里搞了一個語法糖@register
class ActorRolloutRefWorker(Worker):@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)def generate_sequences(self, prompts: DataProto):prompts = prompts.to(torch.cuda.current_device())

上面的注解@register裝飾了方法generate_sequence,包含了 dispatch_mode對應的:

  • dispatch_func:把輸入dispatch到worker group中的各個worker
  • collect_func:把worker group的各個worker的response collect到一起

VeRL的各種參數,有詳細解釋,也有展示的圖片


下篇文章介紹下如何使用VeRL進行Agentic RL,也就是多輪RL。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/922337.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/922337.shtml
英文地址,請注明出處:http://en.pswp.cn/news/922337.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

QML Charts組件之折線圖的鼠標交互

目錄前言相關系列代碼示例詳解(LineSeriesDemo3.qml)功能概覽運行效果代碼說明工程下載參考前言 接上文(QML Charts組件之折線圖的基礎屬性),本文將重點介紹LineSeries的鼠標交互,包括:鼠標拖拽…

二值信號量——學習筆記12

本文是筆者在學習 正點原子官方 的《【正點原子】手把手教你學FreeRTOS實時系統》系列視頻時整理的筆記。 視頻講解清晰透徹,非常感謝UP主的無私奉獻!原課程鏈接如下: 👉 B站視頻鏈接:??????【正點原子】手把手教…

裸機開發 時鐘配置,EPIT

1.概念時鐘(clock):在電子系統中是一個產生穩定、周期性振蕩信號的電路或組件。這個信號像節拍器或心跳一樣,為數字電路中的各種操作提供同步時序基準。PLL(phase locked loop)鎖相環電路: 倍頻PFD(phase fractional P…

Linux-文本三劍客(grep、sed、awk)

Linux-文本三劍客前言一、grep二、sed三、awk模式 -- 正則表達式關系表達式、運算符表達模式匹配表達式動作 輸出流程控制參數傳遞,awk接受外部變量統計數組的使用分組統計練習常用內置函數前言 grep、sed、awk 被稱為 “文本三劍客”,它們是處理文本文…

主流反爬蟲、反作弊防護與風控對抗手段

文章目錄1. 寫在前面2. 指紋檢測3. 行為驗證3. 加固防護4. 鏈路檢測5. 風控埋點6. 游客注冊7. 數據防護8. 賬號權重9. 反調阻斷【🏠作者主頁】:吳秋霖 【💼作者介紹】:擅長爬蟲與JS加密逆向分析!Python領域優質創作者、…

金蝶云星空插件開發記錄(一)

實現目的:新增供應商保存后,觸發釘釘審批流程,并根據釘釘審批結果回寫是否合格供應商。實現思路:通過BOS平臺供在應商管理界面新增兩個復選框字段:是否釘釘審批、是否合格供應商,若在新建供應商檔案時勾選是…

企業跨區域組網新解:SD-WAN技術打造安全穩定網絡體系

前言在數字化浪潮席卷全球的今天,企業跨區域網絡互聯已成為支撐業務發展的關鍵基礎設施。傳統MPLS專線雖性能穩定,但高昂成本和漫長部署周期令眾多企業望而卻步。SD-WAN技術的出現,正以其智能、靈活和成本效益的優勢,重塑企業組網…

Docker 容器化

引言在解釋docker是什么之前,我們首先應該先了解的是容器化的概念。什么是容器?就是一個沙箱,在這個沙箱中涵蓋了特定應用運行的一切依賴的內容。但他不是一個操作系統,且和底層的操作系統是隔離的。什么是容器化?容器…

LeetCode刷題——hot 100(3)

題目1:矩陣置零題目:問題分析:使用兩個布爾數組來分別記錄哪行哪列出現了0,當出現0的行和列,對應的布爾數組值置為true。再次遍歷數組,當出現行數組和列數組中的值為true,則對應的原數組的值置為…

Ajax-day2(圖書管理)-渲染列表

本篇筆記素材來自“黑馬程序員” 渲染列表圖書管理一、獲取數據二、渲染數據完整代碼圖書管理 Bootstrap 框架渲染列表(查)新增圖書(增)刪除圖書(刪)編輯圖書(改) 自己的圖書數據&a…

MOS管的電路

MOS管的三極都會存在以下三個電容,分別是:Cgs,Cgd,Cds 輸入電容CissCgsCgd 輸出電容CossCgdCds 反向傳輸電容CrssCgd,也叫米勒電容 然而,這三個等效電容是構成串并聯組合關系,他們并不是獨立的,而是相互…

STM32_05_時鐘樹

時鐘 d用來輸入數據,CLK就是我們的時鐘,CPU1s中72000000HZ個時鐘周期STM32的時鐘樹鎖相環HSE時鐘源HSI時鐘源LSE時鐘源LSI時鐘源SystemInit函數SetSysClock函數SetSysClockTo72函數SystemInit()后時鐘頻率大小總結RCC標準庫函數定義變量a&…

C語言---判斷語句

文章目錄1. if 語句2. if...else 語句3. if...else if...else 語句4. switch 語句5. 三元運算符 ( ? : )總結與對比如何選擇C語言中的判斷語句用于根據給定的條件來決定執行哪一段代碼。其核心是條件為真(必須)則執行一段代碼,條件為假&…

[硬件電路-212]:電流的本質確實是電子的移動

1. 微觀機制:電子的定向漂移與熱運動定向漂移(Drift Motion):在導體(如金屬)中,自由電子(價電子)受電場驅動,從負端向正端定向移動,形成宏觀電流。…

雙RFSOC47DR-16通道5GSPS ADC采集模塊

16通道5GSPS ADC采集板卡組成如圖1所示。該板卡的輸入接口為SMA單端輸入,ADC采集和處理采用Xilinx公司的XCZU47DR-2FFVE1156I芯片。板卡需配備4路QSFP28光口輸出,并需要集成網口、DDR4、SD卡、USB調試口。兩塊RF-Soc需確保連接通信功能。板卡的16通道需實…

pytest -- 中文文檔

前言 零基礎1小時快速入門pytest自動化測試教程,全套項目框架實戰pytest配置文件可以改變pytest的運行方式,它是一個固定的文件pytest.ini文件,讀取配置信息,按指定的方式去運行 非test文件 pytest里面有些文件是非test文件 pyt…

硬件開發2-ARM裸機開發3-IMX6ULL - 引入中斷

一、鋪墊引入中斷 → 按鍵1、概要:實現按鍵控制發光二極管和蜂鳴器輸入類型的外設:按鍵(key)2、參考手冊內容完成配置過程(1)key 按鍵原理圖(2)core 內核中命名 -- UART1 CTS&#x…

Ansible的 Playbook 模式詳解

目錄一、Playbook模式1.1 Playbook 的優勢1.2 Playbook 的組成1.3 安裝 httpd 服務案例1.4 Playbook 命令及常用參數1.5 Playbook 的語法 —— 權限相關1. remote_user2. become3. become_method1.6 Playbook 的通知與觸發機制1. notify2. handlers3. 使用示例4. 使用場景1.6 P…

猿輔導Java后臺開發面試題及參考答案

int 與 Integer 的區別是什么?若創建數量龐大的數字時使用 Integer,會對重復數字創建新對象嗎?int 是 Java 中的基本數據類型,直接存儲數值,占用 4 個字節,默認值為 0,不需要通過 new 關鍵字創建…

代碼隨想錄學習摘抄day9(回溯1-11)

一個樸實無華的目錄定義:回溯法也可以叫做回溯搜索法,它是一種搜索的方式。應用場景:回溯法解決的問題都可以抽象為樹形結構代碼模板題型第77題. 組合思路:每次從集合中選取元素,可選擇的范圍隨著選擇的進行而收縮&…