【DQN】基于pytorch的強化學習算法Demo

目錄

  • 簡介
  • 代碼

簡介

DQN(Deep Q-Network)是一種基于深度神經網絡的強化學習算法,于2013年由DeepMind提出。它的目標是解決具有離散動作空間的強化學習問題,并在多個任務中取得了令人矚目的表現。

DQN的核心思想是使用深度神經網絡來逼近狀態-動作值函數(Q函數),將當前狀態作為輸入,輸出每個可能動作的Q值估計。通過不斷迭代和更新網絡參數,DQN能夠逐步學習到最優的Q函數,并根據Q值選擇具有最大潛在回報的動作。

DQN的訓練過程中采用了兩個關鍵技術:經驗回放和目標網絡。經驗回放是一種存儲并重復使用智能體經歷的經驗的方法,它可以破壞數據之間的相關性,提高訓練的穩定性。目標網絡用于解決訓練過程中的估計器沖突問題,通過固定一個與訓練網絡參數較為獨立的目標網絡來提供穩定的目標Q值,從而減少訓練的不穩定性。

DQN還采用了一種策略稱為epsilon-貪心策略來在探索和利用之間進行權衡。初始時,智能體以較高的概率選擇隨機動作(探索),隨著訓練的進行,該概率逐漸降低,讓智能體更多地依靠Q值選擇最佳動作(利用)。

DQN在許多復雜任務中取得了顯著的成果,特別是在Atari游戲等需要視覺輸入的任務中。它的成功在很大程度上得益于深度神經網絡的強大擬合能力和經驗回放的效果,使得智能體能夠通過與環境的交互進行自主學習。

代碼

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gym# Hyper Parameters
BATCH_SIZE = 32
LR = 0.01                   # learning rate
EPSILON = 0.9               # greedy policy
GAMMA = 0.9                 # reward discount
TARGET_REPLACE_ITER = 100   # target update frequency
MEMORY_CAPACITY = 2000
env = gym.make('CartPole-v1',render_mode="human")
#env = gym.make('CartPole-v0')
env = env.unwrapped
N_ACTIONS = env.action_space.n
N_STATES = env.observation_space.shape[0]
ENV_A_SHAPE = 0 if isinstance(env.action_space.sample(), int) else env.action_space.sample().shape     # to confirm the shapeclass Net(nn.Module):def __init__(self, ):super(Net, self).__init__()self.fc1 = nn.Linear(N_STATES, 50)self.fc1.weight.data.normal_(0, 0.1)   # initializationself.out = nn.Linear(50, N_ACTIONS)self.out.weight.data.normal_(0, 0.1)   # initializationdef forward(self, x):x = self.fc1(x)x = F.relu(x)actions_value = self.out(x)return actions_valueclass DQN(object):def __init__(self):self.eval_net, self.target_net = Net(), Net()self.learn_step_counter = 0                                     # for target updatingself.memory_counter = 0                                         # for storing memoryself.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2))     # initialize memoryself.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)self.loss_func = nn.MSELoss()def choose_action(self, x):x = torch.unsqueeze(torch.FloatTensor(x), 0)# input only one sampleif np.random.uniform() < EPSILON:   # greedyactions_value = self.eval_net.forward(x)action = torch.max(actions_value, 1)[1].data.numpy()action = action[0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)  # return the argmax indexelse:   # randomaction = np.random.randint(0, N_ACTIONS)action = action if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)return actiondef store_transition(self, s, a, r, s_):transition = np.hstack((s, [a, r], s_))# replace the old memory with new memoryindex = self.memory_counter % MEMORY_CAPACITYself.memory[index, :] = transitionself.memory_counter += 1def learn(self):# target parameter updateif self.learn_step_counter % TARGET_REPLACE_ITER == 0:self.target_net.load_state_dict(self.eval_net.state_dict())self.learn_step_counter += 1# sample batch transitionssample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)b_memory = self.memory[sample_index, :]b_s = torch.FloatTensor(b_memory[:, :N_STATES])b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int))b_r = torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2])b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:])# q_eval w.r.t the action in experienceq_eval = self.eval_net(b_s).gather(1, b_a)  # shape (batch, 1)q_next = self.target_net(b_s_).detach()     # detach from graph, don't backpropagateq_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1)   # shape (batch, 1)loss = self.loss_func(q_eval, q_target)self.optimizer.zero_grad()loss.backward()self.optimizer.step()dqn = DQN()  # 創建 DQN 對象print('\nCollecting experience...')
for i_episode in range(400):  # 進行 400 個回合的訓練s, info = env.reset()  # 環境重置,獲取初始狀態 s 和其他信息ep_r = 0  # 初始化本回合的總獎勵 ep_r 為 0while True:env.render()  # 顯示環境,通過調用 render() 方法,可以將當前環境的狀態以圖形化的方式呈現出來.a = dqn.choose_action(s)  # 根據當前狀態選擇動作 a# 下一個狀態(nextstate):返回智能體執行動作a后環境的下一個狀態。在示例中,它存儲在變量s_中。獎勵(reward):返回智能體執行動作a后在環境中獲得的獎勵。在示例中,它存儲在變中。# 完成標志(doneflag):返回一個布爾值,指示智能體是否已經完成了當前環境。在示例中,它存儲在變量done中。# 截斷標志(truncatedflag):返回一個布爾值,表示當前狀態是否是由于達到了最大時間步驟或其他特定條件而被截斷。在示例中,它存儲在變量truncated中。# 其他信息(info):返回一個包含其他輔助信息的字典或對象。在示例中,它存儲在變量info中。# 執行動作,獲取下一個狀態 s_,獎勵 r,done 標志位,以及其他信息s_, r, done, truncated, info = env.step(a)# 修改獎勵值#根據智能體在x方向和theta方向上與目標位置的偏離程度,計算兩個獎勵值r1和r2。具體計算方法是將每個偏離程度除以相應的閾值,然后減去一個常數(0.8和0.5)得到獎勵值。這樣,如果智能體在這兩個方向上的偏離程度越小,獎勵值越高。x, x_dot, theta, theta_dot = s_  # 從 s_ 中提取參數r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8  # 根據 x 的偏離程度計算獎勵 r1r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5  # 根據 theta 的偏離程度計算獎勵 r2r = r1 + r2  # 組合兩個獎勵成為最終的獎勵 rdqn.store_transition(s, a, r, s_)  # 存儲狀態轉換信息到經驗池ep_r += r  # 更新本回合的總獎勵if dqn.memory_counter > MEMORY_CAPACITY:  # 當經驗池中的樣本數量超過閾值 MEMORY_CAPACITY 時進行學習dqn.learn()if done:  # 如果本回合結束print('Ep: ', i_episode,'| Ep_r: ', round(ep_r, 2))  # 打印本回合的回合數和總獎勵if done:  # 如果任務結束break  # 跳出當前回合的循環s = s_  # 更新狀態,準備進行下一步動作選擇

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/165866.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/165866.shtml
英文地址,請注明出處:http://en.pswp.cn/news/165866.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

