【深度學習】參數優化和訓練技巧

尋找合適的學習率(learning rate)

學習率是一個非常非常重要的超參數,這個參數呢,面對不同規模、不同batch-size、不同優化方式、不同數據集,其最合適的值都是不確定的,我們無法光憑經驗來準確地確定lr的值,我們唯一可以做的,就是在訓練中不斷尋找最合適當前狀態的學習率。

比如下圖利用fastai中的lr_find()函數尋找合適的學習率,根據下方的學習率-損失曲線得到此時合適的學習率為1e-2。
深度學習參數優化和訓練技巧總結

推薦一篇fastai首席設計師「Sylvain Gugger」的一篇博客:How Do You Find A Good Learning Rate[1]

以及相關的論文Cyclical Learning Rates for Training Neural Networks[2]。

learning-rate與batch-size的關系
一般來說,越大的batch-size使用越大的學習率。

原理很簡單,越大的batch-size意味著我們學習的時候,收斂方向的confidence越大,我們前進的方向更加堅定,而小的batch-size則顯得比較雜亂,毫無規律性,因為相比批次大的時候,批次小的情況下無法照顧到更多的情況,所以需要小的學習率來保證不至于出錯。

可以看下圖損失Loss與學習率Lr的關系:
深度學習參數優化和訓練技巧總結

在顯存足夠的條件下,最好采用較大的batch-size進行訓練,找到合適的學習率后,可以加快收斂速度。

另外,較大的batch-size可以避免batch normalization出現的一些小問題,參考如下Pytorch庫Issue[3]

權重初始化

權重初始化相比于其他的trick來說在平常使用并不是很頻繁。

因為大部分人使用的模型都是預訓練模型,使用的權重都是在大型數據集上訓練好的模型,當然不需要自己去初始化權重了。只有沒有預訓練模型的領域會自己初始化權重,或者在模型中去初始化神經網絡最后那幾個全連接層的權重。

常用的權重初始化算法是「kaiming_normal」或者「xavier_normal」。

相關論文:

Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification[4]
Understanding the difficulty of training deep feedforward neural networks[5]
Xavier初始化論文[6]
He初始化論文[7]
不初始化可能會減慢收斂速度,影響收斂效果。
深度學習參數優化和訓練技巧總結

dropout

dropout是指在深度學習網絡的訓練過程中,對于神經網絡單元,按照一定的概率將其暫時從網絡中丟棄。注意是「暫時」,對于隨機梯度下降來說,由于是隨機丟棄,故而每一個mini-batch都在訓練不同的網絡。

Dropout類似于bagging ensemble減少variance。也就是投通過投票來減少可變性。通常我們在全連接層部分使用dropout,在卷積層則不使用。但「dropout」并不適合所有的情況,不要無腦上Dropout。

Dropout一般適合于全連接層部分,而卷積層由于其參數并不是很多,所以不需要dropout,加上的話對模型的泛化能力并沒有太大的影響。
深度學習參數優化和訓練技巧總結
我們一般在網絡的最開始和結束的時候使用全連接層,而hidden layers則是網絡中的卷積層。所以一般情況,在全連接層部分,采用較大概率的dropout而在卷積層采用低概率或者不采用dropout。

數據集處理

主要有「數據篩選」 以及 「數據增強」

fastai中的圖像增強技術為什么相對比較好[9]

難例挖掘 hard-negative-mining

分析模型難以預測正確的樣本,給出針對性方法。

多模型融合

Ensemble是論文刷結果的終極核武器,深度學習中一般有以下幾種方式

同樣的參數,不同的初始化方式
不同的參數,通過cross-validation,選取最好的幾組
同樣的參數,模型訓練的不同階段,即不同迭代次數的模型。
不同的模型,進行線性融合. 例如RNN和傳統模型.
提高模型性能和魯棒性大法:probs融合 和 投票法。

假設這里有model 1, model 2, model 3,可以這樣融合:

  1. model1 probs model2 probs model3 probs ==> final label
  2. model1 label , model2 label , model3 label ==> voting ==> final label
  3. model1_1 probs … model1_n probs ==> mode1 label, model2 label與model3獲取的label方式與1相同 ==> voting ==> final label

