在《Python實戰進階》No37: 強化學習入門:Q-Learning 與 DQN 這篇文章中,我們介紹了Q-Learning算法走出迷宮的代碼實踐,本文加餐,把Q-Learning算法通過代碼可視化呈現。我嘗試了使用Matplotlib實現,但局限于Matplotlib對動畫不支持,做出來的仿動畫太僵硬,所以使用 pygame
重新設計 Q-Learning 的可視化程序可以顯著提升動畫的流暢性和交互性。相比于 matplotlib
,pygame
更適合處理實時動畫和游戲化的內容。以下是一個完整的基于 pygame
的實現方案,
視頻:Q-Learning算法訓練可視化
目標
- 迷宮布局:動態繪制迷宮(包括起點、終點和墻壁)。
- 智能體移動:實時更新智能體的位置。
- 最優路徑:訓練完成后顯示從起點到終點的最優路徑。
- 最終目標:完整呈現Q-Learning算法的訓練過程。
實現步驟
步驟 1:安裝依賴
確保安裝了 pygame
庫:
pip install pygame
步驟 2:修改迷宮環境
我們對迷宮環境進行一些擴展,以便更好地支持 pygame
可視化。
import numpy as npclass MazeEnv:def __init__(self):self.maze = [['.', '.', '.', '#', '.'],['.', '#', '.', '.', '.'],['.', '#', '.', '#', '.'],['.', '.', '.', '#', '.'],['.', '#', 'G', '#', '.']]self.maze = np.array(self.maze)self.start = (0, 0)self.goal = (4, 2)self.current_state = self.startself.actions = [(0, 1), (0, -1), (1, 0), (-1, 0)] # 右、左、下、上def reset(self):self.current_state = self.startreturn self.current_statedef step(self, action):next_state = (self.current_state[0] + action[0], self.current_state[1] + action[1])if (next_state[0] < 0 or next_state[0] >= self.maze.shape[0] ornext_state[1] < 0 or next_state[1] >= self.maze.shape[1] orself.maze[next_state] == '#'):next_state = self.current_state # 如果撞墻,保持原位置reward = -1 # 每步移動的默認獎勵done = Falseif next_state == self.goal:reward = 10 # 到達終點的獎勵done = Trueself.current_state = next_statereturn next_state, reward, donedef get_maze_size(self):return self.maze.shapedef is_wall(self, position):return self.maze[position] == '#'def is_goal(self, position):return position == self.goal
步驟 3:設計 pygame
可視化程序
以下是基于 pygame
的完整可視化代碼:
import pygame
import time
import random
import numpy as np# 初始化 pygame
pygame.init()# 定義顏色
WHITE = (255, 255, 255) # 空地
BLACK = (0, 0, 0) # 墻壁
GREEN = (0, 255, 0) # 終點
RED = (255, 0, 0) # 智能體
BLUE = (0, 0, 255) # 最優路徑# 定義單元格大小
CELL_SIZE = 50
FPS = 10 # 動畫幀率def visualize_with_pygame(env, agent, num_episodes=1000):rows, cols = env.get_maze_size()screen_width = cols * CELL_SIZEscreen_height = rows * CELL_SIZE# 初始化屏幕screen = pygame.display.set_mode((screen_width, screen_height))pygame.display.set_caption("Q-Learning Maze Visualization")clock = pygame.time.Clock()def draw_maze():for i in range(rows):for j in range(cols):rect = pygame.Rect(j * CELL_SIZE, i * CELL_SIZE, CELL_SIZE, CELL_SIZE)if env.is_wall((i, j)):pygame.draw.rect(screen, BLACK, rect)elif env.is_goal((i, j)):pygame.draw.rect(screen, GREEN, rect)else:pygame.draw.rect(screen, WHITE, rect)def draw_agent(position):x, y = positioncenter = (y * CELL_SIZE + CELL_SIZE // 2, x * CELL_SIZE + CELL_SIZE // 2)pygame.draw.circle(screen, RED, center, CELL_SIZE // 3)def draw_path(path):for (x, y) in path:rect = pygame.Rect(y * CELL_SIZE, x * CELL_SIZE, CELL_SIZE, CELL_SIZE)pygame.draw.rect(screen, BLUE, rect)# 訓練過程可視化for episode in range(num_episodes):state = env.reset()done = Falsepath = [state]while not done:# 處理退出事件for event in pygame.event.get():if event.type == pygame.QUIT:pygame.quit()return# 清屏并繪制迷宮screen.fill(WHITE)draw_maze()# 獲取動作action = agent.get_action(state)next_state, reward, done = env.step(action)agent.update_q_table(state, action, reward, next_state)state = next_statepath.append(state)# 繪制智能體draw_agent(state)# 更新屏幕pygame.display.flip()clock.tick(FPS)if episode % 100 == 0:print(f"Episode {episode}: Training...")# 測試過程可視化state = env.reset()done = Falsepath = [state]while not done:for event in pygame.event.get():if event.type == pygame.QUIT:pygame.quit()returnscreen.fill(WHITE)draw_maze()action = agent.get_action(state)state, _, done = env.step(action)path.append(state)draw_agent(state)pygame.display.flip()clock.tick(FPS)# 顯示最終路徑screen.fill(WHITE)draw_maze()draw_path(path)pygame.display.flip()# 等待用戶關閉窗口running = Truewhile running:for event in pygame.event.get():if event.type == pygame.QUIT:running = Falsepygame.quit()
步驟 4:集成到 Q-Learning 算法
將 pygame
可視化函數集成到 Q-Learning 的訓練和測試過程中。
class QLearningAgent:def __init__(self, env, learning_rate=0.1, discount_factor=0.9, epsilon=0.1):self.env = envself.q_table = {}self.learning_rate = learning_rateself.discount_factor = discount_factorself.epsilon = epsilondef get_action(self, state):if random.uniform(0, 1) < self.epsilon:return random.choice(self.env.actions) # 探索else:q_values = [self.get_q_value(state, action) for action in self.env.actions]return self.env.actions[np.argmax(q_values)] # 貪婪策略def get_q_value(self, state, action):key = (state, action)return self.q_table.get(key, 0.0)def update_q_table(self, state, action, reward, next_state):old_q = self.get_q_value(state, action)max_next_q = max([self.get_q_value(next_state, a) for a in self.env.actions])new_q = old_q + self.learning_rate * (reward + self.discount_factor * max_next_q - old_q)self.q_table[(state, action)] = new_q
步驟 5:運行代碼
創建迷宮環境和智能體,并運行訓練和測試代碼。
# 創建環境和智能體
env = MazeEnv()
agent = QLearningAgent(env)# 使用 pygame 可視化訓練和測試
visualize_with_pygame(env, agent, num_episodes=1000)
效果
- 流暢的動畫:
pygame
提供了高效的繪圖性能,動畫更加流暢。 - 實時更新:智能體的位置和路徑會實時更新,清晰展示學習過程。
- 交互性:用戶可以通過關閉窗口隨時停止程序。
擴展功能
- 優化動畫速度:通過調整
FPS
和clock.tick()
控制動畫速度。 - 添加熱力圖:使用不同顏色表示 Q 值表的變化。
- 支持更大迷宮:通過縮放單元格大小(
CELL_SIZE
)適應更大迷宮。
通過以上方法,你可以實現一個高效且流暢的 Q-Learning 可視化程序!