雙深度Q網絡(Double DQN)基礎解析與python實例:訓練穩定倒立擺

目錄

1. 前言

2. Double DQN的核心思想

3. Double DQN 實例:倒立擺

4. Double DQN的關鍵改進點

5. 雙重網絡更新策略

6. 總結


1. 前言

在強化學習領域,深度Q網絡(DQN)開啟了利用深度學習解決復雜決策問題的新篇章。然而,標準DQN存在一個顯著問題:Q值的過估計。為解決這一問題,Double DQN應運而生,它通過引入兩個網絡來減少Q值的過估計,從而提高策略學習的穩定性和性能。本文將深入淺出地介紹Double DQN的核心思想,并通過一個完整python實現案例,幫助大家全面理解強化這一學習算法。

2. Double DQN的核心思想

標準DQN使用同一個網絡同時選擇動作和評估動作價值,這容易導致Q值的過估計。Double DQN通過將動作選擇和價值評估分離到兩個不同的網絡來解決這個問題:

  1. 一個網絡(在線網絡)用于選擇當前狀態下的最佳動作

  2. 另一個網絡(目標網絡)用于評估這個動作的價值

這種分離減少了自舉過程中動作選擇和價值評估的關聯性,從而有效降低了Q值的過估計。

結構如下:

3. Double DQN 實例:倒立擺

接下來,我們將實現一個完整的Double DQN,解決CartPole平衡問題。這個例子包含了所有關鍵組件:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
import random
from collections import deque# 1. 定義DQN網絡結構
class DQN(nn.Module):def __init__(self, state_dim, action_dim):super(DQN, self).__init__()self.fc1 = nn.Linear(state_dim, 128)self.fc2 = nn.Linear(128, 128)self.fc3 = nn.Linear(128, action_dim)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return x# 2. 經驗回放緩沖區
class ReplayBuffer:def __init__(self, capacity):self.buffer = deque(maxlen=capacity)def add(self, state, action, reward, next_state, done):self.buffer.append((state, action, reward, next_state, done))def sample(self, batch_size):samples = random.sample(self.buffer, batch_size)states, actions, rewards, next_states, dones = zip(*samples)return states, actions, rewards, next_states, donesdef __len__(self):return len(self.buffer)# 3. Double DQN代理
class DoubleDQNAgent:def __init__(self, state_dim, action_dim):self.policy_net = DQN(state_dim, action_dim)self.target_net = DQN(state_dim, action_dim)self.target_net.load_state_dict(self.policy_net.state_dict())self.target_net.eval()self.optimizer = optim.Adam(self.policy_net.parameters(), lr=0.001)self.replay_buffer = ReplayBuffer(10000)self.batch_size = 64self.gamma = 0.99  # 折扣因子self.epsilon = 1.0  # 探索率self.epsilon_decay = 0.995self.min_epsilon = 0.01self.action_dim = action_dim# 根據ε-greedy策略選擇動作def select_action(self, state):if random.random() < self.epsilon:return random.randint(0, self.action_dim - 1)else:with torch.no_grad():return self.policy_net(torch.FloatTensor(state)).argmax().item()# 存儲經驗def store_transition(self, state, action, reward, next_state, done):self.replay_buffer.add(state, action, reward, next_state, done)# 更新網絡def update(self):if len(self.replay_buffer) < self.batch_size:return# 從經驗回放中采樣states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)# 轉換為PyTorch張量states = torch.FloatTensor(states)actions = torch.LongTensor(actions)rewards = torch.FloatTensor(rewards)next_states = torch.FloatTensor(next_states)dones = torch.FloatTensor(dones)# 計算當前Q值current_q = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)# 計算目標Q值(使用Double DQN方法)# 使用策略網絡選擇動作,目標網絡評估價值with torch.no_grad():# 從策略網絡中選擇最佳動作policy_actions = self.policy_net(next_states).argmax(dim=1)# 從目標網絡中評估這些動作的值next_q = self.target_net(next_states).gather(1, policy_actions.unsqueeze(1)).squeeze(1)target_q = rewards + self.gamma * next_q * (1 - dones)# 計算損失并優化loss = nn.MSELoss()(current_q, target_q)self.optimizer.zero_grad()loss.backward()self.optimizer.step()# 更新目標網絡(軟更新)for target_param, policy_param in zip(self.target_net.parameters(), self.policy_net.parameters()):target_param.data.copy_(0.001 * policy_param.data + 0.999 * target_param.data)# 減少探索率self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay)# 訓練過程def train_double_dqn():# 創建環境env = gym.make('CartPole-v1')state_dim = env.observation_space.shape[0]action_dim = env.action_space.n# 創建代理agent = DoubleDQNAgent(state_dim, action_dim)# 訓練參數episodes = 500max_steps = 500# 訓練循環for episode in range(episodes):state, _ = env.reset()total_reward = 0for step in range(max_steps):action = agent.select_action(state)next_state, reward, done, _, _ = env.step(action)# 修改獎勵以加速學習reward = reward if not done else -10agent.store_transition(state, action, reward, next_state, done)agent.update()total_reward += rewardstate = next_stateif done:break# 每10個episodes更新一次目標網絡if episode % 10 == 0:agent.target_net.load_state_dict(agent.policy_net.state_dict())print(f"Episode: {episode + 1}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.2f}")env.close()# 執行訓練
if __name__ == "__main__":train_double_dqn()

