PyTorch 深度學習實戰(14):Deep Deterministic Policy Gradient (DDPG) 算法

在上一篇文章中,我們介紹了 Proximal Policy Optimization (PPO) 算法,并使用它解決了 CartPole 問題。本文將深入探討 Deep Deterministic Policy Gradient (DDPG) 算法,這是一種用于連續動作空間的強化學習算法。我們將使用 PyTorch 實現 DDPG 算法,并應用于經典的 Pendulum 問題。


一、DDPG 算法基礎

DDPG 是一種基于 Actor-Critic 框架的算法,專門用于解決連續動作空間的強化學習問題。它結合了深度 Q 網絡(DQN)和策略梯度方法的優點,能夠高效地處理高維狀態和動作空間。

1. DDPG 的核心思想

  • 確定性策略

    • DDPG 使用確定性策略(Deterministic Policy),即給定狀態時,策略網絡直接輸出一個確定的動作,而不是動作的概率分布。

  • 目標網絡

    • DDPG 使用目標網絡(Target Network)來穩定訓練過程,類似于 DQN 中的目標網絡。

  • 經驗回放

    • DDPG 使用經驗回放緩沖區(Replay Buffer)來存儲和重用過去的經驗,從而提高數據利用率。

2. DDPG 的優勢

  • 適用于連續動作空間

    • DDPG 能夠直接輸出連續動作,適用于機器人控制、自動駕駛等任務。

  • 訓練穩定

    • 通過目標網絡和經驗回放,DDPG 能夠穩定地訓練策略網絡和價值網絡。

  • 高效采樣

    • DDPG 可以重復使用舊策略的采樣數據,從而提高數據利用率。

3. DDPG 的算法流程

  1. 使用當前策略采樣一批數據。

  2. 使用目標網絡計算目標 Q 值。

  3. 更新 Critic 網絡以最小化 Q 值的誤差。

  4. 更新 Actor 網絡以最大化 Q 值。

  5. 更新目標網絡。

  6. 重復上述過程,直到策略收斂。


二、Pendulum 問題實戰

我們將使用 PyTorch 實現 DDPG 算法,并應用于 Pendulum 問題。目標是控制擺桿使其保持直立。

1. 問題描述

Pendulum 環境的狀態空間包括擺桿的角度和角速度。動作空間是一個連續的扭矩值,范圍在 ?2,2 之間。智能體每保持擺桿直立一步,就會獲得一個負的獎勵,目標是最大化累積獎勵。

2. 實現步驟

  1. 安裝并導入必要的庫。

  2. 定義 Actor 網絡和 Critic 網絡。

  3. 定義 DDPG 訓練過程。

  4. 測試模型并評估性能。

3. 代碼實現

