?第一部分:引言與背景——為什么需要知識提煉?
一、模型壓縮的背景
隨著深度學習的發展,模型變得越來越大(如 ResNet152、BERT、ViT、GPT 等),其參數量動輒數億甚至上百億。這些大模型雖然性能強大,但也帶來以下問題:
問題 | 描述 |
---|---|
存儲成本高 | 占用大量內存、存儲資源 |
推理速度慢 | 計算量大,難以部署到邊緣設備 |
能耗高 | 大模型耗電多,部署在移動端不可行 |
工業部署難 | 需要簡化模型以適應生產場景 |
為了解決上述問題,人們提出了模型壓縮技術(Model Compression),主要包括:
-
網絡剪枝(Pruning)
-
量化(Quantization)
-
知識提煉(Knowledge Distillation)
其中,知識提煉是一種高效且易于實現的方法,能夠讓一個“小模型(學生)”在大模型(教師)的指導下學習,從而保持性能的同時大幅減少計算資源消耗。
二、知識提煉的核心思想
知識提煉(Knowledge Distillation)最早由 Hinton 等人于 2015 年提出,其核心思想如下:
-
構建一個性能強的大模型,作為“教師模型(Teacher Model)”;
-
訓練一個輕量的小模型,作為“學生模型(Student Model)”;
-
學生模型不僅學習真實標簽(hard label),還要模仿教師模型的輸出(soft label);
-
教師模型輸出的 soft label 包含了樣本間的“類間關系”等隱藏知識。
?通俗理解:
教師模型輸出的概率分布包含了更多的“知識”,學生模型模仿這種分布,就像學生不只學習考試答案,還要理解老師是如何思考的。
三、圖示理解
真實標簽 y↓+----------+| Teacher | => soft label z_T+----------+↓+----------+| Student | => 模擬 z_T + y+----------+
學生模型的目標是同時:
-
擬合真實標簽(監督學習常規)
-
模擬教師模型的輸出(提取知識)
四、知識提煉的優勢
優勢 | 描述 |
---|---|
性能提升 | 在模型尺寸不變的前提下,準確率通常顯著提升 |
參數更少 | 學生模型通常更小、更輕便 |
訓練更快 | 學生模型收斂更快,因為學習目標更具體 |
遷移能力強 | 可以用于跨結構的遷移,例如 CNN → Transformer |
第二部分(加強版):Hinton 知識提煉機制的深入解析
1. 背景復盤:為什么用教師模型的輸出?
傳統監督學習里,我們用one-hot標簽訓練模型,比如貓狗分類,標簽向量是:
y=[0,1,0,0,0](假設第2類是正確類別)
但這其實只告訴模型:
-
“正確類別是第2類”;
-
“其他類別都不對”。
它沒有告訴模型:
-
第3類和第4類與第2類有多相似;
-
哪些類別容易混淆,哪些完全不同。
教師模型輸出的是一個概率分布:
q=[0.1,0.7,0.1,0.05,0.05]
這表示它認為第2類最可能,但第1類和第3類也有一定可能,帶來了更多“類別間的關系信息”,這就是暗知識(Dark Knowledge)。
學生模型學習這個概率分布,能學到:
-
類別間相似性的知識;
-
更豐富的語義結構。
2. Softmax 函數與溫度調節的數學細節
2.1 標準 Softmax
假設某輸入樣本,模型輸出 logits(未歸一化的得分向量):
z=[z1,z2,…,zn]
Softmax 轉換成概率:
這保證所有 pi? 之和為 1,且最大值對應預測類別。
2.2 引入溫度系數 TTT
溫度調節是通過除以溫度參數 TTT 來調整 logits 的“平滑度”:
-
當 T=1 時,是正常 softmax;
-
當 T>1時,所有概率趨向均勻分布,分布更“軟化”;
-
當 T<1時,分布更“尖銳”,趨近 one-hot。
2.3 數學意義
溫度調節后的 softmax 的梯度尺度和概率分布的形狀都會變化:
-
分布更平滑,教師輸出中小概率類別的信息更明顯;
-
梯度變小,因此需要在損失函數中乘以 T2T^2T2 保持梯度大小。
3. KL 散度作為相似度度量
知識提煉中用來衡量學生和教師預測概率分布差異的主要指標是 KL 散度(Kullback-Leibler Divergence)。
-
P通常是教師模型的軟標簽分布 qteacher(T)
-
Q 是學生模型的軟標簽分布 qstudent(T)
KL 散度越小,說明兩分布越接近,學生更好地“模仿”教師。
4. 完整損失函數推導
訓練學生模型的目標是最小化:
-
第一項是學生對真實標簽的交叉熵損失,確保學習基礎知識;
-
第二項是學生模仿教師軟標簽的 KL 散度損失,提取暗知識;
-
T2 是梯度縮放因子;
-
α 控制兩部分權重,通常 0.5~0.9。
5. 為什么要結合真實標簽和軟標簽?
-
僅用真實標簽訓練,學生模型效果差,學習不到類別間信息;
-
僅用軟標簽訓練,軟標簽雖然包含暗知識,但有噪聲,可能導致欠擬合;
-
兩者結合,既保證模型對真是標簽的準確學習,也能從教師軟標簽中獲取更多細節和類別關系。
6. 舉個具體的數值例子
假設某樣本的教師 logits 是:
z=[10,2,1]
溫度 T=2T = 2T=2 時,計算 softmax 概率:
具體計算:
總和約為:
148.41+2.718+1.649=152.78148.41 + 2.718 + 1.649 = 152.78148.41+2.718+1.649=152.78
所以對應的概率:
相比溫度為1時,概率分布更“平滑”,其他類別概率增加,暗含類別相似度信息。
7. 為什么乘以 T2 調整梯度?
從梯度角度看,softmax 輸出中分母含有 T,梯度規模會縮小,直接訓練效果變差。
論文中證明,乘以 T2 能使梯度大小恢復到合理水平,防止溫度增大時梯度過小影響收斂。
?
第三部分:Soft Label 與 Temperature 的數學原理、可視化和調參技巧
1. Soft Label 的數學原理
1.1 什么是 Soft Label?
-
硬標簽(Hard Label) 是傳統的 one-hot 標簽,比如:
y=[0,0,1,0] -
軟標簽(Soft Label) 是教師模型經過 softmax(尤其是帶溫度調節的 softmax)后產生的概率分布:
q=[0.1,0.3,0.5,0.1]它不僅告訴模型“哪個類別是對的”,還告訴模型“對每個類別的置信度”。
1.2 Soft Label 的優勢
-
包含類別之間的相似性信息,模型學習更細致的決策邊界;
-
有利于提升學生模型的泛化能力,減少過擬合;
-
傳遞教師模型的“暗知識”。
2. Temperature 的數學原理與影響
2.1 Softmax 函數帶溫度 T 的定義
對于 logits 向量 z=[z1,...,zn],softmax with temperature:
2.2 溫度對概率分布的影響
-
T=1 :標準 softmax,正常概率分布;
-
T>1 :概率分布更平滑,降低最大類別概率,增加其他類別概率,信息更豐富;
-
T<1?:概率分布更尖銳,趨近 one-hot。
2.3 可視化示意(假設3類 logits)
類別 | logits | Softmax T=1 | Softmax T=3 |
---|---|---|---|
A | 3 | 0.84 | 0.53 |
B | 1 | 0.11 | 0.24 |
C | 0 | 0.05 | 0.23 |
隨著 T 增大,概率趨向均勻,更“軟”。
3. Temperature 調節的梯度效應
-
當 T 增大時,softmax 輸出趨于均勻分布,導致輸出概率對 logits 的梯度變小;
-
為了補償梯度變小的影響,訓練時損失項中乘以 T2,保持梯度量級;
-
這個乘法的數學證明在 Hinton 論文中詳細說明。
4. 實際調參建議
參數 | 說明 | 建議范圍 | 作用 |
---|---|---|---|
溫度 T | 控制 softmax 平滑程度 | 2 ~ 5 | 提取教師更多暗知識,平滑輸出 |
權重 α | 真實標簽損失和軟標簽損失權重 | 0.5 ~ 0.9 | 平衡硬標簽監督和軟標簽監督 |
-
先選 T=3,α=0.7 作為默認值;
-
若學生學習緩慢,嘗試調大 T;
-
若學生效果偏差大,嘗試調小 T;
-
權重 α 可根據學生模型容量調節。
5. Soft Label 在訓練中的作用示例
-
訓練初期,學生模型受軟標簽引導,更快學到類別間相似關系;
-
訓練后期,硬標簽保證學生收斂到正確分類;
-
整體提高模型的魯棒性和泛化能力。
第四部分:知識提煉的主要策略與發展分支
1. 輸出層蒸餾(Response-Based Distillation)?
?核心理念:
輸出層蒸餾是最基礎、最經典的一種知識提煉方式。它的核心思想是:
用**教師模型輸出的 soft label(概率分布)**去訓練學生模型,而不是傳統的 one-hot label。
?數學形式:
-
軟標簽定義(帶溫度 softmax):
對于教師模型的輸出 logits z,引入溫度參數 T,得到 soft label
- ?
同理,學生模型輸出 logits 為 zs?,其 softmax 為:
-
損失函數:
-
第一項是普通的交叉熵損失,使用硬標簽;
-
第二項是 soft label 的 KL 散度損失;
-
T2 是梯度縮放因子(因為 softmax 的梯度在高溫度下會縮小)。
-
代表方法:
-
Hinton 等人 2015 年的經典論文《Distilling the Knowledge in a Neural Network》。
-
適用于大多數分類任務,尤其是圖像分類、文本分類等。
?優點:
-
實現簡單,適用于任何網絡結構;
-
尤其適合做模型壓縮;
-
適用于 CNN、MLP、Transformer 等模型。
?缺點:
-
不適用于教師/學生模型結構差異特別大的情況;
-
只蒸餾了輸出,忽略了教師的中間層“過程知識”;
-
對復雜任務(如檢測、分割、強化學習)信息量不足。
?PyTorch 實現核心片段:
import torch.nn.functional as Fdef distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):hard_loss = F.cross_entropy(student_logits, labels) # 硬標簽損失soft_loss = F.kl_div(F.log_softmax(student_logits / T, dim=1),F.softmax(teacher_logits / T, dim=1),reduction='batchmean') * (T * T) # soft 標簽損失return alpha * hard_loss + (1 - alpha) * soft_loss
?適用場景:
-
輕量模型訓練(MobileNet、ResNet18);
-
需要壓縮部署的大模型;
-
模型蒸餾初學者優選方案;
-
醫學圖像中的分類任務(病灶判斷、異常檢測)中也可用。
?2. 特征層蒸餾(Feature-Based Distillation)?
?核心理念
輸出層蒸餾只關注最終的 soft label,但一個深度模型在前向傳播過程中會在多個中間層提取到大量結構化的信息。特征層蒸餾的目標就是讓學生模型模仿教師模型中間層產生的特征圖或激活圖(activation maps),從而學習“如何處理信息”,而不只是“最終決策”。
數學原理
設:
-
教師模型某一中間層輸出特征圖為?
-
學生模型對應層輸出為?
由于 Ct≠Cs?,一般需要加入對齊層(1×1卷積)使得維度一致,再計算 L2 損失或MSE 損失:
其中 g(?)g(\cdot)g(?) 是一個通道匹配層(linear 或 conv)。
?代表方法:FitNet(Romero et al., 2014)
-
這是第一個系統性使用中間層作為知識來源的 KD 方法;
-
引入了“hint layer”(教師)和“guided layer”(學生);
-
通過加入 MSE 損失監督學生模仿中間層特征。
?優點:
-
提取更多“過程知識”;
-
對于視覺任務(分類、分割、檢測)具有明顯提升;
-
對于教師輸出信息不夠豐富的任務(如 softmax 輸出接近 one-hot)特別有效。
?缺點:
-
對教師和學生結構匹配要求較高(中間層需要對齊);
-
通道維度不一致時需要額外適配;
-
大模型中 feature map 通常非常大,訓練顯存消耗增加。
?PyTorch 實現示意
# 假設學生輸出為 F_s,教師輸出為 F_t
# 維度:B×C×H×W,需要 reshape/conv 對齊通道維度import torch.nn as nnconv_proj = nn.Conv2d(in_channels=F_s.shape[1], out_channels=F_t.shape[1], kernel_size=1)def feature_loss(F_s, F_t):F_s_proj = conv_proj(F_s)return F.mse_loss(F_s_proj, F_t)
可以使用多個中間層,每層都加蒸餾損失再求和或加權求和。
適用場景
-
圖像分類任務(特別是輕量化學生模型);
-
目標檢測任務(如 Faster R-CNN, YOLO)中的 backbone 提煉;
-
圖像分割任務(如 UNet、DeepLab)中用于 encoder 層蒸餾;
-
醫學圖像分割中的跨模型遷移(教師為大ResNet,學生為輕量CNN)尤其常見。
?實戰建議
策略 | 建議 |
---|---|
特征提取層選擇 | 教師中靠近輸出的深層最優 |
特征對齊方式 | 使用 1x1 卷積或全連接層 |
蒸餾損失類型 | L2(MSE)、Smooth L1、Cosine |
特征歸一化 | 可以加 BN 或 LayerNorm |
?3. 關系蒸餾(Relational Knowledge Distillation,RKD)
?核心理念:
前兩種知識提煉方式都要求學生模型模仿教師模型的具體輸出或中間特征,這在教師和學生結構差異較大時會非常困難。
關系蒸餾的核心思想是:
不再直接學習特征,而是學習樣本之間的相對關系(如距離、角度、相似性等)。
即:如果教師模型認為樣本 xi和 xj 距離近、而 xi和 xk? 距離遠,那么學生模型也應當有類似的“感知”。
數學形式與構造:
(1)距離保持(Distance-wise Loss):
計算教師模型中,樣本 i,j的歐氏距離:
學生模型也計算相應的距離:
然后最小化兩者的差異:
(2)角度保持(Angle-wise Loss):
構建三元組 i,j,ki, ,定義教師模型中兩向量的夾角:
同理學生計算角度 aijksa?,損失為:
?代表方法:
-
?RKD(Relational Knowledge Distillation, CVPR 2019):提出了上述兩類距離/角度關系;
-
?CRD(Contrastive Representation Distillation, ICLR 2020):用對比學習方式提煉關系信息;
-
?PKD(Patient KD):提煉教師多個層之間的信息傳遞關系。
?優點:
-
不依賴教師和學生特征結構是否一致;
-
適用于異構架構(如 CNN 教師 → Transformer 學生);
-
能表達更多“任務內結構性”。
缺點:
-
計算復雜度高(關系對數隨樣本數量平方增長);
-
蒸餾信號間接、泛化效果受限于關系設計;
-
對 batch size 較小的訓練不太友好(因 pair/triplet 關系較少)。
?PyTorch 實現示意:
以下是基于距離關系的 RKD 實現:
def pairwise_dist(features):n = features.size(0)dist = torch.cdist(features, features, p=2) # B×B 距離矩陣return distdef distance_loss(student_feat, teacher_feat):with torch.no_grad():d_t = pairwise_dist(teacher_feat)mean_t = d_t[d_t > 0].mean()d_t = d_t / mean_td_s = pairwise_dist(student_feat)mean_s = d_s[d_s > 0].mean()d_s = d_s / mean_sloss = F.smooth_l1_loss(d_s, d_t)return loss
可選擴展:加入角度損失、圖結構約束等。
適用場景
場景 | 原因/建議 |
---|---|
結構差異大的模型(如 ViT ? CNN) | 無需層對齊、特征維度相同 |
小樣本學習 / 數據不均衡 | 學習樣本間的關系信息可以緩解 overfitting |
醫學圖像中“相對特征分布”重要的任務 | 如分類器分不清具體區域,但能感知病灶間的相關性 |
?實戰技巧:
-
構造 triplet 時要注意采樣策略(Hard negative mining 會提升效果);
-
對高維特征做降維或歸一化有助于穩定蒸餾;
-
可以與 feature-based 方法聯合使用(并行優化)。
?4. 注意力蒸餾(Attention-Based Distillation)
核心理念:
注意力蒸餾的主要思想是:
與其直接蒸餾特征本身,不如蒸餾模型關注的位置/區域/通道信息,也就是注意力信息。
教師模型中往往對特定空間區域或通道更關注(如病灶區域、邊緣區域),這些關注信息可以幫助學生模型更好地聚焦于關鍵區域,從而提升性能。
?注意力的提取方式(多種):
?方法一:激活圖的注意力(Activation-based Attention)
來源論文:Attention Transfer (Zagoruyko and Komodakis, CVPR 2017)
-
將中間層特征圖做絕對值平方后再求和:
-
得到一個二維 attention map,再進行歸一化處理:
-
對比教師和學生的 attention:
方法二:通道注意力(Channel-based Attention)
-
對每個通道做平均池化得到通道權重;
-
使用 cosine 相似度、KL 散度 或 MSE 計算損失;
-
用于模型結構差異較大的情況(如學生層通道較少)時需通道映射。
?方法三:Transformer 中的注意力權重(Self-Attention)
適用于 ViT、Swin Transformer 等結構:
-
-
對多個 head 的注意力進行平均;
-
計算對應的學生注意力損失:
代表方法:
名稱 | 簡介 |
---|---|
AT(Attention Transfer) | 最早系統化地使用注意力圖進行知識提煉 |
A2KD(Adaptive Attention KD) | 對 attention 蒸餾進行動態加權 |
ViTKD、Swin-KD | 將 attention 蒸餾擴展至 Transformer 架構 |
?優點:
-
更輕量,蒸餾信號提取成本低;
-
更易于解釋,關注區域可視化;
-
在醫學圖像分割/檢測任務中常有較大提升(尤其對邊界關注);
-
和 feature-based 可組合使用(feature loss + attention loss)。
?缺點:
-
attention map 的質量依賴于教師網絡設計;
-
不同結構之間的 attention 定義不一致(CNN vs Transformer);
-
蒸餾信號相對較弱(只有注意力而沒有語義細節)。
?PyTorch 示例:空間注意力蒸餾
def compute_attention_map(feat): # feat: B x C x H x Watt = feat.pow(2).mean(dim=1, keepdim=True) # 空間 attentionnorm_att = att / (att.sum(dim=(2, 3), keepdim=True) + 1e-6)return norm_attdef attention_loss(student_feat, teacher_feat):att_s = compute_attention_map(student_feat)att_t = compute_attention_map(teacher_feat)return F.mse_loss(att_s, att_t)
?適用場景
場景 | 說明 |
---|---|
醫學圖像分割任務 | 注意力能精準定位病灶區域,適合蒸餾 |
Transformer 模型蒸餾 | 自注意力矩陣能自然被蒸餾 |
腦出血邊緣區域學習 | 利用注意力提高學生網絡對邊緣區域的辨識度 |
?實戰建議
-
可用于中間層或每個 stage 輸出;
-
結合輸出層蒸餾更有效(總損失 = output loss + attention loss);
-
attention 可視化對調試訓練非常有幫助(可以看模型學到了什么)。
?5. 多教師蒸餾(Multi-Teacher Distillation)
?核心理念:
現實中我們可能擁有多個優秀的預訓練模型(例如在不同數據上訓練的模型,或結構不同的高性能模型)。
**多教師蒸餾(MTD)**的目標是:
學生模型不僅學習單一教師模型的知識,而是融合多個教師的知識,共同引導學生學習更豐富、更泛化的表示。
?基本策略結構
設有 N 個教師模型,輸出為 T1,T2,...,TN,學生輸出為 S,有以下策略:
策略一:平均融合(Logits Averaging)
直接將多個教師輸出的 soft logits 平均,作為學生的監督目標:
優點:簡單直接
缺點:忽略不同教師之間的能力差異
策略二:加權融合(Weighted Averaging)
為不同教師分配不同權重 αi?(手動設定或訓練得到):
適合教師質量差異明顯、或模型結構差異較大的情況。
?策略三:自蒸餾與教師選擇(Online Multi-Teacher)
來源:DML(Deep Mutual Learning, CVPR 2018)
多個學生模型互為教師,輪流提取彼此知識:
-
每個學生在訓練過程中學習來自其他學生的輸出;
-
每一步中都更新彼此的參數,相當于雙向蒸餾、三角蒸餾。
適用于:
-
無強教師模型的場景;
-
多模型協同訓練,提升全體性能。
?策略四:圖蒸餾(Graph-Based Distillation)
-
構建教師之間的圖結構,建模教師間“知識傳遞”的路徑;
-
使用圖卷積或圖注意力聚合所有教師輸出;
-
輸出聚合后傳給學生學習。
?代表方法:
方法 | 簡介 |
---|---|
DML | 學生互為教師的雙向蒸餾(CVPR 2018) |
TRKD | 基于 transformer 的多教師蒸餾框架 |
GKT | 圖結構下的教師蒸餾方式 |
AKD | 自適應選擇最優教師的注意力蒸餾方法 |
優點:
-
能夠獲得多個模型的綜合表達能力;
-
提升學生的泛化能力,特別是不同領域知識融合;
-
適合 ensemble 壓縮和異構教師蒸餾場景。
?缺點:
-
實現較復雜(特別是動態加權和圖建模);
-
多模型推理成本高(訓練初期需多個教師并行推理);
-
教師間沖突時,融合策略可能降低性能(需設計教師篩選機制)。
?PyTorch 實現示意:平均融合型多教師 KD
def multi_teacher_kd_loss(student_logits, teacher_logits_list, temperature=4.0):# 平均多個教師輸出teacher_avg = sum(teacher_logits_list) / len(teacher_logits_list)kd_loss = F.kl_div(F.log_softmax(student_logits / temperature, dim=1),F.softmax(teacher_avg / temperature, dim=1),reduction='batchmean') * (temperature ** 2)return kd_loss
?適用場景
應用場景 | 推薦理由 |
---|---|
多數據集/多任務知識遷移 | 不同教師在不同任務上訓練,合并效果佳 |
醫學圖像多模態融合 | 如 CT、MRI、PET 模型合成統一學生 |
結構差異教師合成 | CNN、Transformer 教師模型合一,產出輕量學生 |
模型壓縮或部署優化 | 將多個大模型集成壓縮為一個輕量學生 |
?實戰建議:
-
如果多個教師性能相差大,建議加權或篩選;
-
教師輸出之間差異大時,可用 attention 或 gating 函數融合;
-
可與 feature distillation 或 attention distillation 聯合使用。
?6. 自蒸餾(Self-Distillation)
核心理念:
前面講的知識蒸餾方法都依賴一個**“外部教師模型”**。
但在實際情況中,我們有時沒有預訓練好的大模型可用,或者不方便部署多個模型。
**自蒸餾(Self-Distillation)**的核心思想是:
在一個模型內部進行知識提煉,當前模型的部分輸出(或早期訓練狀態)作為“教師”引導“學生”部分(或后續狀態)學習。
?主要實現方式
?方式一:同一模型不同層之間蒸餾(Intermediate Self-Distillation)
-
把模型的深層輸出作為“教師”;
-
把淺層輸出作為“學生”;
-
讓淺層特征或輸出盡量靠近深層輸出。
例如,在一個分類網絡中:
-
將第4個卷積塊的輸出作為教師;
-
將第2個卷積塊的輸出映射成同樣維度;
-
然后用 MSE 損失函數約束兩者相似。
?損失函數形式:
其中 hi是第 i?層的輸出,L是最深層。
方式二:不同 epoch 的模型蒸餾(Temporal Self-Distillation)
-
把當前 epoch 模型的輸出當作學生;
-
把之前 epoch 的模型(固定權重)當作教師;
-
學生模仿過去某一狀態的模型行為。
這種方式有點類似“學生向自己請教”,讓當前模型別忘了之前學到的好知識,避免過擬合。
?方式三:多頭輸出自蒸餾(Multi-head Self-Distillation)
-
為同一模型添加多個分支(auxiliary heads),例如在不同位置添加分類器;
-
主分支作為“教師”,輔助分支為“學生”,彼此互相引導學習;
-
提高整個模型訓練穩定性和特征共享能力。
?代表方法:
方法 | 簡介 |
---|---|
BYOT (Bring Your Own Teacher, NeurIPS 2020) | 模型內部多個預測頭互為師生 |
DEKD (Deep Embedding KD, CVPR 2021) | 不同層輸出相互約束,輔助深層學習 |
Revisit KD | 利用同一模型多個階段輸出做訓練信號的反向傳播 |
優點:
-
不需要外部教師模型,部署簡單;
-
無額外推理成本;
-
提高模型收斂速度與性能;
-
提升淺層特征的表達能力(→ 提高中小模型的性能);
-
能配合 Mixup、數據增強、Dropout 進一步提升效果。
?缺點:
-
蒸餾信號可能不夠強,需配合設計輔助機制;
-
若模型結構過于簡單,蒸餾效果不明顯;
-
多頭/多層輸出時,可能引入訓練不穩定性。
?PyTorch 示例:中間層自蒸餾
def self_distillation_loss(features_list):teacher_feat = features_list[-1].detach() # 最深層作為teacherloss = 0.0for i in range(len(features_list) - 1):student_feat = features_list[i]loss += F.mse_loss(student_feat, teacher_feat)return loss
?適用場景
場景 | 應用理由 |
---|---|
沒有教師模型 | 自蒸餾可作為輕量級替代 |
單模型部署限制 | 不增加推理復雜度 |
醫學圖像訓練數據小 | 模型內結構間互相學習,緩解過擬合 |
Transformer 微調(如 ViT) | 自蒸餾對不同 block 輸出進行一致性約束效果顯著 |
?實戰技巧
-
盡量配合輔助損失(如分類 loss + 自蒸餾 loss);
-
模型結構需支持多輸出或中間層抽取;
-
可結合 warmup 階段控制蒸餾信號引入節奏。
7. 對比蒸餾(Contrastive Distillation)
?核心理念:
對比蒸餾結合了對比學習(Contrastive Learning)與知識蒸餾的優點。
傳統 KD 往往要求學生輸出模仿教師某個具體的目標(如 softmax、特征圖),
但對比蒸餾的目標是:
保持學生模型在表示空間中與教師模型的結構一致性:即相似樣本之間的距離近,不相似樣本之間的距離遠。
這種方式不直接模仿具體輸出,而是模仿教師在表示空間中的“相對關系”。
?舉個例子:
-
給定一張醫學圖像 x,
-
教師提取特征 fT(x),學生提取特征 fS(x),
-
對于同類別樣本 xi?,應有:
-
而不同類別樣本的相似性應盡可能低。
這種方式本質上是用教師的“結構性表示”來指導學生。
?實現方式概覽:
?方法一:Teacher-Guided Contrastive Loss
典型公式(以 InfoNCE 為基礎):
其中:
-
ziS?:學生模型提取的 anchor 特征;
-
ziT?:教師模型中 anchor 的正樣本(正對);
-
其他 zjT?:教師模型中的負樣本;
-
τ:溫度系數。
?方法二:Relational Knowledge Distillation (RKD)
來源:ECCV 2018
使用特征之間的距離關系(歐式距離/角度)進行對比:
-
計算樣本對之間的距離差異:
-
保持角度方向一致性(可選):
?方法三:CRD(Contrastive Representation Distillation)
來源:ICLR 2021
-
使用大型特征字典(memory bank),采樣正負樣本;
-
引入投影頭將學生特征映射到與教師相同維度;
-
在投影空間中進行對比。
優勢在于:可以更穩定地學習“結構知識”,比直接 mimick 更魯棒。
?PyTorch 示例(InfoNCE形式)
import torch.nn.functional as Fdef contrastive_kd_loss(student_feat, teacher_feat, temperature=0.07):# 假設特征 shape: B x D,歸一化student_feat = F.normalize(student_feat, dim=1)teacher_feat = F.normalize(teacher_feat, dim=1)logits = torch.matmul(student_feat, teacher_feat.T) / temperaturelabels = torch.arange(student_feat.size(0)).to(logits.device)return F.cross_entropy(logits, labels)
?優點:
-
不依賴具體標簽,適合無監督或自監督任務;
-
更關注樣本之間的結構表示,對分類邊界更敏感;
-
在特征蒸餾中表現優越;
-
對圖像檢索、醫學圖像嵌入任務尤其有效。
?缺點:
-
訓練成本較高(需負樣本對比/內存字典);
-
需要對比學習框架支持,代碼實現復雜;
-
溫度超參數 τ\tauτ 較為敏感;
-
與數據增強策略關系緊密,需謹慎設計。
?適用場景
場景 | 理由 |
---|---|
無監督學習 | 不依賴標簽,直接使用對比關系進行知識提取 |
醫學圖像檢索、嵌入式特征學習 | 保留語義關系更重要 |
圖神經網絡、Transformer 等高維結構蒸餾 | 表征空間關系更關鍵 |
高分辨率圖像處理任務 | 對特征表示結構敏感,提升泛化能力 |
?總結
對比蒸餾強調結構性、相對性,是目前較先進的一種蒸餾方式,特別適合與自監督、對比學習結合使用。