第三個方式的啟發來源于,如果一個model的隨機種子沒有固定,多次預測得到的結果可能不同。

以上方式的效果要根據label個數,數據集規模等特征具體問題具體分析,表現可能不同,方式無非是probs融合和投票法的單獨使用or結合。

差分學習率與遷移學習

首先說下遷移學習,遷移學習是一種很常見的深度學習技巧,我們利用很多預訓練的經典模型直接去訓練我們自己的任務。雖然說領域不同,但是在學習權重的廣度方面,兩個任務之間還是有聯系的。
深度學習參數優化和訓練技巧總結

由上圖,我們拿來「model A」訓練好的模型權重去訓練我們自己的模型權重(「Model B」),其中,modelA可能是ImageNet的預訓練權重,而ModelB則是我們自己想要用來識別貓和狗的預訓練權重。

那么差分學習率和遷移學習有什么關系呢?我們直接拿來其他任務的訓練權重,在進行optimize的時候,如何選擇適當的學習率是一個很重要的問題。

一般地,我們設計的神經網絡(如下圖)一般分為三個部分,輸入層,隱含層和輸出層,隨著層數的增加,神經網絡學習到的特征越抽象。因此,下圖中的卷積層和全連接層的學習率也應該設置的不一樣,一般來說,卷積層設置的學習率應該更低一些,而全連接層的學習率可以適當提高。
深度學習參數優化和訓練技巧總結

這就是差分學習率的意思,在不同的層設置不同的學習率,可以提高神經網絡的訓練效果,具體的介紹可以查看下方的連接。
深度學習參數優化和訓練技巧總結

上面的示例圖來自:towardsdatascience.com/transfer-le…[10]

余弦退火(cosine annealing)和熱重啟的隨機梯度下降

「余弦」就是類似于余弦函數的曲線,「退火」就是下降,「余弦退火」就是學習率類似余弦函數慢慢下降。

「熱重啟」就是在學習的過程中,「學習率」慢慢下降然后突然再「回彈」(重啟)然后繼續慢慢下降。

兩個結合起來就是下方的學習率變化圖:
深度學習參數優化和訓練技巧總結

更多詳細的介紹可以查看知乎機器學習算法如何調參?這里有一份神經網絡學習速率設置指南[11]
以及相關論文SGDR: Stochastic Gradient Descent with Warm Restarts[12]

嘗試過擬合一個小數據集

這是一個經典的小trick了,但是很多人并不這樣做,可以嘗試一下。

關閉正則化/隨機失活/數據擴充,使用訓練集的一小部分,讓神經網絡訓練幾個周期。確保可以實現零損失,如果沒有,那么很可能什么地方出錯了。

多尺度訓練

多尺度訓練是一種「直接有效」的方法,通過輸入不同尺度的圖像數據集,因為神經網絡卷積池化的特殊性,這樣可以讓神經網絡充分地學習不同分辨率下圖像的特征,可以提高機器學習的性能。

也可以用來處理過擬合效應,在圖像數據集不是特別充足的情況下,可以先訓練小尺寸圖像,然后增大尺寸并再次訓練相同模型,這樣的思想在Yolo-v2的論文中也提到過:
深度學習參數優化和訓練技巧總結

需要注意的是:多尺度訓練并不是適合所有的深度學習應用,多尺度訓練可以算是特殊的數據增強方法,在圖像大小這一塊做了調整。如果有可能最好利用可視化代碼將多尺度后的圖像近距離觀察一下,「看看多尺度會對圖像的整體信息有沒有影響」,如果對圖像信息有影響的話,這樣直接訓練的話會誤導算法導致得不到應有的結果。

Cross Validation 交叉驗證

在李航的統計學方法中說到,交叉驗證往往是對實際應用中「數據不充足」而采用的,基本目的就是重復使用數據。在平常中我們將所有的數據分為訓練集和驗證集就已經是簡單的交叉驗證了,可以稱為1折交叉驗證。「注意,交叉驗證和測試集沒關系,測試集是用來衡量我們的算法標準的,不參與到交叉驗證中來。」

交叉驗證只針對訓練集和驗證集。

交叉驗證是Kaggle比賽中特別推崇的一種技巧,我們經常使用的是5-折(5-fold)交叉驗證,將訓練集分成5份,隨機挑一份做驗證集其余為訓練集,循環5次,這種比較常見計算量也不是很大。還有一種叫做leave-one-out cross validation留一交叉驗證,這種交叉驗證就是n-折交叉,n表示數據集的容量,這種方法只適合數據量比較小的情況,計算量非常大的情況很少用到這種方法。

吳恩達有一節課The nuts and bolts of building applications using deep learning[13]中也提到了。
深度學習參數優化和訓練技巧總結

優化算法

按理說不同的優化算法適合于不同的任務,不過我們大多數采用的優化算法還是是adam和SGD monmentum。

Adam 可以解決一堆奇奇怪怪的問題(有時 loss 降不下去,換 Adam 瞬間就好了),也可以帶來一堆奇奇怪怪的問題(比如單詞詞頻差異很大,當前 batch 沒有的單詞的詞向量也被更新;再比如Adam和L2正則結合產生的復雜效果)。用的時候要膽大心細,萬一遇到問題找各種魔改 Adam(比如 MaskedAdam[14], AdamW 啥的)搶救。

但看一些博客說adam的相比SGD,收斂快,但泛化能力差,更優結果似乎需要精調SGD。

adam,adadelta等, 在小數據上,我這里實驗的效果不如sgd, sgd收斂速度會慢一些,但是最終收斂后的結果,一般都比較好。

如果使用sgd的話,可以選擇從1.0或者0.1的學習率開始,隔一段時間,在驗證集上檢查一下,如果cost沒有下降,就對學習率減半. 我看過很多論文都這么搞,我自己實驗的結果也很好. 當然,也可以先用ada系列先跑,最后快收斂的時候,更換成sgd繼續訓練.同樣也會有提升.據說adadelta一般在分類問題上效果比較好,adam在生成問題上效果比較好。

adam收斂雖快但是得到的解往往沒有sgd momentum得到的解更好,如果不考慮時間成本的話還是用sgd吧。

adam是不需要特別調lr,sgd要多花點時間調lr和initial weights。

數據預處理方式

zero-center ,這個挺常用的.

深度學習參數優化和訓練技巧總結

PCA whitening,這個用的比較少。

訓練技巧