以下是完整的代碼實現:

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt
?
# 設置 Matplotlib 支持中文顯示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
?
# 檢查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")
?
# 環境初始化
env = gym.make('Pendulum-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
?
# 隨機種子設置
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
?
?
# 定義 Actor 網絡
class Actor(nn.Module):def __init__(self, state_dim, action_dim, max_action):super(Actor, self).__init__()self.fc1 = nn.Linear(state_dim, 512)self.ln1 = nn.LayerNorm(512)  # 層歸一化self.fc2 = nn.Linear(512, 512)self.ln2 = nn.LayerNorm(512)self.fc3 = nn.Linear(512, action_dim)self.max_action = max_action
?def forward(self, x):x = F.relu(self.ln1(self.fc1(x)))x = F.relu(self.ln2(self.fc2(x)))return self.max_action * torch.tanh(self.fc3(x))
?
?
# 定義 Critic 網絡
class Critic(nn.Module):def __init__(self, state_dim, action_dim):super(Critic, self).__init__()self.fc1 = nn.Linear(state_dim + action_dim, 256)self.fc2 = nn.Linear(256, 256)self.fc3 = nn.Linear(256, 1)
?def forward(self, x, u):x = F.relu(self.fc1(torch.cat([x, u], 1)))x = F.relu(self.fc2(x))x = self.fc3(x)return x
?
?
# 添加OU噪聲類
class OUNoise:def __init__(self, action_dim, mu=0, theta=0.15, sigma=0.2):self.mu = mu * np.ones(action_dim)self.theta = thetaself.sigma = sigmaself.reset()
?def reset(self):self.state = np.copy(self.mu)
?def sample(self):dx = self.theta * (self.mu - self.state) + self.sigma * np.random.randn(len(self.state))self.state += dxreturn self.state
?
?
# 定義 DDPG 算法
class DDPG:def __init__(self, state_dim, action_dim, max_action):self.actor = Actor(state_dim, action_dim, max_action).to(device)self.actor_target = Actor(state_dim, action_dim, max_action).to(device)self.actor_target.load_state_dict(self.actor.state_dict())self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-4)
?self.critic = Critic(state_dim, action_dim).to(device)self.critic_target = Critic(state_dim, action_dim).to(device)self.critic_target.load_state_dict(self.critic.state_dict())self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=1e-3)self.noise = OUNoise(action_dim, sigma=0.2)  # 示例:Ornstein-Uhlenbeck噪聲
?self.max_action = max_actionself.replay_buffer = deque(maxlen=1000000)self.batch_size = 64self.gamma = 0.99self.tau = 0.005self.noise_sigma = 0.5  # 初始噪聲強度self.noise_decay = 0.995
?self.actor_lr_scheduler = optim.lr_scheduler.StepLR(self.actor_optimizer, step_size=100, gamma=0.95)self.critic_lr_scheduler = optim.lr_scheduler.StepLR(self.critic_optimizer, step_size=100, gamma=0.95)
?def select_action(self, state):state = torch.FloatTensor(state).unsqueeze(0).to(device)self.actor.eval()with torch.no_grad():action = self.actor(state).cpu().data.numpy().flatten()self.actor.train()return action
?def train(self):if len(self.replay_buffer) < self.batch_size:return
?# 從經驗回放緩沖區中采樣batch = random.sample(self.replay_buffer, self.batch_size)state = torch.FloatTensor(np.array([transition[0] for transition in batch])).to(device)action = torch.FloatTensor(np.array([transition[1] for transition in batch])).to(device)reward = torch.FloatTensor(np.array([transition[2] for transition in batch])).reshape(-1, 1).to(device)next_state = torch.FloatTensor(np.array([transition[3] for transition in batch])).to(device)done = torch.FloatTensor(np.array([transition[4] for transition in batch])).reshape(-1, 1).to(device)
?# 計算目標 Q 值next_action = self.actor_target(next_state)target_Q = self.critic_target(next_state, next_action)target_Q = reward + (1 - done) * self.gamma * target_Q
?# 更新 Critic 網絡current_Q = self.critic(state, action)critic_loss = F.mse_loss(current_Q, target_Q.detach())self.critic_optimizer.zero_grad()critic_loss.backward()self.critic_optimizer.step()
?# 更新 Actor 網絡actor_loss = -self.critic(state, self.actor(state)).mean()self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()
?# 更新目標網絡for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
?def save(self, filename):torch.save(self.actor.state_dict(), filename + "_actor.pth")torch.save(self.critic.state_dict(), filename + "_critic.pth")
?def load(self, filename):self.actor.load_state_dict(torch.load(filename + "_actor.pth"))self.critic.load_state_dict(torch.load(filename + "_critic.pth"))
?
?
# 訓練流程
def train_ddpg(env, agent, episodes=500):rewards_history = []moving_avg = []
?for ep in range(episodes):state,_ = env.reset()episode_reward = 0done = False
?while not done:action = agent.select_action(state)next_state, reward, done, _, _ = env.step(action)agent.replay_buffer.append((state, action, reward, next_state, done))state = next_stateepisode_reward += rewardagent.train()
?rewards_history.append(episode_reward)moving_avg.append(np.mean(rewards_history[-50:]))
?if (ep + 1) % 50 == 0:print(f"Episode: {ep + 1}, Avg Reward: {moving_avg[-1]:.2f}")
?return moving_avg, rewards_history
?
?
# 訓練啟動
ddpg_agent = DDPG(state_dim, action_dim, max_action)
moving_avg, rewards_history = train_ddpg(env, ddpg_agent)
?
# 可視化結果
plt.figure(figsize=(12, 6))
plt.plot(rewards_history, alpha=0.6, label='single round reward')
plt.plot(moving_avg, 'r-', linewidth=2, label='moving average (50 rounds)')
plt.xlabel('episodes')
plt.ylabel('reward')
plt.title('DDPG training performance on Pendulum-v1')
plt.legend()
plt.grid(True)
plt.show()

三、代碼解析

  1. Actor 和 Critic 網絡

    • Actor 網絡輸出連續動作,通過 tanh 函數將動作限制在 ?max_action,max_action 范圍內。

    • Critic 網絡輸出狀態-動作對的 Q 值。

  2. DDPG 訓練過程

    • 使用當前策略采樣一批數據。

    • 使用目標網絡計算目標 Q 值。

    • 更新 Critic 網絡以最小化 Q 值的誤差。

    • 更新 Actor 網絡以最大化 Q 值。

    • 更新目標網絡。

  3. 訓練過程

    • 在訓練過程中,每 50 個 episode 打印一次平均獎勵。

    • 訓練結束后,繪制訓練過程中的總獎勵曲線。


四、運行結果

運行上述代碼后,你將看到以下輸出:

  • 訓練過程中每 50 個 episode 打印一次平均獎勵。

  • 訓練結束后,繪制訓練過程中的總獎勵曲線。


