優先經驗回放(prioritized experience replay)

prioritized experience replay 思路

優先經驗回放出自ICLR 2016的論文《prioritized experience replay》。

prioritized experience replay的作者們認為,按照一定的優先級來對經驗回放池中的樣本采樣,相比于隨機均勻的從經驗回放池中采樣的效率更高,可以讓模型更快的收斂。其基本思想是RL agent在一些轉移樣本上可以更有效的學習,也可以解釋成“更多地訓練會讓你意外的數據”。

那優先級如何定義呢?作者們使用的是樣本的TD error δ \delta δ 的幅值。對于新生成的樣本,TD error未知時,將樣本賦值為最大優先級,以保證樣本至少將會被采樣一次。每個采樣樣本的概率被定義為
P ( i ) = p i α ∑ k p k α P(i) = \frac {p_i^{\alpha}} {\sum_k p_k^{\alpha}} P(i)=k?pkα?piα??
上式中的 p i > 0 p_i >0 pi?>0是回放池中的第i個樣本的優先級, α \alpha α則強調有多重視該優先級,如果 α = 0 \alpha=0 α=0,采樣就退化成和基礎DQN一樣的均勻采樣了。

p i p_i pi?如何取值,論文中提供了如下兩種方法,兩種方法都是關于TD error δ \delta δ 單調的:

  • 基于比例的優先級: p i = ∣ δ i ∣ + ? p_i = |\delta_i| + \epsilon pi?=δi?+? ? \epsilon ?是一個很小的正數常量,防止當TD error為0時樣本就不會被訪問到的情形。(目前大部分實現都是使用的這個形式的優先級)
  • 基于排序的優先級: p i = 1 r a n k ( i ) p_i = \frac {1}{rank(i)} pi?=rank(i)1?, 式中的 r a n k ( i ) rank(i) rank(i)是樣本根據 ∣ δ i ∣ |\delta_i| δi? 在經驗回放池中的排序號,此時P就變成了帶有指數 α \alpha α的冪率分布了。

作者們定義的概率調整了樣本的優先級,因此也就在數據分布中引入了偏差,為了彌補偏差,使用了重要性采樣權重(importance-sampling (IS) weights):
w i = ( 1 N ? 1 P ( i ) ) β w_i = \left( \frac{1}{N} \cdot \frac{1}{P(i)} \right)^{\beta} wi?=(N1??P(i)1?)β
上式權重中,當 β = 1 \beta=1 β=1時就完全補償了非均勻概率采樣引入的偏差,作者們提到為了收斂性考慮,最后讓 β \beta β從0到1中的某個值開始,并逐漸增加到1。在Q-learning更新時使用這些權重乘以TD error,也就是使用 w i δ i w_i \delta_i wi?δi?而不是原來的 δ i \delta_i δi?。此外,為了使訓練更穩定,總是對權重乘以 1 / m a x i w i 1/\mathcal{max}_i{w_i} 1/maxi?wi?進行歸一化。

以Double DQN為例,使用優先經驗回放的算法(論文算法1)如下圖:

在這里插入圖片描述

prioritized experience replay 實現

直接實現優先經驗回放池如下代碼(修改自代碼 )

class PrioReplayBufferNaive:def __init__(self, buf_size, prob_alpha=0.6, epsilon=1e-5, beta=0.4, beta_increment_per_sampling=0.001):self.prob_alpha = prob_alphaself.capacity = buf_sizeself.pos = 0self.buffer = []self.priorities = np.zeros((buf_size, ), dtype=np.float32)self.beta = betaself.beta_increment_per_sampling = beta_increment_per_samplingself.epsilon = epsilondef __len__(self):return len(self.buffer)def size(self):  # 目前buffer中數據的數量return len(self.buffer)def add(self, sample):# 新加入的數據使用最大的優先級,保證數據盡可能的被采樣到max_prio = self.priorities.max() if self.buffer else 1.0if len(self.buffer) < self.capacity:self.buffer.append(sample)else:self.buffer[self.pos] = sampleself.priorities[self.pos] = max_prioself.pos = (self.pos + 1) % self.capacitydef sample(self, batch_size):if len(self.buffer) == self.capacity:prios = self.prioritieselse:prios = self.priorities[:self.pos]probs = np.array(prios, dtype=np.float32) ** self.prob_alphaprobs /= probs.sum()indices = np.random.choice(len(self.buffer), batch_size, p=probs, replace=True)samples = [self.buffer[idx] for idx in indices]total = len(self.buffer)self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])weights = (total * probs[indices]) ** (-self.beta)weights /= weights.max()return samples, indices, np.array(weights, dtype=np.float32)def update_priorities(self, batch_indices, batch_priorities):'''更新樣本的優先級'''for idx, prio in zip(batch_indices, batch_priorities):self.priorities[idx] = prio + self.epsilon

