從代碼學習深度強化學習 - PPO PyTorch版

文章目錄

  • 前言
  • PPO 算法簡介
    • 從 TRPO 到 PPO
    • PPO 的兩種形式:懲罰與截斷
  • 代碼實踐:PPO 解決離散動作空間問題 (CartPole)
    • 環境與工具函數
    • 定義策略與價值網絡
    • PPO 智能體核心實現
    • 訓練與結果
  • 代碼實踐:PPO 解決連續動作空間問題 (Pendulum)
    • 環境準備
    • 適用于連續動作的網絡
    • PPO 智能體 (連續版)
    • 訓練與結果
  • 總結


前言

歡迎來到深度強化學習(DRL)的世界!在眾多 DRL 算法中,Proximal Policy Optimization (PPO) 無疑是最受歡迎和廣泛應用的算法之一。它由 OpenAI 在 2017 年提出,以其出色的性能、相對簡單的實現和穩定的訓練過程而著稱,成為了許多研究和應用的基準算法。

本篇博客旨在通過一個完整的 PyTorch 實現,帶您從代碼層面深入理解 PPO 算法。我們將不僅僅是看公式,更是要“動手”,一步步構建、訓練和分析 PPO 智能體。為了全面掌握其應用,我們將分別在經典的離散動作空間(CartPole-v1)和連續動作空間(Pendulum-v1)兩個環境中進行實踐。

無論您是 DRL 的初學者,還是希望鞏固 PPO 知識的實踐者,相信通過這篇代碼驅動的教程,您都能對 PPO 有一個更具體、更深刻的認識。

完整代碼:下載鏈接


PPO 算法簡介

在深入代碼之前,我們先快速回顧一下 PPO 的核心思想。

從 TRPO 到 PPO

PPO 的思想源于 TRPO(Trust Region Policy Optimization)。TRPO 旨在通過限制每次策略更新的步長,確保更新后的策略不會與舊策略偏離太遠,從而保證學習的穩定性。它的優化目標如下:

TRPO 通過一個 KL 散度的約束來限制策略更新的區域,但這個約束的計算過程非常復雜,涉及泰勒展開、共軛梯度、線性搜索等,導致其實現難度大,運算量也非常可觀。

PPO 的出現正是為了解決這個問題。它繼承了 TRPO 的核心思想,即在更新策略時不要“步子邁得太大”,但采用了更簡單、更易于實現的方法。

PPO 的兩種形式:懲罰與截斷

PPO 主要有兩種形式:PPO-PenaltyPPO-Clip

  1. PPO-Penalty (懲罰)
    它將 TRPO 的 KL 散度約束作為一個懲罰項直接放入目標函數中,變成一個無約束的優化問題,并通過一個動態調整的系數 β 來控制懲罰的力度。

  2. PPO-Clip (截斷)
    這是更常用的一種形式,也是我們代碼將要實現的版本。它直接在目標函數中進行截斷(clip),以保證新的參數和舊的參數的差距不會太大。

    其核心思想在于 clip 函數。我們定義一個比率 r(θ) 為新策略與舊策略輸出同一動作的概率之比。

    • 優勢函數 A > 0 時(即當前動作優于平均水平),我們希望增大這個動作的概率,但 r(θ) 的上限被截斷在 1+ε,防止策略更新過于激進。
    • 優勢函數 A < 0 時(即當前動作劣于平均水平),我們希望減小這個動作的概率,但 r(θ) 的下限被截斷在 1-ε,同樣是為了限制更新幅度。

    下圖直觀地展示了 PPO-Clip 的目標函數 L^Clip 與概率比 r(θ) 的關系:

大量的實驗表明,PPO-Clip 的性能通常比 PPO-Penalty 更好且更穩定。因此,我們的代碼實踐將專注于 PPO-Clip 的實現。

理論鋪墊結束,讓我們開始編碼吧!

