DeepSpeed-Chat RLHF 階段代碼解讀(0) —— 原始 PPO 代碼解讀

為了理解 DeepSpeed-Chat RLHF 的 RLHF 全部過程,這個系列會分三篇文章分別介紹:
原始 PPO 代碼解讀RLHF 獎勵函數代碼解讀RLHF PPO 代碼解讀
這是系列的第一篇文章,我們來一步一步的看 PPO 算法的代碼實現,對于 PPO 算法原理不太了解的同學,可以參考之前的文章:
深度強化學習(DRL)算法 2 —— PPO 之 Clipped Surrogate Objective 篇
深度強化學習(DRL)算法 2 —— PPO 之 GAE 篇

Clipped Surrogate 函數實現

# code from cleanrl: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py
for start in range(0, args.batch_size, args.minibatch_size):end = start + args.minibatch_sizemb_inds = b_inds[start:end]_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])logratio = newlogprob - b_logprobs[mb_inds]ratio = logratio.exp()mb_advantages = b_advantages[mb_inds]if args.norm_adv:mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)# Policy losspg_loss1 = -mb_advantages * ratiopg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)pg_loss = torch.max(pg_loss1, pg_loss2).mean()

Clipped Surrogate 函數的實現很簡單,這里不再贅述,理解算法原理,代碼自然而然就可以看懂,核心是 get_action_and_value 函數的理解。

def get_action_and_value(self, x, action=None):logits = self.actor(x)# probs 相當于計算 softmaxprobs = Categorical(logits=logits)if action is None:action = probs.sample()# probs.log_prob(action) 計算的是 p(a|s) 的 log 形式,方便計算 Clipped Surrogate 函數里的 ratioreturn action, probs.log_prob(action), probs.entropy(), self.critic(x) 

GAE 實現

直接來看 gae 可能比較抽象,我們先來看蒙特卡洛方法實現的優勢估計,對蒙特卡洛方法不熟悉的同學,可以參考之前的文章。
深度強化學習(DRL)算法 附錄 3 —— 蒙特卡洛方法(MC)和時序差分(TD)
兩種方法都采用了反向迭代(因為反向迭代更好計算)的方式來實現優勢估計。

# code from cleanrl: https://github.com/vwxyzjn/cleanrl/commit/b7088a41e5e6f0f5f6940fd29054a35118083b28
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):if t == args.num_steps - 1:nextnonterminal = 1.0 - next_donenext_return = last_valueelse:nextnonterminal = 1.0 - dones[t+1]next_return = returns[t+1]returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
advantages = returns - values

上面的代碼做了什么事情呢,last_value 對應最后的 step(對應 step t) 產生的期望回報,如果 step t-1 整個流程沒有結束,那么 t-1 時刻的期望回報就是 reward(t-1) + args.gamma * nextnonterminal * next_return,這樣一步一步往后推,就可以計算每一個 step 的期望回報,從而得到每一步的優勢,還沒理解的話,看下面每個時間步的拆解。關于 last_value 的使用,這里由于沒有后續的回報可以累積,所以直接使用 last_value 作為最后一個時間步的回報。關于下面為啥用 return[t-1] 替換原始公式的 value[t-1],這樣計算的話就相當于蒙特卡洛方法的優勢估計,如果next_return = returns[t+1] 改成 next_value = values[t+1] 就相當于 TD(1) 的優勢估計。

# t
return(t) = v(t)
# t - 1
return(t-1) = reward(t-1) + gamma * return(t) = reward(t-1) + gamma * return(t)
# t - 2
return(t-2) = reward(t-2) + gamma * return(t-1) = reward(t-2) + gamma * (reward(t-1) + gamma * return(t))
......
# 我們可以看到一步一步往前推,最后就得到蒙特卡洛方法的優勢估計

理解了上面講的蒙特卡洛方法實現的優勢估計,再來看 gae 的實現,我們可以看到代碼實現上十分的相似,只是多了 delta 的計算,這里的 delta 對應的就是之前 PPO GAE 篇里介紹的 delta。

# code from cleanrl: https://github.com/vwxyzjn/cleanrl/commit/b7088a41e5e6f0f5f6940fd29054a35118083b28
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):if t == args.num_steps - 1:nextnonterminal = 1.0 - next_donenextvalues = last_valueelse:nextnonterminal = 1.0 - dones[t+1]nextvalues = values[t+1]delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values

