PyTorch深度學習實戰(46)——深度Q學習

PyTorch深度學習實戰(46)——深度Q學習

    • 0. 前言
    • 1. 深度 Q 學習
    • 2. 網絡架構
    • 3. 實現深度 Q 學習模型進行 CartPole 游戲
    • 小結
    • 系列鏈接

0. 前言

我們已經學習了如何構建一個 Q 表,通過在多個 episode 中重復進行游戲獲取與給定狀態-動作組合相對應的值。然而,當狀態空間是連續時,可能的狀態空間數會變得非常巨大。在本節中,我們將學習如何使用神經網絡在沒有 Q 表的情況下估計狀態-動作組合的 Q 值,因此稱為深度 Q 學習 (deep Q-learning)。

1. 深度 Q 學習

與 Q 表相比,深度 Q 學習利用神經網絡將任意給定的狀態-動作(其中狀態可以是連續或離散的)組合映射到相應 Q 值。
在本節中,將使用 Gym 中的 CartPole 環境,智能體的任務是盡可能長時間地平衡 CartPoleCartPole 環境如下圖所示:

CartPole-v0

當小車向右移動時,桿向左移動,反之亦然,CartPole 環境中的每個狀態都由四個觀測值定義,其名稱及其最小值和最大值如下:

狀態最小值最大值
Cart position-2.42.4
Cart velocity-infinf
Pole angle-41.8°41.8°
Pole velocity at the tip-infinf

需要注意的是,表示狀態的所有觀測值都具有連續值,用于 CartPole 平衡游戲的深度 Q 學習的工作原理如下:

  1. 獲取輸入值(游戲圖像/游戲元數據)
  2. 通過網絡傳遞輸入值,網絡的輸出與可能的動作數相同
  3. 輸出層預測在給定狀態下采取某個動作對應的 Q 值

2. 網絡架構

網絡架構使用狀態(四個觀測值)作為輸入,在當前狀態下采取左/右動作的 Q 值作為輸出。神經網絡訓練策略如下:

  1. 在探索階段,執行輸出層中具有最高值的隨機動作
  2. 將動作、下一個狀態、獎勵和指示游戲是否完成的標志存儲在內存中
  3. 如果游戲沒有完成,計算在給定狀態下采取行動的 Q 值,即獎勵 + 折扣因子 x 下一個狀態中所有動作的最大可能 Q 值
  4. 修改采取動作的Q值,而其他狀態-動作組合的 Q 值保持不變
  5. 多次執行步驟 14 并存儲經驗
  6. 擬合模型,將狀態作為輸入,動作值作為預期輸出(來自內存和回放經驗),并最小化 MSE 損失
  7. 在降低探索率的同時在多個 episode 上重復上述步驟

3. 實現深度 Q 學習模型進行 CartPole 游戲

根據以上策略,使用 PyTorch 編寫深度 Q 學習模型,進行 CartPole 游戲。

(1) 導入相關庫:

import gym
import numpy as np
import cv2
from collections import deque
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from collections import namedtuple, deque
import torch
import torch.nn.functional as F
import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

(2) 定義環境:

env = gym.make('CartPole-v1')

(3) 定義網絡架構:

class DQNetwork(nn.Module):def __init__(self, state_size, action_size):super(DQNetwork, self).__init__()self.fc1 = nn.Linear(state_size, 24)self.fc2 = nn.Linear(24, 24)self.fc3 = nn.Linear(24, action_size)def forward(self, state):       x = F.relu(self.fc1(state))x = F.relu(self.fc2(x))x = self.fc3(x)return x

該架構在兩個隱藏層中僅包含 24 個單元,輸出層包含與可能動作數相同的單元。

(4) 定義 Agent 類。

定義 __init__ 方法,其中包含各種參數、網絡的定義:

class Agent():def __init__(self, state_size, action_size):self.state_size = state_sizeself.action_size = action_sizeself.seed = random.seed(0)## hyperparametersself.buffer_size = 2000self.batch_size = 64self.gamma = 0.99self.lr = 0.0025self.update_every = 4 # Q-Networkself.local = DQNetwork(state_size, action_size).to(device)self.optimizer = optim.Adam(self.local.parameters(), lr=self.lr)# Replay memoryself.memory = deque(maxlen=self.buffer_size) self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])self.t_step = 0

