本文將對【模型剪枝】基于DepGraph(依賴圖)完成復雜模型的一鍵剪枝 文章中剪枝的模型進行蒸餾訓練
一、邏輯蒸餾步驟
- 加載教師模型
- 定義蒸餾loss
- 計算蒸餾loss
- 正常訓練
二、代碼
1、加載教師模型
教師模型使用未進行剪枝,并且已經訓練好的原始模型。
teacher_model = torch.load('./logs/before_prune.pth', map_location=device)
2、定義蒸餾loss
分割和分類的loss,都是用的softmax。
import torch.nn.functional as F
import torch.nn as nn
# 蒸餾溫度
Tempature = 2
def KD_loss(teacher_pred, student_pred):t_p = F.softmax(teacher_pred / Tempature, dim=1)s_p = F.log_softmax(student_pred / Tempature, dim=1)return nn.KLDivLoss(reduction='mean')(s_p, t_p) * (Tempature ** 2)
3、 計算蒸餾loss
teacher_outputs = t_model(imgs)
# 蒸餾loss
soft_loss = KD_loss(teacher_outputs, outputs)
# 總loss = 蒸餾loss*alpha + 原學生模型loss*(1-alpha)
alpha = 0.9
all_loss = loss * (1 - alpha) + soft_loss * alpha
4、正常訓練
all_loss.backward()
用剪枝前訓練好的模型對剪枝后模型進行蒸餾訓練,訓練后測試效果如下: