為了理解 DeepSpeed-Chat RLHF 的 RLHF 全部過程,這個系列會分三篇文章分別介紹:
原始 PPO 代碼解讀RLHF 獎勵函數代碼解讀RLHF PPO 代碼解讀
這是系列的第一篇文章,我們來一步一步的看 PPO 算法的代碼實現,對于 PPO 算法原理不太了解的同學,可以參考之前的文章:
深度強化學習(DRL)算法 2 —— PPO 之 Clipped Surrogate Objective 篇
深度強化學習(DRL)算法 2 —— PPO 之 GAE 篇
Clipped Surrogate 函數實現
# code from cleanrl: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py
for start in range(0, args.batch_size, args.minibatch_size):end = start + args.minibatch_sizemb_inds = b_inds[start:end]_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])logratio = newlogprob - b_logprobs[mb_inds]ratio = logratio.exp()mb_advantages = b_advantages[mb_inds]if args.norm_adv:mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)# Policy losspg_loss1 = -mb_advantages * ratiopg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)pg_loss = torch.max(pg_loss1, pg_loss2).mean()
Clipped Surrogate 函數的實現很簡單,這里不再贅述,理解算法原理,代碼自然而然就可以看懂,核心是 get_action_and_value 函數的理解。
def get_action_and_value(self, x, action=None):logits = self.actor(x)# probs 相當于計算 softmaxprobs = Categorical(logits=logits)if action is None:action = probs.sample()# probs.log_prob(action) 計算的是 p(a|s) 的 log 形式,方便計算 Clipped Surrogate 函數里的 ratioreturn action, probs.log_prob(action), probs.entropy(), self.critic(x)
GAE 實現
直接來看 gae 可能比較抽象,我們先來看蒙特卡洛方法實現的優勢估計,對蒙特卡洛方法不熟悉的同學,可以參考之前的文章。
深度強化學習(DRL)算法 附錄 3 —— 蒙特卡洛方法(MC)和時序差分(TD)
兩種方法都采用了反向迭代(因為反向迭代更好計算)的方式來實現優勢估計。
# code from cleanrl: https://github.com/vwxyzjn/cleanrl/commit/b7088a41e5e6f0f5f6940fd29054a35118083b28
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):if t == args.num_steps - 1:nextnonterminal = 1.0 - next_donenext_return = last_valueelse:nextnonterminal = 1.0 - dones[t+1]next_return = returns[t+1]returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
advantages = returns - values
上面的代碼做了什么事情呢,last_value 對應最后的 step(對應 step t) 產生的期望回報,如果 step t-1 整個流程沒有結束,那么 t-1 時刻的期望回報就是 reward(t-1) + args.gamma * nextnonterminal * next_return,這樣一步一步往后推,就可以計算每一個 step 的期望回報,從而得到每一步的優勢,還沒理解的話,看下面每個時間步的拆解。關于 last_value 的使用,這里由于沒有后續的回報可以累積,所以直接使用 last_value 作為最后一個時間步的回報。關于下面為啥用 return[t-1] 替換原始公式的 value[t-1],這樣計算的話就相當于蒙特卡洛方法的優勢估計,如果next_return = returns[t+1] 改成 next_value = values[t+1] 就相當于 TD(1) 的優勢估計。
# t
return(t) = v(t)
# t - 1
return(t-1) = reward(t-1) + gamma * return(t) = reward(t-1) + gamma * return(t)
# t - 2
return(t-2) = reward(t-2) + gamma * return(t-1) = reward(t-2) + gamma * (reward(t-1) + gamma * return(t))
......
# 我們可以看到一步一步往前推,最后就得到蒙特卡洛方法的優勢估計
理解了上面講的蒙特卡洛方法實現的優勢估計,再來看 gae 的實現,我們可以看到代碼實現上十分的相似,只是多了 delta 的計算,這里的 delta 對應的就是之前 PPO GAE 篇里介紹的 delta。
# code from cleanrl: https://github.com/vwxyzjn/cleanrl/commit/b7088a41e5e6f0f5f6940fd29054a35118083b28
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):if t == args.num_steps - 1:nextnonterminal = 1.0 - next_donenextvalues = last_valueelse:nextnonterminal = 1.0 - dones[t+1]nextvalues = values[t+1]delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values
這里通過反向迭代的方式計算 GAE advantage,可能理解上比較抽象,舉個例子,就很好理解了:
# advantage(t)
adv[t] = lastgaelam = rewards[t] + gamma * values[t+1] - values[t]
# t-1
adv[t-1] = lastgaelam = rewards[t-1] + gamma * values[t] - values[t-1] + gamma * lambda * lastgaelam
# t-2
adv[t-2] = lastgaelam = rewards[t-2] + gamma * values[t-1] - values[t-2] + gamma * lambda * lastgaelam
...
可以看到,逐項展開,每一時刻的 GAE Advantage 和 PPO GAE 篇里介紹的公式是一模一樣的,這里 GAE 就是一種數學公式表達,核心思想是 n step 的優勢估計的加權平均,通過數學技巧恰好是上面的形式。
參考
- The 37 Implementation Details of Proximal Policy Optimization · The ICLR Blog Track (iclr-blog-track.github.io)
- HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION