AWD-LSTM為什么這么棒?

摘要: AWD-LSTM為什么這么棒,看完你就明白啦!

AWD-LSTM是目前最優秀的語言模型之一。在眾多的頂會論文中,對字級模型的研究都采用了AWD-LSTMs,并且它在字符級模型中的表現也同樣出色。

本文回顧了論文——Regularizing and Optimizing LSTM Language Models ,在介紹AWD-LSTM模型的同時并解釋其中所涉及的各項策略。該論文提出了一系列基于詞的語言模型的正則化和優化策略。這些策略不僅行之有效,而且能夠在不改變現有LSTM模型的基礎上使用。

AWD-LSTM即ASGD Weight-Dropped LSTM。它使用了DropConnect及平均隨機梯度下降的方法,除此之外還有包含一些其它的正則化策略。我們將在后文詳細講解這些策略。本文將著重于介紹它們在語言模型中的成功應用。

實驗代碼獲取:awd-lstm-lm GitHub repository

LSTM中的數學公式:

it = σ(Wixt + Uiht-1)

ft = σ(Wfxt + Ufht-1)

ot = σ(Woxt + Uoht-1)

c’t = tanh(Wcxt + Ucht-1)

ct = it ⊙ c’t + ft ⊙ c’t-1

ht = ot ⊙ tanh(ct)

其中, Wi, Wf, Wo, Wc, Ui, Uf, Uo, Uc都是權重矩陣,xt表示輸入向量,ht表示隱藏單元向量,ct表示單元狀態向量, ⊙表示element-wise乘法。

接下來我們將逐一介紹作者提出的策略:

權重下降的LSTM

RNN的循環連接容易導致過擬合問題,如何解決這一問題也成了一個較為熱門的研究領域。Dropouts的引入在前饋神經網絡和卷積網絡中取得了巨大的成功。但將Dropouts引入到RNN中卻反響甚微,這是由于Dropouts的加入破壞了RNN長期依賴的能力。

研究學者們就此提出了許多解決方案,但是這些方法要么作用于隱藏狀態向量ht-1,要么是對單元狀態向量ct進行更新。上述操作能夠解決高度優化的“黑盒”RNN,例如NVIDIA’s cuDNN LSTM中的過擬合問題。

但僅如此是不夠的,為了更好的解決這個問題,研究學者們引入了DropConnect。DropConnect是在神經網絡中對全連接層進行規范化處理。Dropout是指在模型訓練時隨機的將隱層節點的權重變成0,暫時認為這些節點不是網絡結構的一部分,但是會把它們的權重保留下來。與Dropout不同的是DropConnect在訓練神經網絡模型過程中,并不隨機的將隱層節點的輸出變成0,而是將節點中的每個與其相連的輸入權值以1-p的概率變成0。

clipboard.png

DropConnect作用在hidden-to-hidden權重矩陣(Ui、Uf、Uo、Uc)上。在前向和后向遍歷之前,只執行一次dropout操作,這對訓練速度的影響較小,可以用于任何標準優化的“黑盒”RNN中。通過對hidden-to-hidden權重矩陣進行dropout操作,可以避免LSTM循環連接中的過度擬合問題。

你可以在 awd-lstm-lm 中找到weight_drop.py 模塊用于實現。

作者表示,盡管DropConnect是通過作用在hidden-to-hidden權重矩陣以防止過擬合問題,但它也可以作用于LSTM的非循環權重。

使用非單調條件來確定平均觸發器

研究發現,對于特定的語言建模任務,傳統的不帶動量的SGD算法優于帶動量的SGD、Adam、Adagrad及RMSProp等算法。因此,作者基于傳統的SGD算法提出了ASGD(Average SGD)算法。

Average SGD

ASGD算法采用了與SGD算法相同的梯度更新步驟,不同的是,ASGD沒有返回當前迭代中計算出的權值,而是考慮的這一步和前一次迭代的平均值。

傳統的SGD梯度更新:

clipboard.png

