問題陳述
我們有兩個多臂老虎機(Multi-Armed Bandit),分別稱為左邊的老虎機和右邊的老虎機。每個老虎機的獎勵服從不同的正態分布:
-
左邊的老虎機:獎勵服從均值為?500,標準差為?50?的正態分布,即?N(500,50)N(500,50)。
-
右邊的老虎機:獎勵服從均值為?550,標準差為?100?的正態分布,即?N(550,100)N(550,100)。
我們的目標是使用?ε-greedy 強化學習算法(ε=0.1,初始值為 998)來估計這兩個老虎機的獎勵期望值。具體來說,我們需要通過多次嘗試(拉動手臂)來逐步更新對每個老虎機獎勵的估計,最終找到兩個老虎機的獎勵期望值。
問題分解
-
目標:
-
使用 ε-greedy 算法估計兩個老虎機的獎勵期望值。
-
通過多次嘗試,逐步更新對每個老虎機獎勵的估計。
-
-
ε-greedy 算法:
-
ε=0.1:表示有 10% 的概率進行隨機探索(隨機選擇一個老虎機),90% 的概率進行利用(選擇當前估計獎勵最高的老虎機)。
-
初始值=998:表示每個老虎機的初始獎勵估計值為 998。
-
-
獎勵分布:
-
左邊的老虎機:N(500,50)N(500,50)
-
右邊的老虎機:N(550,100)N(550,100)
-
-
輸出:
-
經過多次嘗試后,輸出兩個老虎機的獎勵期望值的估計結果
-
通過運行代碼,我們可以得到一個圖表,顯示兩個老虎機獎勵期望估計值隨著時間的變化情況。隨著拉動次數的增加,兩個估計值應該逐漸接近它們各自的真實獎勵期望值(500 和 550)。
-
import numpy as np
import matplotlib.pyplot as plt# 參數初始化
epsilon = 0.1 # ε-greedy算法中的ε
Q1 = 998 # 左邊老虎機的獎勵期望估計
Q2 = 998 # 右邊老虎機的獎勵期望估計
n1 = 0 # 左邊老虎機的拉動次數
n2 = 0 # 右邊老虎機的拉動次數
num_plays = 10000 # 總共拉動的次數# 獎勵的真實分布
mu1, sigma1 = 500, 50 # 左邊老虎機的真實獎勵分布(均值,標準差)
mu2, sigma2 = 550, 100 # 右邊老虎機的真實獎勵分布(均值,標準差)# 用于存儲結果
Q1_estimates = []
Q2_estimates = []# ε-greedy策略的實驗
for t in range(num_plays):# 根據ε-greedy策略選擇一個老虎機if np.random.random() < epsilon:action = np.random.choice([1, 2]) # 隨機選擇左或右else:action = 1 if Q1 > Q2 else 2 # 選擇當前估計獎勵最大的老虎機if action == 1:reward = np.random.normal(mu1, sigma1) # 從左邊老虎機獲得獎勵n1 += 1Q1 += (reward - Q1) / n1 # 更新左邊老虎機的獎勵期望估計Q1_estimates.append(Q1)else:reward = np.random.normal(mu2, sigma2) # 從右邊老虎機獲得獎勵n2 += 1Q2 += (reward - Q2) / n2 # 更新右邊老虎機的獎勵期望估計Q2_estimates.append(Q2)# 最終的獎勵期望估計
print(f"最終左邊老虎機的獎勵期望估計: {Q1}")
print(f"最終右邊老虎機的獎勵期望估計: {Q2}")# 繪圖
plt.figure(figsize=(12, 6))# 繪制左邊老虎機獎勵期望估計的變化
plt.plot(Q1_estimates, label="Left Slot Machine (Q1)", color="blue")# 繪制右邊老虎機獎勵期望估計的變化
plt.plot(Q2_estimates, label="Right Slot Machine (Q2)", color="red")# 繪制真實獎勵期望值的參考線
plt.axhline(y=mu1, color="blue", linestyle="--", label="True Q1 (500)")
plt.axhline(y=mu2, color="red", linestyle="--", label="True Q2 (550)")# 圖表設置
plt.title("Reward Estimation in ε-greedy Slot Machine Experiment")
plt.xlabel("Number of Plays")
plt.ylabel("Estimated Reward")
plt.legend(loc="best")
plt.grid(True)# 顯示圖表
plt.show()
顯示結果如圖: