從代碼學習深度強化學習 - 目標導向的強化學習-HER算法 PyTorch版

文章目錄

  • 1. 前言:當一個任務有多個目標
  • 2. 目標導向的強化學習 (GoRL) 簡介
  • 3. HER算法:化失敗為成功的智慧
  • 4. 代碼實踐:用PyTorch實現HER+DDPG
    • 4.1 自定義環境 (WorldEnv)
    • 4.2 智能體與算法 (DDPG)
    • 4.3 HER的核心:軌跡經驗回放
    • 4.4 主流程與訓練
  • 5. 訓練結果與分析
  • 6. 總結


1. 前言:當一個任務有多個目標

經典的深度強化學習算法,如 PPO、SAC 等,在各自擅長的任務中都取得了非常好的效果。但它們通常都局限在解決單個任務上,換句話說,訓練好的算法,在運行時也只能完成一個特定的任務。

想象一個場景:我們想讓一個機械臂能把桌子上的任何一個物體重置到任意一個指定位置。對于傳統強化學習而言,如果目標物體的初始位置和目標位置每次都變化,那么這就是一個全新的任務。即便任務的“格式”——抓取并移動——是一樣的,但策略本身可能需要重新訓練。這顯然效率極低。

為了解決這類問題,目標導向的強化學習 (Goal-Oriented Reinforcement Learning, GoRL) 應運而生。它的核心思想是學習一個通用策略,這個策略能夠根據給定的目標 (goal) 來執行相應的動作,從而用一個模型解決一系列結構相同但目標不同的復雜任務。

然而,在諸如機械臂抓取等真實場景中,獎勵往往是稀疏的。只有當機械臂成功將物體放到指定位置時,才會獲得正獎勵,否則獎勵一直為0或-1。在訓練初期,智能體很難通過隨機探索完成任務并獲得獎勵,導致學習效率極低。

為了解決稀疏獎勵下的學習難題,OpenAI 在2017年提出了事后經驗回放 (Hindsight Experience Replay, HER) 算法。HER 的思想極為巧妙:即使我們沒有完成預設的目標,但我們總歸是完成了“某個”目標。 通過這種“事后諸葛亮”的方式,將失敗的經驗轉化為成功的學習樣本,從而極大地提升了在稀疏獎勵環境下的學習效率。

本文將從 HER 的基本概念出發,結合一個完整的 PyTorch 代碼實例,帶你深入理解 HER 是如何與 DDPG 等經典算法結合,并有效解決目標導向的強化學習問題的。

完整代碼:下載鏈接

2. 目標導向的強化學習 (GoRL) 簡介

在目標導向的強化學習中,傳統的馬爾可夫決策過程 (MDP) 被擴展了。除了狀態 S、動作 A、轉移概率 P 之外,還引入了目標空間 G。策略 π 不僅依賴于當前狀態 s,還依賴于目標 g,即 π(a|s, g)

獎勵函數 r 也與目標相關,記為 r_g。在本文的設定中,狀態 s 包含了智能體自身的信息(例如坐標),而目標 g 則是狀態空間中的一個特定子集(例如一個目標坐標)。我們使用一個映射函數 φ 將狀態 s 映射到其對應的目標 g

在 GoRL 中,一個常見的挑戰是稀疏獎勵。例如,只有當智能體達到的狀態 s' 對應的目標 φ(s') 與我們期望的目標 g 足夠接近時,才給予獎勵。這可以用以下公式表示:

其中,δ_g 是一個很小的閾值。這意味著,在絕大多數情況下,智能體得到的獎勵都是-1,學習信號非常微弱。

3. HER算法:化失敗為成功的智慧

HER 的核心思想在于重新利用失敗的軌跡

假設智能體在一次任務(一個 episode)中,目標是 g,但最終沒有達到,整個軌跡獲得的獎勵都是-1。這條“失敗”的軌跡對于學習如何達到目標 g 幾乎沒有幫助。

但 HER 會這樣想:雖然智能體沒有達到目標 g,但它在軌跡的最后達到了某個狀態 s_T。這個狀態 s_T 自身就可以被看作是一個目標,我們稱之為“事后目標” g' = φ(s_T)。如果我們把這次任務的目標“篡改”為 g',那么這條軌跡就變成了一條成功的軌跡!因為智能體確實達到了 g'

