增強學習(Reinforcement Learning)簡介
增強學習是機器學習的一種范式,其核心目標是讓智能體(Agent)通過與環境的交互,基于試錯機制和延遲獎勵反饋,學習如何選擇最優動作以最大化長期累積回報。其核心要素包括:
? 狀態(State):描述環境的當前信息(如棋盤布局、機器人傳感器數據)。
? 動作(Action):智能體在特定狀態下可執行的操作(如移動、下棋)。
? 獎勵(Reward):環境對動作的即時反饋信號(如得分增加或懲罰)。
? 策略(Policy):從狀態到動作的映射規則(如基于Q值選擇動作)。
? 價值函數(Value Function):預測某狀態或動作的長期回報(如Q-Learning中的Q表)。
與監督學習不同,增強學習無需標注數據,而是通過探索-利用權衡(Exploration vs Exploitation)自主學習。
使用PyTorch實現深度Q網絡(DQN)演示
以下以CartPole-v0(平衡桿環境)為例,展示完整代碼及解釋:
- 環境與依賴庫
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random# 初始化環境
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
- 定義DQN網絡
class DQN(nn.Module):def __init__(self, state_dim, action_dim):super(DQN, self).__init__()self.fc = nn.Sequential(nn.Linear(state_dim, 128),nn.ReLU(),nn.Linear(128, 64),nn.ReLU(),nn.Linear(64, action_dim))def forward(self, x):return self.fc(x)
- 經驗回放緩沖區(Replay Buffer)
class ReplayBuffer:def __init__(self, capacity):self.buffer = deque(maxlen=capacity)def push(self, state, action, reward, next_state, done):self.buffer.append((state, action, reward, next_state, done))def sample(self, batch_size):batch = random.sample(self.buffer, batch_size)states, actions, rewards, next_states, dones = zip(*batch)return (torch.FloatTensor(states),torch.LongTensor(actions),torch.FloatTensor(rewards),torch.FloatTensor(next_states),torch.FloatTensor(dones))
- 訓練參數與初始化
# 超參數
batch_size = 64
gamma = 0.99 # 折扣因子
epsilon_start = 1.0
epsilon_decay = 0.995
epsilon_min = 0.01
target_update = 10 # 目標網絡更新頻率# 初始化網絡與優化器
policy_net = DQN(state_dim, action_dim)
target_net = DQN(state_dim, action_dim)
target_net.load_state_dict(policy_net.state_dict())
optimizer = optim.Adam(policy_net.parameters(), lr=0.001)
buffer = ReplayBuffer(10000)
epsilon = epsilon_start
- 訓練循環
num_episodes = 500
for episode in range(num_episodes):state = env.reset()total_reward = 0while True:# ε-貪婪策略選擇動作if random.random() < epsilon:action = env.action_space.sample()else:with torch.no_grad():q_values = policy_net(torch.FloatTensor(state))action = q_values.argmax().item()# 執行動作并存儲經驗next_state, reward, done, _ = env.step(action)buffer.push(state, action, reward, next_state, done)state = next_statetotal_reward += reward# 經驗回放與網絡更新if len(buffer.buffer) >= batch_size:states, actions, rewards, next_states, dones = buffer.sample(batch_size)# 計算目標Q值with torch.no_grad():next_q = target_net(next_states).max(1)[0]target_q = rewards + gamma * next_q * (1 - dones)# 計算當前Q值current_q = policy_net(states).gather(1, actions.unsqueeze(1))# 均方誤差損失loss = nn.MSELoss()(current_q, target_q.unsqueeze(1))# 反向傳播optimizer.zero_grad()loss.backward()optimizer.step()if done:break# 更新目標網絡與εif episode % target_update == 0:target_net.load_state_dict(policy_net.state_dict())epsilon = max(epsilon_min, epsilon * epsilon_decay)print(f"Episode {episode}, Reward: {total_reward}, Epsilon: {epsilon:.2f}")
關鍵點解釋
- 經驗回放(Replay Buffer):通過存儲歷史經驗并隨機采樣,打破數據相關性,提升訓練穩定性。
- 目標網絡(Target Network):固定目標Q值計算網絡,緩解訓練震蕩問題。
- ε-貪婪策略:平衡探索(隨機動作)與利用(最優動作),逐步降低探索率。
結果與優化方向
? 預期效果:經過約200輪訓練,智能體可穩定保持平衡超過195步(CartPole-v1的勝利條件)。
? 優化方法:
? 使用Double DQN或Dueling DQN改進Q值估計。
? 調整網絡結構(如增加卷積層處理圖像輸入)。
? 引入優先級經驗回放(Prioritized Experience Replay)。
完整代碼及更多改進可參考PyTorch官方文檔或強化學習框架(如Stable Baselines3)。