深度學習三大謎團:集成、知識蒸餾和自蒸餾

深度學習三大謎團:集成、知識蒸餾和自蒸餾

轉自:https://mp.weixin.qq.com/s/DdgjJ-j6jHHleGtq8DlNSA

原文(英):https://www.microsoft.com/en-us/research/blog/three-mysteries-in-deep-learning-ensemble-knowledge-distillation-and-self-distillation/

集成(Ensemble,又稱模型平均)是一種"古老"而強大的方法。只需要對同一個訓練數據集上,幾個獨立訓練的神經網絡的輸出,簡單地求平均,便可以獲得比原有模型更高的性能。甚至只要這些模型初始化條件不同,即使擁有相同的架構,集成方法依然能夠將性能顯著提升。

但是,為什么只是簡單的"集成”,便能提升性能呢?

在這里插入圖片描述

目前已有的理論解釋大多只能適用于以下幾種情況:

(1)boosting:模型之間的組合系數是訓練出來的,而不能簡單地取平均;

(2)Bootstrap aggregation:每個模型的訓練數據集都不相同;

(3)每個模型的類型和體系架構都不相同;

(4)隨機特征或決策樹的集合。

但正如上面提到,在(1)模型系數只是簡單的求平均;(2)訓練數據集完全相同;(3)每個模型架構完全相同 下,集成的方法都能夠做到性能提升。

在這里插入圖片描述

論文鏈接:https://arxiv.org/pdf/2012.09816.pdf

來自微軟研究院機器學習與優化組的高級研究員朱澤園博士,以及卡內基梅隆大學機器學習系助理教授李遠志針對這一現象,在最新發表的論文**《在深度學習中理解集成,知識蒸餾和自蒸餾》**(Towards Understanding Ensemble, Knowledge Distillation, and Self-Distillation in Deep Learning)中,提出了一個理論問題:

在這里插入圖片描述

當我們簡單地對幾個獨立訓練的神經網絡求平均值時,“集成”是如何改善深度學習的測試性能的?尤其是當所有神經網絡具有相同的體系結構,使用相同的標準訓練算法(即具有相同學習率和樣本正則化的隨機梯度下降),在相同數據集上進行訓練時,即使所有單個模型都已經進行了100%訓練準確性?隨后,將集合的這種優越性能“蒸餾”到相同架構的單個神經網絡,為何能夠保持性能基本不變?

兩位作者分別從理論和實驗的角度給出了分析結果:

原因在于數據集中“多視圖”(Multi-view)數據的存在。

1、深度學習的三大謎團

謎團 1:集成

觀察結果顯示,使用不同隨機種子的學習網絡 F1,…F10F_1,\dots F_{10}F1?,F10?(盡管具有非常相似的測試性能)相關聯的函數非常不同。在這種情況下,使用“集成”的技術,僅需要獲取這些經過獨立訓練的網絡輸出的未加權平均值,就可以在許多深度學習應用中極大地提高測試時間的性能。(參見圖1)這意味著各個函數F1,…F10F_1,\dots F_{10}F1?,F10?一定是不同的。但是,為什么集成可以大幅提升性能呢?

如果直接訓練 (F1+?+F10)/10(F_1+\dots+F_{10})/10(F1?+?+F10?)/10,為什么性能提升就消失了?

在這里插入圖片描述

圖1:集成(Ensemble)提升了深度學習應用中的測試準確性,但是這種準確性的提高則無法通過直接訓練模型的平均值來實現。

謎團2:知識蒸餾

雖然集成可以極大地提升測試時間性能,但在推理時間(即測試時間)方面它變慢了10倍:我們需要計算10個神經網絡的輸出,而不是1個。當我們在低能耗的移動環境中部署此類模型時,這是一個嚴重的問題。

為了解決這個問題,研究者提出了一種叫做知識蒸餾的開創性技術。知識蒸餾指的是訓練另一個單獨的模型來匹配集成的輸出。在這里,一張貓的圖像上的集成(也稱為隱藏知識)輸出可能看起來像“ 80%貓+ 10%狗+ 10%汽車”,而真正的訓練標簽是“ 100%貓”。(請參見下面的圖2)

