手寫PPO_clip(FrozenLake環境)

參考:白話PPO訓練

成功截圖

算法組件

四大部分

???????? 同A2C相比,PPO算法額外引入了一個old_actor_model.?

????????在PPO的訓練中,首先使用old_actor_model與環境進行交互得到經驗,然后利用一批經驗優化actor_model,最后再將actor_model的參數復制回old_actor_model

超參數

? ? ? ? 同A2C相比,PPO_clip多了兩個參數: 單批數據更新次數和截斷閾值

  • times_per_update:?在收集到的一批數據上,進行多少次梯度更新。
  • clip_param(ε)?:?PPO裁剪目標函數中的閾值,通常取 0.1 或 0.2

訓練過程

? ? ? ? 整體訓練框架同A2C, 差別在于使用old_policy采集經驗,然后優化new_policy,最后復制回old_policy.

? ? ? ? PPO為了高效利用經驗數據,在一批經驗上進行多次數據更新。

目標函數

?1. critic的目標函數同A2C

?2. actor的目標函數為PPO_clip

? ? ?

完整代碼

import torch
import torch.nn as nn
from torch.nn import functional as F
import gymnasium as gym
import tqdm
from torch.distributions import Categorical
from typing import  Tuple
import copyclass PolicyNetwork(nn.Module):def __init__(self, n_observations: int, n_actions: int):super(PolicyNetwork, self).__init__()self.layer1 = nn.Linear(n_observations, 32)   self.layer2 = nn.Linear(32, 16)               self.layer3 = nn.Linear(16, n_actions)        def forward(self, x: torch.Tensor) -> Categorical: x = F.relu(self.layer1(x))x = F.relu(self.layer2(x))action_logits = self.layer3(x)return Categorical(logits=action_logits)class PPO_clip:def __init__(self, env, total_episodes):#############超參數#############self.actor_lr = 0.01self.critic_lr = 0.01self.batch_size = 64self.times_per_update = 5 # 多次更新參數self.clip_param = 0.2     # 比率截斷參數,一般取0.2或0.1self.entropy_coeff = 0.01self.value_loss_coeff = 0.5self.gae_lambda = 0.95 self.discount_rate = 0.9 self.total_episodes = total_episodes#############PPO_clip的核心要件#############self.replay_buffer = []self.actor_model = PolicyNetwork(16, 4)self.old_actor_model = copy.deepcopy(self.actor_model)self.critic_model = nn.Sequential( # 不需要像 actor model那么復雜nn.Linear(16, 16), nn.ReLU(),nn.Linear(16, 1))############優化組件#############self.actor_optimizer = torch.optim.Adam(self.actor_model.parameters(), lr=self.actor_lr) self.critic_optimizer = torch.optim.Adam(self.critic_model.parameters(), lr=self.critic_lr)self.env = envself.count = 0self.success = 0def train(self):bar = tqdm.tqdm(range(self.total_episodes), desc=f"episode {0} {self.success / (self.count+1e-8)}") for i in bar:state, info = self.env.reset()done = Falsetruncated = False# 收集經驗 old_policy (fixed)while not done or truncated:action = self.choose_action(state)new_state, r, done, truncated, info = self.env.step(action) self.append_data(state, action, r, new_state, done)state = new_stateif done or truncated:self.count+=1if new_state == 15: self.success+=1# 優化模型 new_policy (updated)if len(self.replay_buffer) == self.batch_size:self.optimize_model()self.replay_buffer.clear()# 復制new_policy到old_policyself.old_actor_model.load_state_dict(self.actor_model.state_dict()) if i % 100 == 0:self.count = 0self.success = 0bar.set_description(f"episode {i} {self.success / (self.count+1e-8)}")def choose_action(self, state):with torch.no_grad():policy_dist = self.old_actor_model(self.state_to_input(state))action_tensor = policy_dist.sample()action = action_tensor.item()return actiondef optimize_model(self):state = torch.stack([self.state_to_input(tup[0]) for tup in self.replay_buffer[-self.batch_size:]])action = torch.IntTensor([tup[1] for tup in self.replay_buffer[-self.batch_size:]])reward = torch.FloatTensor([tup[2] for tup in self.replay_buffer[-self.batch_size:]])new_state = torch.stack([self.state_to_input(tup[3]) for tup in self.replay_buffer[-self.batch_size:]])done = torch.FloatTensor([tup[4] for tup in self.replay_buffer[-self.batch_size:]])# 以上state和new_state是二維的, 其他是一維的,即batch維with torch.no_grad():value = self.critic_model(state).squeeze()last_value = self.critic_model(new_state[:-1]).squeeze()next_value = torch.cat((value[1:], last_value))# 相比一次TD誤差, GAE效果顯著之好 advantages, returns_to_go = self.compute_gae_and_returns(reward, value, next_value, done, self.discount_rate, self.gae_lambda)# 一份batch上的數據多次更新for _ in range(self.times_per_update):# 更新actorpolicy_dist = self.actor_model(state)old_policy_dist = self.old_actor_model(state) new_log_prob = policy_dist.log_prob(action)old_log_prob = old_policy_dist.log_prob(action).detach() # old 不要梯度 r = torch.exp(new_log_prob - old_log_prob) # 計算比率用exp(ln(a)-ln(b)) 就是 a/bnew_div_old_rate = ractor_fn = -(torch.min(new_div_old_rate*advantages, torch.clamp(new_div_old_rate, 1-self.clip_param, 1+self.clip_param)*advantages) + self.entropy_coeff * policy_dist.entropy()) self.actor_optimizer.zero_grad()actor_fn.mean().backward(retain_graph=True) # .mean() torch要求梯度得標量函數self.actor_optimizer.step()# 更新criticv = self.critic_model(state).squeeze()critic_fn = F.mse_loss(v, returns_to_go)self.critic_optimizer.zero_grad()(self.value_loss_coeff * critic_fn).backward()self.critic_optimizer.step()def compute_gae_and_returns(self,rewards: torch.Tensor, values: torch.Tensor, next_values: torch.Tensor, dones: torch.Tensor, discount_rate: float, lambda_gae: float, ) -> Tuple[torch.Tensor, torch.Tensor]:advantages = torch.zeros_like(rewards)last_advantage = 0.0n_steps = len(rewards)# 計算GAEfor t in reversed(range(n_steps)):mask = 1.0 - dones[t]delta = rewards[t] + discount_rate * next_values[t] * mask - values[t] advantages[t] = delta + discount_rate * lambda_gae * last_advantage * masklast_advantage = advantages[t]# 返回給critic作為TD目標  returns_to_go = advantages + values return advantages, returns_to_godef append_data(self, state, action, r, new_state, done):self.replay_buffer.append((state, action, r, new_state, done))def state_to_input(self, state):input_dim = 16input = torch.zeros(input_dim, dtype=torch.float)input[int(state)] = 1return inputenv = gym.make("FrozenLake-v1", is_slippery=False)
policy = PPO_clip(env, 2000)
policy.train()env = gym.make("FrozenLake-v1", is_slippery=False, render_mode="human")
state, info = env.reset()
done = False
truncated = False
while True:with torch.no_grad():action=policy.choose_action(state) new_state, reward, done, truncated, info = env.step(action)state=new_stateif done or truncated:state, info = env.reset()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/90734.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/90734.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/90734.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

