目標導向的強化學習:問題定義與 HER 算法詳解—強化學習(19)

目錄

1、目標導向的強化學習:問題定義

1.1、 核心要素與符號定義

1.2、?核心問題:稀疏獎勵困境

1.3、?學習目標

2、HER(Hindsight Experience Replay)算法

2.1、?HER 的核心邏輯

2.2、?算法步驟(結合 DDPG 舉例)

2.2.1、步驟 1:收集原始經驗

2.2.2、步驟 2:重構經驗(核心!)

2.2.3、步驟 3:替代目標生成策略

2.2.4、步驟 4:策略更新

2.3、?為什么 HER 有效?

2.4、公式總結

3、通俗理解

4、完整代碼

5、實驗結果??


1、目標導向的強化學習:問題定義

目標導向的強化學習(Goal-Conditioned Reinforcement Learning)是一類讓智能體通過學習策略,從初始狀態達到特定目標的任務。與傳統強化學習不同,這類任務的核心是 “目標”—— 智能體的行為需圍繞 “達成目標” 展開,而目標本身可能隨任務變化(如 “機械臂抓取 A 物體”“機械臂抓取 B 物體” 是兩個不同目標)。

1.1、 核心要素與符號定義

  • 狀態(State):環境的觀測信息,記為s \in \mathcal{S}\mathcal{S}是狀態空間)。例如:機械臂的關節角度、物體的坐標。
  • 目標(Goal):智能體需要達成的狀態,記為g \in \mathcal{G}\mathcal{G}是目標空間,通常與狀態空間重合或相關)。例如:機械臂需抓取的物體坐標。
  • 動作(Action):智能體的行為,記為?a \in \mathcal{A}\mathcal{A}?是動作空間)。例如:機械臂關節的旋轉角度。
  • 轉移函數:狀態 - 動作對到下一狀態的映射,記為?s' \sim P(s' | s, a)(P?是狀態轉移概率)。
  • 獎勵函數:衡量 “當前狀態與目標的差距”,記為r(s, a, g)。目標導向任務的獎勵通常僅與 “狀態是否接近目標” 相關,與動作間接相關。

1.2、?核心問題:稀疏獎勵困境

目標導向任務的獎勵函數通常是稀疏的:僅當狀態?s?與目標?g?幾乎一致時,才給予正獎勵;否則獎勵為 0 或負值。 獎勵函數示例(機械臂抓取任務):

  • 智能體在絕大多數嘗試中(如 99% 的交互)都得不到正獎勵,無法判斷 “哪些動作有助于接近目標”;
  • 策略更新缺乏有效信號(梯度難以計算),學習效率極低,甚至無法收斂。

1.3、?學習目標

目標導向強化學習的目標是學習一個目標條件策略?\pi(a | s, g),使得在策略引導下,智能體從任意初始狀態?\(s_0\)?出發,通過執行動作序列a_0, a_1, ..., a_T,最終達到目標?g?的概率最大化。

2、HER(Hindsight Experience Replay)算法

HER 算法是解決目標導向任務中稀疏獎勵問題的經典方法,核心思想是:從 “失敗經驗” 中 “事后重構” 有效獎勵信號—— 即使智能體沒達成原定目標,也能通過修改目標,將 “失敗軌跡” 轉化為 “成功軌跡”,從而提取學習信號

2.1、?HER 的核心邏輯

假設智能體在一次交互中,原定目標是?g,但實際軌跡為\tau = (s_0, a_0, s_1, a_1, ..., s_T),最終狀態?s_T \neq g(失敗)。 HER 的關鍵操作是:從軌跡?\tau中選一個狀態?s_k作為 “替代目標”?\hat{g} = s_k,此時軌跡?\tau對于新目標?\\hat{g}?是 “成功的”(因為?s_T可能接近?\hat{g}),從而可計算有效獎勵。

2.2、?算法步驟(結合 DDPG 舉例)