五、總結

本文介紹了 DDPG 算法的基本原理,并使用 PyTorch 實現了一個簡單的 DDPG 模型來解決 Pendulum 問題。通過這個例子,我們學習了如何使用 DDPG 算法進行連續動作空間的策略優化。

在下一篇文章中,我們將探討更高級的強化學習算法,如 Twin Delayed DDPG (TD3)。敬請期待!

代碼實例說明

  • 本文代碼可以直接在 Jupyter Notebook 或 Python 腳本中運行。

  • 如果你有 GPU,可以將模型和數據移動到 GPU 上運行,例如:actor = actor.to('cuda')state = state.to('cuda')

希望這篇文章能幫助你更好地理解 DDPG 算法!如果有任何問題,歡迎在評論區留言討論。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/72394.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/72394.shtml
英文地址,請注明出處:http://en.pswp.cn/web/72394.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【深度學習與大模型基礎】第5章-線性相關與生成子空間

線性相關是指一組向量中&#xff0c;至少有一個向量可以表示為其他向量的線性組合。具體來說&#xff0c;對于向量組 v1,v2,…,vn&#xff0c;如果存在不全為零的標量 c1,c2,…,cn使得&#xff1a; c1v1c2v2…cnvn0 則稱這些向量線性相關。否則&#xff0c;它們線性無關。 舉…

【Agent實戰】貨物上架位置推薦助手(RAG方式+結構化prompt(CoT)+API工具結合ChatGPT4o能力Agent項目實踐)

本文原創作者:姚瑞南 AI-agent 大模型運營專家,先后任職于美團、獵聘等中大廠AI訓練專家和智能運營專家崗;多年人工智能行業智能產品運營及大模型落地經驗,擁有AI外呼方向國家專利與PMP項目管理證書。(轉載需經授權) 目錄 結論 效果圖示 1.prompt 2. API工具封…

Go語言入門基礎詳解

一、語言歷史背景 Go語言由Google工程師Robert Griesemer、Rob Pike和Ken Thompson于2007年設計&#xff0c;2009年正式開源。設計目標&#xff1a; 兼具Python的開發效率與C的執行性能內置并發支持&#xff08;goroutine/channel&#xff09;簡潔的類型系統現代化的包管理跨…

HarmonyOS NEXT開發進階(十二):build-profile.json5 文件解析

文章目錄 一、前言二、Hvigor腳本文件三、任務與任務依賴圖四、多模塊管理4.1 靜態配置模塊 五、分模塊編譯六、配置多目標產物七、配置APP多目標構建產物八、定義 product 中包含的 target九、拓展閱讀 一、前言 編譯構建工具DevEco Hvigor&#xff08;以下簡稱Hvigor&#x…

基于SSM + JSP 的圖書商城系統

基于SSM的圖書商城 網上書城、圖書銷售系統、圖書銷售平臺 &#xff5c;Java&#xff5c;SSM&#xff5c;HTML&#xff5c;JSP&#xff5c; 項目采用技術&#xff1a; ①&#xff1a;開發環境&#xff1a;IDEA、JDK1.8、Maven、Tomcat ②&#xff1a;技術棧&#xff1a;Java、…

色板在數據可視化中的創新應用

色板在數據可視化中的創新應用&#xff1a;基于色彩感知理論的優化實踐 引言 在數據可視化領域&#xff0c;色彩編碼系統的設計已成為決定信息傳遞效能的核心要素。根據《Nature》期刊2024年發布的視覺認知研究&#xff0c;人類大腦對色彩的識別速度比形狀快40%&#xff0c;色…

K8S學習之基礎二十七:k8s中daemonset控制器

k8s中DaemonSet控制器 ? DaemonSet控制器確保k8s集群中&#xff0c;所有節點都運行一個相同的pod&#xff0c;當node節點增加時&#xff0c;新節點也會自動創建一個pod&#xff0c;當node節點從集群移除&#xff0c;對應的pod也會自動刪除。刪除DaemonSet也會刪除創建的pod。…

PyTorch 系列教程:使用CNN實現圖像分類

圖像分類是計算機視覺領域的一項基本任務&#xff0c;也是深度學習技術的一個常見應用。近年來&#xff0c;卷積神經網絡&#xff08;cnn&#xff09;和PyTorch庫的結合由于其易用性和魯棒性已經成為執行圖像分類的流行選擇。 理解卷積神經網絡&#xff08;cnn&#xff09; 卷…

Spring Cloud Stream - 構建高可靠消息驅動與事件溯源架構

一、引言 在分布式系統中&#xff0c;傳統的 REST 調用模式往往導致耦合&#xff0c;難以滿足高并發和異步解耦的需求。消息驅動架構&#xff08;EDA, Event-Driven Architecture&#xff09;通過異步通信、事件溯源等模式&#xff0c;提高了系統的擴展性與可觀測性。 作為 S…

