1 緣起
最近要準備升級材料,里面有一骨碌是介紹LLM相關技術的,知識蒸餾就是其中一個點,
不過,只分享了蒸餾過程,沒有講述來龍去脈,比如沒有講解Softmax為什么引入T、損失函數為什么使用KL散度,想再進一步整理細節部分分享出來,這是其一。
其二是,近3個月沒有寫文章了,重拾筆頭。
今年已經制定寫作計劃,后面會陸續分享出來。
2 原理
2.1 簡介
知識蒸餾是一種模型壓縮方法,其中一個小模型被訓練來模仿一個預訓練的大模型(或模型集合)。這種訓練設置有時被稱為“教師-學生”模式,其中大模型是教師,小模型是學生。

該方法最早由Bucila等人在2006年提出,并由Hinton等人在2015年進行了推廣。Distiller中的實現基于后者的論文。在這里,我們將提供該方法的概述。更多信息,可以參考該論文https://arxiv.org/abs/1503.02531。

2.2 為什么引入T
在蒸餾過程中,直接計算教師模型概率分布,即教師模型上的softmax函數的輸出,然而,在許多情況下,這個概率分布中正確類別的概率非常高,而其他類別的概率非常接近于0,導致學生模型學習到的信息并沒有提供比數據集中已經提供的真實標簽更多的信息。Softmax如下:
p i = e z i ∑ j N e z j p_{i}= \frac{e^{z_{i} } }{\sum_{j}^{N}e^{z_{j} } } pi?=∑jN?ezj?ezi??
為了解決這個問題,Hinton等人在2015年引入了“softmax溫度”的概念。類別i的概率pi從logits z計算得出,公式如下:
p i = e z i T ∑ j N e z j T p_{i}= \frac{e^{\frac{z_{i} }{T} } }{\sum_{j}^{N}e^{\frac{z_{j} }{T}} } pi?=∑jN?eTzj??eTzi???
?
其中,T是溫度參數,用于控制概率分布的平滑程度。當T較高時,概率分布會更加平滑,從而提供更多的信息,有助于學生模型更好地學習教師模型的知識。
2.3 蒸餾過程
為提升學生模型的性能,學些到更多的信息,引入T,最終的蒸餾過程如下圖所示,知識蒸餾有兩個Loss,即學生模型與教師模型的 L o s s d i s t i l l a t i o n Loss_{distillation} Lossdistillation?,學生模型與真實值的 L o s s s t u d e n t Loss_{student} Lossstudent?,其中,教師模型的預測值為軟標簽,學生模型溫度t的的預測值為軟預測值,學生模型T=1的預測值為硬預測值。

學生損失函數:
L o s s s t u d e n t = ? ∑ i = 1 N y i l o g q i ( 1 ) Loss_{student}=-\sum_{i=1}^{N}y_{i}logq_{i}^{(1)} Lossstudent?=?i=1∑N?yi?logqi(1)?
蒸餾損失函數:
L o s s d i s t i l l a t i o n = ? t 2 ∑ i = 1 N p i ( t ) l o g q i ( t ) Loss_{distillation}=-t^{2} \sum_{i=1}^{N}p_{i}^{(t)}logq_{i}^{(t)} Lossdistillation?=?t2i=1∑N?pi(t)?logqi(t)?
最終損失函數:
L o s s t o t a l = ( 1 ? α ) L o s s s t u d e n t + α L o s s d i s i l l a t i o n Loss_{total}=(1-\alpha )Loss_{student}+\alpha Loss_{disillation} Losstotal?=(1?α)Lossstudent?+αLossdisillation?
2.4 LLM蒸餾損失函數
LLM蒸餾損失函數使用KL散度。
L o s s d i s t i l l a t i o n ? L L M = ? t 2 K L ( p ( t ) ∣ ∣ q ( t ) ) = ? t 2 ∑ i = 1 N p i ( t ) l o g p i ( t ) q i ( t ) Loss_{distillation-LLM}=-t^{2}KL(p^{(t)}||q^{(t)})=-t^{2}\sum_{i=1}^{N}p_{i}^{(t)}log\frac{p_{i}^{(t)} }{q_{i}^{(t)} } Lossdistillation?LLM?=?t2KL(p(t)∣∣q(t))=?t2i=1∑N?pi(t)?logqi(t)?pi(t)??
2.4.1 為什么使用KL散度
KL散度的概念來源于概率論和信息論中。KL散度又被稱為:相對熵、互熵、鑒別信息、Kullback熵、Kullback-Leible散度(即KL散度的簡寫)。KL 散度比交叉熵更適合作為蒸餾損失,因為當學生模型完美匹配教師模型時,蒸餾損失會為零,而交叉熵卻不為零,直接使用交叉熵作為蒸餾損失可能會導致損失隨著 batch 波動,所以使用KL散度作為蒸餾損失函數。
3 小結
(1)Softmax引入T用于計算學生和老師預測值概率分布。
(2)使用KL散度計算學生與老師損失;
(3)T=1計算學生模型預測值概率分布。
4 參考
https://zhuanlan.zhihu.com/p/692216196
https://blog.csdn.net/keeppractice/article/details/145419077
https://intellabs.github.io/distiller/knowledge_distillation.html
https://hsinjhao.github.io/2019/05/22/KL-DivergenceIntroduction/