HER 通常與離線強化學習算法(如 DDPG)結合使用,流程如下:

2.2.1、步驟 1:收集原始經驗

智能體與環境交互,收集軌跡并存儲到經驗回放池?\mathcal{D}。每條經驗是一個五元組:e = (s_t, a_t, r_t, s_{t+1}, g)其中r_t = r(s_t, a_t, g)是基于原定目標?g?的獎勵(可能為 0)。

2.2.2、步驟 2:重構經驗(核心!)

對回放池中的每條原始經驗?e,HER 通過替代目標生成策略選一個新目標?\hat{g},重構出一條 “虛擬成功經驗”?\hat{e}\hat{e} = (s_t, a_t, \hat{r}_t, s_{t+1}, \hat{g})?其中\hat{r}_t = r(s_t, a_t, \hat{g})?是基于新目標?\hat{g}?的獎勵(此時可能為正,因為?\hat{g}來自軌跡,s_{t+1}?可能接近?\hat{g})。

2.2.3、步驟 3:替代目標生成策略

HER 定義了 4 種常用的替代目標生成策略(以軌跡?\tau = (s_0, s_1, ..., s_T)?為例):

  • Final\hat{g} = s_T(選最終狀態);
  • Future\hat{g} = s_k,其中?k \sim \text{Uniform}(t+1, T)(選未來狀態);
  • Random\hat{g} = s_k,其中?k \sim \text{Uniform}(0, T)(隨機選一個狀態);
  • Episode\hat{g}?從同回合的其他軌跡中隨機選一個狀態(適用于多目標任務)。

2.2.4、步驟 4:策略更新

將原始經驗?e?和重構經驗\hat{e}?一起放入回放池,用離線算法(如 DDPG)更新策略。 以 DDPG 的 Critic 網絡更新為例:

2.3、?為什么 HER 有效?

  • 解決稀疏性:通過重構經驗,將 “0 獎勵” 轉化為 “正獎勵”,使獎勵信號密集化;
  • 利用失敗經驗:原本無用的失敗軌跡被轉化為有效學習樣本,提高數據利用率;
  • 通用兼容:HER 是 “經驗回放增強技術”,可與 DDPG、SAC 等多種算法結合,無需修改算法核心。

2.4、公式總結

  • 原始經驗:e = (s_t, a_t, r(s_t, a_t, g), s_{t+1}, g)
  • 重構經驗:\hat{e} = (s_t, a_t, r(s_t, a_t, \hat{g}), s_{t+1}, \hat{g}),其中?\hat{g} \in \{s_0, s_1, ..., s_T\}
  • 策略目標:\max_{\pi} \mathbb{E}_{\pi} \left[ \sum_{t=0}^T \gamma^t r(s_t, a_t, g) \right]

3、通俗理解

用 “快遞員送貨” 舉例:

  • 目標導向任務:快遞員(智能體)需要把包裹送到目標地址?g(原定目標),但只有送到?g?才有錢(獎勵),中途迷路(失敗)則沒錢。
  • 稀疏獎勵問題:快遞員第一次送陌生地址,99% 的概率找不到,長期沒錢,不知道往哪走。
  • HER 的做法:快遞員雖然沒到?g,但路過了?\hat{g}(比如某個超市),就把 “送到超市” 當作新目標,這次 “成功” 能拿到錢,從而學會 “如何到超市”;多次積累后,就能掌握城市路線,最終學會到任意目標?g?的方法。

4、完整代碼