定義 step 函數,該函數從內存中獲取數據并通過調用 learn 函數將其擬合到模型中:

    def step(self, state, action, reward, next_state, done):# Save experience in replay memoryself.memory.append(self.experience(state, action, reward, next_state, done)) # Learn every update_every time steps.self.t_step = (self.t_step + 1) % self.update_everyif self.t_step == 0:# If enough samples are available in memory, get random subset and learnif len(self.memory) > self.batch_size:experiences = self.sample_experiences()self.learn(experiences, self.gamma)

定義 act 函數,該函數在給定狀態的情況下預測動作:

    def act(self, state, eps=0.):# Epsilon-greedy action selectionif random.random() > eps:state = torch.from_numpy(state).float().unsqueeze(0).to(device)self.local.eval()with torch.no_grad():action_values = self.local(state)self.local.train()return np.argmax(action_values.cpu().data.numpy())else:return random.choice(np.arange(self.action_size))

在以上代碼中,我們在確定要采取的行動時使用探索-利用策略。

定義 learn 函數用于擬合模型,使其在給定狀態時預測動作值:

    def learn(self, experiences, gamma): states, actions, rewards, next_states, dones = experiences# Get expected Q values from local modelQ_expected = self.local(states).gather(1, actions)# Get max predicted Q values (for next states) from local modelQ_targets_next = self.local(next_states).detach().max(1)[0].unsqueeze(1)# Compute Q targets for current states Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))# Compute lossloss = F.mse_loss(Q_expected, Q_targets)# Minimize the lossself.optimizer.zero_grad()loss.backward()self.optimizer.step()

在以上代碼中,獲取采樣經驗并預測我們執行的動作的 Q 值。此外,由于我們已經知道下一個狀態,可以預測下一個狀態下動作的最佳 Q 值。因此,我們可以得到與在給定狀態下采取的動作相對應的目標值。最后,計算在當前狀態下采取的動作的 Q 值的期望值 (Q_targets) 和預測值 (Q_expected) 之間的誤差。

定義 sample_experiences 函數以便從內存中采樣經驗:

    def sample_experiences(self):experiences = random.sample(self.memory, k=self.batch_size)        states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)        return (states, actions, rewards, next_states, dones)

(5) 定義智能體對象:

agent = Agent(env.observation_space.shape[0], env.action_space.n)

(6) 訓練模型。

初始化列表:

scores = [] # list containing scores from each episode
scores_window = deque(maxlen=100) # last 100 scores
n_episodes=5000
max_t=5000
eps_start=1.0
eps_end=0.001
eps_decay=0.9995
eps = eps_start

在每個 episode 中重置環境并獲取狀態的形狀,此外,整形狀態維度形狀,以便可以將其傳遞給網絡:

for i_episode in range(1, n_episodes+1):state = env.reset()state_size = env.observation_space.shape[0]state = np.reshape(state, [1, state_size])score = 0

循環通過 max_t 個時間步,確定要執行的動作,并使用 step 方法執行,使用 np.reshape 整形狀態張量,并將整形后的狀態傳遞給神經網絡:

    for i in range(max_t):action = agent.act(state, eps)next_state, reward, done, _ = env.step(action)next_state = np.reshape(next_state, [1, state_size])

通過指定 agent.step 在當前狀態之上擬合模型,并將狀態重置為下一個狀態,以便在下一次迭代中使用。

如果前 10 步的得分平均值大于 450,則存儲相關數據并停止訓練:

        reward = reward if not done or score == 499 else -10agent.step(state, action, reward, next_state, done)state = next_statescore += rewardif done:break scores_window.append(score) # save most recent score scores.append(score) # save most recent scoreeps = max(eps_end, eps_decay*eps) # decrease epsilonprint('\rEpisode {}\tReward {} \tAverage Score: {:.2f} \tEpsilon: {}'.format(i_episode,score,np.mean(scores_window), eps), end="")if i_episode % 100 == 0:print('\rEpisode {}\tAverage Score: {:.2f} \tEpsilon: {}'.format(i_episode, np.mean(scores_window), eps))if i_episode>10 and np.mean(scores[-10:])>450:break
