機器學習從入門到精通 - 強化學習初探:從 Q-Learning 到 Deep Q-Network 實戰
一、開場白:推開強化學習這扇門
不知道你有沒有過這種感覺 —— 盯著一個復雜的系統,既想讓它達到某個目標,又苦于無法用傳統規則去精確描述每一步該怎么做?比如訓練一個機器人走出迷宮,或者讓算法學會玩《超級馬里奧》。這就是強化學習(Reinforcement Learning, RL)大展拳腳的地方了!它不要求你預先告知所有正確答案,而是讓一個"智能體"(Agent)在環境中不斷試錯、根據反饋調整策略,最終學會達成目標。今天這篇長文,咱們就手把手從最經典的 Q-Learning 開始,一路打通關,最后用 PyTorch 實現一個解決經典控制問題的 Deep Q-Network (DQN)! 我保證,過程中你會掉進不少坑,也會看到我是怎么狼狽爬出來的 —— 這才是真實的學習過程嘛。
二、強化學習基本框架:環境、狀態、動作與獎勵
先說個容易踩的坑:很多人一上來就扎進算法公式里,結果連 Agent 和 Environment 怎么交互都搞不清,后面就全亂了套。必須得先理解這個核心交互循環:
- 環境(Environment):智能體存在的世界(比如一個迷宮、一個游戲畫面、一個股票市場)。
- 狀態(State):環境在時刻
t
的完整描述(比如迷宮里的坐標、游戲畫面像素、股票價格+指標)。 - 動作(Action):智能體在狀態
s_t
下能做出的選擇(比如向上/下/左/右移動、買入/賣出/持有)。 - 獎勵(Reward):環境在智能體執行動作
a_t
后,進入新狀態s_{t+1}
時給出的即時評價信號(比如撞墻扣分,到達終點加分)。記住,智能體的終極目標就是最大化長期累積獎勵!
關鍵概念:馬爾可夫決策過程(MDP)
絕大多數強化學習問題都建模成 MDP。它要求:下一個狀態 s_{t+1}
和當前獎勵 r_t
只取決于當前狀態 s_t
和當前動作 a_t
,與之前的歷史無關。 用數學表示就是:
P(s_{t+1}, r_t | s_t, a_t, s_{t-1}, a_{t-1}, ..., s_0, a_0) = P(s_{t+1}, r_t | s_t, a_t)
這個假設是很多強化學習算法(包括 Q-Learning)的理論基石。
三、Q-Learning:價值函數的藝術
好了,現在我們請出今天的第一位主角:Q-Learning。它是一種 無模型(Model-Free)、基于價值(Value-Based) 的強化學習算法。核心思想是學習一個叫 Q-Table
的東西。
什么是 Q 值?
Q(s, a)
表示在狀態 s
下選擇動作 a
,并且之后一直采取最優策略所能獲得的期望累積獎勵。簡單說,它衡量了在 s
選 a
有多“好”。
目標:找到最優 Q 函數 Q^*(s, a)
如果我能知道所有狀態 s
下所有動作 a
的 Q^*(s, a)
,那么最優策略 π^*(s)
就簡單了:永遠選擇當前狀態下 Q
值最大的那個動作!π^*(s) = argmax_a Q^*(s, a)
Q-Learning 的更新魔法:時間差分(TD)
問題是,我們一開始不知道 Q^*
。Q-Learning 的核心在于通過不斷嘗試和更新來逼近 Q^*
。它的更新公式是重中之重(推導來了!):
Q(s_t, a_t) <-- Q(s_t, a_t) + α * [ TD_Target - Q(s_t, a_t) ]
其中:
α
(Alpha):學習率(Learning Rate),控制新信息覆蓋舊信息的程度(0 < α ≤ 1)。這個值吧 —— 選大了震蕩,選小了學得慢,后面實戰會踩坑。TD_Target
:時間差分目標(Temporal Difference Target),代表我們對Q(s_t, a_t)
的最新估計。它是怎么來的呢?
-
貝爾曼方程(Bellman Equation) 是理解的基礎。對于最優 Q 函數,它滿足:
Q^*(s, a) = E [ r + γ * max_{a'} Q^*(s', a') | s, a ]
s'
:執行a
后轉移到的下一個狀態。r
:執行a
后得到的即時獎勵。γ
(Gamma):折扣因子(0 ≤ γ < 1),表示我們有多重視未來獎勵(γ 越接近 1,越重視長遠收益)。max_{a'} Q^*(s', a')
:在下一個狀態s'
下,采取最優動作a'
所能獲得的最大 Q 值(代表了s'
狀態的價值V^*(s')
)。E[...]
:期望值(因為狀態轉移可能有隨機性)。
這個方程說明:當前狀態-動作對的價值等于即時獎勵加上折扣后的下一個狀態的最優價值。 它揭示了 Q 值之間的遞歸關系。
-
從貝爾曼最優方程到 Q-Learning 更新: Q-Learning 用當前估計值去逼近貝爾曼方程定義的理想值。在
s_t
執行a_t
,我們觀察到即時獎勵r_{t+1}
和新狀態s_{t+1}
。這時候,我們會對Q(s_t, a_t)
應該等于什么有一個新的“目標”:
TD_Target = r_{t+1} + γ * max_{a} Q(s_{t+1}, a)
注意這里用的是我們當前的 Q 表來估計s_{t+1}
狀態的價值 (max_{a} Q(s_{t+1}, a)
),而不是Q^*
。 -
更新量:
TD_Target - Q(s_t, a_t)
就是當前估計值和新的目標值之間的差異,稱為 TD 誤差(Temporal Difference Error)。Q-Learning 做的就是:用這個誤差乘以學習率 α,去調整當前的Q(s_t, a_t)
,讓它更接近TD_Target
。
最終合并得到的 Q-Learning 更新公式:
# Q(s_t, a_t) 更新公式
Q[s_t, a_t] = Q[s_t, a_t] + α * ( r_{t+1} + γ * np.max(Q[s_{t+1}, :]) - Q[s_t, a_t] )
符號釋義:
s_t
:當前時刻t
的狀態。a_t
:在s_t
狀態下選擇的動作。r_{t+1}
:執行a_t
后得到的即時獎勵。s_{t+1}
:執行a_t
后轉移到的下一個狀態。Q[s_{t+1}, :]
:在 Q 表中,狀態s_{t+1}
對應的所有Q
值。np.max(Q[s_{t+1}, :])
:下一個狀態s_{t+1}
下,所有可能動作的最大Q
值估計。α
:學習率。γ
:折扣因子。
Q-Learning 算法的偽代碼流程:
關鍵點:探索與利用(ε-greedy)
- 如果每次都選當前 Q 表認為最好的動作(
argmax_a Q(s, a)
),可能永遠發現不了真正更好的動作。 - ε-greedy 策略: 以
1 - ε
的概率選擇當前 Q 值最大的動作(利用),以ε
的概率隨機選擇一個動作(探索)。ε
通常隨著訓練衰減(從 1.0 開始,逐漸減小到 0.01 或 0.1)。ε 衰減策略沒設計好,模型可能學偏或者卡住,這是個大坑點。
四、Q-Table 的局限與 Deep Q-Network 的崛起
Q-Learning 在小規模、離散狀態和動作空間下表現很好。但是,現實世界往往是連續的,狀態維度極高(比如一張游戲圖像有幾十萬個像素點)。Q-Table 的致命傷來了:它無法處理高維或連續狀態空間!
- 存儲問題: 狀態太多(甚至是無限的),Q-Table 根本存不下。想象一下用表格存儲每個可能的像素組合的 Q 值 —— 天文數字!
- 泛化問題: 即使能存儲,遇到沒見過的狀態,Q-Table 無法給出合理的 Q 值估計。它沒有泛化能力。
解決方案:用函數逼近器代替 Q-Table!
這里 —— 深度神經網絡(DNN)閃亮登場。它強大的函數擬合能力,讓它成為學習 Q(s, a; θ)
函數的絕佳選擇,其中 θ
是神經網絡的參數。這就是 Deep Q-Network (DQN)。
DQN 的核心技術(兩大支柱):
-
經驗回放(Experience Replay):
- 問題: 連續采集的經驗
(s_t, a_t, r_{t+1}, s_{t+1}, done)
是強相關的。直接用它們訓練網絡會導致參數更新不穩定、振蕩甚至發散。 - 解決: 建立一個固定大小的經驗池(Replay Buffer)。每次與環境交互得到的經驗元組先存入池中。訓練時,隨機抽取一小批(Mini-batch)經驗進行學習。
- 好處:
- 打破樣本間的時間相關性,使訓練更穩定。
- 提高數據利用率,單個樣本可被多次學習。
- 離線學習(Off-policy):可以重復利用過去的經驗。經驗池大小和采樣方式的選擇很關鍵,太小容易過時,太大訓練慢,后面會踩坑。
- 問題: 連續采集的經驗
-
目標網絡(Target Network):
- 問題: 在計算
TD_Target = r + γ * max_a Q(s', a; θ)
時,θ
是我們正在更新的網絡參數。更新θ
會導致TD_Target
本身也在快速變化(像個移動的目標),加劇訓練的不穩定性。 - 解決: 引入一個結構相同但參數不同的目標網絡
Q(s, a; θ?)
。這個網絡的參數θ?
并不是每一步都更新,而是定期(比如每 N 步)從當前訓練網絡Q(s, a; θ)
復制參數(θ? ← θ
)。計算TD_Target
時使用目標網絡,并且需要考慮終止狀態:
TD_Target = r
(如果s'
是終止狀態)
TD_Target = r + γ * max_a Q(s', a; θ?)
(如果s'
不是終止狀態) - 好處:
TD_Target
在一段時間內相對穩定,為訓練網絡Q(s, a; θ)
提供了一個更可靠的更新目標。更新頻率N
是個超參數,需要調。
- 問題: 在計算
DQN 算法流程:
網絡架構設計(以 CartPole 為例):
- 輸入: 狀態
s
(CartPole 中是 4 維向量[cart_position, cart_velocity, pole_angle, pole_angular_velocity]
)。 - 輸出: 每個可能動作
a
的 Q 值估計(CartPole 中是 2 維向量[Q(s, left), Q(s, right)]
)。 - 隱藏層: 通常使用全連接層(FC)。對于簡單問題如 CartPole,1-2 個隱藏層(如 128 或 256 個神經元)足夠。激活函數通常用 ReLU。
import torch
import torch.nn as nn
import torch.optim as optimclass DQN(nn.Module):def __init__(self, state_dim, action_dim, hidden_dim=128):super(DQN, self).__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, hidden_dim) # 可選的第二層self.fc3 = nn.Linear(hidden_dim, action_dim)self.relu = nn.ReLU()def forward(self, x):x = self.relu(self.fc1(x))x = self.relu(self.fc2(x)) # 如果只有一層隱藏層則去掉這行x = self.fc3(x)return x
五、實戰:用 PyTorch 實現 DQN 解決 CartPole 問題
為什么選 CartPole? 它是 OpenAI Gym 提供的經典控制環境,狀態簡單(4維),動作離散(2個)。DQN 能很好地解決它,非常適合入門演示。目標: 控制小車左右移動,讓上面的桿子盡可能長時間保持豎直不倒。成功標準通常是連續保持平衡 195 步(或平均獎勵達到 195)。
1. 安裝環境 & 導入庫
!pip install gymnasium[classic_control]==0.29.1 numpy==1.26.4 torch==2.2.2 matplotlib # 強烈建議指定版本避免兼容性問題import gymnasium as gym
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque # 用于實現經驗回放池
import matplotlib.pyplot as plt
2. 核心組件實現
經驗回放池 (Replay Buffer)
class ReplayBuffer:def __init__(self, capacity):self.buffer = deque(maxlen=capacity) # 固定大小的雙端隊列def add(self, state, action, reward, next_state, done):"""存儲一條經驗 (s, a, r, s', done)"""self.buffer.append((state, action, reward, next_state, done))def sample(self, batch_size):"""隨機采樣一批經驗"""experiences = random.sample(self.buffer, batch_size)# 拆分元組為獨立的 NumPy 數組states, actions, rewards, next_states, dones = zip(*experiences)# 轉換為 PyTorch Tensor (注意! 后面踩坑點)return (torch.tensor(np.array(states), dtype=torch.float32),torch.tensor(actions, dtype=torch.long).unsqueeze(1), # 增加批次維度torch.tensor(rewards, dtype=torch.float32).unsqueeze(1),torch.tensor(np.array(next_states), dtype=torch.float32),torch.tensor(dones, dtype=torch.float32).unsqueeze(1))def __len__(self):return len(self.buffer)
踩坑記錄1:數據類型轉換
- 環境返回的
state
,next_state
是np.ndarray
。 action
,reward
,done
是標量或布爾值。- 必須小心地轉換為正確數據類型的
torch.Tensor
,并確保維度一致(特別是actions
,rewards
,dones
通常需要增加一個維度表示 batch)。 不注意這個會在計算損失函數時報各種維度不匹配的錯誤。
DQN Agent
class DQNAgent:def __init__(self, state_dim, action_dim, hidden_dim, lr, gamma, epsilon, target_update_freq, buffer_capacity):self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"Using device: {self.device}")self.action_dim = action_dim# Networksself.policy_net = DQN(state_dim, action_dim, hidden_dim).to(self.device)self.target_net = DQN(state_dim, action_dim, hidden_dim).to(self.device)self.target_net.load_state_dict(self.policy_net.state_dict()) # 同步初始權重self.target_net.eval() # 目標網絡不進行梯度計算self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)self.memory = ReplayBuffer(buffer_capacity)# Hyperparametersself.gamma = gammaself.epsilon = epsilonself.target_update_freq = target_update_freqself.learn_step_counter = 0def select_action(self, state):"""根據 ε-greedy 策略選擇動作"""if random.random() < self.epsilon:return random.randrange(self.action_dim) # 探索:隨機選擇動作else:# 利用:選擇Q值最高的動作with torch.no_grad():state_tensor = torch.tensor(np.array([state]), dtype=torch.float32).to(self.device)q_values = self.policy_net(state_tensor)return q_values.argmax().item()def learn(self, batch_size):"""從經驗池中采樣學習"""if len(self.memory) < batch_size:return # 經驗池不夠,先不學習states, actions, rewards, next_states, dones = self.memory.sample(batch_size)states = states.to(self.device)actions = actions.to(self.device)rewards = rewards.to(self.device)next_states = next_states.to(self.device)dones = dones.to(self.device)# 1. 計算當前狀態的 Q 值: Q(s_t, a_t)# self.policy_net(states) 輸出所有動作的Q值# .gather(1, actions) 提取出實際采取動作 a_t 對應的 Q 值current_q_values = self.policy_net(states).gather(1, actions)# 2. 計算 TD Targetwith torch.no_grad():# 用目標網絡計算下一個狀態的最大 Q 值next_q_values = self.target_net(next_states).max(1)[0].unsqueeze(1)# 如果 done=True (值為1), 那么未來的獎勵為0td_target = rewards + self.gamma * next_q_values * (1 - dones)# 3. 計算損失loss = F.mse_loss(current_q_values, td_target)# 4. 優化模型self.optimizer.zero_grad()loss.backward()self.optimizer.step()self.learn_step_counter += 1# 5. 定期更新目標網絡if self.learn_step_counter % self.target_update_freq == 0:self.target_net.load_state_dict(self.policy_net.state_dict())
3. 設置超參數與訓練循環
踩坑記錄2:超參數調優是門玄學!
DQN 的超參數非常敏感。LEARNING_RATE
太大,訓練會不穩定;GAMMA
太小,智能體會變得短視;EPSILON
衰減太快,探索不足;TARGET_UPDATE_FREQ
太頻繁或太稀疏都不好。下面的參數是針對 CartPole 調優過的一組,但并不唯一。
# --- Hyperparameters ---
EPISODES = 400
BUFFER_CAPACITY = 10000
BATCH_SIZE = 64
LEARNING_RATE = 0.001
GAMMA = 0.99# Epsilon-greedy 策略參數
EPSILON_START = 1.0
EPSILON_END = 0.05
EPSILON_DECAY = (EPSILON_START - EPSILON_END) / (EPISODES * 0.6) # 線性衰減TARGET_UPDATE_FREQ = 100 # 每100步學習更新一次目標網絡
HIDDEN_DIM = 128# --- Setup ---
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.nagent = DQNAgent(state_dim, action_dim, HIDDEN_DIM, LEARNING_RATE, GAMMA, EPSILON_START, TARGET_UPDATE_FREQ, BUFFER_CAPACITY)episode_rewards = []# --- Training Loop ---
for i_episode in range(EPISODES):state, info = env.reset()done = Falsetotal_reward = 0while not done:action = agent.select_action(state)next_state, reward, terminated, truncated, info = env.step(action)done = terminated or truncatedagent.memory.add(state, action, reward, next_state, done)state = next_statetotal_reward += rewardagent.learn(BATCH_SIZE)# 更新 Epsilonif agent.epsilon > EPSILON_END:agent.epsilon -= EPSILON_DECAYepisode_rewards.append(total_reward)if (i_episode + 1) % 20 == 0:print(f"Episode {i_episode+1}/{EPISODES}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.4f}")print("Training finished.")
env.close()
4. 結果可視化
def plot_rewards(rewards):plt.figure(figsize=(12, 6))plt.plot(rewards, label='Reward per Episode')# 計算并繪制100個episode的移動平均線,以更好地觀察趨勢moving_avg = np.convolve(rewards, np.ones(100)/100, mode='valid')plt.plot(np.arange(len(rewards) - 99), moving_avg, label='100-Episode Moving Average')plt.title('CartPole DQN Training Performance')plt.xlabel('Episode')plt.ylabel('Total Reward')plt.grid(True)plt.legend()plt.show()plot_rewards(episode_rewards)
運行代碼后,你大概率會看到一張獎勵曲線圖。一開始獎勵很低(智能體在隨機亂撞),但隨著訓練的進行,曲線會逐漸攀升并最終穩定在高位(比如 200 以上,甚至達到 CartPole-v1 的上限 500),移動平均線能更清晰地展示這個趨勢。這就是 DQN 學會了如何平衡桿子的證明!
六、總結與展望:DQN 之后的路
今天我們從強化學習最基礎的交互框架出發,深入了 Q-Learning 的核心更新機制,然后為了克服 Q-Table 的局限性,引入了 DQN 的兩大支柱——經驗回放和目標網絡。最后,我們用 PyTorch 從零到一地實現了一個能解決 CartPole 問題的 DQN 智能體。
回顧我們踩過的坑:
- 數據類型與維度:在
ReplayBuffer
和learn
函數中,NumPy 和 PyTorch Tensor 之間的轉換、維度的匹配是 bug 高發區。 - 超參數敏感:DQN 的表現嚴重依賴于超參數的選擇,需要耐心調優。沒有一組“萬能”參數。
- TD Target 的終止狀態:計算 TD Target 時忘記處理
done
信號,會導致智能體錯誤地評估終局的價值,這是個常見的邏輯錯誤。
DQN 并非終點,而是起點。 它本身也存在一些問題,比如 Q 值過高估計(Overestimation Bias)。后續的研究提出了許多改進方案,構建了龐大的“彩虹 DQN”(Rainbow DQN)家族:
- Double DQN (DDQN):解耦了“選擇”和“評估”下一個狀態 Q 值的網絡,緩解 Q 值過高估計問題。
- Dueling DQN:將 Q 值網絡結構分解為“狀態價值 V(s)”和“動作優勢 A(s, a)”,學習更高效。
- Prioritized Experience Replay (PER):不再隨機采樣經驗,而是優先學習那些 TD 誤差大的、“更值得學習”的經驗。