"""
文件名: 19.1
作者: 墨塵
日期: 2025/7/25
項目名: d2l_learning
備注: 
"""
import torch
import torch.nn.functional as F
import numpy as np
import random
from tqdm import tqdm
import collections
import matplotlib.pyplot as pltclass WorldEnv:def __init__(self):self.distance_threshold = 0.15self.action_bound = 1def reset(self):  # 重置環境# 生成一個目標狀態, 坐標范圍是[3.5~4.5, 3.5~4.5]self.goal = np.array([4 + random.uniform(-0.5, 0.5), 4 + random.uniform(-0.5, 0.5)])self.state = np.array([0, 0])  # 初始狀態self.count = 0return np.hstack((self.state, self.goal))def step(self, action):action = np.clip(action, -self.action_bound, self.action_bound)x = max(0, min(5, self.state[0] + action[0]))y = max(0, min(5, self.state[1] + action[1]))self.state = np.array([x, y])self.count += 1dis = np.sqrt(np.sum(np.square(self.state - self.goal)))reward = -1.0 if dis > self.distance_threshold else 0if dis <= self.distance_threshold or self.count == 50:done = Trueelse:done = Falsereturn np.hstack((self.state, self.goal)), reward, done
class PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim, action_bound):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)self.fc3 = torch.nn.Linear(hidden_dim, action_dim)self.action_bound = action_bound  # action_bound是環境可以接受的動作最大值def forward(self, x):x = F.relu(self.fc2(F.relu(self.fc1(x))))return torch.tanh(self.fc3(x)) * self.action_boundclass QValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(QValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)self.fc3 = torch.nn.Linear(hidden_dim, 1)def forward(self, x, a):cat = torch.cat([x, a], dim=1)  # 拼接狀態和動作x = F.relu(self.fc2(F.relu(self.fc1(cat))))return self.fc3(x)
class DDPG:''' DDPG算法 '''def __init__(self, state_dim, hidden_dim, action_dim, action_bound,actor_lr, critic_lr, sigma, tau, gamma, device):self.action_dim = action_dimself.actor = PolicyNet(state_dim, hidden_dim, action_dim,action_bound).to(device)self.critic = QValueNet(state_dim, hidden_dim, action_dim).to(device)self.target_actor = PolicyNet(state_dim, hidden_dim, action_dim,action_bound).to(device)self.target_critic = QValueNet(state_dim, hidden_dim,action_dim).to(device)# 初始化目標價值網絡并使其參數和價值網絡一樣self.target_critic.load_state_dict(self.critic.state_dict())# 初始化目標策略網絡并使其參數和策略網絡一樣self.target_actor.load_state_dict(self.actor.state_dict())self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=critic_lr)self.gamma = gammaself.sigma = sigma  # 高斯噪聲的標準差,均值直接設為0self.tau = tau  # 目標網絡軟更新參數self.action_bound = action_boundself.device = devicedef take_action(self, state):state = torch.tensor([state], dtype=torch.float).to(self.device)action = self.actor(state).detach().cpu().numpy()[0]# 給動作添加噪聲,增加探索action = action + self.sigma * np.random.randn(self.action_dim)return actiondef soft_update(self, net, target_net):for param_target, param in zip(target_net.parameters(),net.parameters()):param_target.data.copy_(param_target.data * (1.0 - self.tau) +param.data * self.tau)def update(self, transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions'],dtype=torch.float).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)next_q_values = self.target_critic(next_states,self.target_actor(next_states))q_targets = rewards + self.gamma * next_q_values * (1 - dones)# MSE損失函數critic_loss = torch.mean(F.mse_loss(self.critic(states, actions), q_targets))self.critic_optimizer.zero_grad()critic_loss.backward()self.critic_optimizer.step()# 策略網絡就是為了使Q值最大化actor_loss = -torch.mean(self.critic(states, self.actor(states)))self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()self.soft_update(self.actor, self.target_actor)  # 軟更新策略網絡self.soft_update(self.critic, self.target_critic)  # 軟更新價值網絡
class Trajectory:''' 用來記錄一條完整軌跡 '''def __init__(self, init_state):self.states = [init_state]self.actions = []self.rewards = []self.dones = []self.length = 0def store_step(self, action, state, reward, done):self.actions.append(action)self.states.append(state)self.rewards.append(reward)self.dones.append(done)self.length += 1class ReplayBuffer_Trajectory:''' 存儲軌跡的經驗回放池 '''def __init__(self, capacity):self.buffer = collections.deque(maxlen=capacity)def add_trajectory(self, trajectory):self.buffer.append(trajectory)def size(self):return len(self.buffer)def sample(self, batch_size, use_her, dis_threshold=0.15, her_ratio=0.8):batch = dict(states=[],actions=[],next_states=[],rewards=[],dones=[])for _ in range(batch_size):traj = random.sample(self.buffer, 1)[0]step_state = np.random.randint(traj.length)state = traj.states[step_state]next_state = traj.states[step_state + 1]action = traj.actions[step_state]reward = traj.rewards[step_state]done = traj.dones[step_state]if use_her and np.random.uniform() <= her_ratio:step_goal = np.random.randint(step_state + 1, traj.length + 1)goal = traj.states[step_goal][:2]  # 使用HER算法的future方案設置目標dis = np.sqrt(np.sum(np.square(next_state[:2] - goal)))reward = -1.0 if dis > dis_threshold else 0done = False if dis > dis_threshold else Truestate = np.hstack((state[:2], goal))next_state = np.hstack((next_state[:2], goal))batch['states'].append(state)batch['next_states'].append(next_state)batch['actions'].append(action)batch['rewards'].append(reward)batch['dones'].append(done)batch['states'] = np.array(batch['states'])batch['next_states'] = np.array(batch['next_states'])batch['actions'] = np.array(batch['actions'])return batchif __name__ == '__main__':actor_lr = 1e-3critic_lr = 1e-3hidden_dim = 128state_dim = 4action_dim = 2action_bound = 1sigma = 0.1tau = 0.005gamma = 0.98num_episodes = 2000n_train = 20batch_size = 256minimal_episodes = 200buffer_size = 10000device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")random.seed(0)np.random.seed(0)torch.manual_seed(0)env = WorldEnv()replay_buffer = ReplayBuffer_Trajectory(buffer_size)agent = DDPG(state_dim, hidden_dim, action_dim, action_bound, actor_lr,critic_lr, sigma, tau, gamma, device)return_list = []for i in range(10):with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes / 10)):episode_return = 0state = env.reset()traj = Trajectory(state)done = Falsewhile not done:action = agent.take_action(state)state, reward, done = env.step(action)episode_return += rewardtraj.store_step(action, state, reward, done)replay_buffer.add_trajectory(traj)return_list.append(episode_return)if replay_buffer.size() >= minimal_episodes:for _ in range(n_train):transition_dict = replay_buffer.sample(batch_size, True)agent.update(transition_dict)if (i_episode + 1) % 10 == 0:pbar.set_postfix({'episode':'%d' % (num_episodes / 10 * i + i_episode + 1),'return':'%.3f' % np.mean(return_list[-10:])})pbar.update(1)episodes_list = list(range(len(return_list)))plt.plot(episodes_list, return_list)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('DDPG with HER on {}'.format('GridWorld'))plt.show()

