PPO(2017,OpenAI)核心改進點
Proximal Policy Optimization (PPO):一種基于信賴域優化的強化學習算法,旨在克服傳統策略梯度方法在更新時不穩定的問題,采用簡單易實現的目標函數來保證學習過程的穩定性
- 解決問題:在強化學習中,直接優化策略會導致不穩定的訓練,模型可能因為過大的參數更新而崩潰
- model-free,off-policy,actor-critic
核心改進點 | 說明 |
---|---|
剪切目標函數 | 使用剪切函數 clip 限制策略更新的幅度,避免策略大幅更新導致性能崩潰 |
off-policy | 每個采樣數據可用于多輪更新,提升樣本利用率,提高學習效率 |
PPO 網絡更新
策略網絡
PPO 使用舊策略和新策略的比值來定義目標函數,在保持改進的同時防止策略變化過大:
L C L I P ( θ ) = E t [ min ? ( r t ( θ ) A t , clip ( r t ( θ ) , 1 ? ? , 1 + ? ) A t ) ] , where? r t = π θ ( a t ∣ s t ) π θ old ( a t ∣ s t ) L^{CLIP}(\theta) = {\mathbb{E}}_t \left[ \min \left( r_t(\theta) {A}_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) {A}_t \right) \right], \text{where } r_t = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_\text{old}}(a_t|s_t)} LCLIP(θ)=Et?[min(rt?(θ)At?,clip(rt?(θ),1??,1+?)At?)],where?rt?=πθold??(at?∣st?)πθ?(at?∣st?)?
- Advantage 優勢函數 A t {A}_t At?:如 Q ( s t , a t ) ? V ( s t ) Q(s_t, a_t) - V(s_t) Q(st?,at?)?V(st?)
- 剪切系數 ? \epsilon ?:如 0.2
價值網絡
L V F ( θ μ ) = E t [ ( V θ μ ( s t ) ? R t ) 2 ] L^{VF}(\theta^\mu) = \mathbb{E}_t \left[ (V_{\theta^\mu}(s_t) - R_t)^2 \right] LVF(θμ)=Et?[(Vθμ?(st?)?Rt?)2]
- 真實或估算的回報 R t R_t Rt?:如 ∑ k = 0 n = γ k r t + k \sum^n_{k=0} = \gamma^k r_{t+k} ∑k=0n?=γkrt+k?
總損失函數
PPO 的總損失是策略損失、值函數損失和熵正則項 (鼓勵探索) 的加權和:
L ( θ ) = L C L I P ( θ ) ? c 1 L V F ( θ μ ) + c 2 H ( π ( s t ) ) L(\theta) = L^{CLIP}(\theta) - c_1 L^{VF}(\theta^\mu) + c_2 H(\pi(s_t)) L(θ)=LCLIP(θ)?c1?LVF(θμ)+c2?H(π(st?))
- c 1 , c 2 c_1, c_2 c1?,c2?:權重系數,常用 c 1 = 0.5 c_1=0.5 c1?=0.5, c 2 = 0.01 c_2=0.01 c2?=0.01
基于 stable_baselines3 的快速代碼示例
import gymnasium as gym
from stable_baselines3 import PPO# 創建環境
env = gym.make("CartPole-v1")
env.reset(seed=0)# 初始化模型
model = PPO("MlpPolicy", env, verbose=1)# 訓練模型
model.learn(total_timesteps=100_000)
model.save("ppo_cartpole_v1")# 測試模型
obs, _ = env.reset()
total_reward = 0
for _ in range(200):action, _ = model.predict(obs, deterministic=True) obs, reward, terminated, truncated, _ = env.step(action)total_reward += rewardif terminated or truncated:breakprint("Test total reward:", total_reward)
參考資料:PPO 詳解