了解PPO算法(Proximal Policy Optimization)

Proximal Policy Optimization (PPO) 是一種強化學習算法,由 OpenAI 提出,旨在解決傳統策略梯度方法中策略更新過大的問題。PPO 通過引入限制策略更新范圍的機制,在保證收斂性的同時提高了算法的穩定性和效率。

PPO算法原理

PPO 算法的核心思想是通過優化目標函數來更新策略,但在更新過程中限制策略變化的幅度。具體來說,PPO 引入了裁剪(Clipping)和信賴域(Trust Region)的思想,以確保策略不會發生過大的改變。

PPO算法公式

PPO 主要有兩種變體:裁剪版(Clipped PPO)和信賴域版(Adaptive KL Penalty PPO)。本文重點介紹裁剪版的 PPO。

  • 舊策略:

    \pi_{\theta_{\text{old}}}(a|s)

    其中,\theta_{\text{old}}? 是上一次更新后的策略參數。

  • 計算概率比率:

    r_t(\theta) = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}
  • 裁剪后的目標函數:

    L^{\text{CLIP}}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \hat{A}_t \right) \right]

    其中,\hat{A}_t? 是優勢函數(Advantage Function),\epsilon?是裁剪范圍的超參數,通常取值為0.2。

  • 更新策略參數:

    a_{\text{new}} = \arg\max_{\theta} L^{\text{CLIP}}(\theta)
PPO算法的實現

下面是用Python和TensorFlow實現 PPO 算法的代碼示例:

import tensorflow as tf
import numpy as np
import gym# 定義策略網絡
class PolicyNetwork(tf.keras.Model):def __init__(self, action_space):super(PolicyNetwork, self).__init__()self.dense1 = tf.keras.layers.Dense(128, activation='relu')self.dense2 = tf.keras.layers.Dense(128, activation='relu')self.logits = tf.keras.layers.Dense(action_space, activation=None)def call(self, inputs):x = self.dense1(inputs)x = self.dense2(x)return self.logits(x)# 定義值函數網絡
class ValueNetwork(tf.keras.Model):def __init__(self):super(ValueNetwork, self).__init__()self.dense1 = tf.keras.layers.Dense(128, activation='relu')self.dense2 = tf.keras.layers.Dense(128, activation='relu')self.value = tf.keras.layers.Dense(1, activation=None)def call(self, inputs):x = self.dense1(inputs)x = self.dense2(x)return self.value(x)# 超參數
learning_rate = 0.0003
clip_ratio = 0.2
epochs = 10
batch_size = 64
gamma = 0.99# 創建環境
env = gym.make('CartPole-v1')
obs_dim = env.observation_space.shape[0]
n_actions = env.action_space.n# 創建策略和值函數網絡
policy_net = PolicyNetwork(n_actions)
value_net = ValueNetwork()# 優化器
policy_optimizer = tf.keras.optimizers.Adam(learning_rate)
value_optimizer = tf.keras.optimizers.Adam(learning_rate)def get_action(observation):logits = policy_net(observation)action = tf.random.categorical(logits, 1)return action[0, 0]def compute_advantages(rewards, values, next_values, done):advantages = []gae = 0for i in reversed(range(len(rewards))):delta = rewards[i] + gamma * next_values[i] * (1 - done[i]) - values[i]gae = delta + gamma * gaeadvantages.insert(0, gae)return np.array(advantages)def ppo_update(observations, actions, advantages, returns):with tf.GradientTape() as tape:old_logits = policy_net(observations)old_log_probs = tf.nn.log_softmax(old_logits)old_action_log_probs = tf.reduce_sum(old_log_probs * tf.one_hot(actions, n_actions), axis=1)logits = policy_net(observations)log_probs = tf.nn.log_softmax(logits)action_log_probs = tf.reduce_sum(log_probs * tf.one_hot(actions, n_actions), axis=1)ratio = tf.exp(action_log_probs - old_action_log_probs)surr1 = ratio * advantagessurr2 = tf.clip_by_value(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) * advantagespolicy_loss = -tf.reduce_mean(tf.minimum(surr1, surr2))policy_grads = tape.gradient(policy_loss, policy_net.trainable_variables)policy_optimizer.apply_gradients(zip(policy_grads, policy_net.trainable_variables))with tf.GradientTape() as tape:value_loss = tf.reduce_mean((returns - value_net(observations))**2)value_grads = tape.gradient(value_loss, value_net.trainable_variables)value_optimizer.apply_gradients(zip(value_grads, value_net.trainable_variables))# 訓練循環
for epoch in range(epochs):observations = []actions = []rewards = []values = []next_values = []dones = []obs = env.reset()done = Falsewhile not done:obs = obs.reshape(1, -1)observations.append(obs)action = get_action(obs)actions.append(action)value = value_net(obs)values.append(value)obs, reward, done, _ = env.step(action.numpy())rewards.append(reward)dones.append(done)if done:next_values.append(0)else:next_value = value_net(obs.reshape(1, -1))next_values.append(next_value)returns = compute_advantages(rewards, values, next_values, dones)advantages = returns - valuesobservations = np.concatenate(observations, axis=0)actions = np.array(actions)returns = np.array(returns)advantages = np.array(advantages)ppo_update(observations, actions, advantages, returns)print(f'Epoch {epoch+1} completed')
總結

