pytorch深度Q網絡

?人工智能例子匯總:AI常見的算法和例子-CSDN博客?

DQN 引入了深度神經網絡來近似Q函數,解決了傳統Q-learning在處理高維狀態空間時的瓶頸,尤其是在像 Atari 游戲這樣的復雜環境中。DQN的核心思想是使用神經網絡 Q(s,a;θ)Q(s, a; \theta)Q(s,a;θ) 來近似 Q 值函數,其中 θ\thetaθ 是神經網絡的參數。

DQN 的關鍵創新包括:

  1. 經驗回放(Experience Replay):在強化學習中,當前的學習可能會依賴于最近的經驗,容易導致學習過程的不穩定。經驗回放通過將智能體的經歷存儲到一個回放池中,然后隨機抽取批量數據進行訓練,這樣可以打破數據之間的相關性,使得訓練更加穩定。

  2. 目標網絡(Target Network):在Q-learning中,Q值的更新依賴于下一個狀態的最大Q值。為了避免Q值更新時過度依賴當前網絡的輸出(導致不穩定),DQN引入了目標網絡。目標網絡的結構與行為網絡相同,但它的參數更新頻率較低,這使得Q值更新更加穩定。

DQN算法流程

  1. 初始化Q網絡:初始化Q網絡的參數 θ\thetaθ,以及目標網絡的參數 θ?\theta^-θ?(通常與Q網絡相同)。
  2. 行為選擇:基于當前的Q網絡來選擇動作(通常使用ε-greedy策略,即以ε的概率選擇隨機動作,否則選擇當前Q值最大的動作)。
  3. 執行動作并存儲經驗:執行所選動作,觀察獎勵,并記錄狀態轉移 (st,at,rt+1,st+1)(s_t, a_t, r_{t+1}, s_{t+1})(st?,at?,rt+1?,st+1?)。
  4. 經驗回放:從回放池中隨機抽取一個小批量的經驗數據。
  5. 計算Q值目標:對于每個樣本,計算目標值 y=rt+1+γmax?a′Q(st+1,a′;θ?)y = r_{t+1} + \gamma \max_{a'} Q(s_{t+1}, a'; \theta^-)y=rt+1?+γmaxa′?Q(st+1?,a′;θ?)。
  6. 更新Q網絡:通過最小化損失函數 L(θ)=1N∑(y?Q(st,at;θ))2L(\theta) = \frac{1}{N} \sum (y - Q(s_t, a_t; \theta))^2L(θ)=N1?∑(y?Q(st?,at?;θ))2 來更新Q網絡的參數。
  7. 周期性更新目標網絡:每隔一段時間,將Q網絡的參數復制到目標網絡。

DQN的應用

DQN在多個領域取得了重要應用,尤其是在強化學習任務中:

  • Atari 游戲:DQN 在多個經典的 Atari 游戲上成功展示了其能力,比如《Breakout》和《Pong》等。
  • 機器人控制:利用DQN,機器人可以在復雜的環境中自主學習如何執行任務。
  • 自動駕駛:在自動駕駛領域,DQN可以用來訓練智能體通過道路、避開障礙物等。

例子:

這里我們手動實現一個非常簡單的環境:一個1D平衡問題,類似于一個可以左右移動的棒球,目標是讓它保持在某個位置上。

