強化學習--多維動作狀態空間的設計

目錄

  • 一、離散動作
  • 二、連續動作
    • 1、例子1
    • 2、知乎給出的示例
    • 2、github里面的代碼

免責聲明:以下代碼部分來自網絡,部分來自ChatGPT,部分來自個人的理解。如有其他觀點,歡迎討論!

一、離散動作

注意:本文均以PPO算法為例。

# time: 2023/11/22 21:04
# author: YanJPimport torch
import torch
import torch.nn as nn
from torch.distributions import Categoricalclass MultiDimensionalActor(nn.Module):def __init__(self, input_dim, output_dims):super(MultiDimensionalActor, self).__init__()# Define a shared feature extraction networkself.feature_extractor = nn.Sequential(nn.Linear(input_dim, 128),nn.ReLU(),nn.Linear(128, 64),nn.ReLU())# Define individual output layers for each action dimensionself.output_layers = nn.ModuleList([nn.Linear(64, num_actions) for num_actions in output_dims])def forward(self, state):# Feature extractionfeatures = self.feature_extractor(state)# Generate Categorical objects for each action dimensioncategorical_objects = [Categorical(logits=output_layer(features)) for output_layer in self.output_layers]return categorical_objects# 定義主函數
def main():# 定義輸入狀態維度和每個動作維度的動作數input_dim = 10output_dims = [5, 8]  # 兩個動作維度,分別有 3 和 4 個可能的動作# 創建 MultiDimensionalActor 實例actor_network = MultiDimensionalActor(input_dim, output_dims)# 生成輸入狀態(這里使用隨機數據作為示例)state = torch.randn(1, input_dim)# 調用 actor 網絡categorical_objects = actor_network(state)# 輸出每個動作維度的采樣動作和對應的對數概率for i, categorical in enumerate(categorical_objects):sampled_action = categorical.sample()log_prob = categorical.log_prob(sampled_action)print(f"Sampled action for dimension {i+1}: {sampled_action.item()}, Log probability: {log_prob.item()}")if __name__ == "__main__":main()#Sampled action for dimension 1: 1, Log probability: -1.4930928945541382
#Sampled action for dimension 2: 3, Log probability: -2.1875085830688477

注意代碼中categorical函數的兩個不同傳入參數的區別:參考鏈接
簡單來說,logits是計算softmax的,probs直接就是已知概率的時候傳進去就行。

二、連續動作

參考鏈接:github、知乎
為什么取對數概率?參考回答
在這里插入圖片描述

1、例子1

先看如下的代碼:

# time: 2023/11/21 21:33
# author: YanJP
#這是對應多維連續變量的例子:
# 參考鏈接:https://github.com/XinJingHao/PPO-Continuous-Pytorch/blob/main/utils.py
# https://www.zhihu.com/question/417161289
import torch.nn as nn
import torch
class Policy(nn.Module):def __init__(self, in_dim, n_hidden_1, n_hidden_2, num_outputs):super(Policy, self).__init__()self.layer = nn.Sequential(nn.Linear(in_dim, n_hidden_1),nn.ReLU(True),nn.Linear(n_hidden_1, n_hidden_2),nn.ReLU(True),nn.Linear(n_hidden_2, num_outputs))class Normal(nn.Module):def __init__(self, num_outputs):super().__init__()self.stds = nn.Parameter(torch.zeros(num_outputs))  #創建一個可學習的參數 def forward(self, x):dist = torch.distributions.Normal(loc=x, scale=self.stds.exp())action = dist.sample((every_dimention_output,))  #這里我覺得是最重要的,不填sample的參數的話,默認每個分布只采樣一個值!!!!!!!!return actionif __name__ == '__main__':policy = Policy(4,20,20,5)normal = Normal(5) #設置5個維度every_dimention_output=10  #每個維度10個輸出observation = torch.Tensor(4)action = normal.forward(policy.layer( observation))print("action: ",action)
  • self.stds.exp(),表示求指數,因為正態分布的標準差都是正數。
  • action = dist.sample((every_dimention_output,))這里最重要!!!

2、知乎給出的示例


class Agent(nn.Module):def __init__(self, envs):super(Agent, self).__init__()self.actor_mean = nn.Sequential(layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),nn.Tanh(),layer_init(nn.Linear(64, 64)),nn.Tanh(),layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),)self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))def get_action_and_value(self, x, action=None):action_mean = self.actor_mean(x)action_logstd = self.actor_logstd.expand_as(action_mean)action_std = torch.exp(action_logstd)probs = Normal(action_mean, action_std)if action is None:action = probs.sample()return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)

這里的np.prod(envs.single_action_space.shape),表示每個維度的動作數相乘,然后初始化這么多個actor網絡的標準差和均值,最后action里面的sample就是采樣這么多個數據。(感覺還是拉成了一維計算)

2、github里面的代碼