通過這種方式,HER 能夠從任何軌跡中都提取出有價值的學習信號,將稀疏的獎勵變得稠密。

在具體實現時,HER 會從一條完整的軌跡中,隨機采樣一個時間步 (s_t, a_t, r_t, s_{t+1}),然后根據一定策略選擇一個新的目標 g' 來替換原始目標 g,并根據新目標重新計算獎勵 r'

HER 提出了幾種選擇新目標 g' 的策略,其中最常用也最直觀的是 future 策略:在當前時間步 t 之后,從該軌跡中隨機選擇一個未來狀態 s_k (k > t),將其對應的 φ(s_k) 作為新的目標 g'

這種方法保證了新目標是在當前狀態之后可以達到的,使得學習過程更加穩定和有效。

HER 作為一個通用的技巧,可以與任何 off-policy 的強化學習算法(如 DQN, DDPG, SAC)結合。在本文的實踐中,我們將它與 DDPG 算法相結合。

4. 代碼實踐:用PyTorch實現HER+DDPG

接下來,我們通過一個完整的 PyTorch 代碼項目來學習 HER 的實現。任務非常直觀:在一個二維平面上,智能體需要從原點 (0, 0) 移動到一個隨機生成的目標點。

4.1 自定義環境 (WorldEnv)

首先,我們定義一個簡單的二維世界環境。

  • 狀態空間: 4維向量 [agent_x, agent_y, goal_x, goal_y]
  • 動作空間: 2維向量 [move_x, move_y],每個分量的范圍是 [-1, 1]
  • 目標: 在每個 episode 開始時,在 [3.5, 4.5] x [3.5, 4.5] 區域內隨機生成一個目標點。
  • 獎勵: 如果智能體與目標的距離小于閾值 0.15,獎勵為 0;否則為 -1
  • 終止條件: 達到目標,或達到最大步數 50
# 自定義環境
import numpy as np
import random
from typing import Tupleclass WorldEnv:"""二維世界環境類,用于目標導向的強化學習任務智能體需要從起始位置移動到隨機生成的目標位置"""def __init__(self) -> None:"""初始化環境參數"""# 距離閾值,當智能體與目標的距離小于等于此值時認為任務完成 (標量)self.distance_threshold: float = 0.15# 動作邊界,限制每個動作分量的取值范圍為[-1, 1] (標量)self.action_bound: float = 1.0# 地圖邊界,智能體活動范圍為[0, 5] x [0, 5] (標量)self.map_bound: float = 5.0# 最大步數,防止無限循環 (標量)self.max_steps: int = 50# 當前狀態,智能體在二維平面上的坐標 (2維向量)self.state: np.ndarray = None# 目標位置,智能體需要到達的目標坐標 (2維向量)self.goal: np.ndarray = None# 當前步數計數器 (標量)self.count: int = 0def reset(self) -> np.ndarray:"""重置環境到初始狀態Returns:np.ndarray: 包含當前狀態和目標位置的觀測向量 (4維向量: [state_x, state_y, goal_x, goal_y])"""# 在目標區域[3.5, 4.5] x [3.5, 4.5]內隨機生成目標位置 (2維向量)goal_x = 4.0 + random.uniform(-0.5, 0.5)goal_y = 4.0 + random.uniform(-0.5, 0.5)self.goal = np.array([goal_x, goal_y])# 設置智能體初始位置為原點 (2維向量)self.state = np.array([0.0, 0.0])# 重置步數計數器 (標量)self.count = 0# 返回包含狀態和目標的觀測向量 (4維向量)return np.hstack((self.state, self.goal))def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool]:"""執行一個動作并返回下一個狀態、獎勵和是否結束Args:action (np.ndarray): 智能體的動作,包含x和y方向的移動量 (2維向量)Returns:Tuple[np.ndarray, float, bool]: - 下一個觀測狀態 (4維向量: [state_x, state_y, goal_x, goal_y])- 獎勵值 (標量)- 是否結束標志 (布爾值)"""# 將動作限制在有效范圍內[-action_bound, action_bound] (2維向量)action = np.clip(action, -self.action_bound, self.action_bound)# 計算執行動作后的新位置,并確保在地圖邊界內[0, map_bound] (標量)new_x = max(0.0, min(self.map_bound, self.state[0] + action[0]))new_y = max(0.0, min(self.map_bound, self.state[1] + action[1]))# 更新智能體位置 (2維向量)self.state = np.array([new_x, new_y])# 增加步數計數 (標量)self.count += 1# 計算當前位置與目標位置之間的歐幾里得距離 (標量)distance = np.sqrt(np.sum(np.square(self.state - self.goal)))# 計算獎勵:如果距離大于閾值則給予負獎勵-1.0,否則給予0獎勵 (標量)reward = -1.0 if distance > self.distance_threshold else 0.0# 判斷是否結束:距離足夠近或達到最大步數 (布爾值)if distance <= self.distance_threshold or self.count >= self.max_steps:done = Trueelse:done = False# 返回新的觀測狀態、獎勵和結束標志# 觀測狀態包含當前位置和目標位置 (4維向量)return np.hstack((self.state, self.goal)), reward, done

