文章目錄
- 前言
- DQN 算法核心思想
- Q-Learning 與函數近似
- 經驗回放 (Experience Replay)
- 目標網絡 (Target Network)
- PyTorch 代碼實現詳解
- 1. 環境與輔助函數
- 2. 經驗回放池 (ReplayBuffer)
- 3. Q網絡 (Qnet)
- 4. DQN 主類
- 5. 訓練循環
- 6. 設置超參數與開始訓練
- 訓練結果與分析
- 總結
前言
歡迎來到深度強化學習的世界!如果你對 Q-learning 有所了解,你可能會知道它使用一個表格(Q-table)來存儲每個狀態-動作對的價值。然而,當狀態空間變得巨大,甚至是連續的時候(比如一個小車在軌道上的位置),Q-table 就變得不切實際。這時,深度Q網絡(Deep Q-Network, DQN)就閃亮登場了。
DQN 的核心思想是用一個神經網絡來代替 Q-table,實現從狀態到(各個動作的)Q值的映射。這使得我們能夠處理具有連續或高維狀態空間的環境。本文將以經典的 CartPole-v1
環境為例,通過一個完整的 PyTorch 代碼實現,帶你深入理解 DQN 的工作原理及其關鍵組成部分:神經網絡近似、經驗回放和目標網絡。
圖 1 CartPole環境示意圖
在 CartPole 環境中,智能體的任務是左右移動小車,以保持車上的桿子豎直不倒。這個環境的狀態是連續的(車的位置、速度、桿的角度、角速度),而動作是離散的(向左或向右)。這正是DQN大顯身手的完美場景。
讓我們一起通過代碼,揭開DQN的神秘面紗。
完整代碼:下載鏈接
DQN 算法核心思想
在深入代碼之前,我們先回顧一下 DQN 的幾個關鍵概念。
Q-Learning 與函數近似
傳統的 Q-learning 更新規則如下:
Q ( s , a ) ← Q ( s , a ) + α [ r + γ max ? a ′ ∈ A Q ( s ′ , a ′ ) ? Q ( s , a ) ] Q(s,a)\leftarrow Q(s,a)+\alpha\left[r+\gamma\max_{a^{\prime}\in\mathcal{A}}Q(s^{\prime},a^{\prime})-Q(s,a)\right] Q(s,a)←Q(s,a)+α[r+γa′∈Amax?Q(s′,a′)?Q(s,a)]
當狀態是連續的,我們無法用表格記錄所有 Q(s,a)
。因此,我們引入一個帶參數 w
的神經網絡,即 Q-網絡 Q ω ( s , a ) Q_\omega\left(s,a\right) Qω?(s,a),來近似真實的 Q-函數。我們的目標是讓網絡預測的Q值 Q ω ( s , a ) Q_\omega\left(s,a\right) Qω?(s,a) 逼近“目標Q值” r + γ max ? a ′ ∈ A Q ( s ′ , a ′ ) r+\gamma\max_{a^{\prime}\in\mathcal{A}}Q(s',a') r+γmaxa′∈A?Q(s′,a′)。
為此,我們可以定義一個損失函數,最常見的就是均方誤差(MSE Loss):
ω ? = arg ? min ? ω 1 2 N ∑ i = 1 N [ Q ω ( s i , a i ) ? ( r i + γ max ? a ′ Q ω ( s i ′ , a ′ ) ) ] 2 \omega^*=\arg\min_\omega\frac{1}{2N}\sum_{i=1}^N\left[Q_\omega\left(s_i,a_i\right)-\left(r_i+\gamma\max_{a^{\prime}}Q_\omega\left(s_i^{\prime},a^{\prime}\right)\right)\right]^2 ω?=argωmin?2N1?i=1∑N?[Q