這里通過反向迭代的方式計算 GAE advantage,可能理解上比較抽象,舉個例子,就很好理解了:

# advantage(t)
adv[t] = lastgaelam = rewards[t] + gamma * values[t+1] - values[t]
# t-1
adv[t-1] = lastgaelam = rewards[t-1] + gamma * values[t] - values[t-1] + gamma * lambda * lastgaelam
# t-2
adv[t-2] = lastgaelam = rewards[t-2] + gamma * values[t-1] - values[t-2] + gamma * lambda * lastgaelam
...

可以看到,逐項展開,每一時刻的 GAE Advantage 和 PPO GAE 篇里介紹的公式是一模一樣的,這里 GAE 就是一種數學公式表達,核心思想是 n step 的優勢估計的加權平均,通過數學技巧恰好是上面的形式。

參考

  1. The 37 Implementation Details of Proximal Policy Optimization · The ICLR Blog Track (iclr-blog-track.github.io)
  2. HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/716554.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/716554.shtml
英文地址,請注明出處:http://en.pswp.cn/news/716554.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

部署若依前后端分離項目,連接數據庫失敗

部署若依前后端分離項目,連接數據庫失敗,異常如下: 解決方案:application配置文件里,連接數據庫的參數useSSL的值改為false

leetcode 長度最小的子數組

在本題中,我們可以知道,是要求數組中組成和為target的最小子數組的長度。所以,我們肯定可以想到用兩層for循環進行遍歷,然后枚舉所有的結果進行挑選,但這樣時間復雜度過高。 我們可以采用滑動窗口,其實就是…

編寫dockerfile掛載卷、數據容器卷

編寫dockerfile掛載卷 編寫dockerfile文件 [rootwq docker-test-volume]# vim dockerfile1 [rootwq docker-test-volume]# cat dockerfile1 FROM centosVOLUME ["volume01","volume02"]CMD echo "------end------" CMD /bin/bash [rootwq dock…

2024 年廣東省職業院校技能大賽(高職組)“云計算應用”賽項樣題 2

#需要資源或有問題的,可私博主!!! #需要資源或有問題的,可私博主!!! #需要資源或有問題的,可私博主!!! 某企業根據自身業務需求&#…

每日OJ題_牛客_合法括號序列判斷