import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt# 自定義環境
class SimpleEnv:def __init__(self):self.state = 0.0  # 初始狀態self.goal = 10.0  # 目標位置self.done = Falsedef reset(self):self.state = 0.0self.done = Falsereturn self.statedef step(self, action):if self.done:return self.state, 0, self.done  # 游戲結束,不再變化# 通過動作修改狀態self.state += action  # 動作是 -1、0、1,控制移動方向reward = -abs(self.state - self.goal)  # 獎勵是距離目標位置的負值# 如果距離目標很近,就結束if abs(self.state - self.goal) < 0.1:self.done = Truereward = 10  # 達到目標時獎勵較高return self.state, reward, self.done# Q網絡定義
class QNetwork(nn.Module):def __init__(self, input_dim, output_dim):super(QNetwork, self).__init__()self.fc = nn.Linear(input_dim, 24)self.fc2 = nn.Linear(24, output_dim)def forward(self, x):x = torch.relu(self.fc(x))x = self.fc2(x)return x# DQN智能體
class DQN:def __init__(self, env, gamma=0.99, epsilon=0.1, batch_size=32, learning_rate=1e-3):self.env = envself.gamma = gammaself.epsilon = epsilonself.batch_size = batch_sizeself.learning_rate = learning_rateself.input_dim = 1  # 因為環境狀態是一個單一的數值self.output_dim = 3  # 動作空間大小:-1, 0, 1self.q_network = QNetwork(self.input_dim, self.output_dim)self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.learning_rate)self.criterion = nn.MSELoss()def select_action(self, state):if random.random() < self.epsilon:return random.choice([-1, 0, 1])  # 隨機選擇動作state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)with torch.no_grad():q_values = self.q_network(state)# 將動作值 -1, 0, 1 轉換為索引 0, 1, 2action_idx = torch.argmax(q_values, dim=1).item()action_map = [-1, 0, 1]  # -1 -> 0, 0 -> 1, 1 -> 2return action_map[action_idx]def update(self, state, action, reward, next_state, done):state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)next_state = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0)# 將動作 -1, 0, 1 轉換為索引 0, 1, 2action_map = [-1, 0, 1]action_idx = action_map.index(action)action = torch.tensor(action_idx, dtype=torch.long).unsqueeze(0)reward = torch.tensor(reward, dtype=torch.float32).unsqueeze(0)# 確保done是Python標準bool類型done = torch.tensor(done, dtype=torch.float32).unsqueeze(0)# 計算目標Q值with torch.no_grad():next_q_values = self.q_network(next_state)next_q_value = next_q_values.max(1)[0]target_q_value = reward + self.gamma * next_q_value * (1 - done)# 獲取當前Q值q_values = self.q_network(state)action_q_values = q_values.gather(1, action.unsqueeze(1)).squeeze(1)# 計算損失并更新Q網絡loss = self.criterion(action_q_values, target_q_value)self.optimizer.zero_grad()loss.backward()self.optimizer.step()def train(self, num_episodes=200):rewards = []best_reward = -float('inf')  # 初始最好的獎勵設為負無窮best_episode = 0for episode in range(num_episodes):state = self.env.reset()  # 獲取初始狀態total_reward = 0done = Falsewhile not done:action = self.select_action([state])next_state, reward, done = self.env.step(action)total_reward += reward# 更新Q網絡self.update([state], action, reward, [next_state], done)state = next_staterewards.append(total_reward)# 記錄最佳獎勵和對應的episodeif total_reward > best_reward:best_reward = total_rewardbest_episode = episodeprint(f"Episode {episode}, Total Reward: {total_reward}")# 打印最佳結果print(f"Best Reward: {best_reward} at Episode {best_episode}")# 繪制獎勵圖plt.plot(rewards)plt.title('Total Rewards per Episode')plt.xlabel('Episode')plt.ylabel('Total Reward')# 在最佳位置添加標記plt.scatter(best_episode, best_reward, color='red', label=f"Best Reward at Episode {best_episode}")plt.legend()plt.show()# 初始化環境和DQN智能體
env = SimpleEnv()
dqn = DQN(env)# 訓練智能體
dqn.train()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/67789.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/67789.shtml
英文地址,請注明出處:http://en.pswp.cn/web/67789.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Baklib構建高效協同的基于云的內容中臺解決方案

內容概要 隨著云計算技術的飛速發展&#xff0c;內容管理的方式也在不斷演變。企業面臨著如何在數字化轉型過程中高效管理和協同處理內容的新挑戰。為應對這些挑戰&#xff0c;引入基于云的內容中臺解決方案顯得尤為重要。 Baklib作為創新型解決方案提供商&#xff0c;致力于…