人形機器人指南(八)操作

八、環境交互與操作能力——人形機器人的“靈巧雙手”環境交互與操作能力是人形機器人區別于移動平臺的核心能力標志。通過仿生學設計的運動鏈與智能控制算法,機器人得以在非結構化環境中執行抓取、操縱、裝配等復雜任務。本章將系統解析機械臂運動學架構、靈巧手設…

管理 GitHub Pages 站點的自定義域(Windows)

管理 GitHub Pages 站點的自定義域(Windows) 你可以設置或更新某些 DNS 記錄和存儲庫設置,以將 GitHub Pages 站點的默認域指向自定義域。 誰可以使用此功能? GitHub Pages 在公共存儲庫中提供 GitHub Free 和 GitHub Free for organizations,在公共和私有存儲庫中提供 Gi…

【PCIe 總線及設備入門學習專欄 5.1.3 -- PCIe PERST# 時序要求】

文章目錄 Overview 什么是PERST# 第一條要求 術語解釋 要求含義 第二條要求 術語解釋 要求含義 Perst 示例說明 過程如下 總結 Overview 首先我們看下 PCIe x協議對 PERST 的要求: A component must enter the LTSSM Detect state within 20 rms of the end of Fundamental R…

圖像認知與OpenCV——圖像預處理