github

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Beta,Normalclass GaussianActor_musigma(nn.Module):def __init__(self, state_dim, action_dim, net_width):super(GaussianActor_musigma, self).__init__()self.l1 = nn.Linear(state_dim, net_width)self.l2 = nn.Linear(net_width, net_width)self.mu_head = nn.Linear(net_width, action_dim)self.sigma_head = nn.Linear(net_width, action_dim)def forward(self, state):a = torch.tanh(self.l1(state))a = torch.tanh(self.l2(a))mu = torch.sigmoid(self.mu_head(a))sigma = F.softplus( self.sigma_head(a) )return mu,sigmadef get_dist(self, state):mu,sigma = self.forward(state)dist = Normal(mu,sigma)return distdef deterministic_act(self, state):mu, sigma = self.forward(state)return mu

上述代碼主要是通過設置mu_head 和sigma_head的個數,來實現多維動作。

class GaussianActor_mu(nn.Module):def __init__(self, state_dim, action_dim, net_width, log_std=0):super(GaussianActor_mu, self).__init__()self.l1 = nn.Linear(state_dim, net_width)self.l2 = nn.Linear(net_width, net_width)self.mu_head = nn.Linear(net_width, action_dim)self.mu_head.weight.data.mul_(0.1)self.mu_head.bias.data.mul_(0.0)self.action_log_std = nn.Parameter(torch.ones(1, action_dim) * log_std)def forward(self, state):a = torch.relu(self.l1(state))a = torch.relu(self.l2(a))mu = torch.sigmoid(self.mu_head(a))return mudef get_dist(self,state):mu = self.forward(state)action_log_std = self.action_log_std.expand_as(mu)action_std = torch.exp(action_log_std)dist = Normal(mu, action_std)return distdef deterministic_act(self, state):return self.forward(state)
class Critic(nn.Module):def __init__(self, state_dim,net_width):super(Critic, self).__init__()self.C1 = nn.Linear(state_dim, net_width)self.C2 = nn.Linear(net_width, net_width)self.C3 = nn.Linear(net_width, 1)def forward(self, state):v = torch.tanh(self.C1(state))v = torch.tanh(self.C2(v))v = self.C3(v)return v

上述代碼只定義了mu的個數與維度數一樣,std作為可學習的參數之一。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/160810.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/160810.shtml
英文地址,請注明出處:http://en.pswp.cn/news/160810.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

ERP、CRM、SRM、PLM、HRM、OA……都是啥意思

在天某微電子上班,經常會聽說一些系統或平臺名稱,例如ERP、CRM、SRM、PLM、HRM、OA、FOL等。 這些系統,都是干啥用的呢? █ ERP(企業資源計劃) 英文全稱:Enterprise Resource Planning 定義…

如何使用SD-WAN提升物流供應鏈網絡效率

案例背景 本次分享的物流供應鏈企業是一家國際性的大型企業,專注于提供全球范圍內的物流和供應鏈解決方案。案例用戶在不同國家和地區均設有多個分支機構和辦公地點,以支持客戶需求和業務運營。 在過去,該企業用戶使用傳統的MPLS網絡來連接各…

OceanBase:04-單機在線轉分布式部署

目錄 1.當前部署情況 2.單Zone多OBServer模式 3.多Zone多OBServer模式 3.1 集群規劃 3.2 安裝OBServer程序 3.3 新增Zone 3.4 啟動Zone 3.5 向Zone新增OBserver節點 3.6重復3.2~3.5新增其他Zone 4.擴充資源 OceanBase 數據庫為單機分布式一體化架構,支持單…

ssh遠程使用jupyter notebook

Jupyter配置 密碼生成哈希值 jupyter lab password拷貝出哈希值 vi /root/.jupyter/jupyter_server_config.json生成配置文件 jupyter-lab --generate-config編輯配置文件 vi /root/.jupyter/jupyter_lab_config.py查找 /password 按n查找一下一個 c.ServerApp.password …

純干貨丨電腦監控軟件有哪些(三款電腦監控軟件大盤點)

電腦監控軟件在日常生活和工作中的應用越來越廣泛。這些軟件可以幫助我們監控電腦的使用情況,保護電腦的安全,提高工作效率。本文將介紹一些高人氣的電腦監控軟件,并分享一些純干貨。 1、 域之盾軟件----電腦監控系統 是一款功能強大的電腦監…

LeetCode:307. 區域和檢索 - 數組可修改(樹狀數組 C++)

目錄 307. 區域和檢索 - 數組可修改 題目描述: 實現代碼與解析: 樹狀數組: 原理思路: 307. 區域和檢索 - 數組可修改 題目描述: 給你一個數組 nums ,請你完成兩類查詢。 其中一類查詢要求 更新 數組…

Linux輸入設備應用編程(觸摸屏獲取坐標信息)

