知識蒸餾 - 通過引入溫度參數T調整 Softmax 的輸出
flyfish
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np# 設置中文字體支持
plt.rcParams["font.family"] = ['AR PL UMing CN'] # Linux
plt.rcParams['axes.unicode_minus'] = False # 解決負號顯示問題# 模擬模型輸出的logits
logits = torch.tensor([10.0, 4.0, 1.0])# 定義不同的溫度值
temperatures = [0.5, 1.0, 5.0, 10.0, 20.0]# 計算不同溫度下的softmax輸出
results = {}
for T in temperatures:soft_labels = F.softmax(logits / T, dim=0)results[T] = soft_labels.numpy()# 打印結果(保留四位小數)
print("原始logits:", logits.numpy())
for T, soft_labels in results.items():# 使用列表推導式和格式化字符串保留四位小數formatted_probs = [f"{p:.4f}" for p in soft_labels]print(f"溫度 T={T} 時的軟標簽: [{', '.join(formatted_probs)}]")# 可視化不同溫度下的概率分布
plt.figure(figsize=(14, 7))
x = np.arange(len(logits))
width = 0.8 / len(temperatures)for i, (T, soft_labels) in enumerate(results.items()):bars = plt.bar(x + i * width - 0.4 + width/2, soft_labels, width, label=f'T={T}')# 在每個柱子上方添加保留四位小數的概率值for bar in bars:height = bar.get_height()plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,f'{height:.4f}', ha='center', va='bottom', rotation=90)plt.xticks(x, ['貓', '狗', '狐貍'])
plt.ylabel('概率')
plt.title('不同溫度T下的softmax概率分布')
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.ylim(0, 1.1) # 調整y軸范圍,使標簽顯示完整
plt.tight_layout()
plt.show()
原始logits: [10. 4. 1.]
溫度 T=0.5 時的軟標簽: [1.0000, 0.0000, 0.0000]
溫度 T=1.0 時的軟標簽: [0.9974, 0.0025, 0.0001]
溫度 T=5.0 時的軟標簽: [0.6819, 0.2054, 0.1127]
溫度 T=10.0 時的軟標簽: [0.5114, 0.2807, 0.2079]
溫度 T=20.0 時的軟標簽: [0.4204, 0.3115, 0.2681]
低溫(T=0.5):分布極陡峭,幾乎只保留最大值對應的類別(貓)
標準溫度(T=1.0):接近傳統 softmax,突出最大值但保留少量其他類別概率
高溫(T=10.0):分布非常平滑,所有類別概率接近均等
對于給定的logits向量z=[z1,z2,...,zk]\mathbf{z} = [z_1, z_2, ..., z_k]z=[z1?,z2?,...,zk?](其中ziz_izi?是模型對第iii類的原始輸出分數,比如代碼中的logits = [10.0, 4.0, 1.0]
),以及溫度參數TTT,第iii類的軟標簽概率pip_ipi?計算公式為:
pi=ezi/T∑j=1kezj/T p_i = \frac{e^{z_i / T}}{\sum_{j=1}^{k} e^{z_j / T}} pi?=∑j=1k?ezj?/Tezi?/T?
解釋:
ziz_izi?:代碼中的logits[i]
(如logits[0] = 10.0
對應“貓”的原始分數);
TTT:代碼中的溫度參數(如T=0.5,1.0,5.0
等);
ezi/Te^{z_i / T}ezi?/T:對“原始分數除以溫度”做指數運算(代碼中由F.softmax
內部實現);
分母∑j=1kezj/T\sum_{j=1}^{k} e^{z_j / T}∑j=1k?ezj?/T:所有類別的指數結果之和,用于歸一化(確保所有概率之和為1);
pip_ipi?:最終的軟標簽概率(代碼中soft_labels[i]
,如“貓”在T=5.0
時的概率約為0.6811)。
作用:
通過溫度TTT縮放logits的“差異幅度”:
當T→0+T \to 0^+T→0+時,指數部分對大的ziz_izi?更敏感,概率分布會極度陡峭(接近硬標簽);
當T→+∞T \to +\inftyT→+∞時,所有zi/Tz_i / Tzi?/T趨近于0,指數結果趨近于1,概率分布會趨近均勻(所有類別概率接近相等)。
如T=0.5
時“貓”的概率接近1,T=20
時三類概率更均勻。
在知識蒸餾(Knowledge Distillation)中,引入溫度參數TTT 調整 Softmax 輸出的核心目的是獲取更有信息量的“軟標簽”(Soft Labels),以便讓學生模型(Student Model)更好地學習教師模型(Teacher Model)的“知識”。溫度TTT 的核心作用是通過“軟化”教師模型的輸出分布,保留更多關于類別間關系的細粒度知識,讓學生模型能更有效地學習教師的經驗。
原因
1. 原始 Softmax(T=1T=1T=1)的局限性
原始 Softmax 函數的公式為:
pi=ezi∑jezj
p_i = \frac{e^{z_i}}{\sum_{j} e^{z_j}}
pi?=∑j?ezj?ezi??
其中ziz_izi? 是模型輸出的 logits(未歸一化的分數)。
當模型對正確類別有較高置信度時(比如教師模型很“確信”某個樣本是“貓”),原始 Softmax 的輸出會極度集中在最大 logits 對應的類別上,其他類別的概率幾乎為 0(例如:p貓≈0.999p_{\text{貓}} \approx 0.999p貓?≈0.999,p狗≈0.001p_{\text{狗}} \approx 0.001p狗?≈0.001,p狐貍≈0p_{\text{狐貍}} \approx 0p狐貍?≈0)。
這種“陡峭”的概率分布(接近硬標簽)丟失了很多有價值的信息:教師模型可能認為“狗”比“狐貍”更接近“貓”(即p狗>p狐貍p_{\text{狗}} > p_{\text{狐貍}}p狗?>p狐貍?),但原始 Softmax 會將這種差異壓縮到幾乎不可見。
2. 溫度TTT 的作用:“軟化”概率分布,保留更多知識
當引入溫度TTT 后,Softmax 公式變為:
pi=ezi/T∑jezj/T
p_i = \frac{e^{z_i / T}}{\sum_{j} e^{z_j / T}}
pi?=∑j?ezj?/Tezi?/T?
當T>1T > 1T>1 時:logits 被“縮放”(除以TTT),導致指數函數的“敏感度”降低,不同類別的概率差異被拉平(分布更平緩)。
例如,教師模型對“貓”“狗”“狐貍”的 logits 為 [10, 4, 1]:
T=1T=1T=1 時,輸出可能是 [0.997, 0.002, 0.001](幾乎只有“貓”有概率);
T=10T=10T=10 時,輸出可能是 [0.607, 0.242, 0.151](保留了“狗比狐貍更接近貓”的信息)。
這種“軟化”的軟標簽包含了教師模型對類別間相似性的判斷(哪些類別容易混淆、哪些類別差異大),這些信息比單純的硬標簽(如“貓”)更豐富,能幫助學生模型學習到更魯棒的特征。
3. 知識蒸餾中的配合使用
在知識蒸餾中,教師模型用高溫TTT 生成軟標簽,學生模型在訓練時既學習軟標簽(用相同的TTT),也學習原始硬標簽(可選)。推理時,學生模型再用T=1T=1T=1 輸出最終的硬預測。
通過這種方式,學生模型不僅學到了“正確答案”,還學到了教師模型的“推理過程”(如何權衡不同類別的可能性),從而在參數更少的情況下達到接近教師模型的性能。