5、實驗結果??

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/92950.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/92950.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/92950.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

2025 XYD Summer Camp 7.21 智靈班分班考 · Day1

智靈班分班考 Day1 時間線 8:00 在濱蘭實驗的遠古機房中的一個鍵盤手感爆炸的電腦上開考。開 T1&#xff0c;推了推發現可以 segment tree 優化 dp&#xff0c;由于按空格需要很大的力氣導致馬蜂被迫改變。后來忍不住了頂著疼痛按空格。8:30 過了樣例&#xff0c;但是沒有大樣…

基于多種主題分析、關鍵詞提取算法的設計與實現【TF-IDF算法、LDA、NMF分解、BERT主題模型】

文章目錄有需要本項目的代碼或文檔以及全部資源&#xff0c;或者部署調試可以私信博主一、項目背景二、研究目標與意義三、數據獲取與處理四、文本分析與主題建模方法1. 傳統方法探索2. 主題模型比較與優化3. 深度語義建模與聚類五、研究成果與應用價值六、總結與展望總結每文一…

MDC(Mapped Diagnostic Context) 的核心介紹與使用教程

關于日志框架中 MDC&#xff08;Mapped Diagnostic Context&#xff09; 的核心介紹與使用教程&#xff0c;結合其在分布式系統中的實際應用場景&#xff0c;分模塊說明&#xff1a; 一、MDC 簡介 MDC&#xff08;映射診斷上下文&#xff09; 是 SLF4J/Logback 提供的一種線程…