"""
Episode 100     Average Score: 12.65 ge Epsilon: 0.951217530242334.9512175302423344
...
Episode 2700    Average Score: 116.56 e Epsilon: 0.259152752655221145915275265522114
Episode 2712    Reward 500.0    Average Score: 159.01   Epsilon: 0.2576021050410192
"""

(7) 繪制隨著 episode 的增加的分數變化情況如下:

import matplotlib.pyplot as plt
plt.plot(scores)
plt.title('Scores over increasing episodes')
plt.show()

請添加圖片描述

從上圖中可以看出,在第 2000episode 之后,該模型在進行 CartPole 游戲時能夠獲得較高分。

小結

深度 Q 學習是一種結合了深度學習和強化學習的方法,通過深度神經網絡逼近 Q 值函數,在解決大規模、連續狀態空間問題方面具有優勢,并在多個領域展示了強大的學習和決策能力。在本節中,介紹了深度 Q 學習的基本概念,并學習了如何使用 PyTorch 實現深度 Q 學習進行 CartPole 游戲。

系列鏈接

PyTorch深度學習實戰(1)——神經網絡與模型訓練過程詳解
PyTorch深度學習實戰(2)——PyTorch基礎
PyTorch深度學習實戰(3)——使用PyTorch構建神經網絡
PyTorch深度學習實戰(4)——常用激活函數和損失函數詳解
PyTorch深度學習實戰(5)——計算機視覺基礎
PyTorch深度學習實戰(6)——神經網絡性能優化技術
PyTorch深度學習實戰(7)——批大小對神經網絡訓練的影響
PyTorch深度學習實戰(8)——批歸一化
PyTorch深度學習實戰(9)——學習率優化
PyTorch深度學習實戰(10)——過擬合及其解決方法
PyTorch深度學習實戰(11)——卷積神經網絡
PyTorch深度學習實戰(12)——數據增強
PyTorch深度學習實戰(13)——可視化神經網絡中間層輸出
PyTorch深度學習實戰(14)——類激活圖
PyTorch深度學習實戰(15)——遷移學習
PyTorch深度學習實戰(16)——面部關鍵點檢測
PyTorch深度學習實戰(17)——多任務學習
PyTorch深度學習實戰(18)——目標檢測基礎
PyTorch深度學習實戰(19)——從零開始實現R-CNN目標檢測
PyTorch深度學習實戰(20)——從零開始實現Fast R-CNN目標檢測
PyTorch深度學習實戰(21)——從零開始實現Faster R-CNN目標檢測
PyTorch深度學習實戰(22)——從零開始實現YOLO目標檢測
PyTorch深度學習實戰(23)——從零開始實現SSD目標檢測
PyTorch深度學習實戰(24)——使用U-Net架構進行圖像分割
PyTorch深度學習實戰(25)——從零開始實現Mask R-CNN實例分割
PyTorch深度學習實戰(26)——多對象實例分割
PyTorch深度學習實戰(27)——自編碼器(Autoencoder)
PyTorch深度學習實戰(28)——卷積自編碼器(Convolutional Autoencoder)
PyTorch深度學習實戰(29)——變分自編碼器(Variational Autoencoder, VAE)
PyTorch深度學習實戰(30)——對抗攻擊(Adversarial Attack)
PyTorch深度學習實戰(31)——神經風格遷移
PyTorch深度學習實戰(32)——Deepfakes
PyTorch深度學習實戰(33)——生成對抗網絡(Generative Adversarial Network, GAN)
PyTorch深度學習實戰(34)——DCGAN詳解與實現
PyTorch深度學習實戰(35)——條件生成對抗網絡(Conditional Generative Adversarial Network, CGAN)
PyTorch深度學習實戰(36)——Pix2Pix詳解與實現
PyTorch深度學習實戰(37)——CycleGAN詳解與實現
PyTorch深度學習實戰(38)——StyleGAN詳解與實現
PyTorch深度學習實戰(39)——小樣本學習(Few-shot Learning)
PyTorch深度學習實戰(40)——零樣本學習(Zero-Shot Learning)
PyTorch深度學習實戰(41)——循環神經網絡與長短期記憶網絡
PyTorch深度學習實戰(42)——圖像字幕生成
PyTorch深度學習實戰(43)——手寫文本識別
PyTorch深度學習實戰(44)——基于 DETR 實現目標檢測
PyTorch深度學習實戰(45)——強化學習

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/45231.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/45231.shtml
英文地址,請注明出處:http://en.pswp.cn/web/45231.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Hypertable install of rhel6.0