事實證明,經過這樣訓練的單個模型可以在很大程度上匹配10倍以上集成模型的測試時間性能。但是,這導致了更多的問題。

與匹配真實標簽相比,為什么匹配集成模型的輸出可以為我們提供更好的測試準確性?此外,我們在知識蒸餾后對模型進行集成學習可以進一步提高測試準確性嗎?

在這里插入圖片描述

圖2:知識蒸餾和自蒸餾也能夠提升深度學習的性能。

謎團3: 自蒸餾。

注意,知識蒸餾至少在直觀上是有意義的:teacher model有84.8% 的測試準確率,那么student model 可以達到83.8% 。

但接下來這個現象就讓人難以理解了,使用自蒸餾技術,也即老師的學生就是它自己:通過對具有相同架構的單個模型進行知識蒸餾,竟然可以提高測試準確率。想象一下: 訓練出一個測試準確率為81.5% 的單個模型,結果使用相同結構的模型進行自蒸餾一下,測試準確率竟然提高到了83.5%,這不是很奇怪么?

2、神經網絡集成與特征映射集成

大多數現有的集成理論只適用于單個模型之間存在根本性差異的情況(例如,決策樹支持不同變量的子集),或者在不同的數據集上訓練的情況(例如自舉)。

但這些理論顯然不能解釋前面提到的現象。上面提到,集成的模型,其訓練的框架是相同的,訓練的數據也是相同的 —— 唯一的區別只是訓練期間的隨機性。

或許,與深度學習中的集成最為相近的理論,應該是“隨機特征映射集成”(ensemble in random feature mappings)。這表現在兩個方面:一方面,將多個隨機特征的線性模型進行組合,可以提升測試時的性能,這很顯然,因為它增加了特征的數量;另一方面,在特定的參數區域中,神經網絡的權值可以非常接近它們的初始化(稱為神經正切核區域,或 NTK區域) ,結果網絡只在規定的特征映射上學習一個線性函數,這些特征映射完全由隨機初始化決定。

將這兩者結合起來,可以推測深度學習集成與隨機特征映射集成,在原理上是一致的。

這就引出了另外一個問題:集成和知識蒸餾在深度學習上,與在隨機特征映射(即NTK特征映射)上,是否會有相同的表現呢?

答案是否定的。

如下圖3所示,該圖比較了在深度學習/隨機特征映射中的集成和知識蒸餾的性能。

在這里插入圖片描述

圖3: 集成在隨機特征映射上有效(但是出于與深度學習完全不同的原因) ,而知識蒸餾在隨機特征映射中不起作用。

可以看出,通過集成的方式,無論是在深度學習中,還是在隨機特征映射中,都能夠得到較好的性能;而在隨機特征映射中,知識蒸餾的性能顯然要比單個模型的性能還要差。

這就很明顯地說明:集成和蒸餾,原理上并不相同。

具體來說:與在深度學習情況不同,在隨機特征映射中,集成的優越性能不能蒸餾到單個模型上。

在圖3中,神經正切核(NTK)模型的集成,在 CIFAR-10數據集上達到了70.54%的準確率,但經過知識蒸餾后,它下降到了66.01% ,甚至比單個模型的66.68% 的測試準確率還要低。

在深度學習中,直接訓練模型的平均值 (𝐹1+?+𝐹10)/10(𝐹_1+\dots+𝐹_{10})/10(F1?+?+F10?)/10 與訓練單個模型 𝐹𝑖𝐹_𝑖Fi? 相比沒有任何優勢;而在隨機特征映射中,訓練平均值的效果優于單個模型及其集成。

在圖3中,NTK 模型的集成的準確率為 70.54% ,而直接訓練10個模型的平均值準確率為72.86%。

為什么會這樣呢?