AGSD梯度更新:

clipboard.png

其中,k是在加權平均開始之前運行的最小迭代次數。在k次迭代開始之前,ASGD與傳統的SGD類似。t是當前完成的迭代次數,sum(w_prevs)是迭代k到t的權重之和,lr_t是迭代次數t的學習效率,由學習率調度器決定。

你可以在這里找到AGSD的PyTorch實現。

但作者也強調,該方法有如下兩個缺點:

? 學習率調度器的調優方案不明確

? 如何選取合適的迭代次數k。值太小會對方法的有效性產生負面影響,值太大可能需要額外的迭代才能收斂。

基于此,作者在論文中提出了使用非單調條件來確定平均觸發器,即NT-ASGD,其中:

? 當驗證度量不能改善多個循環時,就會觸發平均值。這是由非單調區間的超參數n保證的。因此,每當驗證度量沒有在n個周期內得到改進時,就會使用到ASGD算法。通過實驗發現,當n=5的時候效果最好。

? 整個實驗中使用一個恒定的學習速率,不需要進一步的調整。

正則化方法

除了上述提及的兩種方法外,作者還使用了一些其它的正則化方法防止過擬合問題及提高數據效率。

長度可變的反向傳播序列

作者指出,使用固定長度的基于時間的反向傳播算法(BPTT)效率較低。試想,在一個時間窗口大小固定為10的BPTT算法中,有100個元素要進行反向傳播操作。在這種情況下,任何可以被10整除的元素都不會有可以反向支撐的元素。這導致了1/10的數據無法以循環的方式進行自我改進,8/10的數據只能使用到部分的BPTT窗口。

為了解決這個問題,作者提出了使用可變長度的反向傳播序列。首先選取長度為bptt的序列,概率為p以及長度為bptt/2的序列,概率為1-p。在PyTorch中,作者將p設為0.95。

clipboard.png

其中,base_bptt用于獲取seq_len,即序列長度,在N(base_bptt, s)中,s表示標準差,N表示服從正態分布。代碼如下:

clipboard.png

學習率會根據seq_length進行調整。由于當學習速率固定時,會更傾向于對段序列而非長序列進行采樣,所以需要進行縮放。

clipboard.png

Variational Dropout

在標準的Dropout中,每次調用dropout連接時都會采樣到一個新的dropout mask。而在Variational Dropout中,dropout mask在第一次調用時只采樣一次,然后locked dropout mask將重復用于前向和后向傳播中的所有連接。

雖然使用了DropConnect而非Variational Dropout以規范RNN中hidden-to-hidden的轉換,但是對于其它的dropout操作均使用的Variational Dropout,特別是在特定的前向和后向傳播中,對LSTM的所有輸入和輸出使用相同的dropout mask。

點擊查看官方awd-lstm-lm GitHub存儲庫的Variational dropout實現。詳情請參閱原文。

Embedding Dropout

論文中所提到的Embedding Dropout首次出現在——《A Theoretically Grounded Application of Dropout in Recurrent Neural Networks》一文中。該方法是指將dropout作用于嵌入矩陣中,且貫穿整個前向和反向傳播過程。在該過程中出現的所有特定單詞均會消失。

Weight Tying(權重綁定)

權重綁定共享嵌入層和softmax層之間的權重,能夠減少模型中大量的參數。

Reduction in Embedding Size

對于語言模型來說,想要減少總參數的數量,最簡單的方法是降低詞向量的維數。即使這樣無法幫助緩解過擬合問題,但它能夠減少嵌入層的維度。對LSTM的第一層和最后一層進行修改,可以使得輸入和輸出的尺寸等于減小后的嵌入尺寸。

Activation Regularization(激活正則化)

L2正則化是對權重施加范數約束以減少過擬合問題,它同樣可以用于單個單元的激活,即激活正則化。激活正則化可作為一種調解網絡的方法。

clipboard.png

Temporal Activation Regularization(時域激活正則化)