上一章學習了開發板外接鍵盤并獲取鍵盤的的輸入 Linux輸入設備應用編程(鍵盤,按鍵)-CSDN博客 本章編寫觸摸屏應用程序,獲取觸摸屏的坐標信息并將其打印出來 目錄 一 觸摸屏數據分析(觸摸,點擊&#xff…

采用connector-c++ 8.0操作數據庫

1.下載最新的Connector https://dev.mysql.com/downloads/connector/cpp/,下載帶debug的庫。 解壓縮到本地,本次使用的是帶debug模式的connector庫: 注:其中mysqlcppconn與mysqlcppconn8的區別是: 2.在cmakelist…

請簡要說明 Mysql 中 MyISAM 和 InnoDB 引擎的區別

“請簡要說明 Mysql 中 MyISAM 和 InnoDB 引擎的區別”。 屏幕前有多少同學在面試過程與遇到過類似問題, 可以在評論區留言:遇到過。 考察目的 對于 xxxx 技術的區別,在面試中是很常見的一個問題 一般情況下,面試官會通過這類…

SpringBoot監聽器解析

監聽器模式介紹 監聽器模式的要素 事件監聽器廣播器觸發機制 SpringBoot監聽器實現 系統事件 事件發送順序 監聽器注冊 監聽器注冊和初始化器注冊流程類似 監聽器觸發機制 獲取監聽器列表核心流程: 通用觸發條件: 自定義監聽器實現 實現方式1 實現監聽器接口: Order(1) …

[操作系統]進程和線程

目錄 1.什么是進程 1.1進程控制塊抽象 1.2 CPU 分配 —— 進程調度(Process Scheduling) 1.3內存分配 —— 內存管理(Memory Manage) 1.4進程間通信(Inter Process Communication) 2.線程 2.1概念 2.2為什么要有線程 2.3線…

論文閱讀 Forecasting at Scale (二)

最近在看時間序列的文章,回顧下經典 論文地址 項目地址 Forecasting at Scale 3.2、季節性 3.3、假日和活動事件3.4、模型擬合3.5、分析師參與的循環建模4、自動化預測評估4.1、使用基線預測4.2、建模預測準確性4.3、模擬歷史預測4.4、識別大的預測誤差 5、結論6、致…

【Python】重磅!這本30w人都在看的Python數據分析暢銷書更新了!

Python 語言極具吸引力。自從 1991 年誕生以來,Python 如今已經成為最受歡迎的解釋型編程語言。 【文末送書】今天推薦一本Python領域優質數據分析書籍,這本30w人都在看的書,值得入手。 目錄 作譯者簡介主要變動導讀視頻購書鏈接文末送書 pan…

【計算機方向】通信、算法、自動化、機器人、電子電氣、計算機工程、控制工程、計算機視覺~~~~~合集!!!

◆本文為大家梳理了近期可投的EI國際會議,涵蓋計算機各個學科方向,均可EI檢索 本期EI會議匯總合集涵蓋領域:計算機視覺、物聯網、算法、通信、智能技術、人工智能、人機交互、機器人、電子電氣等眾多領域! 本期所推薦的EI會議有…

ros2不同機器通訊時IP設置

看到這就是不同機器的IP地址,為了避免在路由器為不同的機器使用DHCP分配到上面的地址,可以設置DHCP分配的范圍:(我的路由器是如下設置的,一般路由器型號都不一樣,自己找一下) 防火墻設置-----&…

Leetcode—13.羅馬數字轉整數【簡單】

2023每日刷題(三十七) Leetcode—13.羅馬數字轉整數 算法思想 當前位置的元素比下個位置的元素小,就減去當前值,否則加上當前值 實現代碼 int getValue(char c) {switch(c) {case I:return 1;case V:return 5;case X:return 1…

elasticsearch 8安裝

問題提前報 max virtual memory areas error max virtual memory areas vm.max_map_count [65530] is too low, increase to at least [262144] 如果您的環境是Linux,注意要做以下操作,否則es可能會啟動失敗 1 用編輯工具打開文件/etc/sysctl.conf 2 …

wpf使用CefSharp.OffScreen模擬網頁登錄,并獲取身份cookie

目錄 框架信息&#xff1a;MainWindow.xamlMainWindow.xaml.cs爬取邏輯模擬登錄攔截請求Cookie獲取 CookieVisitorHandle 框架信息&#xff1a; CefSharp.OffScreen.NETCore 119.1.20 MainWindow.xaml <Window x:Class"Wpf_CHZC_Img_Identy_ApiDataGet.MainWindow&qu…

selinux-policy-default(2:2.20231119-2)軟件包內容詳細介紹(2)

接前一篇文章&#xff1a;selinux-policy-default&#xff08;2:2.20231119-2&#xff09;軟件包內容詳細介紹&#xff08;1&#xff09; 4. 重點文件內容解析 &#xff08;1&#xff09;control/postist文件 文件內容如下&#xff1a; #!/bin/sh set -e# summary of how th…

22LLMSecEval數據集及其在評估大模型代碼安全中的應用:GPT3和Codex根據LLMSecEval的提示生成代碼和代碼補全,CodeQL進行安全評估

LLMSecEval: A Dataset of Natural Language Prompts for Security Evaluations 寫在最前面主要工作 課堂討論大模型和密碼方向&#xff08;沒做&#xff0c;只是一個idea&#xff09; 相關研究提示集目標NL提示的建立NL提示的建立流程 數據集數據集分析 存在的問題 寫在最前面…