直接實現的優先經驗回放,在樣本數很大時的采樣效率不夠高,作者們通過定義了sumtree的數據結構來存儲樣本優先級,該數據結構的每一個節點的值為其子節點之和,而樣本優先級被放在樹的葉子節點上,樹的根節點的值為所有優先級之和 p t o t a l p_{total} ptotal?,更新和采樣時的效率為 O ( l o g N ) O(logN) O(logN)。在采樣時,設采樣批次大小為k,將 [ 0 , p t o t a l ] [0, p_{total}] [0,ptotal?]均分為k等份,然后在每一個區間均勻的采樣一個值,再通過該值從樹中提取到對應的樣本。python 實現如下(代碼來源)

class SumTree:"""父節點的值是其子節點值之和的二叉樹數據結構"""write = 0def __init__(self, capacity):self.capacity = capacityself.tree = np.zeros(2 * capacity - 1)self.data = np.zeros(capacity, dtype=object)self.n_entries = 0# update to the root nodedef _propagate(self, idx, change):parent = (idx - 1) // 2self.tree[parent] += changeif parent != 0:self._propagate(parent, change)# find sample on leaf nodedef _retrieve(self, idx, s):left = 2 * idx + 1right = left + 1if left >= len(self.tree):return idxif s <= self.tree[left]:return self._retrieve(left, s)else:return self._retrieve(right, s - self.tree[left])def total(self):return self.tree[0]# store priority and sampledef add(self, p, data):idx = self.write + self.capacity - 1self.data[self.write] = dataself.update(idx, p)self.write += 1if self.write >= self.capacity:self.write = 0if self.n_entries < self.capacity:self.n_entries += 1# update prioritydef update(self, idx, p):change = p - self.tree[idx]self.tree[idx] = pself._propagate(idx, change)# get priority and sampledef get(self, s):idx = self._retrieve(0, s)dataIdx = idx - self.capacity + 1return (idx, self.tree[idx], self.data[dataIdx])class PrioReplayBuffer:  # stored as ( s, a, r, s_ ) in SumTreeepsilon = 0.01alpha = 0.6beta = 0.4beta_increment_per_sampling = 0.001def __init__(self, capacity):self.tree = SumTree(capacity)self.capacity = capacitydef _get_priority(self, error):return (np.abs(error) + self.epsilon) ** self.alphadef add(self, error, sample):p = self._get_priority(error)self.tree.add(p, sample)def sample(self, n):batch = []idxs = []segment = self.tree.total() / npriorities = []self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])for i in range(n):a = segment * ib = segment * (i + 1)s = random.uniform(a, b)(idx, p, data) = self.tree.get(s)priorities.append(p)batch.append(data)idxs.append(idx)sampling_probabilities = priorities / self.tree.total()is_weight = np.power(self.tree.n_entries * sampling_probabilities, -self.beta)is_weight /= is_weight.max()return batch, idxs, is_weightdef update(self, idx, error):'''這里是一次更新一個樣本,所以在調用時,寫for循環依次更次樣本的優先級'''p = self._get_priority(error)self.tree.update(idx, p)

參考資料

  1. Schaul, Tom, John Quan, Ioannis Antonoglou, and David Silver. 2015. “Prioritized Experience Replay.” arXiv: Learning,arXiv: Learning, November.

  2. sum_tree的實現代碼

  3. 相關blog: 1 (對應的代碼), 2, 3

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/160617.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/160617.shtml
英文地址,請注明出處:http://en.pswp.cn/news/160617.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