要做梯度歸一化,即算出來的梯度除以minibatch size
clip c(梯度裁剪): 限制最大梯度,其實是value = sqrt(w1^2 w2^2….),如果value超過了閾值,就算一個衰減系系數,讓value的值等于閾值: 5,10,15
dropout對小數據防止過擬合有很好的效果,值一般設為0.5
小數據上dropout sgd在我的大部分實驗中,效果提升都非常明顯.因此可能的話,建議一定要嘗試一下。
dropout的位置比較有講究, 對于RNN,建議放到輸入->RNN與RNN->輸出的位置.關于RNN如何用dropout,可以參考這篇論文:http://arxiv.org/abs/1409.2329[15]
除了gate之類的地方,需要把輸出限制成0-1之外,盡量不要用sigmoid,可以用tanh或者relu之類的激活函數.
sigmoid函數在-4到4的區間里,才有較大的梯度。之外的區間,梯度接近0,很容易造成梯度消失問題。
輸入0均值,sigmoid函數的輸出不是0均值的。
rnn的dim和embdding size,一般從128上下開始調整. batch size,一般從128左右開始調整. batch size合適最重要,并不是越大越好.
word2vec初始化,在小數據上,不僅可以有效提高收斂速度,也可以可以提高結果.
盡量對數據做shuffle
LSTM 的forget gate的bias,用1.0或者更大的值做初始化,可以取得更好的結果,來自這篇論文:
http://jmlr.org/proceedings/papers/v37/jozefowicz15.pdf[16]
我這里實驗設成1.0,可以提高收斂速度.實際使用中,不同的任務,可能需要嘗試不同的值.
Batch Normalization據說可以提升效果,參考論文:Accelerating Deep Network Training by Reducing Internal Covariate Shift
如果你的模型包含全連接層(MLP),并且輸入和輸出大小一樣,可以考慮將MLP替換成Highway Network,我嘗試對結果有一點提升,建議作為最后提升模型的手段,原理很簡單,就是給輸出加了一個gate來控制信息的流動,詳細介紹請參考論文:?http://arxiv.org/abs/1505.00387[17]
來自@張馨宇的技巧:一輪加正則,一輪不加正則,反復進行。
在數據集很大的情況下,一上來就跑全量數據。建議先用 1/100、1/10 的數據跑一跑,對模型性能和訓練時間有個底,外推一下全量數據到底需要跑多久。在沒有足夠的信心前不做大規模實驗。
subword 總是會很穩定地漲點,只管用就對了。
GPU 上報錯時盡量放在 CPU 上重跑,錯誤信息更友好。例如 GPU 報 “ERROR:tensorflow:Model diverged with loss = NaN” 其實很有可能是輸入 ID 超出了 softmax 詞表的范圍。
在確定初始學習率的時候,從一個很小的值(例如 1e-7)開始,然后每一步指數增大學習率(例如擴大1.05 倍)進行訓練。訓練幾百步應該能觀察到損失函數隨訓練步數呈對勾形,選擇損失下降最快那一段的學習率即可。
補充一個rnn trick,仍然是不考慮時間成本的情況下,batch size=1是一個很不錯的regularizer, 起碼在某些task上,這也有可能是很多人無法復現alex graves實驗結果的原因之一,因為他總是把batch size設成1。
注意實驗的可復現性和一致性,注意養成良好的實驗記錄習慣 ==> 不然如何分析出實驗結論。
超參上,learning rate 最重要,推薦了解 cosine learning rate 和 cyclic learning rate,其次是 batchsize 和 weight decay。當你的模型還不錯的時候,可以試著做數據增廣和改損失函數錦上添花了。

參考:

關于訓練神經網路的諸多技巧Tricks(完全總結版)[18]
你有哪些deep learning(rnn、cnn)調參的經驗?[19]
Bag of Tricks for Image Classification with Convolutional Neural Networks[20],trick 合集 1。
Must Know Tips/Tricks in Deep Neural Networks[21],trick 合集 2。
33條神經網絡訓練秘技[22],trick 合集 3。
26秒單GPU訓練CIFAR10[23],工程實踐。
Batch Normalization[24],雖然玄學,但是養活了很多煉丹師。
Searching for Activation Functions[25],swish 激活函數。

參考資料

[1] How Do You Find A Good Learning Rate: https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html

[2] Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186

[3] Pytorch庫Issue: https://github.com/pytorch/pytorch/issues/4534

[4] Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification: https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf

[5] Understanding the difficulty of training deep feedforward neural networks: http://proceedings.mlr.press/v9/glorot10a.html

[6] Xavier初始化論文: http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf

[7] He初始化論文: https://arxiv.org/abs/1502.01852

[8] https://arxiv.org/abs/1312.6120: https://arxiv.org/abs/1312.6120

[9] fastai中的圖像增強技術為什么相對比較好: https://oldpan.me/archives/fastai-1-0-quick-study

[10] towardsdatascience.com/transfer-le…: https://towardsdatascience.com/transfer-learning-using-differential-learning-rates-638455797f00

[11] 機器學習算法如何調參?這里有一份神經網絡學習速率設置指南: https://zhuanlan.zhihu.com/p/34236769

[12] SGDR: Stochastic Gradient Descent with Warm Restarts: https://arxiv.org/abs/1608.03983

[13] The nuts and bolts of building applications using deep learning: https://www.youtube.com/watch?v=F1ka6a13S9I

[14] MaskedAdam: https://www.zhihu.com/question/265357659/answer/580469438

[15] http://arxiv.org/abs/1409.2329: http://arxiv.org/abs/1409.2329

[16] http://jmlr.org/proceedings/papers/v37/jozefowicz15.pdf: http://jmlr.org/proceedings/papers/v37/jozefowicz15.pdf

