reinforce 跑 CartPole-v1

gym版本是0.26.1
CartPole-v1的詳細信息,點鏈接里看就行了。
修改了下動手深度強化學習對應的代碼。

然后這里 J ( θ ) J(\theta) J(θ)梯度上升更新的公式是用的不嚴謹的,這個和王樹森書里講的嚴謹公式有點區別。

代碼

import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import rl_utils # 這個要下載源碼,然后放到同個文件目錄下,鏈接在上面給出了
from d2l import torch as d2l # 這個是動手深度學習的庫, pip/conda install d2l 就好了class PolicyNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super().__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, action_dim)def forward(self, X):X = F.relu(self.fc1(X))return F.softmax(self.fc2(X),dim=1)class REINFORCE:def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, device):self.policy_net = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr = learning_rate)self.gamma = gamma # 折扣因子self.device = devicedef take_action(self, state): # 根據動作概率分布隨機采樣state = torch.tensor(np.array([state]),dtype=torch.float).to(self.device)probs = self.policy_net(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def update(self, transition_dict):  # 公式用的是簡化推導reward_list = transition_dict['rewards']state_list = transition_dict['states']action_list = transition_dict['actions']G = 0self.optimizer.zero_grad()for i in reversed(range(len(reward_list))):  # 從最后一步算起reward = reward_list[i]state = torch.tensor(np.array([state_list[i]]), dtype=torch.float).to(self.device)action = torch.tensor([action_list[i]]).reshape(-1,1).to(self.device)log_prob = torch.log(self.policy_net(state).gather(1, action))G = self.gamma * G + reward loss = -log_prob * G  # 因為梯度更新是減的,所以取個負號loss.backward()self.optimizer.step()
lr = 1e-3
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = d2l.try_gpu()env_name="CartPole-v1"
env = gym.make(env_name)
print(f"_max_episode_steps:{env._max_episode_steps}")
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.nagent = REINFORCE(state_dim, hidden_dim, action_dim, lr, gamma, device)
return_list = []
for i in range(10):with tqdm(total=int(num_episodes/10), desc=f'Iteration {i}') as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}state = env.reset()[0]done, truncated= False, Falsewhile not done and not truncated :  # 主要是這部分和原始的有點不同action = agent.take_action(state)next_state, reward, done, truncated, info = env.step(action)transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)agent.update(transition_dict)if (i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'REINFORCE on {env_name}')
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'REINFORCE on {env_name}')
plt.show()

我是在jupyter里直接跑的,結果如下所示。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/212418.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/212418.shtml
英文地址,請注明出處:http://en.pswp.cn/news/212418.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

innobackupex備份目錄

innobackupeex全備腳本思路 四個需求如下: (1)每天晚上23點執行,這需要linux系統做一個定時任務 00 23 * * * /bin/sh /shell/tencent_xtrabackup_all.sh /dev/null 2>&1 (2)每天。。看到這個詞…

標識符···

定義 標識符只能由字母、數字、下劃線(_)和美元符號($)組成。標識符必須以字母、下劃線或美元符號開頭,不能以數字開頭。標識符對大小寫敏感,例如"myVariable"和"myvariable"是不同的…

Android 11 適配——整理總結篇

背景 > 經過檢測,我們識別到您的應用,目前未適配安卓11(API30),請您關注適配截止時間,盡快開展適配工作,避免影響應用正常發布和經營。 > targetSdkVersion30 升級適配工作參考文檔&am…

從零開發短視頻電商 Jmeter壓測示例模板詳解(無認證場景)

文章目錄 添加線程組添加定時器添加HTTP請求默認值添加HTTP頭管理添加HTTP請求添加結果斷言響應斷言 Response AssertionJSON斷言 JSON Assertion持續時間斷言 Duration Assertion 添加察看結果樹添加聚合報告添加表格察看結果參考 以壓測百度搜索為例 https://www.baidu.com/s…

class066 一維動態規劃【算法】

class066 一維動態規劃 算法講解066【必備】從遞歸入手一維動態規劃 code1 509斐波那契數列 // 斐波那契數 // 斐波那契數 (通常用 F(n) 表示)形成的序列稱為 斐波那契數列 // 該數列由 0 和 1 開始,后面的每一項數字都是前面兩項數字的和。…

kotlin - ViewBinding

