第 15 章 模仿學習(gym版本 >= 0.26)
- 摘要
摘要
本系列知識點講解基于動手學強化學習中的內容進行詳細的疑難點分析!具體內容請閱讀動手學強化學習!
對應動手學強化學習——模仿學習
# -*- coding: utf-8 -*-import gym
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import rl_utilsclass PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)class ValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class PPO:''' PPO算法,采用截斷方式 '''def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,lmbda, epochs, eps, gamma, device):self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.critic = ValueNet(state_dim, hidden_dim).to(device)self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)self.gamma = gammaself.lmbda = lmbdaself.epochs = epochs # 一條序列的數據用于訓練輪數self.eps = eps # PPO中截斷范圍的參數self.device = devicedef take_action(self, state):if isinstance(state, tuple):state = state[0]state = torch.tensor([state], dtype=torch.float).to(self.device)probs = self.actor(state)'''根據概率分布創建一個離散分類分布對象,用于采樣離散動作。離散的概率模型。'''action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def update(self, transition_dict): processed_state = []for s in transition_dict['states']:if isinstance(s, tuple):# 如果元素是元組,則取元組的第一個元素processed_state.append(s[0])else:processed_state.append(s)# states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)states = torch.tensor(processed_state, dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)'''計算 TD 目標(即回歸目標):td_target=r+γ×V(s′)×(1?done)'''td_target = rewards + self.gamma * self.critic(next_states) * (1 -dones)'''計算 TD 殘差(或優勢估計的基礎):當前狀態的 TD 目標減去當前 critic 估計的狀態價值。'''td_delta = td_target - self.critic(states)'''調用輔助函數(在 rl_utils 模塊中定義)計算優勢函數,通常使用廣義優勢估計(GAE)。'''advantage = rl_utils.compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)'''先將狀態輸入 actor 網絡得到動作概率分布(例如 shape 為 (batch_size, action_dim))。使用 .gather(1, actions) 選出每個樣本所執行動作對應的概率(注意 actions 的形狀必須匹配)。取對數得到舊的對數概率,再 detach() 阻斷梯度流,保存舊策略下的概率值。'''old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()for _ in range(self.epochs):'''在當前策略下重新計算所有樣本的對數概率,與舊對數概率進行比較。'''log_probs = torch.log(self.actor(states).gather(1, actions))'''計算概率比率,即新舊策略的概率之比,用于 PPO 的 clip 損失計算。'''ratio = torch.exp(log_probs - old_log_probs)'''計算無截斷的策略目標,乘上優勢值。'''surr1 = ratio * advantage'''對 ratio 進行截斷,確保其在 [1??,1+?] 范圍內(例如 [0.8, 1.2]),然后乘以優勢。'''surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage # 截斷'''PPO 算法的目標是最大化最小值,因此這里取兩者中的較小值再取負號作為損失。對整個 batch 求均值。'''actor_loss = torch.mean(-torch.min(surr1, surr2)) # PPO損失函數'''計算 critic 的均方誤差(MSE)損失:當前 critic 估計與 TD 目標之間的誤差,對整個 batch 取平均。'''critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()actor_loss.backward()critic_loss.backward()self.actor_optimizer.step()self.critic_optimizer.step()actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 250
hidden_dim = 128
gamma = 0.98
lmbda = 0.95
epochs = 10
eps = 0.2
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'CartPole-v0'
env = gym.make(env_name)
if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
ppo_agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device)return_list = rl_utils.train_on_policy_agent(env, ppo_agent, num_episodes)def sample_expert_data(n_episode):states = []actions = []for episode in range(n_episode):state = env.reset()done = Falsewhile not done:action = ppo_agent.take_action(state)states.append(state)actions.append(action)result = env.step(action)if len(result) == 5:next_state, reward, done, truncated, info = resultdone = done or truncated # 可合并 terminated 和 truncated 標志else:next_state, reward, done, info = result# next_state, reward, done, _ = env.step(action)state = next_stateprocessed_states = []for s in states:if isinstance(s, tuple):# 如果元素是元組,則取元組的第一個元素processed_states.append(s[0])else:processed_states.append(s)return np.array(processed_states), np.array(actions)if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
random.seed(0)
n_episode = 1
expert_s, expert_a = sample_expert_data(n_episode)n_samples = 30 # 采樣30個數據
random_index = random.sample(range(expert_s.shape[0]), n_samples)
expert_s = expert_s[random_index]
expert_a = expert_a[random_index]class BehaviorClone:def __init__(self, state_dim, hidden_dim, action_dim, lr):self.policy = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)def learn(self, states, actions):"""解釋:定義一個學習函數,接收一批專家數據中的狀態和動作,用于更新策略網絡。"""states = torch.tensor(states, dtype=torch.float).to(device)actions = torch.tensor(actions).view(-1, 1).to(device)'''- 將 states 輸入 policy 網絡,得到每個狀態下所有動作的概率分布,假設輸出形狀為 (batch_size, action_dim);- 使用 .gather(1, actions.long()) 從概率分布中取出對應專家動作的概率(注意動作需要轉換為長整型索引);- 對這些概率取對數,得到對數概率(log likelihood)。'''log_probs = torch.log(self.policy(states).gather(1, actions.long()))# log_probs = torch.log(self.policy(states).gather(1, actions))'''計算行為克隆的損失,即負對數似然損失。對所有樣本的負對數概率取均值。'''bc_loss = torch.mean(-log_probs) # 最大似然估計self.optimizer.zero_grad()bc_loss.backward()self.optimizer.step()def take_action(self, state):if isinstance(state, tuple):state = state[0]state = torch.tensor([state], dtype=torch.float).to(device)probs = self.policy(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def test_agent(agent, env, n_episode):return_list = []for episode in range(n_episode):episode_return = 0state = env.reset()done = Falsewhile not done:action = agent.take_action(state)result = env.step(action)if len(result) == 5:next_state, reward, done, truncated, info = resultdone = done or truncated # 可合并 terminated 和 truncated 標志else:next_state, reward, done, info = result# next_state, reward, done, _ = env.step(action)state = next_stateepisode_return += rewardreturn_list.append(episode_return)return np.mean(return_list)if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
np.random.seed(0)lr = 1e-3
bc_agent = BehaviorClone(state_dim, hidden_dim, action_dim, lr)
n_iterations = 1000
batch_size = 64
test_returns = []with tqdm(total=n_iterations, desc="進度條") as pbar:for i in range(n_iterations):sample_indices = np.random.randint(low=0, high=expert_s.shape[0], size=batch_size)bc_agent.learn(expert_s[sample_indices], expert_a[sample_indices])current_return = test_agent(bc_agent, env, 5)test_returns.append(current_return)if (i + 1) % 10 == 0:pbar.set_postfix({'return': '%.3f' % np.mean(test_returns[-10:])})pbar.update(1)iteration_list = list(range(len(test_returns)))
plt.plot(iteration_list, test_returns)
plt.xlabel('Iterations')
plt.ylabel('Returns')
plt.title('BC on {}'.format(env_name))
plt.show()class Discriminator(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(Discriminator, self).__init__()self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x, a):cat = torch.cat([x, a], dim=1)x = F.relu(self.fc1(cat))return torch.sigmoid(self.fc2(x))class GAIL:def __init__(self, agent, state_dim, action_dim, hidden_dim, lr_d):print(state_dim, action_dim, hidden_dim)self.discriminator = Discriminator(state_dim, hidden_dim, action_dim).to(device)self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d)self.agent = agentdef learn(self, expert_s, expert_a, agent_s, agent_a, next_s, dones):expert_states = torch.tensor(expert_s, dtype=torch.float).to(device)expert_actions = torch.tensor(expert_a).to(device)processed_state = []for s in agent_s:if isinstance(s, tuple):# 如果元素是元組,則取元組的第一個元素processed_state.append(s[0])else:processed_state.append(s)agent_states = torch.tensor(processed_state, dtype=torch.float).to(device)agent_actions = torch.tensor(agent_a).to(device)'''作用:將專家動作轉換為 one-hot 編碼形式,轉換為浮點數。'''expert_actions = F.one_hot(expert_actions.long(), num_classes=2).float()agent_actions = F.one_hot(agent_actions.long(), num_classes=2).float()expert_prob = self.discriminator(expert_states, expert_actions)agent_prob = self.discriminator(agent_states, agent_actions)'''作用:計算二元交叉熵損失(BCE):- 對 agent 數據,目標標簽設為 1(即希望判別器認為 agent 數據為“真”),損失為 BCE(agent_prob, 1);- 對專家數據,目標標簽設為 0(希望判別器認為專家數據為“假”),損失為 BCE(expert_prob, 0)。- 然后將兩部分損失相加。'''discriminator_loss = nn.BCELoss()(agent_prob, torch.ones_like(agent_prob)) + nn.BCELoss()(expert_prob, torch.zeros_like(expert_prob))self.discriminator_optimizer.zero_grad()discriminator_loss.backward()self.discriminator_optimizer.step()'''作用:利用判別器對 agent 數據輸出計算獎勵:- 計算 –log(agent_prob) 作為獎勵信號(當 agent_prob 較小時,獎勵較高,鼓勵 agent 模仿專家);- detach() 阻斷梯度,轉移到 CPU 并轉換為 numpy 數組,方便后續傳入 agent.update。'''rewards = -torch.log(agent_prob).detach().cpu().numpy()transition_dict = {'states': agent_s,'actions': agent_a,'rewards': rewards,'next_states': next_s,'dones': dones}self.agent.update(transition_dict)if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
lr_d = 1e-3
agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device)
gail = GAIL(agent, state_dim, action_dim, hidden_dim, lr_d)
n_episode = 500
return_list = []with tqdm(total=n_episode, desc="進度條") as pbar:for i in range(n_episode):episode_return = 0state = env.reset()done = Falsestate_list = []action_list = []next_state_list = []done_list = []while not done:action = agent.take_action(state)result = env.step(action)if len(result) == 5:next_state, reward, done, truncated, info = resultdone = done or truncated # 可合并 terminated 和 truncated 標志else:next_state, reward, done, info = result# next_state, reward, done, _ = env.step(action)state_list.append(state)action_list.append(action)next_state_list.append(next_state)done_list.append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)gail.learn(expert_s, expert_a, state_list, action_list, next_state_list, done_list)if (i + 1) % 10 == 0:pbar.set_postfix({'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1) iteration_list = list(range(len(return_list)))
plt.plot(iteration_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('GAIL on {}'.format(env_name))
plt.show()