同時,L2正則化能對RNN在不同時間步驟上的輸出差值進行范數約束。它通過在隱藏層產生較大變化對模型進行懲罰。

clipboard.png

其中,alpha和beta是縮放系數,AR和TAR損失函數僅對RNN最后一層的輸出起作用。

模型分析

作者就上述模型在不同的數據集中進行了實驗,為了對分分析,每次去掉一種策略。

clipboard.png

圖中的每一行表示去掉特定策略的困惑度(perplexity)分值,從該圖中我們能夠直觀的看出各策略對結果的影響。

實驗細節

數據——來自Penn Tree-bank(PTB)數據集和WikiText-2(WT2)數據集。

網絡體系結構

——所有的實驗均使用的是3層LSTM模型。

批尺寸——WT2數據集的批尺寸為80,PTB數據集的批尺寸為40。根據以往經驗來看,較大批尺寸(40-80)的性能優于較小批尺寸(10-20)。

其它超參數的選擇請參考原文。

總結

該論文很好的總結了現有的正則化及優化策略在語言模型中的應用,對于NLP初學者甚至研究者都大有裨益。論文中強調,雖然這些策略在語言建模中獲得了成功,但它們同樣適用于其他序列學習任務。

本文作者:【方向】

閱讀原文

本文為云棲社區原創內容,未經允許不得轉載。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/282049.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/282049.shtml
英文地址,請注明出處:http://en.pswp.cn/news/282049.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Spread / Rest 操作符

Spread / Rest 操作符指的是 ...,具體是 Spread 還是 Rest 需要看上下文語境。 當被用于迭代器中時,它是一個 Spread 操作符:(參數為數組) function foo(x,y,z) {console.log(x,y,z); }let arr [1,2,3]; foo(...arr);…

python postman腳本自動化_如何用Postman做接口自動化測試

什么是自動化測試把人對軟件的測試行為轉化為由機器執行測試行為的一種實踐。例如GUI自動化測試,模擬人去操作軟件界面,把人從簡單重復的勞動中解放出來本質是用代碼去測試另一段代碼,屬于一種軟件開發工作,已經開發完成的用例還必…

Mac上,為虛擬機集群上的每臺虛擬機設置固定IP

一、環境介紹 本機:macOS系統 虛擬機軟件:VMware Fusion 虛擬機上:centos7內核的Linux系統集群 二、為什么要為每臺虛擬機設置固定ip 由于每次啟動虛擬機,得到的ip可能不一樣,這樣對遠程連接非常不友好&#xff0c…

朱曄的互聯網架構實踐心得S1E7:三十種架構設計模式(上)

設計模式是前人通過大量的實踐總結出來的一些經驗總結和最佳實踐。在經過多年的軟件開發實踐之后,回過頭來去看23種設計模式你會發現很多平時寫代碼的套路和OO的套路和設計模式里總結的類似,這也說明了你悟到的東西和別人悟到的一樣,經過大量…

記一次某制造業ERP系統 CPU打爆事故分析

一:背景 1.講故事前些天有位朋友微信找到我,說他的程序出現了CPU階段性爆高,過了一會就下去了,咨詢下這個爆高階段程序內部到底發生了什么?畫個圖大概是下面這樣,你懂的。按經驗來說,這種情況一…

PC端和移動APP端CSS樣式初始化

CSS樣式初始化分為PC端和移動APP端 1.PC端:使用Normalize.css Normalize.css是一種CSS reset的替代方案。 我們創造normalize.css有下面這幾個目的: 保護有用的瀏覽器默認樣式而不是完全去掉它們一般化的樣式:為大部分HTML元素提供修復瀏覽器…

FPGA浮點數定點化

因為在普通的fpga芯片里面,寄存器只可以表示無符號型,不可以表示小數,所以在計算比較精確的數值時,就需要做一些處理,不過在altera在Arria 10 中增加了硬核浮點DSP模塊,這樣更加適合硬件加速和做一些比較精…

框架實現修改功能的原理_JAVA集合框架的特點及實現原理簡介