前言 為什么用ViewBinding,而不用findViewById(),這個有很多優秀的博主都做了講解,就不再列出了。 可參考下列博主的文章: kotlin ViewBinding的使用 文章里也給出了如何在gradle中做出相應的配置。 (我建議先看這位博…

【LeetCode熱題100】【滑動窗口】無重復字符的最長子串

給定一個字符串 s ,請你找出其中不含有重復字符的 最長子串 的長度。 示例 1: 輸入: s "abcabcbb" 輸出: 3 解釋: 因為無重復字符的最長子串是 "abc",所以其長度為 3。示例 2: 輸入: s "bbbbb" 輸出: 1 解釋: 因為無…

Docker安裝教程

docker官網 1.卸載舊版 yum remove docker \docker-client \docker-client-latest \docker-common \docker-latest \docker-latest-logrotate \docker-logrotate \docker-engine2.配置Docker的yum庫 安裝yum工具 yum install -y yum-utils配置Docker的yum源 yum-config-ma…

Redis,什么是緩存穿透?怎么解決?

Redis,什么是緩存穿透?怎么解決? 1、緩存穿透 一般的緩存系統,都是按照key去緩存查詢,如果不存在對用的value,就應該去后端系統查找(比如DB數據庫)。一些惡意的請求會故意查詢不存在…

不想寫大量 if 判斷?試試用規則執行器優化,就很絲滑!

近日在公司領到一個小需求,需要對之前已有的試用用戶申請規則進行拓展。我們的場景大概如下所示: if (是否海外用戶) {return false; }if (刷單用戶) {return false; }if (未付費用戶 && 不再服務時段) {return false }if (轉介紹用戶 || 付費用戶 || 內推…

16ASM 分段和機器碼

8086CPU存儲分段管理 問題1:8086是16位cpu,最多可訪問(尋址)多大內存? 運算器一次最多處理16位的數據。地址寄存器的最大寬度為16位。訪問的最大內存為:216 64K 即 0000 - FFFF。 問題2:808…

Hadoop集群破壞試驗可靠性驗證

集群環境說明: 準備5臺服務器,hadoop1、hadoop2、hadoop3、hadoop4、hadoop5; 分別部署5個節點的zookeeper集群、hadoop集群、hbase集群 本次對于Hadoop集群測試主要分為五個方面: 手動進行datanode節點刪除:&#…

typedef 與#define 的區別

typedef 與#define 的區別 typedef : 給一個已經存在的數據類型(注意:是類型不是變量)取一個別名,而非定義一個新的數據類型 #define宏定義: #define宏定義:在預編譯時直接進行簡單的文本替換 舉…

WIFI直連(Wi-Fi P2P)

一、概述 Wifi peer-to-peer(也稱Wifi-Direct)是Wifi聯盟推出的一項基于原來WIfi技術的可以讓設備與設備間直接連接的技術,使用戶不需要借助局域網或者AP(Access Point)就可以進行一對一或一對多通信。這種技術的應用…

計算機畢業設計 SpringBoot的樂樂農產品銷售系統 Javaweb項目 Java實戰項目 前后端分離 文檔報告 代碼講解 安裝調試

🍊作者:計算機編程-吉哥 🍊簡介:專業從事JavaWeb程序開發,微信小程序開發,定制化項目、 源碼、代碼講解、文檔撰寫、ppt制作。做自己喜歡的事,生活就是快樂的。 🍊心愿:點…

Xmanager

什么是 XManager Xmanager 是市場上領先的 PC X 服務器,可將X應用程序的強大功能帶入 Windows 環境。 提供了強大的會話管理控制臺,易于使用的 X 應用程序啟動器,X 服務器配置文件管理工具,SSH 模塊和高性能 PC X 服務器。 Xman…

javaScript(六):DOM操作

文章目錄 1、DOM介紹2、DOM:獲取Element對象3、DOM:事件監聽3.1、事件介紹3.2、常見事件3.3、設置事件的兩種方式3.4、事件案例 1、DOM介紹 概念 Document Object Model ,文檔對象模型 將標記語言的各個組成部分封裝為對應的對象&#xff1a…

Realme X7 Pro Root 刷機教程

Realme X7 Pro 刷機教程 Just For Fun,最近倒騰了下Realme X7 Pro 刷root。此博客為個人記錄刷機過程,如有機友跟隨本教程操作,請謹慎操作!!! 以下教程真針對Realme X7 Pro,其他版本方法未知&…

springboot(ssm高校競賽管理系統 在線競賽平臺 Java系統

springboot(ssm高校競賽管理系統 在線競賽平臺 Java系統 開發語言:Java 框架:ssm/springboot vue JDK版本:JDK1.8(或11) 服務器:tomcat 數據庫:mysql 5.7(或8.0) 數…

qt 模型視圖結構

在Qt中,Model、View和Delegate三者之間的關系如下: Model(模型):Model是數據的抽象表示,它提供了一種結構化的方式來存儲和管理數據。Model負責維護數據的狀態,并提供接口供其他組件&#xff08…