一、什么是邏輯(logit)知識蒸餾
Feature-based
蒸餾原理是知識蒸餾中的一種重要方法,其關鍵在于利用教師模型的隱藏層特征來指導學生模型的學習過程。這種蒸餾方式旨在使學生模型能夠學習到教師模型在特征提取和表示方面的能力,從而提升其性能。
具體來說,Feature-based
蒸餾通過比較教師模型和學生模型在某一或多個隱藏層的特征表示來實現知識的遷移。在訓練過程中,教師模型的隱藏層特征被提取出來,并作為監督信號來指導學生模型相應層的特征學習。通過優化兩者在特征層面的差異(如使用均方誤差、余弦相似度等作為損失函數),可以使學生模型逐漸逼近教師模型的特征表示能力。
這種蒸餾方式有幾個顯著的優勢。首先,它充分利用了教師模型在特征提取方面的優勢,幫助學生模型學習到更具判別性的特征表示。其次,通過比較特征層面的差異,可以更加細致地指導學生模型的學習過程,使其在保持較高性能的同時減小模型復雜度。最后,Feature-based
蒸餾可以與其他蒸餾方式相結合,形成更為復雜的蒸餾策略,以進一步提升模型性能。
需要注意的是,在選擇進行Feature-based
蒸餾的隱藏層時,需要謹慎考慮。不同層的特征具有不同的語義信息和抽象程度,因此選擇合適的層進行蒸餾對于最終效果至關重要。此外,蒸餾過程中的損失函數和權重設置也需要根據具體任務和數據集進行調整。
綜上所述,Feature-based
蒸餾原理是通過利用教師模型的隱藏層特征來指導學生模型的學習過程,從而實現知識的遷移和模型性能的提升。這種方法在深度學習領域具有廣泛的應用前景,尤其在需要提高模型特征提取能力的場景中表現出色。
二、如何進行多任務模型的知識蒸餾
(1)加載學生和教師模型
(2)定義分割蒸餾損失,定義檢測蒸餾損失
(3)計算分割蒸餾損失,計算檢測蒸餾損失
(4)計算學生模型的分割,檢測損失
(5)計算總損失,反向傳播
三、實現代碼
(1)加載學生和教師模型
# 學生模型
model = torch.load(args.student_model, map_location=device)
# 教師模型
teacher_model = YourModel(task="multi")
teacher_model.load_state_dict(torch.load(args.teacher_model, map_location=device))
(2)定義分割蒸餾損失,定義檢測蒸餾損失
分割損失,參考:【知識蒸餾】語義分割模型邏輯蒸餾實戰,對剪枝的模型進行蒸餾訓練
# ------------ seg logit distill loss -------------#
def seg_logit_distill_loss(t_pred, s_pred, tempature = 2):KD = nn.KLDivLoss(reduction='mean')t_p = F.softmax(t_pred / tempature, dim=1)s_p = F.log_softmax(s_pred / tempature, dim=1)loss = KD(s_p, t_p) * (tempature ** 2)return loss
檢測損失,參考:【知識蒸餾】yolov5邏輯蒸餾和特征蒸餾實戰
# ------------ det logit distill loss -------------#
def det_logit_distill_loss(t_pred,s_pred,tempature=1):L2 = nn.MSELoss(reduction="none")t_lobj = L2(s_pred[..., 4], t_pred[..., 4]).mean()t_lBox = L2(s_pred[..., :4], t_pred[..., :4]).mean()t_lcls = L2(s_pred[..., 5:], t_pred[..., 5:]).mean()return (t_lobj + t_lBox + t_lcls) * tempature
(3)計算分割蒸餾loss,計算檢測蒸餾損失
with torch.no_grad():teacher_outputs = teacher_model(images)
# 分割蒸餾loss
teacher_seg_output = teacher_outputs.get("seg")
student_seg_output = predictions.get("seg")
seg_soft_loss = seg_logit_distill_loss(teacher_seg_output, student_seg_output)
# 檢測蒸餾loss
teacher_det_output = teacher_outputs.get("det")
student_det_output = predictions.get("det")
det_soft_loss = det_logit_distill_loss(teacher_det_output, student_det_output)
(4)計算學生模型的分割,檢測損失
det_loss = calc_det_loss(...)
seg_loss = CE_Loss(...)
(5)計算總損失,反向傳播
seg_distill_loss = seg_loss * (1 - seg_alpha) + seg_soft_loss * seg_alpha
det_distill_loss = det_loss * (1 - det_alpha) + det_soft_loss * det_alpha
loss = det_distill_loss * Ratio_det + seg_distill_loss * Ratio_seg
loss.backward()