[17] http://arxiv.org/abs/1505.00387: http://arxiv.org/abs/1505.00387

[18] 關于訓練神經網路的諸多技巧Tricks(完全總結版): https://juejin.im/post/5be5b0d7e51d4543b365da51

[19] 你有哪些deep learning(rnn、cnn)調參的經驗?: https://www.zhihu.com/question/41631631

[20] Bag of Tricks for Image Classification with Convolutional Neural Networks: https://arxiv.org/abs/1812.01187

[21] Must Know Tips/Tricks in Deep Neural Networks: http://lamda.nju.edu.cn/weixs/project/CNNTricks/CNNTricks.html

[22] 33條神經網絡訓練秘技: https://zhuanlan.zhihu.com/p/63841572

[23] 26秒單GPU訓練CIFAR10: https://zhuanlan.zhihu.com/p/79020733

[24] Batch Normalization: https://arxiv.org/abs/1502.03167%3Fcontext%3Dcs

[25] Searching for Activation Functions: https://arxiv.org/abs/1710.05941

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/164759.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/164759.shtml
英文地址,請注明出處:http://en.pswp.cn/news/164759.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

6.2.SDP協議

那今天呢?我們來介紹一下sdp協議,那實際上呢?sdp協議非常的簡單。我們如果拿到一個stp的文檔去看的話,那你要分閱里邊的所有的內容會覺得很枯燥,但實際上呢,如果我們按照這張圖所展示的結構去看stp的話。你…

Javascript每天一道算法題(十四)——合并數組區間_中等

文章目錄 1、問題2、示例3、解決方法(0)方法0——雙指針(錯誤思路)(1)方法1——雙指針(正確) 總結 1、問題 以數組 intervals 表示若干個區間的集合,其中單個區間為 inte…

怎么讀一個網絡的代碼

1.網絡代碼怎么來的? 我想要實現一個功能,這個功能是輸入一張圖像,返回一個類別結果。 所以很明確就有三個部分,一個是接受圖像輸入,一個是處理圖像得到處理結果,一個是對處理結果判斷生成結果。 現在想要使…

rocketmq 發送時異常:system busy 和 broker busy 解決方案

之前寫的解決方案,都是基于測試環境測試的.到生產環境之后,正常使用沒有問題,生產環境壓測時,又出現了system busy異常(簡直崩潰).最后在rocketmq群里大佬指導下,終于解決(希望是徹底解決). 下面直接給出結果: 目前通過生產環境各種參數修改測試得出: broker busy異常: 可通…

Using PeopleCode in Application Engine Programs在應用引擎程序中使用PeopleCode

This section provides an overview of PeopleCode and Application Engine programs and discusses how to: 本節概述了PeopleCode和應用程序引擎程序,并討論了如何: Decide when to use PeopleCode.決定何時使用PeopleCode。Consider the program environment.考…

Java之《ATM自動取款機》(面向對象)

《JAVA編程基礎》項目說明 一、項目名稱: 基于JAVA控制臺版本銀行自動取款機 項目要求: 實現銀行自動取款機的以下基本操作功能:讀卡、取款、查詢。(自動取款機中轉賬、修改密碼不作要求) 具體要求: 讀卡…

基于SSM的校園奶茶點單管理系統

基于SSM的校園奶茶點單管理系統的設計與實現~ 開發語言:Java數據庫:MySQL技術:SpringMyBatisSpringMVC工具:IDEA/Ecilpse、Navicat、Maven 系統展示 主頁 奶茶列表 登錄界面 管理員界面 用戶界面 摘要 隨著社會的發展和科技的進…

ubuntu搭建phpmyadmin+wordpress

Ubuntu搭建phpmyadmin wordpress Linux系統設置:Ubuntu 22配置apache2搭建phpmyadmin配置Nginx環境,搭建wordpress Linux系統設置:Ubuntu 22 配置apache2 安裝apache2 sudo apt -y install apache2設置端口號為8080 sudo vim /etc/apache…

paddle detection 訓練參數

#####################################基礎配置##################################### # 檢測算法使用YOLOv3,backbone使用MobileNet_v1,數據集使用roadsign_voc的配置文件模板,本配置文件默認使用單卡,單卡的batch_size=1 # 檢測模型的名稱 architecture: YOLOv3 # 根據…