目錄 合法括號序列判斷 解析代碼 合法括號序列判斷 合法括號序列判斷__牛客網 解析代碼 class Parenthesis {public:bool chkParenthesis(string A, int n){if (n & 1) // 如果n是奇數return false;stack<char> st;for (int i 0; i < n; i) {if (A[i] () {s…

筆記本hp6930p安裝Android-x86補記

在上一篇日記中&#xff08;筆記本hp6930p安裝Android-x86避坑日記-CSDN博客&#xff09;提到hp6930p安裝Android-x86-9.0&#xff0c;無法正常啟動&#xff0c;本文對此再做嘗試&#xff0c;原因是&#xff1a;Android-x86-9.0不支持無線網卡&#xff0c;需要在BIOS中關閉WLAN…

《Docker極簡教程》--Docker的高級特性--Docker Compose的使用

Docker Compose是一個用于定義和運行多容器Docker應用程序的工具。它允許開發人員通過簡單的YAML文件來定義應用程序的服務、網絡和卷等資源&#xff0c;并使用單個命令來啟動、停止和管理整個應用程序的容器。以下是關于Docker Compose的一些關鍵信息和優勢&#xff1a; 定義…

B082-SpringCloud-Eureka

目錄 微服務架構與springcloud架構演變為什么使用微服務微服務的通訊方式架構的選擇springcloud概述場景模擬之基礎架構的搭建模擬微服務之間的服務調用目前遠程調用的問題 eureka注冊中心的作用注冊中心的實現服務提供者注冊到注冊中心 springcloud基于springboot 微服務架構與…

10 計算機結構

馮諾依曼體系結構 馮諾依曼體系結構&#xff0c;也被稱為普林斯頓結構&#xff0c;是一種計算機架構&#xff0c;其核心特點包括將程序指令存儲和數據存儲合并在一起的存儲器結構&#xff0c;程序指令和數據的寬度相同&#xff0c;通常都是16位或32位 我們常見的計算機,筆記本…

在Centos7中用Docker部署gitlab-ce

一、介紹 GitLab Community Edition (GitLab CE) 是一個開源的版本控制系統和協作平臺&#xff0c;用于管理和追蹤軟件開發項目。它提供了一套完整的工具和功能&#xff0c;包括代碼托管、版本控制、問題跟蹤、持續集成、持續交付和協作功能&#xff0c;使團隊能夠更加高效地進…

動態規劃|【路徑問題】|931.下降路徑最小和

目錄 題目 題目解析 思路 1.狀態表示 2.狀態轉移方程 3.初始化 4.填表順序 5.返回值 代碼 題目 931. 下降路徑最小和 給你一個 n x n 的 方形 整數數組 matrix &#xff0c;請你找出并返回通過 matrix 的下降路徑 的 最小和 。 下降路徑 可以從第一行中的任何元素開…

【Vue3】Props的使用詳解

&#x1f497;&#x1f497;&#x1f497;歡迎來到我的博客&#xff0c;你將找到有關如何使用技術解決問題的文章&#xff0c;也會找到某個技術的學習路線。無論你是何種職業&#xff0c;我都希望我的博客對你有所幫助。最后不要忘記訂閱我的博客以獲取最新文章&#xff0c;也歡…

概率基礎——多元正態分布

概率基礎——多元正態分布 介紹 多元正態分布是統計學中一種重要的多維概率分布&#xff0c;描述了多個隨機變量的聯合分布。在多元正態分布中&#xff0c;每個隨機變量都服從正態分布&#xff0c;且不同隨機變量之間可能存在相關性。本文將以二元標準正態分布為例&#xff0…

多線程JUC 第2季 中斷線程

一 中斷線程 1.1 中斷概念 1.在java中&#xff0c;沒有提供一種立即停止一條線程。但卻給了停止線程的協商機制-中斷。 中斷是一種協商機制。中斷的過程完全需要程序員自己實現。也即&#xff0c;如果要中斷一個線程&#xff0c;你需要手動調用該線程的interrupt()方法&…

錄制用戶操作實現自動化任務

先上視頻&#xff01;&#xff01; 流程自動化工具-錄制操作繪制流程 這個想法之前就有了&#xff0c;趁著周末時間給它擼出來。 實現思路 從之前的文章自動化桌面未來展望中已經驗證了錄制繪制流程圖的可行性。基于DOM錄制頁面操作軌跡的思路監聽頁面點擊、輸入事件即可&…

無人機鏡頭穩定的原理和相關算法

無人機的鏡頭穩定主要基于兩個關鍵技術&#xff1a;鏡頭平衡技術和實時電子穩像。無人機鏡頭穩定的原理和相關算法主要是通過鏡頭平衡技術和實時電子穩像技術來保持攝像鏡頭的穩定性&#xff0c;從而拍攝出清晰、穩定的畫面。無人機鏡頭穩定的原理主要是通過傳感器和算法來實現…

Ocr之PaddleOcr模型訓練

目錄 一、系統環境 1 鏡像拉取ppocr 進行部署 2 安裝paddlepaddle 二、訓練前的準備 1 下載源碼 2 預模型下載 3 修改模型訓練文件yml 4 編排訓練集 5 執行腳本進行訓練 6 需要修改文件夾名稱 三、開始訓練 1 執行訓練命令 2 對第一次評估進行解釋 3 引言 五、總…

NestJS使用模板引擎ejs

模板引擎? 模板引擎是一種用于生成動態內容的工具&#xff0c;它通過將預定義的模板與特定數據結合&#xff0c;來生成最終的輸出。? 在NodeJS開發中&#xff0c;我們會使用模板引擎來渲染一些常用的頁面&#xff0c;比如渲染代表404的Not Found 頁面&#xff0c;502的Bad …

異常值檢測-值域法 頭歌代碼解釋

這關做得不是很明白&#xff0c;如果有清楚的同志可以在評論區里面討論 import pandas as pd import numpy as np import matplotlib.pyplot as plt from sklearn.neighbors import LocalOutlierFactor # 導入數據 abc pd.read_csv(deaths.csv) ## 只分析其中的Population和L…

C語言對類型的轉換

C語言對類型的轉換 文章目錄 C語言對類型的轉換整形提升和截斷整形提升整形提升規則整形提升的意義 截斷截斷規則 算數轉換 我們都知道&#xff0c;C語言中內置了多種整形類型&#xff0c;占用空間從大到小&#xff0c;基本滿足各類使用場景&#xff08;比如超長數字的運算就不…