【深度學習】學習率及多種選擇策略

學習率是最影響性能的超參數之一,如果我們只能調整一個超參數,那么最好的選擇就是它。相比于其它超參數學習率以一種更加復雜的方式控制著模型的有效容量,當學習率最優時,模型的有效容量最大。本文從手動選擇學習率到使用預熱機制介紹了很多學習率的選擇策略。

這篇文章記錄了我對以下問題的理解:

  • 學習速率是什么?學習速率有什么意義?
  • 如何系統地獲得良好的學習速率?
  • 我們為什么要在訓練過程中改變學習速率?
  • 當使用預訓練模型時,我們該如何解決學習速率的問題?

本文的大部分內容都是以 fast.ai 研究員寫的內容 [1], [2], [5] 和 [3] 為基礎的。本文是一個更為簡潔的版本,通過本文可以快速獲取這些文章的主要內容。如果您想了解更多詳情,請參閱參考資料。

首先,什么是學習速率?

學習速率是指導我們該如何通過損失函數的梯度調整網絡權重的超參數。學習率越低,損失函數的變化速度就越慢。雖然使用低學習率可以確保我們不會錯過任何局部極小值,但也意味著我們將花費更長的時間來進行收斂,特別是在被困在高原區域的情況下。

下述公式表示了上面所說的這種關系。

  1. new_weight = existing_weight — learning_rate * gradient

理解深度學習中的學習率及多種選擇策略

采用小學習速率(頂部)和大學習速率(底部)的梯度下降。來源:Coursera 上吳恩達(Andrew Ng)的機器學習課程。

一般而言,用戶可以利用過去的經驗(或其他類型的學習資料)直觀地設定學習率的最佳值。

因此,想得到最佳學習速率是很難做到的。下圖演示了配置學習速率時可能遇到的不同情況。

理解深度學習中的學習率及多種選擇策略

不同學習速率對收斂的影響(圖片來源:cs231n)

此外,學習速率對模型收斂到局部極小值(也就是達到最好的精度)的速度也是有影響的。因此,從正確的方向做出正確的選擇意味著我們可以用更短的時間來訓練模型。

  
  1. Less training time, lesser money spent on GPU cloud compute. 😃

有更好的方法選擇學習速率嗎?

在「訓練神經網絡的周期性學習速率」[4] 的 3.3 節中,Leslie N. Smith 認為,用戶可以以非常低的學習率開始訓練模型,在每一次迭代過程中逐漸提高學習率(線性提高或是指數提高都可以),用戶可以用這種方法估計出最佳學習率。

理解深度學習中的學習率及多種選擇策略

在每一個 mini-batch 后提升學習率

如果我們對每次迭代的學習進行記錄,并繪制學習率(對數尺度)與損失,我們會看到,隨著學習率的提高,從某個點開始損失會停止下降并開始提高。在實踐中,學習速率的理想情況應該是從圖的左邊到最低點(如下圖所示)。在本例中,是從 0.001 到 0.01。

理解深度學習中的學習率及多種選擇策略

上述方法看似有用,但該如何應用呢?

目前,上述方法在 fast.ai 包中作為一個函數進行使用。fast.ai 包是由 Jeremy Howard 開發的一種高級 pytorch 包(就像 Keras 之于 Tensorflow)。

在訓練神經網絡之前,只需輸入以下命令即可開始找到最佳學習速率。

  
  1. # learn is an instance of Learner class or one of derived classes like ConvLearner
  2. learn.lr_find()
  3. learn.sched.plot_lr()

使之更好

現在我們已經知道了什么是學習速率,那么當我們開始訓練模型時,怎樣才能系統地得到最理想的值呢。接下來,我們將介紹如何利用學習率來改善模型的性能。

傳統的方法

一般而言,當已經設定好學習速率并訓練模型時,只有等學習速率隨著時間的推移而下降,模型才能最終收斂。

然而,隨著梯度達到高原,訓練損失會更難得到改善。在 [3] 中,Dauphin 等人認為,減少損失的難度來自鞍點,而不是局部最低點。

理解深度學習中的學習率及多種選擇策略

誤差曲面中的鞍點。鞍點是函數上的導數為零但不是軸上局部極值的點。(圖片來源:safaribooksonline)

所以我們該如何解決這個問題?

我們可以采取幾種辦法。[1] 中是這么說的:

…無需使用固定的學習速率,并隨著時間的推移而令它下降。如果訓練不會改善損失,我們可根據一些周期函數 f 來改變每次迭代的學習速率。每個 Epoch 的迭代次數都是固定的。這種方法讓學習速率在合理的邊界值之間周期變化。這是有益的,因為如果我們卡在鞍點上,提高學習速率可以更快地穿越鞍點。

