文章目錄
- 前言
- 算法原理
- 1. 從策略梯度到Actor-Critic
- 2. Actor 和 Critic 的角色
- 3. Critic 的學習方式:時序差分 (TD)
- 4. Actor 的學習方式:策略梯度
- 5. 算法流程
- 代碼實現
- 1. 環境與工具函數
- 2. 構建Actor-Critic智能體
- 3. 組織訓練流程
- 4. 主程序:啟動訓練
- 5. 實驗結果
- 總結
前言
在深度強化學習(DRL)的廣闊天地中,算法可以大致分為兩大家族:基于價值(Value-based)的算法和基于策略(Policy-based)的算法。像DQN這樣的算法通過學習一個價值函數來間接指導策略,而像REINFORCE這樣的算法則直接對策略進行參數化和優化。
然而,這兩種方法各有優劣。基于價值的方法通常數據效率更高、更穩定,但難以處理連續動作空間;基于策略的方法可以直接處理各種動作空間,并能學習隨機策略,但其學習過程往往伴隨著高方差,導致訓練不穩定、收斂緩慢。
為了融合兩者的優點,Actor-Critic(演員-評論家) 框架應運而生。它構成了現代深度強化學習的基石,許多前沿算法(如A2C, A3C, DDPG, TRPO, PPO等)都屬于這個大家族。
本文將從理論出發,結合一個完整的 PyTorch 代碼實例,帶您深入理解基礎的 Actor-Critic 算法。我們將通過經典的 CartPole(車桿)環境,一步步構建、訓練并評估一個 Actor-Critic 智能體,直觀地感受它是如何工作的。
完整代碼:下載鏈接
算法原理
Actor-Critic 算法本質上是一種基于策略的算法,其目標是優化一個帶參數的策略。與REINFORCE算法不同的是,它會額外學習一個價值函數,用這個價值函數來“評論”策略的好壞,從而幫助策略函數更好地學習。
1. 從策略梯度到Actor-Critic
在策略梯度方法中,目標函數的梯度可以寫成一個通用的形式:
g = E [ ∑ t = 0 T ψ t ? θ log ? π θ ( a t ∣ s t ) ] g=\mathbb{E}\left[\sum_{t=0}^T\psi_t\nabla_\theta\log\pi_\theta(a_t|s_t)\right] g=E[t=0∑T?ψt??θ?logπθ?(at?∣st?)]
其中,ψt
是一個用于評估在狀態 st
下采取動作 at
的優劣的標量。ψt
的選擇直接影響了算法的性能:
- 形式2:
ψt
是動作at
之后的所有回報之和。這是 REINFORCE 算法使用的形式。它使用蒙特卡洛方法來估計動作的價值,雖然是無偏估計,但由于包含了從t
時刻到回合結束的所有隨機性,其方差非常大。 - 形式6:
ψt
是 時序差分誤差(TD Error)。這是本文 Actor-Critic 算法將采用的核心形式。它只利用了一步的真實獎勵r_t
和對下一狀態價值的估計V(s_t+1)
,極大地降低了方差。
這個轉變正是 Actor-Critic 算法的核心思想:不再使用完整的、高方差的軌跡回報,而是引入一個價值函數來提供更穩定、低方差的指導信號。
2. Actor 和 Critic 的角色
我們將 Actor-Critic 算法拆分為兩個核心部分:
- Actor (演員):即策略網絡。它的任務是與環境進行交互,并根據 Critic 的“評價”來學習一個更好的策略。它決定了在某個狀態下應該采取什么動作。
- Critic (評論家):即價值網絡。它的任務是通過觀察 Actor 與環境的交互數據,學習一個價值函數。這個價值函數用于判斷在當前狀態下,Actor 選擇的動作是“好”還是“壞”,從而指導 Actor 的策略更新。
3. Critic 的學習方式:時序差分 (TD)
Critic 的目標是準確地估計狀態價值函數 V(s)
。它采用**時序差分(Temporal-Difference, TD)**學習方法。具體來說,是TD(0)方法。
在TD學習中,我們希望價值網絡的預測值 V(s_t)
能夠逼近 TD目標 (TD Target),即 r_t + γV(s_t+1)
。因此,Critic 的損失函數定義為兩者之間的均方誤差:
L ( ω ) = 1 2 ( r + γ V ω ( s t + 1 ) ? V ω ( s t ) ) 2 \mathcal{L}(\omega)=\frac{1}{2}(r+\gamma V_\omega(s_{t+1})-V_\omega(s_t))^2 L(ω)=21?(r+γVω?(st+1?)?Vω?(st?))2
當我們對這個損失函數求梯度以更新 Critic 的網絡參數 w
時,有一個非常關鍵的點:
在TD學習中,目標值
r_t + γV(s_t+1)
被視為一個固定的“標簽”(Target),不參與反向傳播。因此,梯度只對當前狀態的值函數V(s_t)
求導。
Critic 價值網絡表示為 V w V_w Vw?,參數為 w w w。價值函數的梯度為:
? ω L ( ω ) = ? ( r + γ V ω ( s t + 1 ) ? V ω ( s t ) ) ? ω V ω ( s t ) \nabla_\omega\mathcal{L}(\omega)=-(r+\gamma V_\omega(s_{t+1})-V_\omega(s_t))\nabla_\omega V_\omega(s_t) ?ω?L(ω)=