Linux隨記(二十一)

一、highgo切換leader&#xff0c;follow - 隨記 【待寫】二、highgo的etcd未授權訪問 - 隨記 【待寫】三、highgo的etcd未授權訪問 - 隨記 【待寫】3.2、etcd的metric未授權訪問 - 隨記 【待寫】四、安裝Elasticsearch 7.17.29 和 Elasticsearch 未授權訪問【原理掃描】…

Java環境配置之各類組件下載安裝教程整理(jdk、idea、git、maven、mysql、redis)

Java環境配置之各類組件下載安裝教程整理&#xff08;jdk、idea、git、maven、mysql、redis&#xff09;1.[安裝配置jdk8]2.[安裝配置idea]3.[安裝配置git]4.[安裝配置maven]5.[安裝配置postman]6.[安裝配置redis和可視化工具]7.[安裝配置mysql和可視化工具]8.[安裝配置docker]…

配置https ssl證書生成

1.可用openssl生成私鑰和自簽名證書 安裝opensslsudo yum install openssl -y 2.生成ssl證書 365天期限sudo openssl req -x509 -nodes -days 365 -newkey rsa:2048 \-keyout /etc/ssl/private/nginx-selfsigned.key \-out /etc/ssl/certs/nginx-selfsigned.crt3、按照提示編…

編程語言Java——核心技術篇(四)集合類詳解

言不信者行不果&#xff0c;行不敏者言多滯. 目錄 4. 集合類 4.1 集合類概述 4.1.1 集合框架遵循原則 4.1.2 集合框架體系 4.2 核心接口和實現類解析 4.2.1 Collection 接口體系 4.2.1.1 Collection 接口核心定義 4.2.1.2 List接口詳解 4.2.1.3 Set 接口詳解 4.2.1.4…

GaussDB 數據庫架構師(八) 等待事件(1)-概述

1、等待事件概述 等待事件&#xff1a;指當數據庫會話(session)因資源競爭或依賴無法繼續執行時&#xff0c;進入"等待"狀態&#xff0c;此時產生的性能事件即等待事件。 2、等待事件本質 性能瓶頸的信號燈&#xff0c;反映CPU,I/O、鎖、網絡等關鍵資源的阻塞情況。…

五分鐘系列-文本搜索工具grep

目錄 1??核心功能?? ??2??基本語法?? 3????常用選項 & 功能詳解?? ??4??經典應用場景 & 示例?? 5????重要的提示 & 技巧?? ??6??總結?? grep 是 Linux/Unix 系統中功能強大的??文本搜索工具??&#xff0c;其名稱源自 …

Java面試題及詳細答案120道之(041-060)

《前后端面試題》專欄集合了前后端各個知識模塊的面試題&#xff0c;包括html&#xff0c;javascript&#xff0c;css&#xff0c;vue&#xff0c;react&#xff0c;java&#xff0c;Openlayers&#xff0c;leaflet&#xff0c;cesium&#xff0c;mapboxGL&#xff0c;threejs&…

【嘗試】本地部署openai-whisper,通過 http請求識別

安裝whisper的教程&#xff0c;已在 https://blog.csdn.net/qq_23938507/article/details/149394418 和 https://blog.csdn.net/qq_23938507/article/details/149326290 中說明。 1、創建whisperDemo1.py from fastapi import FastAPI, UploadFile, File import whisper i…