DeepSeek-R1 論文. Reinforcement Learning 通過強化學習激勵大型語言模型的推理能力

論文鏈接&#xff1a; [2501.12948] DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning 實在太長&#xff0c;自行扔到 Model 里&#xff0c;去翻譯去提問吧。 工作原理&#xff1a; 主要技術&#xff0c;就是訓練出一些專有用途小模型&…

C++泛型編程指南03-CTAD

文章目錄 C17 自定義類型推斷指引&#xff08;CTAD&#xff09;深度解析一、基礎概念1. 核心作用2. 工作原理 二、標準庫中的 CTAD 應用1. 容器類型推導2. 智能指針推導3. 元組類型推導 三、自定義推導指引語法1. 基本語法結構2. 典型應用場景 四、推導指引設計模式1. 迭代器范…

deepseek+vscode自動化測試腳本生成

近幾日Deepseek大火,我這里也嘗試了一下,確實很強。而目前vscode的AI toolkit插件也已經集成了deepseek R1,這里就介紹下在vscode中利用deepseek幫助我們完成自動化測試腳本的實踐分享 安裝AI ToolKit并啟用Deepseek 微軟官方提供了一個針對AI輔助的插件,也就是 AI Toolk…

電介質超表面中指定渦旋的非線性生成

渦旋光束在眾多領域具有重要應用&#xff0c;但傳統光學器件產生渦旋光束的方式限制了其在集成系統中的應用。超表面的出現為渦旋光束的產生帶來了新的可能性&#xff0c;尤其是在非線性領域&#xff0c;盡管近些年來已經有一些研究&#xff0c;但仍存在諸多問題&#xff0c;如…

基于Springboot+mybatis+mysql+html圖書管理系統2

基于Springbootmybatismysqlhtml圖書管理系統2 一、系統介紹二、功能展示1.用戶登陸2.用戶主頁3.圖書查詢4.還書5.個人信息修改6.圖書管理&#xff08;管理員&#xff09;7.學生管理&#xff08;管理員&#xff09;8.廢除記錄&#xff08;管理員&#xff09; 三、數據庫四、其它…

重構字符串(767)

767. 重構字符串 - 力扣&#xff08;LeetCode&#xff09; 解法&#xff1a; class Solution { public:string reorganizeString(string s){string res;//因為1 < s.length < 500 &#xff0c; uint64_t 類型足夠uint16_t n s.size();if (n 0) {return res;}unordere…

本地部署DeepSeek方法

本地部署完成后的效果如下圖&#xff0c;整體與chatgpt類似&#xff0c;只是模型在本地推理。 我們在本地部署主要使用兩個工具&#xff1a; ollamaopen-webui ollama是在本地管理和運行大模型的工具&#xff0c;可以直接在terminal里和大模型對話。open-webui是提供一個類…

游戲引擎 Unity - Unity 啟動(下載 Unity Editor、生成 Unity Personal Edition 許可證)

Unity Unity 首次發布于 2005 年&#xff0c;屬于 Unity Technologies Unity 使用的開發技術有&#xff1a;C# Unity 的適用平臺&#xff1a;PC、主機、移動設備、VR / AR、Web 等 Unity 的適用領域&#xff1a;開發中等畫質中小型項目 Unity 適合初學者或需要快速上手的開…

【開源免費】基于Vue和SpringBoot的公寓報修管理系統(附論文)

本文項目編號 T 186 &#xff0c;文末自助獲取源碼 \color{red}{T186&#xff0c;文末自助獲取源碼} T186&#xff0c;文末自助獲取源碼 目錄 一、系統介紹二、數據庫設計三、配套教程3.1 啟動教程3.2 講解視頻3.3 二次開發教程 四、功能截圖五、文案資料5.1 選題背景5.2 國內…

Haskell語言的多線程編程

