【經典簡讀】知識蒸餾(Knowledge Distillation) 經典之作
轉自:【經典簡讀】知識蒸餾(Knowledge Distillation) 經典之作
作者:潘小小
知識蒸餾是一種模型壓縮方法,是一種基于“教師-學生網絡思想”的訓練方法,由于其簡單,有效,在工業界被廣泛應用。這一技術的理論來自于2015年Hinton發表的一篇神作:Distilling the Knowledge in a Neural Network。Knowledge Distillation,簡稱KD,顧名思義,就是將已經訓練好的模型包含的知識(”Knowledge”),蒸餾(“Distill”)提取到另一個模型里面去。今天,我們就來簡單讀一下這篇論文,力求用簡單的語言描述論文作者的主要思想。
在本文中,我們將從背景和動機講起,然后著重介紹“知識蒸餾”的方法,最后我會討論“溫度“這個名詞:
- 溫度: 我們都知道“蒸餾”需要在高溫下進行,那么這個“蒸餾”的溫度代表了什么,又是如何選取合適的溫度?
本文的內容由以下幾個部分組成
-
介紹
- 論文提出的背景
- “思想歧路“
-
知識蒸餾的理論依據
- Teacher Model和Student Model
- 知識蒸餾的關鍵點
- softmax函數
-
知識蒸餾的具體方法
- 通用的知識蒸餾方法
- 一種特殊情形: 直接match logits
-
關于"溫度"的討論
- 溫度的特點
- 溫度代表了什么,如何選取合適的溫度?
-
參考
1. 介紹
1.1. 論文提出的背景
雖然在一般情況下,我們不會去區分訓練和部署使用的模型,但是訓練和部署之間存在著一定的不一致性:
-
在訓練過程中,我們需要使用復雜的模型,大量的計算資源,以便從非常大、高度冗余的數據集中提取出信息。在實驗中,效果最好的模型往往規模很大,甚至由多個模型集成得到。而大模型不方便部署到服務中去,常見的瓶頸如下:
- 推斷速度慢
- 對部署資源要求高(內存,顯存等)
-
在部署時,我們對延遲以及計算資源都有著嚴格的限制。
因此,模型壓縮(在保證性能的前提下減少模型的參數量)成為了一個重要的問題。而”模型蒸餾“屬于模型壓縮的一種方法。
插句題外話,我們可以從模型參數量和訓練數據量之間的相對關系來理解underfitting和overfitting。AI領域的從業者可能對此已經習以為常,但是為了力求讓小白也能讀懂本文,還是引用我同事的解釋(我印象很深)形象地說明一下:
模型就像一個容器,訓練數據中蘊含的知識就像是要裝進容器里的水。當數據知識量(水量)超過模型所能建模的范圍時(容器的容積),加再多的數據也不能提升效果(水再多也裝不進容器),因為模型的表達空間有限(容器容積有限),就會造成underfitting;而當模型的參數量大于已有知識所需要的表達空間時(容積大于水量,水裝不滿容器),就會造成overfitting,即模型的variance會增大(想象一下搖晃半滿的容器,里面水的形狀是不穩定的)。
1.2. “思想歧路”
上面容器和水的比喻非常經典和貼切,但是會引起一個誤解: 人們在直覺上會覺得,要保留相近的知識量,必須保留相近規模的模型。也就是說,一個模型的參數量基本決定了其所能捕獲到的數據內蘊含的“知識”的量。
這樣的想法是基本正確的,但是需要注意的是:
- 模型的參數量和其所能捕獲的“知識“量之間并非穩定的線性關系(下圖中的1),而是接近邊際收益逐漸減少的一種增長曲線(下圖中的2和3)
- 完全相同的模型架構和模型參數量,使用完全相同的訓練數據,能捕獲的“知識”量并不一定完全相同,另一個關鍵因素是訓練的方法。合適的訓練方法可以使得在模型參數總量比較小時,盡可能地獲取到更多的“知識”(下圖中的3與2曲線的對比).
2. 知識蒸餾的理論依據
2.1. Teacher Model和Student Model
知識蒸餾使用的是Teacher—Student模型,其中teacher是“知識”的輸出者,student是“知識”的接受者。知識蒸餾的過程分為2個階段:
- 原始模型訓練: 訓練"Teacher模型", 簡稱為Net-T,它的特點是模型相對復雜,也可以由多個分別訓練的模型集成而成。我們對"Teacher模型"不作任何關于模型架構、參數量、是否集成方面的限制,唯一的要求就是,對于輸入X, 其都能輸出Y,其中Y經過softmax的映射,輸出值對應相應類別的概率值。
- 精簡模型訓練: 訓練"Student模型", 簡稱為Net-S,它是參數量較小、模型結構相對簡單的單模型。同樣的,對于輸入X,其都能輸出Y,Y經過softmax映射后同樣能輸出對應相應類別的概率值。
在本論文中,作者將問題限定在分類問題下,或者其他本質上屬于分類問題的問題,該類問題的共同點是模型最后會有一個softmax層,其輸出值對應了相應類別的概率值。
2.2. 知識蒸餾的關鍵點
如果回歸機器學習最最基礎的理論,我們可以很清楚地意識到一點(而這一點往往在我們深入研究機器學習之后被忽略): 機器學習最根本的目的在于訓練出在某個問題上泛化能力強的模型。
- 泛化能力強: 在某問題的所有數據上都能很好地反應輸入和輸出之間的關系,無論是訓練數據,還是測試數據,還是任何屬于該問題的未知數據。
而現實中,由于我們不可能收集到某問題的所有數據來作為訓練數據,并且新數據總是在源源不斷的產生,因此我們只能退而求其次,訓練目標變成在已有的訓練數據集上建模輸入和輸出之間的關系。由于訓練數據集是對真實數據分布情況的采樣,訓練數據集上的最優解往往會多少偏離真正的最優解(這里的討論不考慮模型容量)。
而在知識蒸餾時,由于我們已經有了一個泛化能力較強的Net-T,我們在利用Net-T來蒸餾訓練Net-S時,可以直接讓Net-S去學習Net-T的泛化能力。
一個很直白且高效的遷移泛化能力的方法就是:使用softmax層輸出的類別的概率來作為“soft target”。
【KD的訓練過程和傳統的訓練過程的對比】
- 傳統training過程(hard targets): 對ground truth求極大似然
- KD的training過程(soft targets): 用large model的class probabilities作為soft targets
KD的訓練過程為什么更有效?
softmax層的輸出,除了正例之外,負標簽也帶有大量的信息,比如某些負標簽對應的概率遠遠大于其他負標簽。而在傳統的訓練過程(hard target)中,所有負標簽都被統一對待。也就是說,KD的訓練方式使得每個樣本給Net-S帶來的信息量大于傳統的訓練方式。
【舉個例子】
在手寫體數字識別任務MNIST中,輸出類別有10個。假設某個輸入的“2”更加形似"3",softmax的輸出值中"3"對應的概率為0.1,而其他負標簽對應的值都很小,而另一個"2"更加形似"7","7"對應的概率為0.1。這兩個"2"對應的hard target的值是相同的,但是它們的soft target卻是不同的,由此我們可見soft target蘊含著比hard target多的信息。并且soft target分布的熵相對高時,其soft target蘊含的知識就更豐富。
這就解釋了為什么通過蒸餾的方法訓練出的Net-S相比使用完全相同的模型結構和訓練數據只使用hard target的訓練方法得到的模型,擁有更好的泛化能力。
2.3. softmax函數
先回顧一下原始的softmax函數:
qi=exp?(zi)∑jexp?(zj)q_i=\frac{\exp(z_i)}{\sum_j\exp(z_j)} qi?=∑j?exp(zj?)exp(zi?)?
但要是直接使用softmax層的輸出值作為soft target, 這又會帶來一個問題: 當softmax輸出的概率分布熵相對較小時,負標簽的值都很接近0,對損失函數的貢獻非常小,小到可以忽略不計。因此**“溫度”**這個變量就派上了用場。
下面的公式時加了溫度這個變量之后的softmax函數:
qi=exp?(zi/T)∑jexp?(zj/T)q_i=\frac{\exp(z_i/T)}{\sum_j\exp(z_j/T)} qi?=∑j?exp(zj?/T)exp(zi?/T)?
- 這里的T就是溫度。
- 原來的softmax函數是T = 1的特例。 T越高,softmax的output probability distribution越趨于平滑,其分布的熵越大,負標簽攜帶的信息會被相對地放大,模型訓練將更加關注負標簽
3. 知識蒸餾的具體方法
3.1. 通用的知識蒸餾方法
- 第一步是訓練Net-T;第二步是在高溫T下,蒸餾Net-T的知識到Net-S
訓練Net-T的過程很簡單,下面詳細講講第二步:高溫蒸餾的過程。高溫蒸餾過程的目標函數由distill loss(對應soft target)和student loss(對應hard target)加權得到。示意圖如上。
L=αLsoft+βLhardL=\alpha L_{soft}+\beta L_{hard}L=αLsoft?+βLhard?
-
viv_ivi?: Net-T的logits
-
ziz_izi?: Net-S的logits
-
piTp_i^TpiT?: Net-T的在溫度=T下的softmax輸出在第i類上的值
-
qiTq_i^TqiT?: Net-S的在溫度=T下的softmax輸出在第i類上的值
-
cic_ici?: 在第i類上的ground truth值, ci∈{0,1}, 正標簽取1,負標簽取0.
-
NNN: 總標簽數量
Net-T 和 Net-S 同時輸入 transfer set (這里可以直接復用訓練Net-T用到的training set), 用Net-T產生的softmax distribution (with high temperature) 來作為soft target,Net-S在相同溫度T條件下的softmax輸出和soft target的cross entropy就是Loss函數的第一部分 LsoftL_{soft}Lsoft?:
Lsoft=?∑jNpjTlog??(qjT)L_{soft}=?\sum_j^Np_j^T\log?(q_j^T) Lsoft?=?j∑N?pjT?log?(qjT?)
其中,piT=exp?(vi/T)∑kNexp?(vk/T),qiT=exp?(zi/T)∑kNexp?(zk/T)p_i^T=\frac{\exp(v_i/T)}{\sum_k^N\exp(v_k/T)},\ q_i^T=\frac{\exp(z_i/T)}{\sum_k^N\exp(z_k/T)}piT?=∑kN?exp(vk?/T)exp(vi?/T)?,?qiT?=∑kN?exp(zk?/T)exp(zi?/T)? 。
Net-S在T=1的條件下的softmax輸出和ground truth的cross entropy就是Loss函數的第二部分 LhardL_{hard}Lhard?。
Lhard=?∑jNcjlog?(qjT=1)L_{hard}=-\sum_j^Nc_j\log(q_j^{T=1}) Lhard?=?j∑N?cj?log(qjT=1?)
其中,qjT=1=exp?(zi)∑kNexp?(zk)q^{T=1}_j=\frac{\exp(z_i)}{\sum_k^N\exp(z_k)}qjT=1?=∑kN?exp(zk?)exp(zi?)?
第二部分Loss LhardL_{hard}Lhard? 的必要性其實很好理解: Net-T也有一定的錯誤率,使用ground truth可以有效降低錯誤被傳播給Net-S的可能。打個比方,老師雖然學識遠遠超過學生,但是他仍然有出錯的可能,而這時候如果學生在老師的教授之外,可以同時參考到標準答案,就可以有效地降低被老師偶爾的錯誤“帶偏”的可能性。
【討論】
實驗發現第二部分所占比重比較小的時候,能產生最好的結果,這是一個經驗的結論。一個可能的原因是,由于soft target產生的gradient與hard target產生的gradient之間有與 T 相關的比值。原論文中只是一筆帶過,我在下面補充了一些簡單的推導。(ps. 下面推導可能有些錯誤,如果有讀者能夠正確推出來請私信我~)
Soft Target
Lsoft=?∑jNpjTlog?(qjT)=?∑jNzj/T×exp?(vj/T)∑kNexp?(vk/T)(1∑kNexp?(zk/T)?exp?(zj/T)(∑kNexp?(zk/T))2)≈?1T∑kNexp?(vk/T)(∑jNzjexp?(vj/T)∑kNexp?(zk/T)?∑jNzjexp?(zj/T)exp?(vj/T)(∑kNexp?(zk/T))2)\begin{aligned} L_{soft}&=-\sum_j^Np_j^T\log(q_j^T)\\ &=-\sum_j^N\frac{z_j/T\times\exp(v_j/T)}{\sum_k^N\exp(v_k/T)}(\frac{1}{\sum_k^N\exp(z_k/T)}-\frac{\exp(z_j/T)}{(\sum_k^N\exp(z_k/T))^2})\\ &\approx-\frac{1}{T\sum_k^N\exp(v_k/T)}(\frac{\sum_j^Nz_j\exp(v_j/T)}{\sum_k^N\exp(z_k/T)}-\frac{\sum_j^Nz_j\exp(z_j/T)\exp(v_j/T)}{(\sum_k^N\exp(z_k/T))^2}) \end{aligned} Lsoft??=?j∑N?pjT?log(qjT?)=?j∑N?∑kN?exp(vk?/T)zj?/T×exp(vj?/T)?(∑kN?exp(zk?/T)1??(∑kN?exp(zk?/T))2exp(zj?/T)?)≈?T∑kN?exp(vk?/T)1?(∑kN?exp(zk?/T)∑jN?zj?exp(vj?/T)??(∑kN?exp(zk?/T))2∑jN?zj?exp(zj?/T)exp(vj?/T)?)?
Hard Target
Lhard=?∑jNcjlog?(qjT=1)=?(∑jNcjzj∑kNexp?(zk)?∑jNcjzjexp?(zj)(∑kNexp?(zk))2)L_{hard}=-\sum_j^Nc_j\log(q^{T=1}_j)=-(\frac{\sum_j^Nc_jz_j}{\sum_{k}^N\exp(z_k)}-\frac{\sum_j^Nc_jz_j\exp(z_j)}{(\sum_k^N\exp(z_k))^2}) Lhard?=?j∑N?cj?log(qjT=1?)=?(∑kN?exp(zk?)∑jN?cj?zj???(∑kN?exp(zk?))2∑jN?cj?zj?exp(zj?)?)
由于 ?Lsoft?zi\frac{\partial{L_{soft}}}{\partial{z_i}}?zi??Lsoft?? 的magnitude大約是 ?Lhard?zi\frac{\partial{L_{hard}}}{\partial{z_i}}?zi??Lhard?? 的 1T2\frac{1}{T^2}T21? ,因此在同時使用soft target和hard target的時候,需要在soft target之前乘上 T2T^2T2 的系數,這樣才能保證soft target和hard target貢獻的梯度量基本一致。
【注意】 在Net-S訓練完畢后,做inference時其softmax的溫度T要恢復到1.
3.2. 一種特殊情形: 直接match logits(不經過softmax)
直接match logits指的是,直接使用softmax層的輸入logits(而不是輸出)作為soft targets,需要最小化的目標函數是Net-T和Net-S的logits之間的平方差。
直接上結論: 直接match logits的做法是 T→∞T\rightarrow\inftyT→∞ 的情況下的特殊情形。
由單個case貢獻的loss,推算出對應在Net-S每個logit ziz_izi?上的gradient:
?Lsoft?zi=1T(qi?pi)=1T(exp?(zi/T)∑jexp?(zj/T)?exp?(vi/T)∑jexp?(vj/T))\frac{\partial{L_{soft}}}{\partial z_i}=\frac{1}{T}(q_i-p_i)=\frac{1}{T}(\frac{\exp(z_i/T)}{\sum_j\exp(z_j/T)}-\frac{\exp(v_i/T)}{\sum_j\exp(v_j/T)}) ?zi??Lsoft??=T1?(qi??pi?)=T1?(∑j?exp(zj?/T)exp(zi?/T)??∑j?exp(vj?/T)exp(vi?/T)?)
當 T→∞T\rightarrow \inftyT→∞ 時,我們使用 1+x/T1+x/T1+x/T 來近似 exp?(x/T)\exp(x/T)exp(x/T) ,于是得到
?Lsoft?zi≈1T(1+zi/TN+∑jzj/T?1+vi/TN+∑jvj/T)\frac{\partial L_{soft}}{\partial z_i}\approx\frac{1}{T}(\frac{1+z_i/T}{N+\sum_jz_j/T}-\frac{1+v_i/T}{N+\sum_jv_j/T}) ?zi??Lsoft??≈T1?(N+∑j?zj?/T1+zi?/T??N+∑j?vj?/T1+vi?/T?)
如果再加上 logits 是零均值的假設 ∑jzj=∑jvj=0\sum_jz_j=\sum_jv_j=0∑j?zj?=∑j?vj?=0 。那么上面的公式可以簡化成:
?Lsoft?zi≈1NT2(zi?vi)\frac{\partial L_{soft}}{\partial z_i}\approx\frac{1}{NT^2}(z_i-v_i) ?zi??Lsoft??≈NT21?(zi??vi?)
等價于 minimize 以下損失函數
Lsoft′=1/2(zi?vi)2L'_{soft}={1}/{2}(z_i-v_i)^2 Lsoft′?=1/2(zi??vi?)2
4. 關于"溫度"的討論
【問題】 我們都知道“蒸餾”需要在高溫下進行,那么這個“蒸餾”的溫度代表了什么,又是如何選取合適的溫度?如下圖所示,隨著溫度T的增大,概率分布的熵逐漸增大。
4.1. 溫度的特點
在回答這個問題之前,先討論一下溫度T的特點
- 原始的softmax函數是 T=1 時的特例, T<1 時,概率分布比原始更“陡峭”, T>1 時,概率分布比原始更“平緩”。
- 溫度越高,softmax上各個值的分布就越平均(思考極端情況: (i) T=∞\infty∞ , 此時softmax的值是平均分布的;(ii) T=0,此時softmax的值就相當于 argmax , 即最大的概率處的值趨近于1,而其他值趨近于0)
- 不管溫度T怎么取值,Soft target都有忽略相對較小的 pip_ipi? 攜帶的信息的傾向
4.2. 溫度代表了什么,如何選取合適的溫度?
溫度的高低改變的是Net-S訓練過程中對負標簽的關注程度: 溫度較低時,對負標簽的關注,尤其是那些顯著低于平均值的負標簽的關注較少;而溫度較高時,負標簽相關的值會相對增大,Net-S會相對多地關注到負標簽。
實際上,負標簽中包含一定的信息,尤其是那些值顯著高于平均值的負標簽。但由于Net-T的訓練過程決定了負標簽部分比較noisy,并且負標簽的值越低,其信息就越不可靠。因此溫度的選取比較empirical,本質上就是在下面兩件事之中取舍:
- 從有部分信息量的負標簽中學習 --> 溫度要高一些
- 防止受負標簽中噪聲的影響 -->溫度要低一些
總的來說,T的選擇和Net-S的大小有關,Net-S參數量比較小的時候,相對比較低的溫度就可以了(因為參數量小的模型不能capture all knowledge,所以可以適當忽略掉一些負標簽的信息)
5. 參考
- 深度壓縮之蒸餾模型 - 風雨兼程的文章 - 知乎 https://zhuanlan.zhihu.com/p/24337627
- 知識蒸餾Knowledge Distillation - 船長的文章 - 知乎 https://zhuanlan.zhihu.com/p/83456418
- https://towardsdatascience.com/knowledge-distillation-simplified-dd4973dbc764
- https://nervanasystems.github.io/distiller/knowledge_distillation.html