4.2 智能體與算法 (DDPG)

我們選擇 DDPG (深度確定性策略梯度) 作為基礎的 off-policy 算法。DDPG 包含一個 Actor (策略網絡) 和一個 Critic (Q值網絡),非常適合處理連續動作空間問題。

  • PolicyNet: Actor 網絡,輸入狀態 s(包含目標 g),輸出一個確定性的動作 a
  • QValueNet: Critic 網絡,輸入狀態 s 和動作 a,輸出該狀態-動作對的Q值。
  • DDPG: 算法主類,集成了 Actor 和 Critic,并包含目標網絡、優化器、軟更新和 update 邏輯。這里的實現是標準的 DDPG。
# 要訓練的智能體和采用的算法
import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict, Anyclass PolicyNet(torch.nn.Module):"""策略網絡(Actor網絡)用于輸出連續動作空間中的動作值"""def __init__(self, state_dim: int, hidden_dim: int, action_dim: int, action_bound: float) -> None:"""初始化策略網絡Args:state_dim (int): 狀態空間維度 (標量)hidden_dim (int): 隱藏層神經元數量 (標量)action_dim (int): 動作空間維度 (標量)action_bound (float): 動作邊界值,動作取值范圍為[-action_bound, action_bound] (標量)"""super(PolicyNet, self).__init__()# 第一個全連接層:狀態維度 -> 隱藏層維度self.fc1 = torch.nn.Linear(state_dim, hidden_dim)# 第二個全連接層:隱藏層維度 -> 隱藏層維度self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)# 輸出層:隱藏層維度 -> 動作維度(本環境中動作維度為2)self.fc3 = torch.nn.Linear(hidden_dim, action_dim)# 動作邊界,用于將輸出限制在有效范圍內 (標量)self.action_bound = action_bounddef forward(self, x: torch.Tensor) -> torch.Tensor:"""前向傳播計算動作輸出Args:x (torch.Tensor): 輸入狀態 (batch_size, state_dim)Returns:torch.Tensor: 輸出動作,范圍在[-action_bound, action_bound] (batch_size, action_dim)"""# 通過兩個隱藏層,使用ReLU激活函數 (batch_size, hidden_dim)x = F.relu(self.fc2(F.relu(self.fc1(x))))# 輸出層使用tanh激活函數,將輸出限制在[-1, 1],然后乘以action_bound# 得到范圍在[-action_bound, action_bound]的動作 (batch_size, action_dim)return torch.tanh(self.fc3(x)) * self.action_boundclass QValueNet(torch.nn.Module):"""Q值網絡(Critic網絡)用于評估給定狀態和動作的Q值"""def __init__(

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/94057.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/94057.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/94057.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

前端 H5分片上傳 vue實現大文件

用uniapp開發APP上傳視頻文件&#xff0c;大文件可以上傳成功&#xff0c;但是一旦打包為H5的代碼&#xff0c;就會一提示鏈接超時&#xff0c;我的代碼中是實現的上傳到阿里云 如果需要看全文的私信我 官方開發文檔地址 前端&#xff1a;使用分片上傳的方式上傳大文件_對象…