4. Double DQN的關鍵改進點

  1. 雙網絡結構:通過將動作選擇(策略網絡)和價值評估(目標網絡)分離,減少了Q值的過估計。

  2. 經驗回放:通過存儲和隨機采樣歷史經驗,打破了數據的相關性,提高了學習穩定性。

  3. ε-greedy策略:平衡探索與利用,隨著訓練進行逐漸減少探索概率。

目標網絡在Double DQN中扮演著非常重要的角色:

  • 它為策略網絡提供穩定的目標Q值

  • 通過延遲更新,減少了目標Q值的波動

  • 與策略網絡共同工作,實現了動作選擇和價值評估的分離

5. 雙重網絡更新策略

在Double DQN中,我們使用了軟更新(soft update)策略來更新目標網絡:

for target_param, policy_param in zip(self.target_net.parameters(), self.policy_net.parameters()):target_param.data.copy_(0.001 * policy_param.data + 0.999 * target_param.data)

這種軟更新方式比傳統的目標網絡定期硬更新(hard update)更平滑,有助于訓練過程的穩定。

6. 總結

本文通過詳細講解Double DQN的原理,并提供了完整的python實現代碼,展示了如何應用這一先進強化學習算法解決實際問題。與標準DQN相比,Double DQN通過引入雙網絡結構,有效解決了Q值過估計問題,提高了策略學習的穩定性和最終性能。Double DQN是強化學習領域的一個重要進步,為后續更高級的算法(如Dueling DQN、C51、Rainbow DQN等)奠定了基礎。通過理解Double DQN的原理和實現,讀者可以為進一步探索復雜強化學習算法打下堅實基礎。在實際應用中,可以根據具體任務調整網絡結構、超參數(如學習率、折扣因子、經驗回放緩沖區大小等)以及探索策略,以獲得最佳性能。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/84515.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/84515.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/84515.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

使用KubeKey快速部署k8s v1.31.8集群

實戰環境涉及軟件版本信息&#xff1a; 使用kubekey部署k8s 1. 操作系統基礎配置 設置主機名、DNS解析、時鐘同步、防火墻關閉、ssh免密登錄等等系統基本設置 dnf install -y curl socat conntrack ebtables ipset ipvsadm 2. 安裝部署 K8s 2.1 下載 KubeKey ###地址 https…

SQL:窗口函數(Window Functions)

目錄 什么是窗口函數&#xff1f; 基本語法結構 為什么要用窗口函數&#xff1f; 常見的窗口函數分類 1?? 排名類函數 2?? 聚合類函數&#xff08;不影響原始行&#xff09; 3?? 值訪問函數 窗口范圍說明&#xff08;ROWS / RANGE&#xff09; 什么是窗口函數&a…

相機內參 opencv