1.rpm 安裝:(如果已存在,會提示沖突,使用--replacefiles) 1.1 編譯環境 安裝gcc gcc-c++ make cmake(在admin machine上,放置rpm包的文件里依次執行下面的語句): sudo rpm -ivh cpp-4.4.6-4.el6.x86_64.rpm --replacefiles sudo rpm -ivh libgcc-4.4.6-4.el6.x86_64.rp…

【學習筆記】無人機(UAV)在3GPP系統中的增強支持(十四)-無人機操控關鍵績效指標(KPI)框架

引言 本文是3GPP TR 22.829 V17.1.0技術報告,專注于無人機(UAV)在3GPP系統中的增強支持。文章提出了多個無人機應用場景,分析了相應的能力要求,并建議了新的服務級別要求和關鍵性能指標(KPIs)。…

第二證券:轉融通是什么意思?什么是轉融通?

轉融通,包含轉融資和轉融券,實質是借錢和借券。轉融通是指證券金融公司借入證券、籌得資金后,再轉借給證券公司,是一假貸聯絡,具體是指證券公司從符合要求的基金處理公司、保險公司、社保基金等組織出資者融券&#xf…

Python應用開發——30天學習Streamlit Python包進行APP的構建(15):優化性能并為應用程序添加狀態

Caching and state 優化性能并為應用程序添加狀態! Caching 緩存 Streamlit 為數據和全局資源提供了強大的緩存原語。即使從網絡加載數據、處理大型數據集或執行昂貴的計算,它們也能讓您的應用程序保持高性能。 本頁僅包含有關 st.cache_data API 的信息。如需深入了解緩…

技術成神之路:設計模式(六)策略模式

1.介紹 策略模式(Strategy Pattern)是一種行為型設計模式,它定義了一系列算法,封裝每一個算法,并使它們可以相互替換。策略模式使得算法的變化獨立于使用算法的客戶端。 2.主要作用 策略模式的主要作用是將算法或行為…

面試問題梳理:項目中防止配置中的密碼泄露-Jasypt

背景 想起面試的時候,面試官問我現在大家用Spring框架,數據庫、ES之類的密碼都是配置在配置文件中的,有很大的安全隱患,你有考慮過怎么解決嘛? 當時我回答是可以在項目啟動的過程中的命令行追加的方式,感覺…

Hello,World!(C++)

題目描述 編寫一個能夠輸出 Hello,World! 的程序。 提示&#xff1a; - 使用英文標點符號&#xff1b; Hello,World! 逗號后面沒有空格。 H 和 W 為大寫字母。 樣例 #1 樣例輸入 #1 無 樣例輸出 #1 Hello,World! &#xff08;1&#xff09; #include<bits/stdc.…

力扣題解( 讓字符串成為回文串的最少插入次數)

1312. 讓字符串成為回文串的最少插入次數 給你一個字符串 s &#xff0c;每一次操作你都可以在字符串的任意位置插入任意字符。 請你返回讓 s 成為回文串的 最少操作次數 。 「回文串」是正讀和反讀都相同的字符串。 思路&#xff1a; 本題要求的是最少插入次數&#xff0c;…

什么叫圖像的雙邊濾波,并附利用OpenCV和MATLB實現雙邊濾波的代碼

雙邊濾波&#xff08;Bilateral Filtering&#xff09;是一種在圖像處理中常用的非線性濾波技術&#xff0c;主要用于去噪和保邊。它在空間域和像素值域上同時進行加權&#xff0c;既考慮了像素之間的空間距離&#xff0c;也考慮了像素值之間的相似度&#xff0c;從而能夠有效地…

手機怎么看WiFi的IP地址

在如今數字化快速發展的時代&#xff0c;無線網絡已成為我們日常生活中不可或缺的一部分。無論是工作、學習還是娛樂&#xff0c;我們可能都離不開WiFi的陪伴。然而&#xff0c;在使用WiFi的過程中&#xff0c;有時我們可能需要查看其IP地址&#xff0c;以便更好地管理我們的網…

【動態規劃】背包問題 {01背包問題;完全背包問題;二維費用背包問題}