Visual Studio 的常用快捷鍵

Visual Studio 作為主流的開發工具&#xff0c;提供了大量快捷鍵提升編碼效率。以下按功能分類整理常用快捷鍵&#xff0c;涵蓋基礎操作、代碼編輯、調試等場景&#xff08;以 Visual Studio 2022 為例&#xff0c;部分快捷鍵可在「工具 > 選項 > 環境 > 鍵盤」中自定…

Triton Server部署Embedding模型

在32核CPU、無GPU的服務器上&#xff0c;使用Python后端和ONNX后端部署嵌入模型&#xff0c;并實現并行調用和性能優化策略。方案一&#xff1a;使用Python后端部署Embedding模型 Python后端提供了極大的靈活性&#xff0c;可以直接在Triton中運行您熟悉的sentence-transformer…

Java動態調試技術原理

本文轉載自 美團技術團隊胡健的Java 動態調試技術原理及實踐, 通過學習java agent方式進行動態調試了解目前很多大廠開源的一些基于此的調試工具。 簡介 斷點調試是我們最常使用的調試手段&#xff0c;它可以獲取到方法執行過程中的變量信息&#xff0c;并可以觀察到方法的執…

人工智能-python-OpenCV 圖像基礎認知與運用

文章目錄OpenCV 圖像基礎認知與運用1. OpenCV 簡介與安裝OpenCV 的優勢安裝 OpenCV2. 圖像的基本概念2.1. 圖像的存儲格式2.2. 圖像的表示3. 圖像的基本操作3.1. 創建圖像窗口3.2. 讀取與顯示圖像3.3. 保存圖像3.4. 圖像切片與區域提取3.5. 圖像大小調整4. 圖像繪制與注釋4.1. …

Windows電腦添加、修改打印機的IP地址端口的方法

本文介紹在Windows電腦中&#xff0c;為打印機添加、修改IP地址&#xff0c;從而解決電腦能找到打印機、但是無法打印問題的方法。最近&#xff0c;辦公室的打印機出現問題——雖然在電腦的打印機列表能找到這個打印機&#xff0c;但是選擇打印時&#xff0c;就會顯示文檔被掛起…

告別復雜配置!Spring Boot優雅集成百度OCR的終極方案

1. 準備工作 1.1 注冊百度AI開放平臺 訪問百度AI開放平臺 注冊賬號并登錄 進入控制臺 → 文字識別 → 創建應用 記錄下API Key和Secret Key 2. 項目配置 2.1 添加依賴 (pom.xml) <dependencies><!-- Spring Boot Web --><dependency><groupId>o…

「iOS」——內存五大分區

UI學習iOS-底層原理 24&#xff1a;內存五大區總覽一、棧區&#xff08;Stack&#xff09;1.1 核心特性1.2 優缺點1.3函數棧與棧幀1.3 堆棧溢出風險二、堆區&#xff08;Heap&#xff09;;2.1 核心特性2.2 與棧區對比三、全局 / 靜態區&#xff08;Global/Static&#xff09;3.…

每日一題【刪除有序數組中的重復項 II】

刪除有序數組中的重復項 II思路class Solution { public:int removeDuplicates(vector<int>& nums) {if(nums.size()<2){return nums.size();}int index 2;for (int i 2; i < nums.size();i ) {if(nums[i] ! nums[index-2]) {nums[index]nums[i];}}return ind…

兼容性問題記錄

1、dialog設置高度MATCH_PARENT全屏后&#xff0c;三星機型和好像是一加&#xff0c;會帶出頂部狀態欄&#xff0c;設置隱藏狀態欄屬性無效。解決方法&#xff1a;高度不設置為MATCH_PARENT&#xff0c;通過windowmanager.getdefaultdisplay來獲取并設置高度&#xff0c;再設置…