代碼實踐:PPO 解決離散動作空間問題 (CartPole)

我們將從經典的 CartPole-v1 環境開始,它要求智能體通過向左或向右施加力來保持桿子豎直不倒,是一個典型的離散動作空間問題(動作:0-向左,1-向右)。

環境與工具函數

首先,我們定義一些通用的工具函數并初始化環境。這里的核心是 compute_advantage 函數,它實現了廣義優勢估計(GAE),這是一種在偏差和方差之間取得平衡的優勢函數計算方法,對于穩定策略梯度算法的訓練至關重要。

PPO離散動作.ipynb

"""
強化學習工具函數集
包含廣義優勢估計(GAE)和數據平滑處理功能
"""import torch
import numpy as npdef compute_advantage(gamma, lmbda, td_delta):"""計算廣義優勢估計(Generalized Advantage Estimation,GAE)GAE是一種在強化學習中用于減少策略梯度方差的技術,通過對時序差分誤差進行指數加權平均來估計優勢函數,平衡偏差和方差的權衡。參數:gamma (float): 折扣因子,維度: 標量取值范圍[0,1],決定未來獎勵的重要性lmbda (float): GAE參數,維度: 標量  取值范圍[0,1],控制偏差-方差權衡lmbda=0時為TD(0)單步時間差分,lmbda=1時為蒙特卡洛方法用采樣到的獎勵-狀態價值估計td_delta (torch.Tensor): 時序差分誤差序列,維度: [時間步數]包含每個時間步的TD誤差值返回:torch.Tensor: 廣義優勢估計值,維度: [時間步數]與輸入td_delta維度相同的優勢函數估計數學公式:A_t^GAE(γ,λ) = Σ_{l=0}^∞ (γλ)^l * δ_{t+l}其中 δ_t = r_t + γV(s_{t+1}) - V(s_t) 是TD誤差"""# 將PyTorch張量轉換為NumPy數組進行計算# td_delta維度: [時間步數] -> [時間步數]td_delta = td_delta.detach().numpy() # 因為A用來求g的,需要梯度,防止梯度向下傳播# 初始化優勢值列表,用于存儲每個時間步的優勢估計# advantage_list維度: 最終為[時間步數]advantage_list = []# 初始化當前優勢值,從序列末尾開始反向計算# advantage維度: 標量advantage = 0.0# 從時間序列末尾開始反向遍歷TD誤差# 反向計算是因為GAE需要利用未來的信息# delta維度: 標量(td_delta中的單個元素)for delta in td_delta[::-1]:  # [::-1]實現序列反轉# GAE遞歸公式:A_t = δ_t + γλA_{t+1}# gamma * lmbda * advantage: 來自未來時間步的衰減優勢值# delta: 當前時間步的TD誤差# advantage維度: 標量advantage = gamma * lmbda * advantage + delta# 將計算得到的優勢值添加到列表中# advantage_list維度: 逐步增長到[時間步數]advantage_list.append(advantage)# 由于是反向計算,需要將結果列表反轉回正確的時間順序# advantage_list維度: [時間步數](時間順序已恢復)advantage_list.reverse()# 將NumPy列表轉換回PyTorch張量并返回# 返回值維度: [時間步數]return torch.tensor(advantage_list, dtype=torch.float)def moving_average(data, window_size):"""計算移動平均值,用于平滑獎勵曲線該函數通過滑動窗口的方式對時間序列數據進行平滑處理,可以有效減少數據中的噪聲,使曲線更加平滑美觀。常用于強化學習中對訓練過程的獎勵曲線進行可視化優化。參數:data (list): 原始數據序列,維度: [num_episodes]包含需要平滑處理的數值數據(如每輪訓練的獎勵值)window_size (int): 移動窗口大小,維度: 標量決定了平滑程度,窗口越大平滑效果越明顯但也會導致更多的數據點丟失返回:list: 移動平均后的數據,維度: [len(data) - window_size + 1]返回的數據長度會比原數據少 window_size - 1 個元素這是因為需要足夠的數據點來計算第一個移動平均值示例:>>> data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]  # 維度: [10]>>> smoothed = moving_average(data, 3)       # window_size = 3>>> print(smoothed)  # 輸出: [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]  維度: [8]"""# 邊界檢查:如果數據長度小于窗口大小,直接返回原數據# 這種情況下無法計算移動平均值# data維度: [num_episodes], window_size維度: 標量if len(data) < window_size:return data# 初始化移動平均值列表# moving_avg維度: 最終為[len(data) - window_size + 1]moving_avg = []# 遍歷數據,計算每個窗口的移動平均值# i的取值范圍: 0 到 len(data) - window_size# 循環次數: len(data) - window_size + 1# 每次循環處理一個滑動窗口位置for i in range(len(data) - window_size + 1):# 提取當前窗口內的數據切片# window_data維度: [window_size]# 包含從索引i開始的連續window_size個元素# 例如:當i=0, window_size=3時,提取data[0:3]window_data = data[i:i + window_size]# 計算當前窗口內數據的算術平均值# np.mean(window_data)維度: 標量# 將平均值添加到結果列表中moving_avg.append(np.mean(window_data))# 返回移動平均后的數據列表# moving_avg維度: [len(data) - window_size + 1]return moving_avg``````python
"""
強化學習環境初始化模塊
用于創建和配置OpenAI Gym環境
"""import gym# 環境配置
# 定義要使用的強化學習環境名稱
# CartPole-v1是經典的平衡桿控制問題:
# - 狀態空間:4維連續空間(車位置、車速度、桿角度、桿角速度)
# - 動作空間:2維離散空間(向左推車、向右推車)
# - 目標:保持桿子平衡盡可能長的時間
# env_name維度: 標量(字符串)
env_name = 'CartPole-v1'# 創建強化學習環境實例
# gym.make()函數根據環境名稱創建對應的環境對象
# 該環境對象包含了狀態空間、動作空間、獎勵函數等定義
# env維度: gym.Env對象(包含狀態空間[4]和動作空間[2]的環境實例)
# env.observation_space.shape: (4,) - 觀測狀態維度
# env.action_space.n: 2 - 離散動作數量
env = gym.make(env_name)

