知識蒸餾 - 基于KL散度的知識蒸餾 HelloWorld 示例 采用PyTorch 內置函數F.kl_div的實現方式
flyfish
kl_div 是 Kullback-Leibler Divergence的英文縮寫。
其中,KL 對應提出該概念的兩位學者(Kullback 和 Leibler)的姓氏首字母“div”是 divergence(散度)的縮寫。
F.kl_div(logQ, P, reduction='sum')
等價于 torch.sum(P * (torch.log(P) - logQ))
import torch
import torch.nn.functional as F# 1. 定義示例輸入(教師和學生的logits)
teacher_logits = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32)
student_logits = torch.tensor([[1.2, 2.1, 2.9], [3.8, 5.2, 6.1]], dtype=torch.float32)
T = 2.0 # 溫度參數
batch_size = teacher_logits.size(0)# 2. 溫度軟化處理
teacher_scaled = teacher_logits / T
student_scaled = student_logits / T# 3. 計算分布
teacher_soft = F.softmax(teacher_scaled, dim=-1) # 教師分布 P
student_log_soft = F.log_softmax(student_scaled, dim=-1) # 學生對數分布 log Q# 4. 兩種方式計算KL散度
# 方式1:手動計算(原始公式)
manual_kl = torch.sum(teacher_soft * (torch.log(teacher_soft) - student_log_soft)) / batch_size
manual_kl *= T**2 # 溫度補償# 方式2:使用PyTorch自帶的F.kl_div
# 注意:F.kl_div(input=logQ, target=P, reduction='sum') 對應 sum(P*(logP - logQ))
torch_kl = F.kl_div(student_log_soft, teacher_soft, reduction='sum') / batch_size
torch_kl *= T**2 # 溫度補償# 5. 結果對比
print("===== 教師分布 P (softmax后) =====")
print(teacher_soft)
print("\n===== 學生對數分布 logQ (log_softmax后) =====")
print(student_log_soft)
print("\n===== KL散度計算結果 =====")
print(f"手動計算: {manual_kl.item():.6f}")
print(f"F.kl_div計算: {torch_kl.item():.6f}")
print(f"兩者是否等價 (誤差<1e-6): {torch.allclose(manual_kl, torch_kl, atol=1e-6)}")
===== 教師分布 P (softmax后) =====
tensor([[0.1863, 0.3072, 0.5065],[0.1863, 0.3072, 0.5065]])===== 學生對數分布 logQ (log_softmax后) =====
tensor([[-1.5909, -1.1409, -0.7409],[-1.8200, -1.1200, -0.6700]])===== KL散度計算結果 =====
手動計算: 0.008507
F.kl_div計算: 0.008507
兩者是否等價 (誤差<1e-6): True
說明:
1.輸入設置:構造了教師和學生模型的logits(模擬不同的預測結果),并設置溫度參數T=2.0
。
2.分布計算:
教師分布teacher_soft
:通過softmax
得到概率分布 PPP
學生對數分布student_log_soft
:通過log_softmax
得到 log?Q\log QlogQ
3.兩種KL計算方式:
手動計算:嚴格按照公式 KL(P∥Q)=∑P?(log?P?log?Q)\text{KL}(P \parallel Q) = \sum P \cdot (\log P - \log Q)KL(P∥Q)=∑P?(logP?logQ) 實現,除以批次大小后乘以溫度平方補償。
F.kl_div計算:直接調用PyTorch函數,注意參數順序為(logQ, P)
,使用reduction='sum'
確保與手動計算的求和邏輯一致。
4.等價性驗證:通過torch.allclose
檢查兩者結果是否在允許的浮點數誤差范圍內(1e-6)一致。