UML建模圖文詳解教程——類圖

版權聲明 本文原創作者&#xff1a;谷哥的小弟作者博客地址&#xff1a;http://blog.csdn.net/lfdfhl本文參考資料&#xff1a;《UML面向對象分析、建模與設計&#xff08;第2版&#xff09;》呂云翔&#xff0c;趙天宇 著 類圖概述 類圖用來描述系統內各種實體的類型以及不同…

Unsupervised MVS論文筆記

Unsupervised MVS論文筆記 摘要1 引言2 相關工作3 實現方法 Tejas Khot and Shubham Agrawal and Shubham Tulsiani and Christoph Mertz and Simon Lucey and Martial Hebert. Tejas Khot and Shubham Agrawal and Shubham Tulsiani and Christoph Mertz and Simon Lucey and …

JAVA小游戲拼圖

第一步是創建項目 項目名自擬 第二部創建個包名 來規范class 然后是創建類 創建一個代碼類 和一個運行類 代碼如下&#xff1a; package heima; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.event.KeyEvent; import …

10、信息打點——APP小程序篇抓包封包XP框架反編譯資產提取

APP信息搜集思路 外在——抓包封包——資產安全測試 抓包&#xff08;Fiddle&茶杯&burp&#xff09;封包&#xff08;封包監聽工具&#xff09;&#xff0c;提取資源信息 資產收集——資源提取——ICO、MAD、hash——FOFA等網絡測繪進行資產搜集 外在——功能邏輯 內在…

國際版Amazon Lightsail的功能解析

Amazon Lightsail是一項易于使用的云服務,可為您提供部署應用程序或網站所需的一切,從而實現經濟高效且易于理解的月度計劃。它是部署簡單的工作負載、網站或開始使用亞馬遜云科技的理想選擇。 作為 AWS 免費套餐的一部分&#xff0c;可以免費開始使用 Amazon Lightsail。注冊…

【Python進階】近200頁md文檔14大體系第4篇:Python進程使用詳解(圖文演示)

本文從14大模塊展示了python高級用的應用。分別有Linux命令&#xff0c;多任務編程、網絡編程、Http協議和靜態Web編程、htmlcss、JavaScript、jQuery、MySql數據庫的各種用法、python的閉包和裝飾器、mini-web框架、正則表達式等相關文章的詳細講述。 Python全套筆記直接地址…

PostgreSQL10安裝postgis插件

1.安裝pgsql10 2.下載插件&#xff0c;以Windows為例&#xff0c;地址&#xff1a;Index of /postgis/windows/pg10/ 3.安裝插件&#xff0c;直接安裝&#xff0c;和pgsql的目錄相同即可&#xff0c;一直下一步 4.安裝之后&#xff0c;需要執行sql打開 CREATE EXTENSION po…

028 - STM32學習筆記 - ADC結構體學習(二)

028 - STM32學習筆記 - 結構體學習&#xff08;二&#xff09; 上節對ADC基礎知識進行了學習&#xff0c;這節在了解一下ADC相關的結構體。 一、ADC初始化結構體 在標準庫函數中基本上對于外設都有一個初始化結構體xx_InitTypeDef&#xff08;其中xx為外設名&#xff0c;例如…

Redis設計與實現-數據結構(建設進度17%)

Redis數據結構 引言數據結構stringSDS數據結構原生string的不足 hash 本博客基于《Redis設計與實現》進行整理和補充&#xff0c;該書依賴于Redis 3.0版本&#xff0c;但是Redis6.0版本在一些底層實現上仍然沒有明顯的變動&#xff0c;因此本文將在該書的基礎上&#xff0c;對于…

PostgreSQL基本操作

1.查詢某個表的所在磁盤大小 select pg_size_pretty(pg_relation_size(grb_grid)); 2.插入point類型的記錄 insert into tb_person ("name", "address", "location", "create_time", "area", "girls") values …

Java 兩個線程交替打印1-100