1.集合框架總體架構集合大致分為Set、List、Queue、Map四種體系,其中List,Set,Queue繼承自Collection接口,Map為獨立接口Set的實現類有:HashSet,LinkedHashSet,TreeSet...List下有ArrayList,Vector,LinkedList...Map下…

NPM報錯終極大法

2019獨角獸企業重金招聘Python工程師標準>>> 所有的錯誤基本上都跟node的版本相關 直接刪除系統中的node 重新安裝 sudo rm -rf /usr/local/{bin/{node,npm},lib/node_modules/npm,lib/node,share/man/*/node.*} 重新安裝 $ n lts $ npm install -g npm $ n stable…

自己使用的一個.NET輕量開發結構

三個文件夾,第一個是放置前端部分,第二個是各種支持的類文件,第三個是單元測試文件。Core文件類庫放置的是與數據庫做交互的文件,以及一些第三方類庫,還有與數據庫連接的文件1.Lasy.Validator是一個基于Attribute驗證器…

英語影視臺詞---八、the shawshank redemption

英語影視臺詞---八、the shawshank redemption 一、總結 一句話總結:肖申克的救贖 1、Its funny. On the outside, I was an honest man. Straight as an arrow. I had to come to prison to be a crook.? 這很有趣。 在外面,我是一個誠實的人…

10.python網絡編程(socket server 實現并發 part 2)

一、基于tcp的socket通信的基本原理分析。基于tcp的socket通信,主要依靠兩個循環,分別是連接循環和通信循環。這個前面的文章有寫過,在這里就不再重復了。二、socketserver實現多并發的原理分析。1.server類:2.reques類。類繼承關…

如何在一小時內更新100篇文章?-Evernote Sync插件介紹

上一篇“手把手教你制作微信小程序,開源、免費、快速搞定”,已經教會你如何快速制作一個小程序,但作為資訊類小程序,內容不可少,并且還需要及時更新。 但是,如果讓你復制粘貼,可能還需要上傳圖片…

linux awk

grep 文本過濾器sed 流編輯器awk 報告生成器 格式化以后顯示awk [option] PATTERN {action} file1 file2awk -F"|" BEGIN{OFS":"} {print $1,$2,$3} test.txt #文本字符串用雙引號awk -F"|" BEGIN{OFS":"} {print $1,"jksong&quo…

iOS無線真機調試

為什么80%的碼農都做不了架構師?>>> Xcode從9開始 就支持無線真機調試,那么怎么操作呢? 首先用數據線連接你的設備,接下來Xcode- Window-Devices and Simulators 點開之后看到你的設備 默認情況下Connect via networ…

Mybatis中jdbcType和javaType的對應關系

2019獨角獸企業重金招聘Python工程師標準>>> Mybatis中jdbcType和javaType的對應關系 1 JDBC Type Java Type 2 CHAR String 3 VARCHAR String 4 LONGVARCHAR String 5 NUMERIC java.math.…

java貪吃蛇

使用雙向鏈表實現貪吃蛇程序 1.鏈表節點定義: package snake;public class SnakeNode {private int x;private int y;private SnakeNode next;private SnakeNode ahead;public SnakeNode() {}public SnakeNode(int x, int y) {super();this.x x;this.y y;}public …

【死磕 Spring】----- IOC 之解析 bean 標簽:解析自定義標簽

前面四篇文章都是分析 Bean 默認標簽的解析過程,包括基本屬性、六個子元素(meta、lookup-method、replaced-method、constructor-arg、property、qualifier),涉及內容較多,拆分成了四篇文章,導致我們已經忘…

Codeigniter 4.0-dev 版源碼學習筆記之四——詳細路由過程

前言 我個人覺得在當前 MVC 流行的架構下,要想去了解一個框架,或者是一個基于此架構下的應用程序,最好的入手方式就是先看路由,雖然路由不是 MVC 里的任何一個,但是知道了路由的來龍去脈就知道了整個框架或者是應用的結…