主要原因在于,神經網絡是使用分層特征學習,盡管每個模型 𝐹𝑖𝐹_𝑖Fi? 有不同的初始化,但在每一層它們都擁有相同的特征集合。因此,與單個模型相比,多個模型的平均模型,并沒有增加其特征集合的大小。

在隨機特征映射中,每個 𝐹𝑖𝐹_𝑖Fi? 都使用了一組完全不同的規定特征。因此,無論是使用集成的方式,還是直接求平均的方式,都能夠帶來一些性能優勢,但由于特征的稀缺性,在蒸餾后,性能必然會有一定下降。

3、集成與減少單個模型的方差

除了隨機特征的集成外,還有人推測認為,由于神經網絡的高度復雜性,每個單獨的模型 𝐹𝑖𝐹_𝑖Fi? 可能學習到一個函數 𝐹𝑖(𝑥)=𝑦+ξ𝑖𝐹_𝑖(𝑥)=𝑦+ξ_𝑖Fi?(x)=y+ξi?ξ𝑖ξ_𝑖ξi? 是某種噪聲,這種噪聲取決于訓練過程中使用的隨機性。

經典的統計學認為,如果所有的 ξ𝑖ξ_𝑖ξi? 是大致獨立的,那么求取他們的平均值能夠大大減少噪音量。

因此,“集成能夠減少方差”真的是集成能提高提高性能的原因嗎?

證據表明,在深度學習的背景下,這種減少方差來提升性能的假設是值得懷疑的:

1. 集成并不能無限制地提高測試的準確性。

集成超過100個單個模型通常與集成10個單個模型基本沒有差別。因此,100 ξ𝑖ξ_𝑖ξi? 的平均值與10 ξ𝑖ξ_𝑖ξi? 的平均值相比,方差不再減小,表明 ξ𝑖ξ_𝑖ξi? 可能是不獨立的,而且有可能存在偏差,因此均值不為零。在 ξ𝑖ξ_𝑖ξi? 不獨立的情況下,很難討論求得這些 ξ𝑖ξ_𝑖ξi? 的平均值能夠減少多少偏差。

2. 即使理想情況下,我們認為ξ𝑖ξ_𝑖ξi?是相互獨立的,那么這就表明ξ𝑖ξ_𝑖ξi?是有偏或異號的。

于是我們可以將 𝐹𝑖𝐹_𝑖Fi? 寫成:
𝐹𝑖(x)=𝑦+ξ+ξ𝑖𝐹_𝑖(x)=𝑦+ξ+ξ_𝑖 Fi?(x)=y+ξ+ξi?
ξξξ 是一個固定誤差,ξ𝑖ξ_𝑖ξi? 則指每個模型的獨立誤差。于是在集成之后,期望的網絡輸出將接近 y+ξy + ξy+ξ,這會有一個固定的偏差 ξξξ

在這種情況下,為什么知識蒸餾會有效呢?那么,為什么這個帶有偏差 ξξξ (也被稱為隱藏知識)的輸出會優于原來的訓練呢?

3. 集成學習并不總是能夠提高準確性

在圖4中,我們可以看到神經網絡的集成學習并不總是能夠提高測試的準確性,至少在輸入類似高斯分布的情況下是這樣。換句話說,在這些網絡中,求平均值不會帶來任何準確性的增益。

綜上來看,我們需要更深入地理解深度學習中的集成,而不只是認為“集成能夠減少方差”這么簡單。

在這里插入圖片描述

圖4: 當輸入類似高斯分布時,實驗表明集成并不能提高測試的準確性。

4、多視圖數據:深度學習中集成的一種新方法

圖4表明,在非結構化隨機輸入的情況下,集成并不湊效。在我們最新的工作中,我們從數據中找到了集成之所以能夠在深度學習中有效的原因所在。

通常,在一個數據集中(以視覺數據集為例),一個對象通常會有多個視角(muti-view)的數據。以“car”為例,一個汽車的數據集中,通常會有從各個角度拍攝的車輛的照片,通常我們僅需要通過車頭燈、車輪或車窗等其中的一個特征,便可以對汽車進行分類了;即使在圖片中有些特征因為拍攝角度的原因而缺失了,也沒有太大的關系。例如從正前方拍攝的汽車,圖像中便沒有車輪,但這并不妨礙我們識別出“car"。