定義策略與價值網絡

PPO 是一種 Actor-Critic 架構的算法。我們需要定義兩個網絡:

  • 策略網絡 (PolicyNet):作為 Actor,輸入狀態,輸出一個動作的概率分布。
  • 價值網絡 (ValueNet):作為 Critic,輸入狀態,輸出該狀態的價值估計 V(s)。
"""
PPO(Proximal Policy Optimization)算法實現
包含策略網絡、價值網絡和PPO智能體的完整定義
"""import torch
import torch.nn.functional as F
import numpy as npclass PolicyNet(torch.nn.Module):"""策略網絡(Actor Network)用于輸出動作概率分布,指導智能體如何選擇動作"""def __init__(self, state_dim, hidden_dim, action_dim):"""初始化策略網絡參數:state_dim (int): 狀態空間維度,維度: 標量對于CartPole-v1環境,state_dim=4hidden_dim (int): 隱藏層神經元數量,維度: 標量控制網絡的表達能力action_dim (int): 動作空間維度,維度: 標量對于CartPole-v1環境,action_dim=2"""super(PolicyNet, self).__init__()# 第一層全連接層:狀態輸入 -> 隱藏層# 輸入維度: [batch_size, state_dim] -> 輸出維度: [batch_size, hidden_dim]self.fc1 = torch.nn.Linear(state_dim, hidden_dim)# 第二層全連接層:隱藏層 -> 動作概率# 輸入維度: [batch_size, hidden_dim] -> 輸出維度: [batch_size, action_dim]self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):"""前向傳播過程參數:x (torch.Tensor): 輸入狀態,維度: [batch_size, state_dim]返回:torch.Tensor: 動作概率分布,維度: [batch_size, action_dim]每行為一個狀態對應的動作概率分布,概率和為1"""# 第一層 + ReLU激活函數# x維度: [batch_size, state_dim] -> [batch_size, hidden_dim]x = F.relu(self.fc1(x))# 第二層 + Softmax激活函數,輸出概率分布# x維度: [batch_size, hidden_dim] -> [batch_size, action_dim]# dim=1表示在第1維(動作維度)上進行softmax,確保每行概率和為1return F.softmax(self.fc2(x), dim=1)class ValueNet(torch.nn.Module):"""價值網絡(Critic Network)用于估計狀態價值函數V(s),評估當前狀態的好壞"""def __init__(self, state_dim, hidden_dim):"""初始化價值網絡參數:state_dim (int): 狀態空間維度,維度: 標量對于CartPole-v1環境,state_dim=4hidden_dim (int): 隱藏層神經元數量,維度: 標量控制網絡的表達能力"""super(ValueNet, self).__init__()# 第一層全連接層:狀態輸入 -> 隱藏層# 輸入維度: [batch_size, state_dim] -> 輸出維度: [batch_size, hidden_dim]self.fc1 = torch.nn.Linear(state_dim, hidden_dim)# 第二層全連接層:隱藏層 -> 狀態價值(標量)# 輸入維度: [batch_size, hidden_dim] -> 輸出維度: [batch_size, 1]self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):"""前向傳播過程參數:x (torch.Tensor): 輸入狀態,維度: [batch_size, state_dim]返回:torch.Tensor: 狀態價值估計,維度: [batch_size, 1]每行為一個狀態對應的價值估計"""# 第一層 + ReLU激活函數# x維度: [batch_size, state_dim] -> [batch_size, hidden_dim]x = F.relu(self.fc1(x))# 第二層,輸出狀態價值(無激活函數,可以輸出負值)# x維度: [batch_size, hidden_dim] -> [batch_size, 1]return self.fc2(x)