目錄 一、顏色加法 顏色加法 顏色加權加法 示例 二、顏色空間轉換 RGB轉Gray(灰度) RGB轉HSV HSV轉RGB 示例 三、灰度化 最大值法 平均值法 加權平均值法 四、圖像二值化處理 閾值法 反閾值法 截斷閾值法 低閾值零處理 超閾值法 OTSU…

Vue 3 組件通信全解析:從 Props 到 Pinia 的深入實踐

引言 Vue 3 作為現代前端框架的代表之一,以其靈活性和高效性受到開發者的廣泛喜愛。在 Vue 3 中,組件是構建用戶界面的核心單元,而組件之間的通信則是實現動態交互和數據流動的關鍵環節。無論是簡單的父子組件通信,還是復雜的跨組…

CodeBuddy IDE實戰:用AI全棧能力快速搭建課程表網頁

聲明:本文僅是實踐測評,并非廣告 1.前言 在數字化開發的浪潮中,工具的革新往往是效率躍遷的起點。騰訊云 CodeBuddy IDE 是 “全球首個產設研一體 AI 全棧開發平臺” ,它不僅打破了產品、設計與研發的職能壁壘,更重新…

11. HTML 中 DOCTYPE 的作用

總結H5 的聲明HTML5 的 DOCTYPE 聲明 HTML5 中的 <!DOCTYPE html> 聲明用于告訴瀏覽器當前文檔使用的是 HTML5 的文檔類型。它必須是文檔中的第一行內容&#xff08;在任何 HTML 標簽之前&#xff09;&#xff0c;以確保瀏覽器能夠正確地解析和渲染頁面。DOCTYPE 的作用 …

Linux C 網絡基礎編程

基礎知識在進行網絡編程之前&#xff0c;我們需要簡單回顧一下計算機網絡五層模型的網絡層和傳輸層&#xff0c;這兩層在面向后端編程時用的最多。物理層和鏈路層過于底層&#xff0c;已經完全由內核協議棧實現&#xff0c;不再細述。這里假設讀者已經對計算機網絡有一個大致的…

循環神經網絡--NLP基礎

一、簡單介紹NLP&#xff08;Natural Language Processing&#xff09;&#xff1a;自然語言處理是人工智能和語言領域的一個分支&#xff0c;它涉及計算機和人類語言之間的相互作用。二、NLP基礎概念詞表&#xff08;詞庫&#xff09;&#xff1a;文本數據集出現的所有單詞的集…

【Android】約束布局總結(1)

三三要成為安卓糕手 零&#xff1a;創建布局文件方式 1&#xff1a;創建步驟ctrl alt 空格 設置根元素2&#xff1a;處理老版本約束布局 在一些老的工程中&#xff0c;constrainlayout可能沒有辦法被直接使用&#xff0c;這里需要手動添加依賴implementation androidx.const…

S7-200 SMART 數字量 I/O 組態指南:從參數設置到實戰案例

在工業自動化控制中&#xff0c;PLC 的數字量輸入&#xff08;DI&#xff09;和輸出&#xff08;DO&#xff09;是連接傳感器、執行器與控制系統的 “神經末梢”。西門子 S7-200 SMART 作為一款高性價比的小型 PLC&#xff0c;其數字量 I/O 的靈活組態直接影響系統的穩定性與響…