在這里插入圖片描述

圖5: 在CIFAR-10數據集上進行訓練的 ResNet-34第23層的一些通道的可視化

這種現象在多數數據中都會存在,其中每類數據都具有多個視角的特征,這種結構被稱為“多視圖”(multi-view)。

在大多數數據中,幾乎所有的視圖特征都會顯示出來;但在某些數據中,卻可能缺少一些視圖特征。

更廣泛地說,這種“多視圖”結構事實上,不僅在原始數據中存在,在中間層抽取的特征集合中也會存在。

在這種“多視圖”結構下進行訓練,網絡會:

1)根據學習過程中的隨機性,快速學習這些視圖特征的一個子集;

2)會使用這些視圖特征,記下剩余那些少量不能正確分類的數據。

第一點意味著,如果將不同網絡進行集成,將能夠把學習到的視圖特征聚合起來,從而達到更高的測試精度

第二點意味著,單個模型不能學習所有的視圖特性,不是因為它們沒有足夠的容量,而是因為沒有足夠的訓練數據;大多數數據已經被現有的視圖特征正確分類,因此在訓練階段,它們基本上不提供梯度。

5、知識蒸餾: 強制單個模型學習多個視圖

基于上述視角,我們可以再來分析知識蒸餾是如何工作的。

在現實生活的場景中,一些汽車圖像可能看起來“更像一只貓”:例如,一些汽車圖像的前燈可能看起來像貓眼。當這種情況發生時,集成模型可以提供有意義的隱藏知識,例如**“汽車圖像 X 有10% 像一只貓。”**

這里是個關鍵點。在訓練單個神經網絡模型時,如果沒有學習“前燈”視圖,剩下的視圖或許仍然有可能根據別的視圖將圖像 x 標記為汽車,但它卻無法匹配隱藏知識“圖像 X 有10% 像貓”。

而在知識蒸餾的過程中,蒸餾模型會學習每一個可能的視圖特征,來匹配集成的性能。需要注意的是,深度學習中知識蒸餾的關鍵是,作為一個神經網絡,單個模型在特征學習中能夠學習到集成的所有特征。這與實驗中觀察到的情況是一致的。(見圖6)

在這里插入圖片描述

圖6: 知識蒸餾已經從集成中學習了大部分視圖特性,因此在知識蒸餾之后對模型進行集成學習不會帶來更多的性能提升。

6、自蒸餾: 集成與知識蒸餾的隱性結合

這個解釋也可以用到知識自蒸餾中——訓練一個模型來匹配另一個相同的架構的模型(但使用不同的隨機種子)的輸出,在某種程度上也能提高性能。

簡單來理解,自蒸餾是知識蒸餾的一種特殊情況。

假設我們使用模型 𝐹2𝐹_2F2? 從一個隨機的初始化開始,來匹配另外一個模型 𝐹1𝐹_1F1? 的輸出。在這個過程中 𝐹2𝐹_2F2? 一方面會學習 𝐹1𝐹_1F1? 已經學習到特征子集,另一方面其能夠學習到的特征子集也會受其隨機初始化的影響。

這個過程,可以看做是:首先對兩個單獨的模型 𝐹1𝐹_1F1?𝐹2𝐹_2F2? 進行集成學習,然后蒸餾成 𝐹2𝐹_2F2?

最終的 𝐹2𝐹_2F2? 可能不一定涵蓋數據集中所有可學習的視圖,但它至少有學習所有視圖(通過兩個單個模型的集成學習數據庫來覆蓋)的潛力。這就是自蒸餾模型測試時性能提升的來源!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/532361.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/532361.shtml
英文地址,請注明出處:http://en.pswp.cn/news/532361.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

在墻上找垂直線_墻上如何快速找水平線