企業數字化轉型的作用是什么?_光點科技

在當今快速變化的商業環境中&#xff0c;數字化轉型已成為企業發展的重要策略。企業數字化轉型指的是利用數字技術改造傳統業務模式和管理方式&#xff0c;以提升效率、增強競爭力和創造新的增長機會。 提升運營效率&#xff1a;數字化轉型通過引入自動化工具和智能系統&#x…

指數退避重試

指數退避重試&#xff08;Exponential Backoff and Retry&#xff09;是一種網絡通信中常用的錯誤處理和重試策略。它通常用于處理臨時性的故障&#xff0c;例如網絡延遲、服務器過載或臨時性的錯誤&#xff0c;以提高系統的可靠性和穩定性。 基本思想是&#xff0c;當發生一個…

NX二次開發UF_CSYS_ask_wcs 函數介紹

文章作者&#xff1a;里海 來源網站&#xff1a;https://blog.csdn.net/WangPaiFeiXingYuan UF_CSYS_ask_wcs Defined in: uf_csys.h int UF_CSYS_ask_wcs(tag_t * wcs_id ) overview 概述 Gets the object identifier of the coordinate system to which the work coordin…

JMeter壓測常見面試問題

1、JMeter可以模擬哪些類型的負載&#xff1f; JMeter可以模擬各種類型的負載&#xff0c;包括但不限于Web應用程序、API、數據庫、FTP、SMTP、JMS、SOAP / RESTful Web服務等。這使得JMeter成為一個功能強大且靈活的壓力測試工具。 2、如何配置JMeter來進行分布式壓力測試&a…

在華為昇騰開發板安裝gdal-python

作者:朱金燦 來源:clever101的專欄 為什么大多數人學不會人工智能編程?>>> 在華為昇騰開發板安裝gdal-python分為兩步:編譯gdal庫和下載gdal對應的python包。 1.編譯gdal庫 首先下載gdal庫,。在linux(arm架構)上編譯的gdal庫及其第三方庫源碼,內含一個編譯…

智慧法院 | RPA+AI打造智慧執行助手,解決“案多人少”現實難題

為深化政法智能化建設&#xff0c;加強“智慧治理”“智慧法院”“智慧檢務”“智慧警務”“智慧司法”等信息平臺建設&#xff0c;深入實施大數據戰略&#xff0c;實現科技創新成果同政法工作深度融合。法制日報社于今年3月繼續舉辦了2023政法智能化建設創新案例及論文征集宣傳…

Unity UGUI的HorizontalLayoutGroup(水平布局)組件

Horizontal Layout Group | Unity UI | 1.0.0 1. 什么是HorizontalLayoutGroup組件&#xff1f; HorizontalLayoutGroup是Unity UGUI中的一種布局組件&#xff0c;用于在水平方向上對子物體進行排列和布局。它可以根據一定的規則自動調整子物體的位置和大小&#xff0c;使它…

Shell腳本:Linux Shell腳本學習指南(第二部分Shell編程)二

第二部分&#xff1a;Shell編程&#xff08;二&#xff09; 十一、Shell數組&#xff1a;Shell數組定義以及獲取數組元素 和其他編程語言一樣&#xff0c;Shell 也支持數組。數組&#xff08;Array&#xff09;是若干數據的集合&#xff0c;其中的每一份數據都稱為元素&#…