PPO 智能體核心實現

這是我們 PPO 算法的核心。PPO 類封裝了 Actor 和 Critic,并實現了 take_action(動作選擇)和 update(網絡更新)兩個關鍵方法。請特別關注 update 函數,它完整地實現了 PPO-Clip 的目標函數計算和參數更新邏輯。

class PPO:"""PPO(Proximal Policy Optimization)算法實現采用截斷方式防止策略更新過大,確保訓練穩定性"""def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,lmbda, epochs, eps, gamma, device):"""初始化PPO智能體參數:state_dim (int): 狀態空間維度,維度: 標量hidden_dim (int): 隱藏層神經元數量,維度: 標量action_dim (int): 動作空間維度,維度: 標量actor_lr (float): Actor網絡學習率,維度: 標量critic_lr (float): Critic網絡學習率,維度: 標量lmbda (float): GAE參數λ,維度: 標量,取值范圍[0,1]epochs (int): 每次更新的訓練輪數,維度: 標量eps (float): PPO截斷參數ε,維度: 標量,通常取0.1-0.3gamma (float): 折扣因子γ,維度: 標量,取值范圍[0,1]device (torch.device): 計算設備(CPU或GPU),維度: 標量"""# 初始化Actor網絡(策略網絡)# 網絡參數維度:fc1權重[state_dim, hidden_dim], fc2權重[hidden_dim, action_dim]self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)#

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/88580.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/88580.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/88580.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

PortsWiggerLab: Blind OS command injection with output redirection

實驗目的This lab contains a blind OS command injection vulnerability in the feedback function.The application executes a shell command containing the user-supplied details. The output from the command is not returned in the response. However, you can use o…

星云穿越與超光速飛行特效的前端實現原理與實踐

文章目錄 1,引言2,特效設計思路3,技術原理解析1. 星點的三維分布2. 視角推進與星點運動3. 三維到二維的投影4. 星點的視覺表現5. 色彩與模糊處理4,關鍵實現流程圖5,應用場景與優化建議6,總結1,引言 在現代網頁開發中,炫酷的視覺特效不僅能提升用戶體驗,還能為產品增添…

【Linux】C++項目分層架構:核心三層與關鍵輔助

C 項目分層架構全指南&#xff1a;核心三層 關鍵輔助一、核心三層架構 傳統的三層架構&#xff08;或三層體系結構&#xff09;是構建健壯系統的基石&#xff0c;包括以下三層&#xff1a; 1. 表現層&#xff08;Presentation Layer&#xff09; 負責展示和輸入處理&#xff0…

【機器學習】保序回歸平滑校準算法

保序回歸平滑校準算法&#xff08;SIR&#xff09;通過分桶合并線性插值解決廣告預估偏差問題&#xff0c;核心是保持原始排序下糾偏。具體步驟&#xff1a;1&#xff09;按預估分升序分桶&#xff0c;統計每個分桶的后驗CTR&#xff1b;2&#xff09;合并逆序桶重新計算均值&a…

項目開發日記

框架整理學習UIMgr&#xff1a;一、數據結構與算法 1.1 關鍵數據結構成員變量類型說明m_CtrlsList<PageInfo>當前正在顯示的所有 UI 頁面m_CachesList<PageInfo>已打開過、但現在不顯示的頁面&#xff08;緩存池&#xff09; 1.2 算法邏輯查找緩存頁面&#xff1a;…

60 美元玩轉 Li-Fi —— 開源 OpenVLC 平臺入門(附 BeagleBone Black 驅動簡單解析)

60 美元玩轉 Li-Fi —— 開源 OpenVLC 平臺入門&#xff08;附 BeagleBone Black 及驅動解析&#xff09;一、什么是 OpenVLC&#xff1f; OpenVLC 是由西班牙 IMDEA Networks 研究所推出的開源可見光通信&#xff08;VLC / Li-Fi&#xff09;研究平臺。它把硬件、驅動、協議棧…

Python性能優化

Python 以其簡潔和易用性著稱,但在某些計算密集型或大數據處理場景下,性能可能成為瓶頸。幸運的是,通過一些巧妙的編程技巧,我們可以顯著提升Python代碼的執行效率。本文將介紹8個實用的性能優化技巧,幫助你編寫更快、更高效的Python代碼。   一、優化前的黃金法則:先測…

easyui碰到想要去除頂部欄按鈕邊框

只需要加上 plain"true"<a href"javascript:void(0)" class"easyui-linkbutton" iconCls"icon-add" plain"true"onclick"newCheck()">新增</a>

C++字符串詳解:原理、操作及力扣算法實戰

一、C字符串簡介在C中&#xff0c;字符串的處理方式主要有兩種&#xff1a;字符數組&#xff08;C風格字符串&#xff09;和std::string類。雖然字符數組是C語言遺留的底層實現方式&#xff0c;但現代C更推薦使用std::string類&#xff0c;其封裝了復雜的操作邏輯&#xff0c;提…

CMU15445-2024fall-project1踩坑經歷

p1目錄&#xff1a;lRU\_K替換策略LRULRU\_K大體思路SetEvictableRecordAccessSizeEvictRemoveDisk SchedulerBufferPoolNewPageDeletePageFlashPage/FlashAllPageCheckReadPage/CheckWritePagePageGuard并發設計主邏輯感謝CMU的教授們給我們分享了如此精彩的一門課程&#xff…

【C語言進階】帶你由淺入深了解指針【第四期】:數組指針的應用、介紹函數指針

前言上一期講了數組指針的原理&#xff0c;這一期接著上一期講述數組指針的應用以及數組參數、函數參數。首先看下面的代碼進行上一期內容的復習&#xff0c;pc應該是什么類型&#xff1f;char* arr[5] {0}; xxx pc &arr;分析&#xff1a;①首先判斷arr是一個數組&#x…

