基于“動手學強化學習”的知識點(二):第 15 章 模仿學習(gym版本 >= 0.26)

第 15 章 模仿學習(gym版本 >= 0.26)

  • 摘要

摘要

本系列知識點講解基于動手學強化學習中的內容進行詳細的疑難點分析!具體內容請閱讀動手學強化學習!


對應動手學強化學習——模仿學習


# -*- coding: utf-8 -*-import gym
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import rl_utilsclass PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)class ValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class PPO:''' PPO算法,采用截斷方式 '''def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,lmbda, epochs, eps, gamma, device):self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.critic = ValueNet(state_dim, hidden_dim).to(device)self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)self.gamma = gammaself.lmbda = lmbdaself.epochs = epochs  # 一條序列的數據用于訓練輪數self.eps = eps  # PPO中截斷范圍的參數self.device = devicedef take_action(self, state):if isinstance(state, tuple):state = state[0]state = torch.tensor([state], dtype=torch.float).to(self.device)probs = self.actor(state)'''根據概率分布創建一個離散分類分布對象,用于采樣離散動作。離散的概率模型。'''action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def update(self, transition_dict):    processed_state = []for s in transition_dict['states']:if isinstance(s, tuple):# 如果元素是元組,則取元組的第一個元素processed_state.append(s[0])else:processed_state.append(s)# states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)states = torch.tensor(processed_state, dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)'''計算 TD 目標(即回歸目標):td_target=r+γ×V(s′)×(1?done)'''td_target = rewards + self.gamma * self.critic(next_states) * (1 -dones)'''計算 TD 殘差(或優勢估計的基礎):當前狀態的 TD 目標減去當前 critic 估計的狀態價值。'''td_delta = td_target - self.critic(states)'''調用輔助函數(在 rl_utils 模塊中定義)計算優勢函數,通常使用廣義優勢估計(GAE)。'''advantage = rl_utils.compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)'''先將狀態輸入 actor 網絡得到動作概率分布(例如 shape 為 (batch_size, action_dim))。使用 .gather(1, actions) 選出每個樣本所執行動作對應的概率(注意 actions 的形狀必須匹配)。取對數得到舊的對數概率,再 detach() 阻斷梯度流,保存舊策略下的概率值。'''old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()for _ in range(self.epochs):'''在當前策略下重新計算所有樣本的對數概率,與舊對數概率進行比較。'''log_probs = torch.log(self.actor(states).gather(1, actions))'''計算概率比率,即新舊策略的概率之比,用于 PPO 的 clip 損失計算。'''ratio = torch.exp(log_probs - old_log_probs)'''計算無截斷的策略目標,乘上優勢值。'''surr1 = ratio * advantage'''對 ratio 進行截斷,確保其在 [1??,1+?] 范圍內(例如 [0.8, 1.2]),然后乘以優勢。'''surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage  # 截斷'''PPO 算法的目標是最大化最小值,因此這里取兩者中的較小值再取負號作為損失。對整個 batch 求均值。'''actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO損失函數'''計算 critic 的均方誤差(MSE)損失:當前 critic 估計與 TD 目標之間的誤差,對整個 batch 取平均。'''critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()actor_loss.backward()critic_loss.backward()self.actor_optimizer.step()self.critic_optimizer.step()actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 250
hidden_dim = 128
gamma = 0.98
lmbda = 0.95
epochs = 10
eps = 0.2
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'CartPole-v0'
env = gym.make(env_name)
if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
ppo_agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device)return_list = rl_utils.train_on_policy_agent(env, ppo_agent, num_episodes)def sample_expert_data(n_episode):states = []actions = []for episode in range(n_episode):state = env.reset()done = Falsewhile not done:action = ppo_agent.take_action(state)states.append(state)actions.append(action)result = env.step(action)if len(result) == 5:next_state, reward, done, truncated, info = resultdone = done or truncated  # 可合并 terminated 和 truncated 標志else:next_state, reward, done, info = result# next_state, reward, done, _ = env.step(action)state = next_stateprocessed_states = []for s in states:if isinstance(s, tuple):# 如果元素是元組,則取元組的第一個元素processed_states.append(s[0])else:processed_states.append(s)return np.array(processed_states), np.array(actions)if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
random.seed(0)
n_episode = 1
expert_s, expert_a = sample_expert_data(n_episode)n_samples = 30  # 采樣30個數據
random_index = random.sample(range(expert_s.shape[0]), n_samples)
expert_s = expert_s[random_index]
expert_a = expert_a[random_index]class BehaviorClone:def __init__(self, state_dim, hidden_dim, action_dim, lr):self.policy = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)def learn(self, states, actions):"""解釋:定義一個學習函數,接收一批專家數據中的狀態和動作,用于更新策略網絡。"""states = torch.tensor(states, dtype=torch.float).to(device)actions = torch.tensor(actions).view(-1, 1).to(device)'''- 將 states 輸入 policy 網絡,得到每個狀態下所有動作的概率分布,假設輸出形狀為 (batch_size, action_dim);- 使用 .gather(1, actions.long()) 從概率分布中取出對應專家動作的概率(注意動作需要轉換為長整型索引);- 對這些概率取對數,得到對數概率(log likelihood)。'''log_probs = torch.log(self.policy(states).gather(1, actions.long()))# log_probs = torch.log(self.policy(states).gather(1, actions))'''計算行為克隆的損失,即負對數似然損失。對所有樣本的負對數概率取均值。'''bc_loss = torch.mean(-log_probs)  # 最大似然估計self.optimizer.zero_grad()bc_loss.backward()self.optimizer.step()def take_action(self, state):if isinstance(state, tuple):state = state[0]state = torch.tensor([state], dtype=torch.float).to(device)probs = self.policy(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def test_agent(agent, env, n_episode):return_list = []for episode in range(n_episode):episode_return = 0state = env.reset()done = Falsewhile not done:action = agent.take_action(state)result = env.step(action)if len(result) == 5:next_state, reward, done, truncated, info = resultdone = done or truncated  # 可合并 terminated 和 truncated 標志else:next_state, reward, done, info = result# next_state, reward, done, _ = env.step(action)state = next_stateepisode_return += rewardreturn_list.append(episode_return)return np.mean(return_list)if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
np.random.seed(0)lr = 1e-3
bc_agent = BehaviorClone(state_dim, hidden_dim, action_dim, lr)
n_iterations = 1000
batch_size = 64
test_returns = []with tqdm(total=n_iterations, desc="進度條") as pbar:for i in range(n_iterations):sample_indices = np.random.randint(low=0, high=expert_s.shape[0], size=batch_size)bc_agent.learn(expert_s[sample_indices], expert_a[sample_indices])current_return = test_agent(bc_agent, env, 5)test_returns.append(current_return)if (i + 1) % 10 == 0:pbar.set_postfix({'return': '%.3f' % np.mean(test_returns[-10:])})pbar.update(1)iteration_list = list(range(len(test_returns)))
plt.plot(iteration_list, test_returns)
plt.xlabel('Iterations')
plt.ylabel('Returns')
plt.title('BC on {}'.format(env_name))
plt.show()class Discriminator(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(Discriminator, self).__init__()self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x, a):cat = torch.cat([x, a], dim=1)x = F.relu(self.fc1(cat))return torch.sigmoid(self.fc2(x))class GAIL:def __init__(self, agent, state_dim, action_dim, hidden_dim, lr_d):print(state_dim, action_dim, hidden_dim)self.discriminator = Discriminator(state_dim, hidden_dim, action_dim).to(device)self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d)self.agent = agentdef learn(self, expert_s, expert_a, agent_s, agent_a, next_s, dones):expert_states = torch.tensor(expert_s, dtype=torch.float).to(device)expert_actions = torch.tensor(expert_a).to(device)processed_state = []for s in agent_s:if isinstance(s, tuple):# 如果元素是元組,則取元組的第一個元素processed_state.append(s[0])else:processed_state.append(s)agent_states = torch.tensor(processed_state, dtype=torch.float).to(device)agent_actions = torch.tensor(agent_a).to(device)'''作用:將專家動作轉換為 one-hot 編碼形式,轉換為浮點數。'''expert_actions = F.one_hot(expert_actions.long(), num_classes=2).float()agent_actions = F.one_hot(agent_actions.long(), num_classes=2).float()expert_prob = self.discriminator(expert_states, expert_actions)agent_prob = self.discriminator(agent_states, agent_actions)'''作用:計算二元交叉熵損失(BCE):- 對 agent 數據,目標標簽設為 1(即希望判別器認為 agent 數據為“真”),損失為 BCE(agent_prob, 1);- 對專家數據,目標標簽設為 0(希望判別器認為專家數據為“假”),損失為 BCE(expert_prob, 0)。- 然后將兩部分損失相加。'''discriminator_loss = nn.BCELoss()(agent_prob, torch.ones_like(agent_prob)) + nn.BCELoss()(expert_prob, torch.zeros_like(expert_prob))self.discriminator_optimizer.zero_grad()discriminator_loss.backward()self.discriminator_optimizer.step()'''作用:利用判別器對 agent 數據輸出計算獎勵:- 計算 –log(agent_prob) 作為獎勵信號(當 agent_prob 較小時,獎勵較高,鼓勵 agent 模仿專家);- detach() 阻斷梯度,轉移到 CPU 并轉換為 numpy 數組,方便后續傳入 agent.update。'''rewards = -torch.log(agent_prob).detach().cpu().numpy()transition_dict = {'states': agent_s,'actions': agent_a,'rewards': rewards,'next_states': next_s,'dones': dones}self.agent.update(transition_dict)if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
lr_d = 1e-3
agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device)
gail = GAIL(agent, state_dim, action_dim, hidden_dim, lr_d)
n_episode = 500
return_list = []with tqdm(total=n_episode, desc="進度條") as pbar:for i in range(n_episode):episode_return = 0state = env.reset()done = Falsestate_list = []action_list = []next_state_list = []done_list = []while not done:action = agent.take_action(state)result = env.step(action)if len(result) == 5:next_state, reward, done, truncated, info = resultdone = done or truncated  # 可合并 terminated 和 truncated 標志else:next_state, reward, done, info = result# next_state, reward, done, _ = env.step(action)state_list.append(state)action_list.append(action)next_state_list.append(next_state)done_list.append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)gail.learn(expert_s, expert_a, state_list, action_list, next_state_list, done_list)if (i + 1) % 10 == 0:pbar.set_postfix({'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)    iteration_list = list(range(len(return_list)))
plt.plot(iteration_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('GAIL on {}'.format(env_name))
plt.show()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/73682.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/73682.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/73682.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

JAVA面試_進階部分_Java JVM:垃圾回收(GC 在什么時候,對什么東西,做了什么事情)

在什么時候: 首先需要知道,GC又分為minor GC 和 Full GC(major GC)。Java堆內存分為新生代和老年代,新生代 中又分為1個eden區和兩個Survior區域。 一般情況下,新創建的對象都會被分配到eden區&#xff…

2024年消費者權益數據分析

📅 2024年315消費者權益數據分析 數據見:https://mp.weixin.qq.com/s/eV5GoionxhGpw7PunhOVnQ 一、引言 在數字化時代,消費者維權數據對于市場監管、商家誠信和行業發展具有重要價值。本文基于 2024年315平臺線上投訴數據,采用數…

設計模式Python版 訪問者模式

文章目錄 前言一、訪問者模式二、訪問者模式示例 前言 GOF設計模式分三大類: 創建型模式:關注對象的創建過程,包括單例模式、簡單工廠模式、工廠方法模式、抽象工廠模式、原型模式和建造者模式。結構型模式:關注類和對象之間的組…

安全無事故連續天數計算,python 時間工具的高效利用

安全天數計算,數據系統時間直取,安全標準高效便捷好用。 筆記模板由python腳本于2025-03-17 23:50:52創建,本篇筆記適合對python時間工具有研究欲的coder翻閱。 【學習的細節是歡悅的歷程】 博客的核心價值:在于輸出思考與經驗&am…

大型語言模型(LLM)部署中的內存消耗計算

在部署大型語言模型(LLM)時,顯存(VRAM)的合理規劃是決定模型能否高效運行的核心問題。本文將通過詳細的公式推導和示例計算,系統解析模型權重、鍵值緩存(KV Cache)、激活內存及額外開…

Mysql表的查詢

一:創建一個新的數據庫(companydb),并查看數據庫。 二:使用該數據庫,并創建表worker。 mysql> use companydb;mysql> CREATE TABLE worker(-> 部門號 INT(11) NOT NULL,-> 職工號 INT(11) NOT NULL,-> 工作時間 D…

ASP.NET Webform和ASP.NET MVC 后臺開發 大概80%常用技術

本文涉及ASP.NET Webform和ASP.NET MVC 后臺開發大概80%技術 2019年以前對標 深圳22K左右 廣州18K左右 武漢16K左右 那么有人問了2019年以后的呢? 答:吉祥三寶。。。 So 想繼續看下文的 得有自己的獨立判斷能力。 C#.NET高級筆試題 架構 優化 性能提…

首頁性能優化

首頁性能提升是前端優化中的核心任務之一,因為首頁是用戶訪問的第一入口,其加載速度和交互體驗直接影響用戶的留存率和轉化率。 1. 性能瓶頸分析 在優化之前,首先需要通過工具分析首頁的性能瓶頸。常用的工具包括: Chrome DevTo…

一周學會Flask3 Python Web開發-SQLAlchemy刪除數據操作-班級模塊

鋒哥原創的Flask3 Python Web開發 Flask3視頻教程&#xff1a; 2025版 Flask3 Python web開發 視頻教程(無廢話版) 玩命更新中~_嗶哩嗶哩_bilibili 首頁list.html里加上刪除鏈接&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta c…

改變一生的思維模型【12】笛卡爾思維模型

目錄 基本結構 警惕認知暗礁 案例分析應用 一、懷疑階段:破除慣性認知 二、解析階段:拆解問題為最小單元 三、整合階段:重構邏輯鏈條 四、檢驗階段:多維驗證解決方案 總結與啟示 笛卡爾說,唯獨自己的思考是可以相信的。 世界上所有的事情,都是值得被懷疑的,但是…

需求文檔(PRD,Product Requirement Document)的基本要求和案例參考:功能清單、流程圖、原型圖、邏輯能力和表達能力

文章目錄 引言I 需求文檔的基本要求結構清晰內容完整語言準確圖文結合版本管理II 需求文檔案例參考案例1:電商平臺“商品中心”功能需求(簡化版)案例2:教育類APP“記憶寶盒”非功能需求**案例3:軟件項目的功能需求模板3.1 功能需求III 需求文檔撰寫技巧1. **從核心邏輯出發…

五大方向全面對比 IoTDB 與 OpenTSDB

對比系列第三彈&#xff0c;詳解 IoTDB VS OpenTSDB&#xff01; 之前&#xff0c;我們已經深入探討了時序數據庫 Apache IoTDB 與 InfluxDB、Apache HBase 在架構設計、性能和功能方面等多個維度的區別。還沒看過的小伙伴可以點擊閱讀&#xff1a; Apache IoTDB vs InfluxDB 開…

Electron使用WebAssembly實現CRC-16 MAXIM校驗

Electron使用WebAssembly實現CRC-16 MAXIM校驗 將C/C語言代碼&#xff0c;經由WebAssembly編譯為庫函數&#xff0c;可以在JS語言環境進行調用。這里介紹在Electron工具環境使用WebAssembly調用CRC-16 MAXIM格式校驗的方式。 CRC-16 MAXIM校驗函數WebAssembly源文件 C語言實…

vue3vue-elementPlus-admin框架中form組件的upload寫法

dialog中write組件代碼 let ImageList reactive<UploadFile[]>([])const formSchema reactive<FormSchema[]>([{field: ImageFiles,label: 現場圖片,component: Upload,colProps: { span: 24 },componentProps: {limit: 5,action: PATH_URL /upload,headers: {…

Linux mount和SSD分區

為什么要用 mount&#xff1f; Linux 的文件系統結構是單一的樹狀層次 所有文件、目錄和設備都從根目錄 / 開始延伸。 外部的存儲設備&#xff08;如硬盤、U盤、網絡存儲&#xff09;或虛擬文件系統&#xff08;如 /proc、/sys&#xff09;必須通過掛載點“嫁接”到這棵樹上&a…

【Function】Azure Function通過托管身份或訪問令牌連接Azure SQL數據庫

【Function】Azure Function通過托管身份或訪問令牌連接Azure SQL數據庫 推薦超級課程: 本地離線DeepSeek AI方案部署實戰教程【完全版】Docker快速入門到精通Kubernetes入門到大師通關課AWS云服務快速入門實戰目錄 【Function】Azure Function通過托管身份或訪問令牌連接Azu…

舉例說明 牛頓法 Hessian 矩陣

矩陣求逆的方法及示例 目錄 矩陣求逆的方法及示例1. 伴隨矩陣法2. 初等行變換法矩陣逆的實際意義1. 求解線性方程組2. 線性變換的逆操作3. 數據分析和機器學習4. 優化問題牛頓法原理解釋舉例說明 牛頓法 Hessian 矩陣1. 伴隨矩陣法 原理:對于一個 n n n 階方陣 A A

安科瑞分布式光伏監測系統:推動綠色能源高效發展

安科瑞顧強 為應對傳統能源污染與資源短缺&#xff0c;分布式光伏發電成為關鍵解決方案。安科瑞Acrel-1000DP分布式光伏監控系統結合光功率預測技術&#xff0c;有效提升發電穩定性&#xff0c;助力上海汽車變速器有限公司8.3MW屋頂光伏項目實現清潔能源高效利用。 項目亮點 …

從零開始使用 **Taki + Node.js** 實現動態網頁轉靜態網站的完整代碼方案

以下是從零開始使用 Taki Node.js 實現動態網頁轉靜態網站的完整代碼方案&#xff0c;包含預渲染、自動化構建、靜態托管及優化功能&#xff1a; 一、環境準備 1. 初始化項目 mkdir static-site && cd static-site npm init -y2. 安裝依賴 npm install taki expre…

商業智能BI分析中,汽車4S銷售行業的返廠頻次有什么分析價值?

買過車的朋友會發現&#xff0c;同一款車不管在哪個4S店去買&#xff0c;基本上價格都相差不大。即使有些差別&#xff0c;也是帶著附加條件的&#xff0c;比如要做些加裝需要額外再付一下費用。為什么汽車4S銷售行業需要商業智能BI&#xff1f;就是因為在汽車4S銷售行業&#…