Linux服務器Systemctl命令詳細使用指南

目錄 1. 基本語法 2. 基礎命令速查表 3. 常用示例 3.1 部署新服務后&#xff0c;設置開機自啟并啟動 3.2 檢查系統中所有失敗的服務并嘗試修復 3.3 查看系統中所有開機自啟的服務 4. 總結 以下是 systemctl 使用指南&#xff0c;涵蓋服務管理、單元操作、運行級別控制、…

【JVM內存結構系列】二、線程私有區域詳解:程序計數器、虛擬機棧、本地方法棧——搞懂棧溢出與線程隔離

上一篇文章我們搭建了JVM內存結構的整體框架,知道程序計數器、虛擬機棧、本地方法棧屬于“線程私有區域”——每個線程啟動時會單獨分配內存,線程結束后內存直接釋放,無需GC參與。這三個區域看似“小眾”,卻是理解線程執行邏輯、排查棧溢出異常的關鍵,也是面試中高頻被問的…

紅帽認證升級華為openEuler證書活動!

如果您有紅帽證書&#xff0c;可以升級以下相應的證書&#xff1a;&#x1f447; 有RHCSA證書&#xff0c;可以99元升級openEuler HCIA 有RHCE證書&#xff0c;可以99元升級openEuler HCIP 有RHCA證書&#xff0c;可以2100元升級openEuler HCIE 現金激勵&#xff1a;&#x1f4…

迭代器模式與幾個經典的C++實現

迭代器模式詳解1. 定義與意圖迭代器模式&#xff08;Iterator Pattern&#xff09; 是一種行為設計模式&#xff0c;它提供一種方法順序訪問一個聚合對象中的各個元素&#xff0c;而又不暴露該對象的內部表示。主要意圖&#xff1a;為不同的聚合結構提供統一的遍歷接口。將遍歷…

epoll 陷阱:隧道中的高級負擔

上周提到了 tun/tap 轉發框架的數據通道結構和優化 tun/tap 轉發性能優化&#xff0c;涉及 RingBuffer&#xff0c;packetization 等核心話題。我也給出了一定的數據結構以及處理邏輯&#xff0c;但竟然沒有高尚的 epoll&#xff0c;本文說說它&#xff0c;因為它不適合。 epo…

微前端架構常見框架

1. iframe 這里指的是每個微應用獨立開發部署,通過 iframe 的方式將這些應用嵌入到父應用系統中,幾乎所有微前端的框架最開始都考慮過 iframe,但最后都放棄,或者使用部分功能,原因主要有: url 不同步。瀏覽器刷新 iframe url 狀態丟失、后退前進按鈕無法使用。 UI 不同…

SQL Server更改日志模式:操作指南與最佳實踐!

全文目錄&#xff1a;開篇語**前言****摘要****概述&#xff1a;SQL Server 的日志模式****日志模式的作用****三種日志模式**1. **簡單恢復模式&#xff08;Simple&#xff09;**2. **完整恢復模式&#xff08;Full&#xff09;**3. **大容量日志恢復模式&#xff08;Bulk-Log…

git的工作使用中實際經驗

老輸入煩人的密碼 每次我git pull的時候都要叫我輸入三次煩人的密碼&#xff0c;問了deepseek也沒有嘗試成功 出現 enter passphrase for key ‘~/.ssh/id_rsa’ 的原因: 在生成key的時候,沒有注意,不小心設置了密碼, 導致每次提交的時候都會提示要輸入密碼, 也就是上面的提示…

科技賦能,寧夏農業繪就塞上新“豐”景

在賀蘭山的巍峨身影下&#xff0c;在黃河水的溫柔滋養中&#xff0c;寧夏這片古老而神奇的土地&#xff0c;正借助農業科技的磅礴力量&#xff0c;實現從傳統農耕到智慧農業的華麗轉身&#xff0c;奏響一曲科技與自然和諧共生的壯麗樂章。一、數字農業&#xff1a;開啟智慧種植…

imx6ull-驅動開發篇36——Linux 自帶的 LED 燈驅動實驗