在裝修房子的時候,墻面的面積一般都很大,所以在施工的時候要找準水平線很重要,那么一般施工人員是如何在墻上快速找水平線的呢?今天小編就來告訴大家幾種找水平線的方法。一、如何快速找水平線1、用一根透明的軟管,長度…

百度地圖mysql打點_關于百度地圖連接MYSQL的問題,謝謝啦!

該樓層疑似違規已被系統折疊 隱藏此樓查看此樓大家好,剛使用百度地圖API,請教大家一個問題,謝啦!我需要從我的數據庫中取出字段為"city"的所有數據,然后通過bdGEO()函數在地圖上標注這些城市,我是…

PyTorch中的torch.nn.Parameter() 詳解

PyTorch中的torch.nn.Parameter() 詳解 今天來聊一下PyTorch中的torch.nn.Parameter()這個函數,筆者第一次見的時候也是大概能理解函數的用途,但是具體實現原理細節也是云里霧里,在參考了幾篇博文,做過幾個實驗之后算是清晰了&am…

Vision Transformer(ViT)PyTorch代碼全解析(附圖解)

Vision Transformer(ViT)PyTorch代碼全解析 最近CV領域的Vision Transformer將在NLP領域的Transormer結果借鑒過來,屠殺了各大CV榜單。本文將根據最原始的Vision Transformer論文,及其PyTorch實現,將整個ViT的代碼做一…

hdfs的副本數為啥增加了_HDFS詳解之塊大小和副本數

1.HDFSHDFS : 偽分布式(學習)NNDNSNNsbin/start-dfs.sh(開啟hdfs使用的腳本)bin/hdfs dfs -ls (輸入命令加前綴bin/hdfs dfs)2.block(塊)dfs.blocksize : 134217728(字節) / 128M 官網默認一個塊的大小128M*舉例理解塊1個文件 130M,默認一個塊的大小128M…

Linux下的ELF文件、鏈接、加載與庫(含大量圖文解析及例程)

Linux下的ELF文件、鏈接、加載與庫 鏈接是將將各種代碼和數據片段收集并組合為一個單一文件的過程,這個文件可以被加載到內存并執行。鏈接可以執行與編譯時,也就是在源代碼被翻譯成機器代碼時;也可以執行于加載時,也就是被加載器加…

mysql gender_Mysql第一彈

1、創建數據庫pythoncreate database python charsetutf8;2、設計班級表結構為id、name、isdelete,編寫創建表的語句create table classes(id int unsigned auto_increment primary key not null,name varchar(10),isdelete bit default 0);向班級表中插入數據pytho…

python virtualenv nginx_Ubuntu下搭建Nginx+supervisor+pypy+virtualenv

系統:Ubuntu 14.04 LTS搭建python的運行環境:NginxSupervisorPypyVirtualenv軟件說明:Nginx:通過upstream進行負載均衡Supervisor:管理python進程Pypy:用Python實現的Python解釋器PyPy is a fast, complian…

如何設置mysql表中文亂碼_php mysql表中文亂碼問題如何解決