一、背包問題概述 背包問題(Knapsackproblem)是?種組合優化的NP完全問題。 問題可以描述為&#xff1a;給定一組物品&#xff0c;每種物品都有自己的重量和價格&#xff0c;在限定的總重量內&#xff0c;我們如何選擇&#xff0c;才能使得物品的總價格最?。 根據物品的個數…

鏈接追蹤系列-07.logstash安裝json_lines插件

進入docker中的logstash 容器內&#xff1a; jelexbogon ~ % docker exec -it 7ee8960c99a31e607f346b2802419b8b819cc860863bc283cb7483bc03ba1420 /bin/sh $ pwd /usr/share/logstash $ ls bin CONTRIBUTORS Gemfile jdk logstash-core modules tools x-pack …

語音識別概述

語音識別概述 一.什么是語音&#xff1f; 語音是語言的聲學表現形式&#xff0c;是人類自然的交流工具。 圖片來源&#xff1a;https://www.shenlanxueyuan.com/course/381 二.語音識別的定義 語音識別&#xff08;Automatic Speech Recognition, ASR 或 Speech to Text, ST…

基于RAG大模型的變電站智慧運維-第十屆Nvidia Sky Hackathon參賽作品

第十屆Nvidia Sky Hackathon參賽作品 1. 項目說明 變電站是用于變電的設施&#xff0c;主要的作用是將電壓轉化&#xff0c;使電能在輸電線路中能夠長距離傳輸。在電力系統中&#xff0c;變電站起到了極為重要的作用&#xff0c;它可以完成電能的負荷分配、電壓的穩定、容錯保…

電影購票小程序論文(設計)開題報告

一、課題的背景和意義 隨著互聯網技術的不斷發展&#xff0c;人們對于購票的需求也越來越高。傳統的購票方式存在著排隊時間長、購票流程繁瑣等問題&#xff0c;而網上購票則能夠有效地解決這些問題。電影購票小程序是網上購票的一種新型應用&#xff0c;它能夠讓用戶隨時隨地…

06.截斷文本 選擇任何鏈接 :root 和 html 有什么區別

截斷文本 對超過一行的文本進行截斷,在末尾添加省略號(…)。 使用 overflow: hidden 防止文本超出其尺寸。使用 white-space: nowrap 防止文本超過一行高度。使用 text-overflow: ellipsis 使得如果文本超出其尺寸,將以省略號結尾。為元素指定固定的 width,以確定何時顯示省略號…

Selenium WebDriver中的顯式等待與隱式等待:深入理解與應用

在自動化測試中&#xff0c;尤其是在使用Selenium WebDriver進行Web應用的自動化測試時&#xff0c;等待元素加載完成是一個常見的需求。Selenium提供了兩種等待機制來處理這一問題&#xff1a;顯式等待&#xff08;Explicit Wait&#xff09;和隱式等待&#xff08;Implicit W…

筆記 4 :linux 0.11 中繼續分析 0 號進程創建一號進程的 fork () 函數

&#xff08;27&#xff09;本條目開始&#xff0c; 開始分析 copy_process () 函數&#xff0c;其又會調用別的函數&#xff0c;故先分析別的函數。 get_free_page &#xff08;&#xff09; &#xff1b; 先 介紹匯編指令 scasb &#xff1a; 以及 指令 sstosd &#xff1a;…

什么是架構設計師?定義、職責和任務,全方位解析需要具備的專業素質

目錄 1. 架構設計師的定義 2. 架構設計師的職責和任務 2.1 系統架構設計 2.1.1 模塊劃分 2.1.2 接口設計 2.1.3 通信方式 2.2 技術選型與決策 2.2.1 技術評估 2.2.2 技術選型 2.2.3 技術決策 2.3 性能優化與調優 2.3.1 性能分析 2.3.2 性能優化 2.3.3 性能調優 …

基于BitMap的工作日間隔計算

背景問題 在我們實際開發過程中&#xff0c;時常會遇到日期的間隔計算&#xff0c;即計算多少工作日之后的日期&#xff0c;在不考慮法定節假日的情況下也不是那么復雜&#xff0c;畢竟周六、周日是相對固定的&#xff0c;Java語言也提供了豐富的類來處理此問題。 然而&#x…