Haskell語言的多線程編程 Haskell是一種基于函數式編程范式的編程語言&#xff0c;以其強大的類型系統和懶惰求值著稱。近年來&#xff0c;隨著多核處理器的發展&#xff0c;多線程編程變得日益重要。雖然Haskell最初并不是為了多線程而設計&#xff0c;但它的設計理念和工具集…

《蒼穹外賣》項目學習記錄-Day11訂單統計

根據起始時間和結束時間&#xff0c;先把begin放入集合中用while循環當begin不等于end的時候&#xff0c;讓begin加一天&#xff0c;這樣就把這個區間內的時間放到List集合。 查詢每天的訂單總數也就是查詢的時間段是大于當天的開始時間&#xff08;0點0分0秒&#xff09;小于…

【python】python油田數據分析與可視化(源碼+數據集)【獨一無二】

&#x1f449;博__主&#x1f448;&#xff1a;米碼收割機 &#x1f449;技__能&#x1f448;&#xff1a;C/Python語言 &#x1f449;專__注&#x1f448;&#xff1a;專注主流機器人、人工智能等相關領域的開發、測試技術。 【python】python油田數據分析與可視化&#xff08…

FBX SDK的使用:基礎知識

Windows環境配置 FBX SDK安裝后&#xff0c;目錄下有三個文件夾&#xff1a; include 頭文件lib 編譯的二進制庫&#xff0c;根據你項目的配置去包含相應的庫samples 官方使用案列 動態鏈接 libfbxsdk.dll, libfbxsdk.lib是動態庫&#xff0c;需要在配置屬性->C/C->預…

【單層神經網絡】基于MXNet庫簡化實現線性回歸

寫在前面 同最開始的兩篇文章 完整程序及注釋 導入使用的庫# 基本 from mxnet import autograd, nd, gluon # 模型、網絡 from mxnet.gluon import nn from mxnet import init # 學習 from mxnet.gluon import loss as gloss # 數據集 from mxnet.gluon…

【爬蟲】JS逆向解決某藥的商品價格加密

??????????歡迎來到我的博客?????????? ??作者:秋無之地 ??簡介:CSDN爬蟲、后端、大數據領域創作者。目前從事python爬蟲、后端和大數據等相關工作,主要擅長領域有:爬蟲、后端、大數據開發、數據分析等。 ??歡迎小伙伴們點贊????、收藏??、…

OpenAI開源戰略反思:中國力量推動AI產業變革

在周五的Reddit問答會上&#xff0c;OpenAI首席執行官Sam Altman罕見承認公司正面臨來自中國科技企業的強勁挑戰。這位向來強硬的硅谷領軍者坦言&#xff0c;以深度求索&#xff08;DeepSeek&#xff09;為代表的中國AI公司正在改寫行業游戲規則。 這場歷時三小時的對話揭示了…

一文講解HashMap線程安全相關問題(上)

HashMap不是線程安全的&#xff0c;主要有以下幾個問題&#xff1a; ①、多線程下擴容會死循環。JDK1.7 中的 HashMap 使用的是頭插法插入元素&#xff0c;在多線程的環境下&#xff0c;擴容的時候就有可能導致出現環形鏈表&#xff0c;造成死循環。 JDK 8 時已經修復了這個問…

android java系統彈窗的基礎模板

1、資源文件 app\src\main\res\layout下增加custom_pop_layout.xml 定義彈窗的控件資源。 <?xml version"1.0" encoding"utf-8"?> <androidx.constraintlayout.widget.ConstraintLayout xmlns:android"http://schemas.android.com/apk/…

python學習——常用的內置函數匯總

文章目錄 類型轉換函數數學函數常用的迭代器操作函數常用的其他內置函數 類型轉換函數 數學函數 常用的迭代器操作函數 實操&#xff1a; from cv2.gapi import descr_oflst [55, 42, 37, 2, 66, 23, 18, 99]# (1) 排序操作 asc_lst sorted(lst) # 升序 desc_lst sorted(l…