視場角定相機內參 import numpy as np import cv2 import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3Ddef calculate_camera_intrinsics(image_width640, image_height480, fov55, is_horizontalTrue):"""計算相機內參矩陣參數:image_w…

MATLAB 各個工具箱 功能說明

? 想必大家在安裝MATLAB時&#xff0c;或多或少會疑惑應該安裝哪些工具箱。筆者遇到了兩種情況——只安裝了MATLAB主程序&#xff0c;老師讓用MATLAB的時候卻發現沒有安裝對應安裝包&#xff1b;第二次安裝學聰明了&#xff0c;全選安裝&#xff0c;嗯……占用了20多個G。 ?…

學習日記-day14-5.23

完成目標&#xff1a; 學習java下半段課程 知識點&#xff1a; 1.多態轉型 知識點 核心內容 重點 多態轉型 向上轉型&#xff08;父類引用指向子類對象&#xff09; 與向下轉型&#xff08;強制類型轉換&#xff09;的機制與區別 向上轉型自動完成&#xff0c;向下轉型需…

【編程語言】【Java】一篇文章學習java,復習完善知識體系

第一章 Java基礎 1.1 變量與數據類型 1.1.1 基本數據類型 1.1.1.1 整數類型&#xff08;byte、short、int、long&#xff09; 在 Java 中&#xff0c;整數類型用于表示沒有小數部分的數字&#xff0c;不同的整數類型有不同的取值范圍和占用的存儲空間&#xff1a; byte&am…

匯量科技前端面試題及參考答案

數組去重的方法有哪些&#xff1f; 在 JavaScript 中&#xff0c;數組去重是一個常見的操作&#xff0c;有多種方法可以實現這一目標。每種方法都有其適用場景和性能特點&#xff0c;下面將詳細介紹幾種主要的去重方法。 使用 Set 數據結構 Set 是 ES6 引入的一種新數據結構&a…

Git實戰演練,模擬日常使用,快速掌握命令

01 引言 上一期借助Idea&#xff0c;完成了Git倉庫的建立、配置、代碼提交等操作&#xff0c;初步入門了Git的使用。然而日常開發中經常面臨各種各樣的問題&#xff0c;入門級的命令遠遠不夠使用。 這一期&#xff0c;我們將展開介紹Git的日常處理命令&#xff0c;解決日常問…

wordpress主題開發中常用的12個模板文件

在WordPress主題開發中&#xff0c;有多種常用的模板文件&#xff0c;它們負責控制網站不同部分的顯示內容和布局&#xff0c;以下是一些常見的模板文件&#xff1a; 1.index.php 這是WordPress主題的核心模板文件。當沒有其他更具體的模板文件匹配當前頁面時&#xff0c;Wor…

數據庫blog5_數據庫軟件架構介紹(以Mysql為例)

&#x1f33f;軟件的架構 &#x1f342;分類 軟件架構總結為兩種主要類型&#xff1a;一體式架構和分布式架構 ● 一體化架構 一體式架構是一種將所有功能集成到一個單一的、不可分割的應用程序中的架構模式。這種架構通常是一個大型的、復雜的單一應用程序&#xff0c;包含所…

離線服務器算法部署環境配置

本文將詳細記錄我如何為一臺全新的離線服務器配置必要的運行環境&#xff0c;包括基礎編譯工具、NVIDIA顯卡驅動以及NVIDIA-Docker&#xff0c;以便順利部署深度學習算法。 前提條件&#xff1a; 目標離線服務器已安裝操作系統&#xff08;本文以Ubuntu 18.04為例&#xff09…

chromedp -—— 基于 go 的自動化操作瀏覽器庫

chromedp chromedp 是一個用于 Chrome 瀏覽器的自動化測試工具&#xff0c;基于 Go 語言開發&#xff0c;專門用于控制和操作 Chrome 瀏覽器實例。 chromedp 安裝 go get -u github.com/chromedp/chromedp基于chromedp 實現的的簡易學習通刷課系統 目前實現的功能&#xff…

高級特性實戰:死信隊列、延遲隊列與優先級隊列(三)

四、優先級隊列&#xff1a;優先處理重要任務 4.1 優先級隊列概念解析 優先級隊列&#xff08;Priority Queue&#xff09;是一種特殊的隊列數據結構&#xff0c;它與普通隊列的主要區別在于&#xff0c;普通隊列遵循先進先出&#xff08;FIFO&#xff09;的原則&#xff0c;…

python打卡day34

GPU訓練及類的call方法 知識點回歸&#xff1a; CPU性能的查看&#xff1a;看架構代際、核心數、線程數GPU性能的查看&#xff1a;看顯存、看級別、看架構代際GPU訓練的方法&#xff1a;數據和模型移動到GPU device上類的call方法&#xff1a;為什么定義前向傳播時可以直接寫作…

Newtonsoft Json序列化數據不序列化默認數據

問題描述 數據在序列號為json時,一些默認值也序列化了,像旋轉rot都是0、縮放scal都是1,這樣的默認值完全可以去掉,減少和服務器通信數據量 核心代碼 數據結構字段增加[DefaultValue(1.0)]屬性,縮放的默認值為1 public class Vec3DataOne{[DefaultValue(1.0)] public flo…

可增添功能的鼠標右鍵優化工具

軟件介紹 本文介紹一款能優化Windows電腦的軟件&#xff0c;它可以讓鼠標右鍵菜單添加多種功能。 軟件基本信息 這款名為Easy Context Menu的鼠標右鍵菜單工具非常小巧&#xff0c;軟件大小僅1.14MB&#xff0c;打開即可直接使用&#xff0c;無需進行安裝。 添加功能列舉 它…

Gemini 2.5 Pro 一次測試

您好&#xff0c;您遇到的重定向循環問題&#xff0c;即在 /user/messaging、/user/login?return_to/user/messaging 和 /user/login 之間反復跳轉&#xff0c;通常是由于客戶端的身份驗證狀態檢查和頁面重定向邏輯存在沖突或競爭條件。 在分析了您提供的代碼&#xff08;特別…

vue3前端后端地址可配置方案

在開發vue3項目過程中&#xff0c;需要切換不同的服務器部署&#xff0c;代碼中配置的服務需要可靈活配置&#xff0c;不隨著run npm build把網址打包到代碼資源中&#xff0c;不然每次切換都需要重新run npm build。需要一個配置文件可以修改服務地址&#xff0c;而打包的代碼…

大模型微調與高效訓練

隨著預訓練大模型(如BERT、GPT、ViT、LLaMA、CLIP等)的崛起,人工智能進入了一個新的范式:預訓練-微調(Pre-train, Fine-tune)。這些大模型在海量數據上學習到了通用的、強大的表示能力和世界知識。然而,要將這些通用模型應用于特定的下游任務或領域,通常還需要進行微調…

編程技能:字符串函數10,strchr

專欄導航 本節文章分別屬于《Win32 學習筆記》和《MFC 學習筆記》兩個專欄&#xff0c;故劃分為兩個專欄導航。讀者可以自行選擇前往哪個專欄。 &#xff08;一&#xff09;WIn32 專欄導航 上一篇&#xff1a;編程技能&#xff1a;字符串函數09&#xff0c;strncmp 回到目錄…