Double-DQN算法的原理簡介、與DQN對比等。
參考深度Q網絡進階技巧
1. 原理簡介
在DQN算法中,雖然有target_net和eval_net,但還是容易出現Q值高估的情況,原因在于訓練時用通過target_net選取最優動作
a ? = argmax ? a Q ( s t + 1 , a ; w ? ) a^{\star}=\underset{a}{\operatorname{argmax}} Q\left(s_{t+1}, a ; \mathbf{w}^{-}\right) a?=aargmax?Q(st+1?,a;w?)
并得到其Q值后,再根據
y t = r t + γ ? Q ( s t + 1 , a ? ; w ? ) y_{t}=r_{t}+\gamma \cdot Q\left(s_{t+1}, a^{\star} ; \mathbf{w}^{-}\right) yt?=rt?+γ?Q(st+1?,a?;w?)
算出TD-target,所以一旦高估,就會頻繁被選中然后導致目標值持續較大。
而Double-DQN算法則是設計兩個Q網絡,一個進行動作選取,一個進行Q值計算。即通過eval_net選取最優動作
a ? 2 = argmax ? a Q ( s t + 1 , a ; w ) a^{\star2}=\underset{a}{\operatorname{argmax}} Q\left(s_{t+1}, a ; \mathrm{w}\right) a?2=aargmax?Q(st+1?,a;w)
隨后再通過target_net計算其Q值得到TD_target目標值
y t = r t + γ ? Q ( s t + 1 , a ? 2 ; w ? ) y_{t}=r_{t}+\gamma \cdot Q\left(s_{t+1}, a^{\star2} ; \mathbf{w}^{-}\right) yt?=rt?+γ?Q(st+1?,a?2;w?)
毫無疑問, Q ( s t + 1 , a ? 2 ; w ? ) ≤ Q ( s t + 1 , a ? ; w ? ) Q\left(s_{t+1}, a^{\star2} ; \mathbf{w}^{-}\right) \leq Q\left(s_{t+1}, a^{\star} ; \mathbf{w}^{-}\right) Q(st+1?,a?2;w?)≤Q(st+1?,a?;w?).
2. 與DQN區別
進行TD算法梯度下降時,DQN算法是直接從target_net中選取最大Q值,而Double-DQN則是eval_net選取最優動作,target_net再選取該動作的Q值。
3. 代碼
直接在DQN的代碼上進行幾行的修改即可,修改類中的learn(update)方法。
代碼如下:
class DoubleDQN:
...def learn(self):# target parameter updateif self.learn_step_counter % TARGET_REPLACE_ITER == 0:self.target_net.load_state_dict(self.eval_net.state_dict())self.learn_step_counter += 1# sample batch transitionssample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)b_memory = self.memory[sample_index, :]b_s = torch.FloatTensor(b_memory[:, :N_STATES])b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES + 1].astype(int))b_r = torch.FloatTensor(b_memory[:, N_STATES + 1:N_STATES + 2])b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:])######## 就這一點和DQN不一樣 ########q_eval = self.eval_net(b_s).gather(1, b_a) # 相當于Q(s,a)a_next_eval = self.eval_net(b_s_) # eval_net估計下一步動作q_next = self.target_net(b_s_).detach() # target計算 Q(st+1)q_target = b_r + GAMMA * q_next.gather(1, torch.max(a_next_eval, 1)[1].unsqueeze(1)) # 根據eval_net估計的動作的最大值,找出target_net中對應的Q值,得到TD_targetloss = self.loss_func(q_eval, q_target)self.optimizer.zero_grad() # 梯度重置loss.backward() # 反向求導self.optimizer.step() # 更新模型參數
...