王者榮耀道具頁面爬蟲(json格式數據)

首先這個和英雄頁面是不一樣的&#xff0c;英雄頁面的圖片鏈接是直接放在源代碼里面的&#xff0c;直接就可以請求到&#xff0c;但是這個源代碼里面是沒有的 雖然在檢查頁面能夠搜索到&#xff0c;但是應該是動態加載的&#xff0c;源碼中搜不到該鏈接 然后就去看看是不是某…

【一起來學kubernetes】12、k8s中的Endpoint詳解

一、Endpoint的定義與作用二、Endpoint的創建與管理三、Endpoint的查看與組成四、EndpointSlice五、Endpoint的使用場景六、Endpoint與Service的關系1、定義與功能2、創建與管理3、關系與交互4、使用場景與特點 七、Endpoint的kubectl命令1. 查看Endpoint2. 創建Endpoint3. 編輯…

結構型模式之橋接模式:解耦抽象和實現

在面向對象設計中&#xff0c;我們經常遇到需要擴展某些功能&#xff0c;但又不能修改現有代碼的情況。為了避免繼承帶來的復雜性和維護難度&#xff0c;橋接模式&#xff08;Bridge Pattern&#xff09;應運而生。橋接模式是一種結構型設計模式&#xff0c;旨在解耦抽象部分和…

如何用Java將實體類轉換為JSON并輸出到控制臺?

在軟件開發的過程中&#xff0c;Java是一種廣泛使用的編程語言&#xff0c;而在眾多應用中&#xff0c;數據的傳輸和存儲經常需要使用JSON格式。JSON&#xff08;JavaScript Object Notation&#xff09;是一種輕量級的數據交換格式&#xff0c;易于人類閱讀和編寫&#xff0c;…

Vue3 開發的 VSCode 插件

1. Volar Vue3 正式版發布&#xff0c;Vue 團隊官方推薦 Volar 插件來代替 Vetur 插件&#xff0c;不僅支持 Vue3 語言高亮、語法檢測&#xff0c;還支持 TypeScript 和基于 vue-tsc 的類型檢查功能。 2. Vue VSCode Snippets 為開發者提供最簡單快速的生成 Vue 代碼片段的方…

C# Enumerable類 之 集合操作

總目錄 前言 在 C# 中&#xff0c;System.Linq.Enumerable 類是 LINQ&#xff08;Language Integrated Query&#xff09;的核心組成部分&#xff0c;它提供了一系列靜態方法&#xff0c;用于操作實現了 IEnumerable 接口的集合。通過這些方法&#xff0c;我們可以輕松地對集合…

51c自動駕駛~合集54

我自己的原文哦~ https://blog.51cto.com/whaosoft/13517811 #Chameleon 快慢雙系統&#xff01;清華&博世最新&#xff1a;無需訓練即可解決復雜道路拓撲 在自動駕駛技術中&#xff0c;車道拓撲提取是實現無地圖導航的核心任務之一。它要求系統不僅能檢測出車道和交…

Spring Cloud Eureka - 高可用服務注冊與發現解決方案

在微服務架構中&#xff0c;服務注冊與發現是確保系統動態擴展和高效通信的關鍵。Eureka 作為 Spring Cloud 生態的核心組件&#xff0c;不僅提供去中心化的服務治理能力&#xff0c;還通過自我保護、健康檢查等機制提升系統的穩定性&#xff0c;使其成為微服務架構中的重要支撐…

Unity屏幕適配——立項時設置

項目類型&#xff1a;2D游戲、豎屏、URP 其他類型&#xff0c;部分原理類似。 1、確定設計分辨率&#xff1a;750*1334 為什么是它&#xff1f; 因為它是 iphone8 的尺寸&#xff0c;寬高比適中。 方便后續適配到真機的 “更長屏” 或 “更寬屏” 2、在場景…

深度學習中LayerNorm與RMSNorm對比

LayerNorm不同于BatchNorm&#xff0c;其與batch大小無關&#xff0c;均值和方差 在 每個樣本的特征維度 C 內計算&#xff0c; 適用于 變長輸入&#xff08;如 NLP 任務中的 Transformer&#xff09; 詳細的BatchNorm在之前的一篇文章進行了詳細的介紹&#xff1a;深度學習中B…

使用WireShark解密https流量

概述 https協議是在http協議的基礎上&#xff0c;使用TLS協議對http數據進行了加密&#xff0c;使得網絡通信更加安全。一般情況下&#xff0c;使用WireShark抓取的https流量&#xff0c;數據都是加密的&#xff0c;無法直接查看。但是可以通過以下兩種方法&#xff0c;解密抓…