線程題&#xff1a;交替打印1-100 這里演示兩個線程&#xff0c;一個打印奇數&#xff0c;一個打印偶數 方式一&#xff1a;synchronized FixedThreadPool public class example {private static int count 1;private static final Object lock new Object();public stat…

WPF基礎DataGrid控件

WPF DataGrid 是一個用于顯示和編輯表格數據的強大控件。它提供了豐富的功能&#xff0c;包括排序、篩選、分組、編輯、選擇等&#xff0c;使你能夠以類似電子表格的方式呈現和操作數據。 DataGrid 的布局主要由以下部分組成&#xff1a; 列定義 (Columns): DataGrid 列定義了…

YOLO目標檢測——衛星遙感多類別檢測數據集下載分享【含對應voc、coco和yolo三種格式標簽】

實際項目應用&#xff1a;衛星遙感目標檢測數據集說明&#xff1a;衛星遙感多類別檢測數據集&#xff0c;真實場景的高質量圖片數據&#xff0c;數據場景豐富&#xff0c;含網球場、棒球場、籃球場、田徑場、儲罐、車輛、橋、飛機、船等類別標簽說明&#xff1a;使用lableimg標…

2023年【上海市安全員C證】考試及上海市安全員C證找解析

題庫來源&#xff1a;安全生產模擬考試一點通公眾號小程序 2023年上海市安全員C證考試為正在備考上海市安全員C證操作證的學員準備的理論考試專題&#xff0c;每個月更新的上海市安全員C證找解析祝您順利通過上海市安全員C證考試。 1、【多選題】2017年9月頒發的《中共上海市委…

基于STM32的煙霧濃度檢測報警仿真設計(仿真+程序+講解視頻)

這里寫目錄標題 &#x1f4d1;1.主要功能&#x1f4d1;2.仿真&#x1f4d1;3. 程序&#x1f4d1;4. 資料清單&下載鏈接&#x1f4d1;[資料下載鏈接](https://docs.qq.com/doc/DS0VHTmxmUHBtVGVP) 基于STM32的煙霧濃度檢測報警仿真設計(仿真程序講解&#xff09; 仿真圖prot…

【數據結構】B : DS圖應用--最短路徑

B : DS圖應用–最短路徑 文章目錄 B : DS圖應用--最短路徑DescriptionInputOutputSampleInput Output 解題思路&#xff1a;初始化主循環心得&#xff1a; AC代碼 Description 給出一個圖的鄰接矩陣&#xff0c;再給出指定頂點v0&#xff0c;求頂點v0到其他頂點的最短路徑 In…

SkyWalking配置報警推送到企業微信

1、先在企業微信群里創建一個機器人&#xff0c;復制webhook的地址&#xff1a; 2、找到SkyWalking部署位置的alarm-settings.yml文件 編輯&#xff0c;在最后面加上此段配置 &#xff01;&#xff01;&#xff01;一定格式要對&#xff0c;不然一直報警報不出來按照網上指導…

JVM 堆外內存詳解

Java 進程內存占用除了JVM 運行時數據區&#xff0c;還有直接內存&#xff08;Direct Memory&#xff09;區域及 JVM 程序自身也會占用內存 直接內存&#xff08;Direct Memory&#xff09;區域&#xff1a;直接內存通過使用Native堆外內存來存儲數據&#xff0c;這意味著數據…

大數據平臺實踐之CDH6.2.1+spark3.3.0+kyuubi-1.6.0

前言&#xff1a;關于kyuubi的原理和功能這里不做詳細的介紹&#xff0c;感興趣的同學可以直通官網&#xff1a;https://kyuubi.readthedocs.io/en/v1.7.1-rc0/index.html 下載軟件版本 wget http://distfiles.macports.org/scala2.12/scala-2.12.16.tgz wget https://archi…

pikachu_php反序列化

pikachu_php反序列化 源代碼 class S{var $test "pikachu";function __construct(){echo $this->test;} }//O:1:"S":1:{s:4:"test";s:29:"<script>alert(xss)</script>";} $html; if(isset($_POST[o])){$s $_POST[…