可調諧激光器原理與設計 【DFB 與 DBR 激光器剖析】

可調諧激光器原理與設計 【DFB 與 DBR 激光器剖析】1. 可調諧激光器的原理與分類簡介2. DFB 與 DBR 激光器結構原理比較2.1 DFB&#xff08;Distributed Feedback Laser&#xff09;激光器2.2 DBR&#xff08;Distributed Bragg Reflector&#xff09;激光器2.3 DFB 激光器與 D…

【前端工程化】前端項目開發過程中如何做好通知管理?

在企業級后臺系統中&#xff0c;通知是保障團隊協作、監控系統狀態和及時響應問題的重要手段。與 C 端產品不同&#xff0c;B 端更關注構建完成、部署狀態、異常報警等關鍵節點的推送機制。 本文主要圍繞通知場景、通知內容、通知渠道、自動化集成等方面展開&#xff0c;適用于…

MySQL 9.4.0創新版發布,AI開始輔助編寫發布說明

2025 年 7 月 22 日&#xff0c;MySQL 9.4.0 正式發布。 作為一個創新版&#xff0c;MySQL 9.4.0 最大的創新應該就是使用 Oracle HeatWave GenAI 作為助手幫助編寫版本發布說明了。難道下一步要開始用 AI 輔助編寫數據庫文檔了&#xff1f; 該版本包含的核心功能更新以及問題修…

基于WebSockets和OpenCV的安卓眼鏡視頻流GPU硬解碼實現

基于WebSockets和OpenCV的安卓眼鏡視頻流GPU硬解碼實現 前些天發現了一個巨牛的人工智能學習網站&#xff0c;通俗易懂&#xff0c;風趣幽默&#xff0c;忍不住分享一下給大家&#xff0c;覺得好請收藏。點擊跳轉到網站。 1. 項目概述 本項目旨在實現一個通過WebSockets接收…

人大金倉 kingbase 連接數太多, 清理數據庫連接數

問題描述 kingbase 連接數太多, 清理數據庫連接數 [rootFCVMDZSZNST25041 ~]# su root [rootFCVMDZSZNST25041 ~]# [rootFCVMDZSZNST25041 ~]# su kingbase [kingbaseFCVMDZSZNST25041 root]$ [kingbaseFCVMDZSZNST25041 root]$ ksql could not change directory to "/r…

SpringMVC相關基礎知識

1. servlet.multipart 大小配置 SpringBoot 文件上傳接口中有 MultipartFile 類型的文件參數,上傳較大文件時報錯: org.springframework.web.multipart.MaxUploadSizeExceededException: Maximum upload size exceeded; nested exception is java.lang.IllegalStateExceptio…

HCIP第一次實驗報告

一.實驗需求及拓撲圖&#xff1a;二.實驗需求分析根據提供的網絡拓撲圖和實驗要求&#xff0c;以下是對實驗需求的詳細分析&#xff1a;R5作為ISP:R5只能進行IP地址配置&#xff0c;其所有接口均配置為公有IP地址。認證方式:R1和R5之間使用PPP的PAP認證&#xff0c;R5為主認證方…

React入門學習——指北指南(第五節)

React 交互性:過濾與條件渲染 在前文我們學習了 React 中事件處理和狀態管理的基礎。本節將聚焦兩個重要的進階技巧 ——條件渲染(根據狀態動態顯示不同 UI)和列表過濾(根據條件篩選數據),這兩者是構建交互式應用的核心能力,能讓界面根據用戶操作呈現更智能的響應。 條…

學習嵌入式的第二十九天-數據結構-(2025.7.16)線程控制:互斥與同步

以下是您提供的文本內容的排版整理版本。我已根據內容主題將其分為幾個主要部分&#xff08;互斥鎖、信號量、死鎖、IPC進程間通信、管道操作&#xff09;&#xff0c;并使用清晰的結構組織信息&#xff1a;代碼片段用代碼塊格式&#xff08;指定語言為C&#xff09;突出顯示。…