Navicat 技術指引 | GaussDB服務器對象的創建/設計(編輯)

Navicat Premium&#xff08;16.2.8 Windows版或以上&#xff09; 已支持對GaussDB 主備版的管理和開發功能。它不僅具備輕松、便捷的可視化數據查看和編輯功能&#xff0c;還提供強大的高階功能&#xff08;如模型、結構同步、協同合作、數據遷移等&#xff09;&#xff0c;這…

【華為OD題庫-034】字符串化繁為簡-java

題目 給定一個輸入字符串&#xff0c;字符串只可能由英文字母(a ~ z、A ~ Z)和左右小括號()組成。當字符里存在小括號時&#xff0c;小括號是成對的&#xff0c;可以有一個或多個小括號對&#xff0c;小括號對不會嵌套&#xff0c;小括號對內可以包含1個或多個英文字母也可以不…

Jenkins Ansible 參數構建

首先在Jenkins中創建自由項目 在web端配置完成后在另一臺機子上下載nginx 在gitlab端創建項目并創建文件配置代碼 在有Jenkins的機器上下載Ansible [rootslave1 ~]# yum -y install epel-release [rootslave1 ~]# yum -y install ansible再進入下載nginx機器中克隆gitlab項目…

Android 框架層AIDL 添加接口

文章目錄 AIDL的原理構建AIDL的流程往凍結的AIDL中加接口 AIDL的原理 可以利用ALDL定義客戶端與服務均認可的編程接口&#xff0c;以便二者使用進程間通信 (IPC) 進行相互通信。在 Android 中&#xff0c;一個進程通常無法訪問另一個進程的內存。因此&#xff0c;為進行通信&a…

卷積神經網絡(AlexNet)鳥類識別

文章目錄 一、前言二、前期工作1. 設置GPU&#xff08;如果使用的是CPU可以忽略這步&#xff09;2. 導入數據3. 查看數據 二、數據預處理1. 加載數據2. 可視化數據3. 再次檢查數據4. 配置數據集 三、AlexNet (8層&#xff09;介紹四、構建AlexNet (8層&#xff09;網絡模型五、…

微信小程序image組件圖片設置最大寬度 寬高自適應

問題描述&#xff1a;在使用微信小程序image組件的時候&#xff0c;在不確定圖片寬高情況下 想給一個最大寬度讓圖片自適應&#xff0c;按比例&#xff0c;image的widthfiex和heightFiex并不能滿足&#xff08;只指定最大寬/高并不會生效&#xff09; 問題解決&#xff1a;使用…

居家適老化設計第二十九條---衛生間之花灑

無電源 燈光顯示 無障礙扶手型花灑 以上產品圖片均來源于淘寶 侵權聯系刪除 居家適老化衛生間的花灑通常具有以下特點和功能&#xff1a;1. 高度可調節&#xff1a;適老化衛生間花灑可通過調節高度&#xff0c;滿足不同身高的老年人使用需求&#xff0c;避免彎腰或過高伸展造…

【開源】基于Vue.js的固始鵝塊銷售系統

項目編號&#xff1a; S 060 &#xff0c;文末獲取源碼。 \color{red}{項目編號&#xff1a;S060&#xff0c;文末獲取源碼。} 項目編號&#xff1a;S060&#xff0c;文末獲取源碼。 目錄 一、摘要1.1 項目介紹1.2 項目錄屏 二、功能模塊2.1 數據中心模塊2.2 鵝塊類型模塊2.3 固…

qgis添加xyz柵格瓦片

方式1&#xff1a;手動一個個添加 左側瀏覽器-XYZ Tiles-右鍵-新建連接 例如添加高德瓦片地址 https://wprd01.is.autonavi.com/appmaptile?langzh_cn&size1&style7&x{x}&y{y}&z{z} 雙擊即可呈現 收集到的一些圖源&#xff0c;僅供參考&#xff0c;其中一…

【C++學習手札】模擬實現list

? &#x1f3ac;慕斯主頁&#xff1a;修仙—別有洞天 ??今日夜電波&#xff1a;リナリア—まるりとりゅうが 0:36━━━━━━?&#x1f49f;──────── 3:51 &#x1f504; ?? ? ??…

聊聊httpclient的staleConnectionCheckEnabled

序 本文主要研究一下httpclient的staleConnectionCheckEnabled staleConnectionCheckEnabled org/apache/http/client/config/RequestConfig.java public class RequestConfig implements Cloneable {public static final RequestConfig DEFAULT new Builder().build();pr…

【ARM 嵌入式 編譯 Makefile 系列 18 -- Makefile 中的 export 命令詳細介紹】

文章目錄 Makefile 中的 export 命令詳細介紹Makefile 使用 export導出與未導出變量的區別示例&#xff1a;導出變量以供子 Makefile 使用 Makefile 中的 export 命令詳細介紹 在 Makefile 中&#xff0c;export 命令用于將變量從 Makefile 導出到由 Makefile 啟動的子進程的環…