在 [2] 中,Leslie 提出了一種「三角」方法,這種方法可以在每次迭代之后重新開始調整學習速率。

理解深度學習中的學習率及多種選擇策略

Leslie N. Smith 提出的「Triangular」和「Triangular2」學習率周期變化的方法。左圖中,LR 的最小值和最大值保持不變。右圖中,每個周期之后 LR 最小值和最大值之間的差減半。

另一種常用的方法是由 Loshchilov&Hutter [6] 提出的預熱重啟(Warm Restarts)隨機梯度下降。這種方法使用余弦函數作為周期函數,并在每個周期最大值時重新開始學習速率。「預熱」是因為學習率重新開始時并不是從頭開始的,而是由模型在最后一步收斂的參數決定的 [7]。

下圖展示了伴隨這種變化的過程,該過程將每個周期設置為相同的時間段。

理解深度學習中的學習率及多種選擇策略

SGDR 圖,學習率 vs 迭代次數。

因此,我們現在可以通過周期性跳過「山脈」的辦法縮短訓練時間(下圖)。

理解深度學習中的學習率及多種選擇策略

比較固定 LR 和周期 LR(圖片來自 ruder.io)

研究表明,使用這些方法除了可以節省時間外,還可以在不調整的情況下提高分類準確性,而且可以減少迭代次數。

遷移學習中的學習速率

在 fast.ai 課程中,非常重視利用預訓練模型解決 AI 問題。例如,在解決圖像分類問題時,會教授學生如何使用 VGG 或 Resnet50 等預訓練模型,并將其連接到想要預測的圖像數據集。

我們采取下面的幾個步驟,總結了 fast.ai 是如何完成模型構建(該程序不要與 fast.ai 包混淆)的:

1. 啟用數據增強,precompute = True

2. 使用 lr_find() 找到損失仍在降低的最高學習速率

3. 從預計算激活值到最后一層訓練 1~2 個 Epoch

4. 在 cycle_len = 1 的情況下使用數據增強(precompute=False)訓練最后一層 2~3 次

5. 修改所有層為可訓練狀態

6. 將前面層的學習率設置得比下一個較高層低 3~10 倍

7. 再次使用 lr_find()

8. 在 cycle_mult=2 的情況下訓練整個網絡,直到過度擬合

從上面的步驟中,我們注意到步驟 2、5 和 7 提到了學習速率。這篇文章的前半部分已經基本涵蓋了上述步驟中的第 2 項——如何在訓練模型之前得出最佳學習率。

在下文中,我們會通過 SGDR 來了解如何通過重啟學習速率來減少訓練時間和提高準確性,以避免梯度接近零。

在最后一節中,我們將重點介紹差異學習(differential learning),以及如何在訓練帶有預訓練模型中應用差異學習確定學習速率。

什么是差異學習

差異學習(different learning)在訓練期間為網絡中的不同層設置不同的學習速率。這種方法與人們常用的學習速率配置方法相反,常用的方法是訓練時在整個網絡中使用相同的學習速率。

理解深度學習中的學習率及多種選擇策略

在寫這篇文章的時候,Jeremy 和 Sebastian Ruder 發表的一篇論文深入探討了這個問題。所以我估計差異學習速率現在有一個新的名字——差別性的精調。😃

為了更清楚地說明這個概念,我們可以參考下面的圖。在下圖中將一個預訓練模型分成 3 組,每個組的學習速率都是逐漸增加的。

理解深度學習中的學習率及多種選擇策略

具有差異學習速率的簡單 CNN 模型。圖片來自 [3]

這種方法的意義在于,前幾個層通常會包含非常細微的數據細節,比如線和邊,我們一般不希望改變這些細節并想保留它的信息。因此,無需大量改變權重。

相比之下,在后面的層,以綠色以上的層為例,我們可以從中獲得眼球、嘴巴或鼻子等數據的細節特征,但我們可能不需要保留它們。

這種方法與其他微調方法相比如何?

在 [9] 中提出,微調整個模型太過昂貴,因為有些模型可能超過了 100 層。因此人們通常一次一層地對模型進行微調。

然而,這樣的調整對順序有要求,不具并行性,且因為需要通過數據集進行微調,導致模型會在小數據集上過擬合。

下表證明 [9] 中引入的方法能夠在各種 NLP 分類任務中提高準確度且降低錯誤率。

理解深度學習中的學習率及多種選擇策略

參考文獻:

[1] Improving the way we work with learning rate.