【CCF-PTA】第03屆Scratch第05題 -- 統計出現次數最多的字

統計出現次數最多的字 【題目描述】 我國自古流傳下來不少膾炙人口的詩歌,各具特色,別具一格。有些詩只用寥寥幾個字,就能描繪出生動的意境。 請找出以下詩篇中出現次數最多的字,如果有多個字出現次數相同,則答案為…

Java中基于SSM框架的數據保存方法與日期處理

? 一、詳解 在SSM框架中,保存數據通常涉及到服務層和數據訪問層。服務層處理業務邏輯,而數據訪問層負責與數據庫進行交互。 二、代碼 Override public void save(Student student) { Date date new Date(); SimpleDateFormat format new Sim…

什么是LLC電路?

LLC電路是由2個電感和1個電容構成的諧振電路,故稱之為LLC; LLC電路主要由三個元件組成:兩個電感分別為變壓器一次側漏感(Lr)和勵磁電感(Lm),電容為變壓器一次側諧振電容(Cr)。這些元件構成了一個諧振回路,其中輸入電感…

【C/PTA】函數專項練習(四)

本文結合PTA專項練習帶領讀者掌握函數,刷題為主注釋為輔,在代碼中理解思路,其它不做過多敘述。 目錄 6-1 計算A[n]1/(1 A[n-1])6-2 遞歸實現順序輸出整數6-3 自然數的位數(遞歸版)6-4 分治法求解金塊問題6-5 漢諾塔6-6 重復顯示字符(遞歸版)…

字母異位詞分組

給你一個字符串數組,請你將 字母異位詞 組合在一起。可以按任意順序返回結果列表。 字母異位詞 是由重新排列源單詞的所有字母得到的一個新單詞。 示例 1: 輸入: strs [“eat”, “tea”, “tan”, “ate”, “nat”, “bat”] 輸出: [[“bat”],[“nat”,“tan…

Android MemoryFile 共享內存

應用場景: 跨進程傳輸大數據,如文件、圖片等; 技術選型: 共享內存–MemoryFile; 優點: 1. 共享內存沒有傳輸大小限制,所以和應用總的分配內存一樣(512MB)&#xff1…

Java 根據文件名獲取文件類型

比如文件名是“測試文件.png”,則獲取的文件類型就是 png 直接上一個通用的方法,拿去直接就能用。 // 比如入參文件名是“測試文件.png”,則出參就是 pngprivate String getFileSuffix(String fileName) {String[] fileStr fileName.split(&…

educoder中共享單車之數據可視化

第1關:繪制地圖 <%@ page language="java" contentType="text/html; charset=utf-8"pageEncoding="utf-8"%> <html> <head><meta http-equiv="Content-Type" content="text/html; charset=utf-8" /&…

專用設備上的SD卡插入電腦想讀取數據,提示要格式化?

環境&#xff1a; Win10 專業版 車載感應數據專用SD卡 問題描述&#xff1a; 專用設備上的SD&#xff0c;現在把SD卡從設備取出&#xff0c;用讀卡器插入電腦提示要格式化&#xff1f; 解決方案&#xff1a; 1.先進入PE查看SD分區情況&#xff0c;SD格式為ext4 查看文件…

lombok中使用@Builder構造器模式時的默認值問題

這里寫自定義目錄標題 問題case原因解決方案 文章參考來源&#xff1a;https://chenyongjun.vip/articles/107 問題case Lombok 使用廣泛&#xff0c;這里分享一個 Lombok Builder 小 case&#xff0c;今天自己踩了坑。 Data Builder public class User {private String name…

MLP 有哪些可學習的參數

多層感知機&#xff08;MLP&#xff09;的參數是需要在訓練過程中學習的。MLP是一種前饋神經網絡&#xff0c;其結構包括輸入層、多個隱藏層和輸出層。在訓練過程中&#xff0c;MLP通過反向傳播算法來調整網絡的權重&#xff0c;以最小化預測值與實際值之間的誤差。 MLP的學習…