PG(1984,Sutton) 核心改進點
策略梯度算法 (PG): 直接對策略函數進行建模,可以適用于連續的動作空間
- model-free, on-policy, PG, stochastic 策略
核心改進點 | 說明 |
---|---|
策略梯度優化 | 通過Actor網絡直接優化策略,適應連續動作問題: θ n e w = θ o l d + α ? θ J ( θ ) \theta_{new} = \theta_{old} + \alpha \nabla_\theta J(\theta) θnew?=θold?+α?θ?J(θ) |
PG 網絡更新 – 基于蒙特卡洛估計的 REINFORCE
? θ J ( θ ) ≈ ∑ t = 0 T ? 1 ? θ log ? π θ ( a t ∣ s t ) G t ,where? G t = ∑ t ′ = t T γ t ′ ? t r t ′ \nabla_\theta J(\theta) \approx \sum_{t=0}^{T-1} \nabla_\theta \log \pi_\theta(a_t | s_t) G_t,\text{where } G_t = \sum_{t'=t}^{T} \gamma^{t' - t} r_{t'} ?θ?J(θ)≈t=0∑T?1??θ?logπθ?(at?∣st?)Gt?,where?Gt?=t′=t∑T?γt′?trt′?
詳細網絡更新公式推導
策略更新目標:使得 θ \theta θ 策略下得到的所有軌跡 τ \tau τ 的回報期望 R ˉ θ \bar{R}_\theta Rˉθ? 最大化: 可以用 N 條軌跡的均值近似
- τ = { s 1 , a 1 , r 1 , s 2 , a 2 , r 2 , … , s τ , a τ , r τ } \tau = \{s_1, a_1, r_1, s_2, a_2, r_2, \dots, s_\tau, a_\tau, r_\tau\} τ={s1?,a1?,r1?,s2?,a2?,r2?,…,sτ?,aτ?,rτ?}
R ˉ θ = ∑ τ R ( τ ) P ( τ ∣ θ ) ≈ 1 N ∑ n N R ( τ n ) \bar{R}_\theta =\textcolor{red}{\sum_\tau} R(\tau) \textcolor{red}{P(\tau | \theta)} \approx \textcolor{blue}{\frac{1}{N} \sum_n^N}R(\tau^n) Rˉθ?=τ∑?R(τ)P(τ∣θ)≈N1?n∑N?R(τn)
計算梯度 (近似)
? R ˉ θ = ∑ τ R ( τ ) ? P ( τ ∣ θ ) = ∑ τ R ( τ ) P ( τ ∣ θ ) ? P ( τ ∣ θ ) P ( τ ∣ θ ) = ∑ τ R ( τ ) P ( τ ∣ θ ) ? θ log ? P ( τ ∣ θ ) ≈ 1 N ∑ n = 1 N R ( τ n ) ? θ log ? P ( τ n ∣ θ ) \nabla \bar{R}_\theta = \sum_{\tau} R(\tau) \nabla P(\tau | \theta) = \sum_\tau R(\tau) P(\tau | \theta) \frac{\nabla P(\tau | \theta)}{P(\tau | \theta)}=\textcolor{red}{\sum_\tau} R(\tau) \textcolor{red}{P(\tau | \theta)} \nabla_\theta \log P(\tau | \theta)\\ \approx \textcolor{blue}{\frac{1}{N} \sum_{n=1}^N} R(\tau^n) \nabla_\theta \log P(\tau^n | \theta) ?Rˉθ?=τ∑?R(τ)?P(τ∣θ)=τ∑?R(τ)P(τ∣θ)P(τ∣θ)?P(τ∣θ)?=τ∑?R(τ)P(τ∣θ)?θ?logP(τ∣θ)≈N1?n=1∑N?R(τn)?θ?logP(τn∣θ)
- 注:轉為
log
時利用了公式 d log ? ( f ( x ) ) d x = 1 f ( x ) ? d f ( x ) d x \frac{d \log(f(x))}{dx} = \frac{1}{f(x)} \cdot \frac{d f(x)}{dx} dxdlog(f(x))?=f(x)1??dxdf(x)?
其中, ? θ log ? P ( τ n ∣ θ ) \nabla_\theta\log P(\tau^n | \theta) ?θ?logP(τn∣θ) 可以做進一步表示
P ( τ ∣ θ ) = p ( s 1 ) ∏ t = 1 T p ( a t ∣ s t , θ ) p ( r t , s t + 1 ∣ s t , a t ) log ? P ( τ ∣ θ ) = log ? p ( s 1 ) + ∑ t = 1 T log ? p ( a t ∣ s t , θ ) + log ? p ( r t , s t + 1 ∣ s t , a t ) ? θ log ? P ( τ ∣ θ ) = ∑ t = 1 T ? θ log ? p ( a t ∣ s t , θ ) P(\tau|\theta) = p(s_1) \prod_{t=1}^{T} p(a_t|s_t, \theta) p(r_t, s_{t+1}|s_t, a_t) \\ \log P(\tau|\theta) = \log p(s_1) + \sum_{t=1}^{T} \log p(a_t|s_t, \theta) + \log p(r_t, s_{t+1}|s_t, a_t)\\ \nabla_\theta\log P(\tau | \theta) = \sum_{t=1}^{T} \nabla_\theta \log p(a_t | s_t, \theta) P(τ∣θ)=p(s1?)t=1∏T?p(at?∣st?,θ)p(rt?,st+1?∣st?,at?)logP(τ∣θ)=logp(s1?)+t=1∑T?logp(at?∣st?,θ)+logp(rt?,st+1?∣st?,at?)?θ?logP(τ∣θ)=t=1∑T??θ?logp(at?∣st?,θ)
所以梯度 (近似)的表示更新為
? R ˉ θ ≈ 1 N ∑ n = 1 N ∑ t = 1 T n R ( τ n ) ? θ log ? p ( a t n ∣ s t n , θ ) \nabla \bar{R}_\theta \approx {\frac{1}{N} \sum_{n=1}^N} \sum_{t=1}^{T^n} R(\tau^n) \nabla_\theta \log p(a_t^n | s_t^n, \theta) ?Rˉθ?≈N1?n=1∑N?t=1∑Tn?R(τn)?θ?logp(atn?∣stn?,θ)
- 注:梯度用的是總的回報 R ( τ n ) R(\tau^n) R(τn) 而不是 a t n a_t^n atn? 對應的即時獎勵,也就是說,總的回報會
增強/減弱
軌跡上所有有利/有害
的動作輸出;進一步,由于對于第t
個step,所選擇的動作只會影響未來的 U t n = ∑ t T n r t n U^n_t = \sum_t^{T^n} r^n_t Utn?=t∑Tn?rtn? 所以 R ( τ n ) R(\tau^n) R(τn) 可以被優化為 U t n U^n_t Utn?,對應本文一開始所給出的梯度公式
關于如何理解這個梯度,李宏毅老師類比分類學習的講法也很有啟發,強烈推薦學習下 【PG 李宏毅 B 站】

進一步的 還可以通過添加 baseline 等方法進一步優化表現
- 解決全正數值的獎勵導致的 – 沒有被 sample 到的 action 輸出概率會下降 (因為其他被 sample 到了的 actions,獲得了正數值的獎勵導致其被視為
有利
的動作,進而被增強
了其的輸出) 的問題

基于 stable_baselines3 的快速代碼示例
- 見后續 PPO 算法章節
參考資料:策略梯度算法(PG)詳解