[2] The Cyclical Learning Rate technique.

[3] Transfer Learning using differential learning rates.

[4] Leslie N. Smith. Cyclical Learning Rates for Training Neural Networks.

[5] Estimating an Optimal Learning Rate for a Deep Neural Network

[6] Stochastic Gradient Descent with Warm Restarts

[7] Optimization for Deep Learning Highlights in 2017

[8] Lesson 1 Notebook, fast.ai Part 1 V2

[9] Fine-tuned Language Models for Text Classification

原文鏈接:https://towardsdatascience.com/understanding-learning-rates-and-how-it-improves-performance-in-deep-learning-d0d4059c1c10

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/162782.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/162782.shtml
英文地址,請注明出處:http://en.pswp.cn/news/162782.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

qt msvc2010 qdatetime.h:122: error: C2589: “(”:“::”右邊的非法標記

報錯內容: C:\Qt\Qt5.4.0\5.4.0\msvc2010_opengl\include\QtCore\qdatetime.h:114: error: C2589: “(”:“::”右邊的非法標記 C:\Qt\Qt5.4.0\5.4.0\msvc2010_opengl\include\QtCore\qdatetime.h:114: error: C2059: 語法錯誤:“::” 解決方法: 打開qd…

2023小紅書Android面試之旅

一面 自我介紹 看你寫了很多文章,拿你理解最深刻的一篇出來講一講 講了Binder相關內容 Binder大概分了幾層 哪些方法調用會涉及到Binder通信 大概講一下startActivity的流程,包括與AMS的交互 全頁面停留時長埋點是怎么做的 我在項目中做過的內容&am…

RocketMQ-NameServer詳解

前言 ? RocketMQ架構上主要分為四部分, Broker、Producer、Consumer、NameServer,其他三個都會與NameServer進行通信。 Producer: ? **消息發布的角色,可集群部署。**通過NameServer集群獲得Topic的路由信息,包括Topic下面有哪些Queue&a…

PTA-病毒感染檢測

人的DNA和病毒DNA均表示成由一些字母組成的字符串序列。然后檢測某種病毒DNA序列是否在患者的DNA序列中出現過,如果出現過,則此人感染了該病毒,否則沒有感染。例如,假設病毒的DNA序列為baa,患者1的DNA序列為aaabbba&am…

數據結構與算法編程題15

設計一個算法&#xff0c;通過遍歷一趟&#xff0c;將鏈表中所有結點的鏈接方向逆轉&#xff0c;仍利用原表的存儲空間。 #include <iostream> using namespace std;typedef int Elemtype; #define ERROR 0; #define OK 1;typedef struct LNode {Elemtype data; …

【從入門到起飛】JavaSE—多線程(3)(生命周期,線程安全問題,同步方法)

&#x1f38a;專欄【JavaSE】 &#x1f354;喜歡的詩句&#xff1a;路漫漫其修遠兮&#xff0c;吾將上下而求索。 &#x1f386;音樂分享【如愿】 &#x1f384;歡迎并且感謝大家指出小吉的問題&#x1f970; 文章目錄 &#x1f354;生命周期&#x1f384;線程的安全問題&#…

【Leetcode合集】1410. HTML 實體解析器

1410. HTML 實體解析器 1410. HTML 實體解析器 代碼倉庫地址&#xff1a; https://github.com/slience-me/Leetcode 個人博客 &#xff1a;https://slienceme.xyz 編寫一個函數來查找字符串數組中的最長公共前綴。 如果不存在公共前綴&#xff0c;返回空字符串 ""…

YOLOv7獨家改進: Inner-IoU基于輔助邊框的IoU損失,高效結合 GIoU, DIoU, CIoU,SIoU 等 | 2023.11

??????本文獨家改進:Inner-IoU引入尺度因子 ratio 控制輔助邊框的尺度大小用于計算損失,并與現有的基于 IoU ( GIoU, DIoU, CIoU,SIoU )損失進行有效結合 推薦指數:5顆星 新穎指數:5顆星 收錄: YOLOv7高階自研專欄介紹: http://t.csdnimg.cn/tYI0c …

開發抖音小游戲什么技術

開發抖音小游戲&#xff0c;使用以下技術可能會相對簡單&#xff1a; HTML5&#xff1a;HTML5 是一種用于創建網頁和應用程序的標準標記語言。它具有豐富的功能和靈活性&#xff0c;可以在各種設備和平臺上運行&#xff0c;包括移動設備和瀏覽器。HTML5 提供了許多游戲開發所需…

大模型AI Agent 前沿調研

前言 大模型技術百花齊放&#xff0c;越來越多&#xff0c;同時大模型的落地也在緊鑼密鼓的進行著&#xff0c;其中Agent智能體這個概念可謂是火的一灘糊涂。 今天就分享一些Agent相關的前沿研究&#xff08;僅限基于大模型的AI Agent研究&#xff09;&#xff0c;包括一些論…

完美解決AttributeError: module ‘numpy‘ has no attribute ‘typeDict‘

文章目錄 前言一、完美解決辦法安裝低版本1.21或者1.19.3都可以總結 前言 這個問題從表面看就是和numpy庫相關&#xff0c;所以是小問題&#xff0c;經過來回調試安裝numpy&#xff0c;發現是因為目前的版本太高&#xff0c;因此我們直接安裝低版本numpy。也不用專門卸載目前的…

Qt全球峰會2023中國站 參會概要

Qt全球峰會2023中國站 參會概要 前言峰會議程簽到 & Demo 演示開場致辭Qt Group 產品總監演講&#xff08;產品開發的趨勢-開放的軟件、工具和框架&#xff09;產品戰略QtQuick or QtWidgets&#xff08;c or qml&#xff09;Qt如何定義AI個人看法 Qt 在券商數字化轉型和信…

【MySQL】內連接和外連接

內連接和外連接 前言正式開始內連接外連接左外連接右外連接 前言 前一篇講多表查詢的時候講過笛卡爾積&#xff0c;其實笛卡爾積就算一種連接&#xff0c;不過前一篇講的時候并沒有細說連接相關的內容&#xff0c;本篇就來詳細說說表的連接有哪些。 本篇博客中主要用到的還是…

快速去除Word文檔密碼,全面解決你的困擾

如果你忘記了Word文檔密碼&#xff0c;或者想解密和去除Word文檔密碼&#xff0c;下面是簡單步驟&#xff1a;第一步&#xff0c;百度搜索【密碼帝官網】找到官方網站&#xff1b;第二步&#xff0c;點擊“立即開始”&#xff0c;進入用戶中心&#xff0c;上傳需要解密的文件。…

中部A股第一城,長沙如何贏商?

文|智能相對論 作者|范柔絲 長沙的馬路&#xff0c;都很有故事。 一條解放西路&#xff0c;是全國人民都爭相打卡的娛樂地標&#xff1b;一條太平街&#xff0c;既承載了歷史的厚重又演繹著現代的鮮活...... 但如果來到河西的桐梓坡路&#xff0c;風景會變得截然不同。 沿…

應用軟件安全編程--28SSL 連接時要進行服務器身份驗證

當進行SSL 連接時&#xff0c;服務器身份驗證處于禁用狀態。在某些使用SSL 連接的庫中&#xff0c;默認情況下不 驗證服務器證書。這相當于信任所有證書。 對 SSL 連接時要進行服務器身份驗證的情況&#xff0c;示例1給出了不規范用法(Java 語言)示例。示例2 給出了規范用法(J…

安裝MySQL搭建論壇

課前默寫&#xff1a; 1、nginx配置文件的區域有哪些 ①全局區域 ②events區域 ③http區域 2、區域模塊的作用 全局區域模塊主要是用戶和工作進程 events區域模塊配置最大連接數時需先配置:vim /etc/limits.conf 因為系統默認最大是1024 http區域模塊&#xff1a;代理地…

BUUCTF [HBNIS2018]excel破解 1

BUUCTF:https://buuoj.cn/challenges 題目描述&#xff1a; 得到的 flag 請包上 flag{} 提交。來源&#xff1a; https://github.com/hebtuerror404/CTF_competition_warehouse_2018 密文&#xff1a; 下載附件&#xff0c;得到一個attachment.xls文件。 解題思路&#xff…

計算機視覺的應用19-基于pytorch框架搭建卷積神經網絡CNN的衛星地圖分類問題實戰應用

大家好&#xff0c;我是微學AI&#xff0c;今天給大家介紹一下計算機視覺的應用19-基于pytorch框架搭建卷積神經網絡CNN的衛星地圖分類問題實戰應用。隨著遙感技術和衛星圖像獲取能力的快速發展&#xff0c;衛星圖像分類任務成為了計算機視覺研究中一個重要的挑戰。為了促進這一…

git的用法

目錄 一、為什么需要git 二、git基本操作 2.1、初始化git倉庫 2.2、配置本地倉庫的name和email 2.3、認識工作區、暫存區、版本庫 三、git的實際操作 3.1 提交文件 3.2 查看git狀態以及具體的修改 3.3 git版本回退 git reset 3.1 撤銷修改 四、git分支管理 4.…