在HTML中CSS三種使用方式

一、行內樣式在標簽<>中輸入style "屬性&#xff1a;屬性值;"。(中等使用頻率)不利于CSS樣式的復用&#xff1b;違背了CSS內容和樣式分離的設計理念&#xff0c;后期難以維護。<p style"color: red">這是div中的p元素</p>二、內部樣式在…

汽車功能安全-軟件單元驗證 (Software Unit Verification)【用例導出方法、輸出物】8

文章目錄1 軟件單元驗證用例導出方法2 測試用例完整性度量標準3 驗證環境要求4 軟件單元驗證的工作產品1 軟件單元驗證用例導出方法 為確保軟件單元測試的測試案例規范符合9.4.2要求&#xff0c;應通過表8所列方法開發測試用例。 表8 軟件單元測試用例的得出方法&#xff1a; …

MySQL內置函數(8)

文章目錄前言一、日期函數二、字符串函數三、數學函數四、其它函數總結前言 其實在之前的幾篇中我們也用到了內置函數&#xff0c;現在我們再來系統學習一下它&#xff01; 一、日期函數 函數名稱描述current_date()獲取當前日期current_time()獲取當前時間current_timestamp(…

蒼穹外賣項目日記(day04)

蒼穹外賣|項目日記(day04) 前言: 今天主要是接口開發, 涉及的新東西不多, 需要注意的只有多表聯查和修改的邏輯,今日難點: 1.菜品的停起售狀態設置 2.套餐的停起售狀態設置 3.動態sql中的 useGeneratedKeys 與 keyProperty 兩個參數 一. 菜品的停起售狀態設置 ? 在菜品的停售中…

React之旅-05 List Key

每個React的初學者&#xff0c;在調試程序時&#xff0c;都會遇到這樣的警告&#xff1a;Warning: Each child in a list should have a unique "key" prop. 如下面的代碼&#xff1a; const list [Learn React, Learn GraphQL];const ListWithoutKey () > (&l…

[特殊字符] 人工智能技術全景:從基礎理論到前沿應用的深度解析

&#x1f680; 人工智能技術全景&#xff1a;從基礎理論到前沿應用的深度解析 在這個AI驅動的時代&#xff0c;理解人工智能的核心技術和應用場景已成為技術人員的必備技能。本文將帶你深入探索AI的發展脈絡、核心技術差異以及在各行業的創新應用。 文章目錄&#x1f680; 人工…

Go語言教程-環境搭建

前言 Go&#xff08;又稱 Golang&#xff09;是由 Google 開發的一種 開源、靜態類型、編譯型 編程語言&#xff0c;于 2009 年正式發布。它旨在解決現代軟件開發中的高并發、高性能和可維護性問題&#xff0c;尤其適合 云計算、微服務、分布式系統 等領域。 Go 語言國際官網…

windows指定某node及npm版本下載

下載并安裝 nvm-windowshttps://github.com/coreybutler/nvm-windows/releases&#xff08;選擇 nvm-setup.zip&#xff09;。打開命令提示符&#xff08;管理員權限&#xff09;&#xff0c;安裝 Node.js v16.15.0&#xff1a; nvm install 16.15.0 nvm use 16.15.0 驗證node版…

每天一個前端小知識 Day 28 - Web Workers / 多線程模型在前端中的應用實踐

Web Workers / 多線程模型在前端中的應用實踐&#x1f9e0; 一、為什么前端需要多線程&#xff1f; 單線程 JS 的瓶頸&#xff1a;瀏覽器主線程不僅負責執行 JS&#xff0c;還要負責&#xff1a; UI 渲染&#xff08;DOM/CSS&#xff09;用戶事件處理&#xff08;點擊、輸入&am…