PPO 算法通過引入裁剪機制和信賴域約束,限制了策略更新的幅度,提高了訓練過程的穩定性和效率。其簡單而有效的特性使其成為目前強化學習中最流行的算法之一。通過理解并實現 PPO 算法,可以更好地應用于各種強化學習任務,提升模型的性能。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/42633.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/42633.shtml
英文地址,請注明出處:http://en.pswp.cn/web/42633.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Oracle數據庫自帶的內置表和視圖、常用內部視圖

文章目錄 一.Oracle數據庫自帶的內置表和視圖1.dba_開頭表2.user_開頭表3.v$開頭表4.all_開頭表5.session_開頭表6.index_開頭表 三.按組分的幾組重要的性能視圖1.System的over view2.某個session的當前情況3.SQL的情況4.Latch/lock/ENQUEUE5.IO方面的 分類類別關系群集、表、視…

【docker 把系統盤空間耗沒了!】windows11 更改 ubuntu 子系統存儲位置

系統:win11 ubuntu 22 子系統,docker 出現問題:系統盤突然沒空間了,一片紅 經過排查,發現 AppData\Local\packages\CanonicalGroupLimited.Ubuntu22.04LTS_79rhkp1fndgsc\ 這個文件夾竟然有 90GB 下面提供解決辦法 步…

Spring-AOP(二)

作者:月下山川 公眾號:月下山川 1、什么是AOP AOP(Aspect Oriented Programming)是一種設計思想,是軟件設計領域中的面向切面編程,它是面向對象編程的一種補充和完善,它以通過預編譯方式和運行期…

【課程總結】Day13(下):人臉識別和MTCNN模型

前言 在上一章課程【課程總結】Day13(上):使用YOLO進行目標檢測,我們了解到目標檢測有兩種策略,一種是以YOLO為代表的策略:特征提取→切片→分類回歸;另外一種是以MTCNN為代表的策略:先圖像切片→特征提取→分類和回歸。因此,本章內容將深入了解MTCNN模型,包括:MTC…

CountDownLatch 是 Java 中的一個同步輔助工具類

下面是一個使用 CountDownLatch 的案例分析,我們將通過一個簡單的示例來展示如何使用 CountDownLatch 來同步多個線程的操作。 ### 場景描述 假設我們有一個任務,需要從多個數據源(比如多個數據庫表或文件)中讀取數據&#xff0c…

使用jdk11運行javafx程序和jdk11打包jre包含javafx模塊

我們都知道jdk11是移除了javafx的,如果需要使用javafx,需要單獨下載。 這就導致我們使用javafx開發的桌面程序使用jdk11時提示缺少javafx依賴。但這是可以通過下面的方法解決。 一,使用jdk11運行javafx程序 我們可以通過設置vmOptions來使用jdk11運行javafx程序 1,添加j…

【RAG KG】GraphRAG開源:查詢聚焦摘要的圖RAG方法