在之前的文章里&#xff0c;我們掌握了無設備樹和有設備樹這兩種 platform 驅動的開發方式。但實際上有現成的&#xff0c;Linux 內核的 LED 燈驅動采用 platform 框架&#xff0c;我們只需要按照要求在設備樹文件中添加相應的 LED 節點即可。本講內容&#xff0c;我們就來學習…

深度學習中主流激活函數的數學原理與PyTorch實現綜述

1. 定義與作用什么是激活函數&#xff1f;激活函數有什么用&#xff1f;答&#xff1a;激活函數&#xff08;Activation Function&#xff09;是一種添加到人工神經網絡中的函數&#xff0c;旨在幫助網絡學習數據中的復雜模式。類似于人類大腦中基于神經元的模型&#xff0c;激…

Linux高效備份:rsync + inotify實時同步

一、rsync 簡介 rsync&#xff08;Remote Sync&#xff09;是 Linux 系統下的數據鏡像備份工具&#xff0c;支持本地復制、遠程同步&#xff08;通過 SSH 或 rsync 協議&#xff09;&#xff0c;是一個快速、安全、高效的增量備份工具。二、rsync 特性 支持鏡像保存整個目錄樹和…

一種通過模板輸出Docx的方法

起因在2個群里都有網友討論這個問題&#xff0c;俺就寫了一個最簡單的例子。其實&#xff0c;我們經常遇到一些Docx的輸出的需求&#xff0c;“用模板文件進行處理”是最簡單的一個方法&#xff0c;如果想預覽也簡單 DevExpress 、Teleric 都可以&#xff0c;而且也支持 Web 、…

探索 List 的奧秘:自己動手寫一個 STL List?

&#x1f4d6;引言大家好&#xff01;今天我們要一起來揭開 C 中 list 容器的神秘面紗——不是直接用 STL&#xff0c;而是親手實現一個簡化版的 list&#xff01;&#x1f389;你是不是曾經好奇過&#xff1a;list 是怎么做到高效插入和刪除的&#xff1f;&#x1f50d;迭代器…

mysql占用高內存排查與解決

mysql占用高內存排查-- 查看當前全局內存使用情況&#xff08;需要啟用 performance_schema&#xff09; SELECT * FROM sys.memory_global_total; -- 查看總內存使用 SELECT * FROM sys.memory_global_by_current_bytes LIMIT 10; -- 按模塊分類查看內存使用排行memory/perfor…

構建真正自動化知識工作的AI代理

引言&#xff1a;新一代生產力范式的黎明 自動化知識工作的人工智能代理&#xff08;AI Agent&#xff09;&#xff0c;或稱“智能體”&#xff0c;正迅速從理論構想演變為重塑各行各業生產力的核心引擎。這些AI代理被定義為能夠感知環境、進行自主決策、動態規劃、調用工具并持…

青少年機器人技術(四級)等級考試試卷-實操題(2021年12月)

更多內容和歷年真題請查看網站&#xff1a;【試卷中心 -----> 電子學會 ----> 機器人技術 ----> 四級】 網站鏈接 青少年軟件編程歷年真題模擬題實時更新 青少年機器人技術&#xff08;四級&#xff09;等級考試試卷-實操題&#xff08;2021年12月&#xff09; …

最新短網址源碼,防封。支持直連、跳轉。 會員無廣

最新短網址源碼&#xff0c;防封。支持直連、跳轉。 會員無廣告1.可將長網址自動縮短為短網址&#xff0c;方便記憶和使用。2.短網址默認為臨時有效&#xff0c;可付費升級為永久有效&#xff0c;接入支付后可自動完成&#xff0c;無需人工操作。3.系統支持設置圖片/文字/跳轉頁…

緩存-變更事件捕捉、更新策略、本地緩存和熱key問題

緩存-基礎知識 熟悉計算機基礎的同學們都知道&#xff0c;服務的存儲大多是多層級的&#xff0c;呈現金字塔類型。通常來說本機存儲比通過網絡通信的外部存儲更快&#xff08;現在也不一定了&#xff0c;因為網絡傳輸速度很快&#xff0c;至少可以比一些過時的本地存儲設備速度…