為避免mysql中出現中文亂碼,建議在創建數據庫時指定編碼格式:復制代碼 代碼示例:create database zzjz CHARACTER SET gbk COLLATE gbk_chinese_ci;create table zz_employees (employeeid int unsigned not null auto_increment primary key,name varch…

java 按鈕 監聽_Button的四種監聽方式

Button按鈕設置點擊的四種監聽方式注:加粗放大的都是改變的代碼1.使用匿名內部類的形式進行設置使用匿名內部類的形式,直接將需要設置的onClickListener接口對象初始化,內部的onClick方法會在按鈕被點擊的時候執行第一個活動的java代碼&#…

java int轉bitmap_Java Base64位編碼與String字符串的相互轉換,Base64與Bitmap的相互轉換實例代碼...

首先是網上大神給的類package com.duanlian.daimengmusic.utils;public final class Base64Util {private static final int BASELENGTH 128;private static final int LOOKUPLENGTH 64;private static final int TWENTYFOURBITGROUP 24;private static final int EIGHTBIT …

linux查看java虛擬機內存_深入理解java虛擬機(linux與jvm內存關系)

本文轉載自美團技術團隊發表的同名文章https://tech.meituan.com/linux-jvm-memory.html一, linux與進程內存模型要理解jvm最重要的一點是要知道jvm只是linux的一個進程,把jvm的視野放大,就能很好的理解JVM細分的一些概念下圖給出了硬件系統進程三個層面內存之間的關系.從硬件上…

java 循環stringbuffer_java常用類-----StringBuilder和StringBuffer的用法

一、可變字符常用方法package cn.zxg.PackgeUse;/*** 測試StringBuilder,StringBuffer可變字符序列常用方法*/public class TestStringBuilder2 {public static void main(String[] args) {StringBuilder sbnew StringBuilder();for(int i0;i<26;i){char temp(char)(ai);sb.…

java function void_Java8中你可能不知道的一些地方之函數式接口實戰

什么時候可以使用 Lambda&#xff1f;通常 Lambda 表達式是用在函數式接口上使用的。從 Java8 開始引入了函數式接口&#xff0c;其說明比較簡單&#xff1a;函數式接口(Functional Interface)就是一個有且僅有一個抽象方法&#xff0c;但是可以有多個非抽象方法的接口。 java8…

java jvm內存地址_JVM--Java內存區域

Java虛擬機在執行Java程序的過程中會把它所管理的內存劃分為若干個不同的數據區域&#xff0c;如圖&#xff1a;1.程序計數器可以看作是當前線程所執行的字節碼的行號指示器&#xff0c;通俗的講就是用來指示執行哪條指令的。為了線程切換后能恢復到正確的執行位置Java多線程是…

java情人節_情人節寫給女朋友Java Swing代碼程序

馬上又要到情人節了&#xff0c;再不解風情的人也得向女友表示表示。作為一個程序員&#xff0c;示愛的時候自然也要用我們自己的方式。這里給大家上傳一段我在今年情人節的時候寫給女朋友的一段簡單的Java Swing代碼&#xff0c;主要定義了一個對話框&#xff0c;讓女友選擇是…

java web filter鏈_filter過濾鏈:Filter鏈是如何構建的?

在一個Web應用程序中可以注冊多個Filter程序&#xff0c;每個Filter程序都可以針對某一個URL進行攔截。如果多個Filter程序都對同一個URL進行攔截&#xff0c;那么這些Filter就會組成一個Filter鏈(也叫過濾器鏈)。Filter鏈用FilterChain對象來表示&#xff0c;FilterChain對象中…

java web 應用技術與案例教程_《Java Web應用開發技術與案例教程》怎么樣_目錄_pdf在線閱讀 - 課課家教育...

出版說明前言第1章 java Web應用開發技術概述1.1 Java Web應用開發技術簡介1.1.1 Java Web應用1.1.2 Java Web應用開發技術1.2 Java Web開發環境及開發工具1.2.1 JDK的下載與安裝1.2.2 Tomcat服務器的安裝和配置1.2.3 MyEclipse集成開發工具的安裝與操作1.3 Java Web應用程序的…

java環境變量自動設置_自動設置Java環境變量

echo offSETLOCALENABLEDELAYEDEXPANSIONfor /f "tokens2* delims " %%i in(reg query "HKLM\Software\JavaSoft\Java Development Kit" /s ^|find /I"JavaHome") do (echo 找到目錄 %%jset /p isOK該目錄是不是JDK^(JavaDevelopment Kit^)的安裝…

mysql運行狀態監控研究內容_如何監控mysql主從的運行狀態shell腳本實例介紹

如何監控mysql主從的運行狀態shell腳本實例介紹。#!/bin/bash#define mysql variablemysql_user”root”mysql_pass”123456″email_addr”slavecentos.bz”mysql_statusnetstat -nl | awk ‘NR>2{if ($4 ~ /.*:3306/) {print “Yes”;exit 0}}’if [ "$mysql_status&q…