前言 傳統的 RAG 方法在處理針對整個文本語料庫的全局性問題時存在不足,例如查詢:“數據中的前 5 個主題是什么?” 對于此類問題,是因為這類問題本質上是查詢聚焦的摘要(Query-Focused Summarization, QFS&#xff09…

嵌入式單片機,兩者有什么關聯又有什么區別?

在開始前剛好我有一些資料,是我根據網友給的問題精心整理了一份「嵌入式的資料從專業入門到高級教程」, 點個關注在評論區回復“666”之后私信回復“666”,全部無償共享給大家!!!使用單片機是嵌入式系統的…

iOS 國際化語言第一語言不支持時候默認語言強轉英文

對bundle擴展 直接貼代碼 .h文件 // // NSBundleKdLocalBundle.h // QooCam // // Created by bob bob on 2023/9/8.//#import <Foundation/Foundation.h>NS_ASSUME_NONNULL_BEGINinterface NSBundle (KdLocalBundle)end interface KdLocalBundle:NSBundleend interf…

CurrentHashMap巧妙利用位運算獲取數組指定下標元素

先來了解一下數組對象在堆中的存儲形式【數組長度&#xff0c;數組元素類型信息等】 【存放元素對象的空間】 Ma 基礎信息實例數據內存填充Mark Word,ClassPointer,數組長度第一個元素第二個元素固定的填充內容 所以我們想要獲取某個下標的元素首先要獲取這個元素的起始位置…

軟件工程常見知識點

下午收到字節日常實習的面試邀請&#xff0c;希望這次能有一個好的表現。言歸正傳&#xff0c;郵件中提到這些問題&#xff0c;我這邊借了書并查了網上的資料&#xff0c;做一個提前準備。 軟件工程核心概念&#xff1a; 如何從一個需求落實到一個系統設計&#xff1f; 經過我…

c++ primer plus 第15章友,異常和其他:異常,15.3.7 其他異常特性

c primer plus 第15章友&#xff0c;異常和其他&#xff1a;異常,15.3.7 其他異常特性 c primer plus 第15章友&#xff0c;異常和其他&#xff1a;異常,15.3.7 其他異常特性 文章目錄 c primer plus 第15章友&#xff0c;異常和其他&#xff1a;異常,15.3.7 其他異常特性 15.…

Sorted Set 類型命令(命令語法、操作演示、命令返回值、時間復雜度、注意事項)

Sorted Set 類型 文章目錄 Sorted Set 類型zadd 命令zrange 命令zcard 命令zcount 命令zrevrange 命令zrangebyscore 命令zpopmax 命令bzpopmax 命令zpopmin 命令bzpopmin 命令zrank 命令zscore 命令zrem 命令zremrangebyrank 命令zremrangebyscore 命令zincrby 命令zinterstor…

線程池案例

秒殺 需求 10個禮物20個客戶搶隨機10個客戶獲取禮物&#xff0c;另外10無法獲取禮物 任務類 記得給共享資源加鎖 public class MyTask implements Runnable{// 禮物列表private ArrayList<String> gifts ;// 用戶名private String username;public MyTask( String user…

android Dialog全屏沉浸式狀態欄實現

在Android中&#xff0c;創建沉浸式狀態欄通常意味著讓狀態欄背景與應用的主題顏色一致&#xff0c;并且讓對話框在狀態欄下面顯示&#xff0c;而不是浮動。為了實現這一點&#xff0c;你可以使用以下代碼片段&#xff1a; 1、實際效果圖&#xff1a; 2、代碼實現&#xff1a;…

揭秘GPT-4o:未來智能的曙光

引言 近年來&#xff0c;人工智能&#xff08;AI&#xff09;的發展突飛猛進&#xff0c;尤其是自然語言處理&#xff08;NLP&#xff09;領域的進步&#xff0c;更是引人注目。在這一背景下&#xff0c;OpenAI發布的GPT系列模型成為了焦點。本文將詳細探討最新的模型GPT-4o&a…

Unity海面效果——6、反射和高光

Unity引擎制作海面效果 大家好&#xff0c;我是阿趙。 上一篇的結束時&#xff0c;海面效果已經做成這樣了&#xff1a; 這個Shader的復雜程度已經比較高了&#xff1a; 不過還有一些美中不足的地方。 1、 海平面沒有反射到天空球 2、 在近岸邊看得到水底的部分&#xff0c;水…

JVM調優:深入理解與實戰指南

引言 Java虛擬機&#xff08;JVM&#xff09;作為Java應用程序的運行環境&#xff0c;其性能直接影響到應用程序的響應速度、吞吐量和穩定性。JVM調優是Java開發者必須掌握的一項關鍵技能&#xff0c;它能夠幫助我們更好地利用系統資源&#xff0c;提升應用程序的性能。本文將…

一些關于C++的基礎知識

引言&#xff1a;C兼容C的大部分內容&#xff0c;但其中仍有許多小細節的東西需要大家注意 一.C的第一個程序 #include <iostream> using namespace std;int main() {cout << "hello world!" << endl;return 0; } 第一次看這個是否感覺一頭霧水…

數據挖掘——matplotlib

matplotlib概述 Mat指的是Matlab&#xff0c;plot指的是畫圖&#xff0c;lib即library&#xff0c;顧名思義&#xff0c;matplotlib是python專門用于開發2D圖表的第三方庫&#xff0c;使用之前需要下載該庫&#xff0c;使用pip命令即可下載。 pip install matplotlib1、matpl…