知識蒸餾 - 基于KL散度的知識蒸餾 HelloWorld 示例
flyfish
知識蒸餾 - 蒸的什么
知識蒸餾 - 通過引入溫度參數T調整 Softmax 的輸出
知識蒸餾 - 對數函數的單調性
知識蒸餾 - 信息量的公式為什么是對數
知識蒸餾 - 根據真實事件的真實概率分布對其進行編碼
知識蒸餾 - 信息熵中的平均為什么是按概率加權的平均
知識蒸餾 - 自信息量是單個事件的信息量,而平均自信息量(即信息熵)是所有事件自信息量以其概率為權重的加權平均值
知識蒸餾 - 最小化KL散度與最小化交叉熵是完全等價的
知識蒸餾 - 基于KL散度的知識蒸餾 HelloWorld 示例
知識蒸餾的步驟如下:
-
訓練教師模型:使用常規交叉熵損失在訓練集上訓練深層教師模型,使其在任務上達到較好性能(作為知識的"來源")。
-
固定教師模型:將訓練好的教師模型設為評估模式(不更新參數),僅用于提供"軟標簽"知識。
-
初始化學生模型:創建輕量級學生模型(與基線學生模型初始化相同,保證公平對比)。
-
學生模型蒸餾訓練:
輸入數據同時傳入教師模型和學生模型;
教師模型輸出logits(不計算梯度),經溫度T軟化后得到軟概率分布(教師的"軟標簽");
學生模型輸出logits,經相同溫度T軟化后得到對數概率分布;
計算學生與教師軟分布的KL散度損失(衡量兩者差異,即蒸餾損失);
計算學生與原始硬標簽(真實類別)的交叉熵損失;
總損失為KL散度損失與交叉熵損失的加權和;
基于總損失更新學生模型參數,教師模型參數保持不變。 -
重復訓練:迭代多輪,直至學生模型收斂,最終得到通過蒸餾學習了教師知識的輕量級模型。
使用的數據集
CIFAR-10 數據集:
-
介紹
CIFAR-10(Canadian Institute for Advanced Research 10)是由加拿大高級研究所發布的小型圖像數據集,廣泛用于計算機視覺領域的入門級模型訓練和測試。 -
數據組成
包含 10個類別 的彩色圖像,類別分別為:飛機(airplane)、汽車(automobile)、鳥類(bird)、貓(cat)、鹿(deer)、狗(dog)、青蛙(frog)、馬(horse)、船(ship)、卡車(truck)。
圖像尺寸統一為 32×32像素,且為 RGB三通道(每個圖像包含 3×32×32=3072 個像素值,范圍 0-255)。
數據集分為兩部分:- 訓練集:50,000 張圖像(每個類別 5,000 張);
- 測試集:10,000 張圖像(每個類別 1,000 張)。
-
用途
主要用于圖像分類任務的基準測試,適合驗證輕量級模型(如簡單CNN)的性能,這里用于驗證知識蒸餾對輕量級學生模型的性能提升)。 -
使用
通過datasets.CIFAR10
加載數據集時,指定root
(數據保存路徑)、train
(是否為訓練集)、download
(是否自動下載)和transform
(數據預處理方式),即可便捷地獲取預處理后的圖像數據和對應標簽。
模型設計
維度 | TeacherNet(教師) | StudentNet(學生) | 設計目的 |
---|---|---|---|
卷積層數量 | 4層(更深) | 2層(更淺) | 教師通過更多層提取更復雜的特征,學生通過精簡結構降低部署成本。 |
卷積核數量 | 初始128,后續64/32(更多) | 固定16(更少) | 教師用更多卷積核捕捉細節,學生用少量卷積核減少參數和計算量。 |
參數規模 | 大 | 小 | 教師擬合能力強(作為知識源),學生輕量(適合邊緣設備部署)。 |
輸出維度 | 10類(與學生一致) | 10類(與教師一致) | 確保兩者輸出分布可通過KL散度比較,實現知識遷移。 |
知識蒸餾的核心前提——用“強模型”指導“弱模型”學習
教師網絡(TeacherNet)
TeacherNet
被設計為“性能更強”的網絡,通過更多的卷積層和卷積核提取更豐富的圖像特征,作為知識的“來源”。其結構可分為特征提取部分(self.features) 和分類部分(self.classifier):
1. 特征提取部分(self.features)
由4個卷積層(Conv2d
)、2個池化層(MaxPool2d
)和ReLU激活函數組成,逐步提取圖像的層級特征:
-
第一層:
nn.Conv2d(3, 128, kernel_size=3, padding=1)
輸入:3通道(RGB圖像),輸出:128個特征圖(卷積核),卷積核3x3,邊緣填充1(保證輸出尺寸與輸入一致,32x32)。
作用:提取最基礎的圖像特征(如邊緣、紋理)。 -
第二層:
nn.Conv2d(128, 64, kernel_size=3, padding=1)
輸入128個特征圖,輸出64個,進一步壓縮特征并提煉細節。 -
第一次池化:
nn.MaxPool2d(kernel_size=2, stride=2)
將特征圖尺寸從32x32壓縮到16x16(減少計算量,保留關鍵特征)。 -
第三、四層卷積:
nn.Conv2d(64, 64, ...)
和nn.Conv2d(64, 32, ...)
繼續深化特征提取,最終輸出32個特征圖。 -
第二次池化:
nn.MaxPool2d(...)
特征圖尺寸從16x16壓縮到8x8。
2. 分類部分(self.classifier)
將特征提取得到的特征圖轉換為類別概率:
- 先通過
torch.flatten(x, 1)
將8x8的32個特征圖展平為一維向量:32 * 8 * 8 = 2048
(維度)。 - 再通過全連接層處理:
2048 → 512 → 10
(10為CIFAR-10的類別數),中間用ReLU激活和Dropout(0.1)防止過擬合。
學生網絡(StudentNet)
StudentNet
被設計為“輕量級”網絡,參數更少、結構更簡單(便于部署),但其輸出維度與教師網絡一致(均為10類),確保能通過KL散度學習教師的知識。結構同樣分為特征提取和分類兩部分:
1. 特征提取部分(self.features)
僅含2個卷積層和2個池化層,參數遠少于教師網絡:
-
第一層:
nn.Conv2d(3, 16, kernel_size=3, padding=1)
- 輸入3通道,輸出僅16個特征圖(遠少于教師的128),同樣保留32x32尺寸。
-
第一次池化:
MaxPool2d
將尺寸壓縮到16x16。 -
第二層卷積:
nn.Conv2d(16, 16, ...)
- 保持16個特征圖,進一步提取特征。
-
第二次池化:尺寸壓縮到8x8。
2. 分類部分(self.classifier)
- 展平后特征維度:
16 * 8 * 8 = 1024
(遠小于教師的2048)。 - 全連接層:
1024 → 256 → 10
(隱藏層維度也遠小于教師的512)。
完整的 HelloWorld 示例
"""知識蒸餾(KL散度實現)
"""import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets# 設置設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用{device}設備")######################################################################
# 數據加載
######################################################################
# 數據預處理
transforms_cifar = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加載CIFAR-10數據集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar
)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar
)# 數據加載器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2
)######################################################################
# 模型定義
######################################################################
# 教師模型(較深網絡)
class TeacherNet(nn.Module):def __init__(self, num_classes=10):super(TeacherNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 128, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(128, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(64, 32, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(2048, 512),nn.ReLU(),nn.Dropout(0.1),nn.Linear(512, num_classes))def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)x = self.classifier(x)return x# 學生模型(輕量級網絡)
class StudentNet(nn.Module):def __init__(self, num_classes=10):super(StudentNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(16, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(1024, 256),nn.ReLU(),nn.Dropout(0.1),nn.Linear(256, num_classes))def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)x = self.classifier(x)return x######################################################################
# 核心函數(KL散度蒸餾)
######################################################################
def train_baseline(model, train_loader, epochs, learning_rate, device):"""常規交叉熵訓練(用于訓練教師模型)"""criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=learning_rate)model.train()for epoch in range(epochs):running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"輪次 {epoch+1}/{epochs}, 損失: {running_loss/len(train_loader):.4f}")def train_distillation(teacher, student, train_loader, epochs, learning_rate, T, kl_weight, ce_weight, device):"""基于KL散度的知識蒸餾訓練"""ce_criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(student.parameters(), lr=learning_rate)teacher.eval() # 教師模型固定student.train() # 學生模型訓練for epoch in range(epochs):running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# 教師輸出(不計算梯度)with torch.no_grad():teacher_logits = teacher(inputs)# 學生輸出student_logits = student(inputs)# 計算KL散度損失(知識蒸餾核心)# 教師分布:softmax(teacher_logits / T)# 學生分布:log_softmax(student_logits / T)teacher_soft = nn.functional.softmax(teacher_logits / T, dim=-1)student_soft = nn.functional.log_softmax(student_logits / T, dim=-1)kl_loss = torch.sum(teacher_soft * (teacher_soft.log() - student_soft)) / inputs.size(0)kl_loss *= T ** 2 # 溫度縮放補償# 計算硬標簽損失ce_loss = ce_criterion(student_logits, labels)# 總損失:KL散度損失 + 交叉熵損失total_loss = kl_weight * kl_loss + ce_weight * ce_losstotal_loss.backward()optimizer.step()running_loss += total_loss.item()print(f"蒸餾輪次 {epoch+1}/{epochs}, 總損失: {running_loss/len(train_loader):.4f}")def test(model, test_loader, device):"""測試模型準確率"""model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f"測試準確率: {accuracy:.2f}%")return accuracy######################################################################
# 執行流程
######################################################################
if __name__ == "__main__":# 1. 訓練教師模型torch.manual_seed(42)teacher = TeacherNet().to(device)print("===== 訓練教師模型 =====")train_baseline(teacher, train_loader, epochs=10, learning_rate=0.001, device=device)teacher_acc = test(teacher, test_loader, device)# 2. 初始化學生模型(與基線對比)torch.manual_seed(42)student_baseline = StudentNet().to(device)print("\n===== 訓練基線學生模型(無蒸餾) =====")train_baseline(student_baseline, train_loader, epochs=10, learning_rate=0.001, device=device)student_baseline_acc = test(student_baseline, test_loader, device)# 3. 用KL散度蒸餾訓練學生模型torch.manual_seed(42)student_distill = StudentNet().to(device)print("\n===== 用KL散度蒸餾訓練學生模型 =====")train_distillation(teacher=teacher,student=student_distill,train_loader=train_loader,epochs=10,learning_rate=0.001,T=2, # 溫度參數kl_weight=0.25, # KL散度損失權重ce_weight=0.75, # 交叉熵損失權重device=device)student_distill_acc = test(student_distill, test_loader, device)# 4. 結果對比print("\n===== 最終結果 =====")print(f"教師模型準確率: {teacher_acc:.2f}%")print(f"基線學生模型準確率: {student_baseline_acc:.2f}%")print(f"KL蒸餾學生模型準確率: {student_distill_acc:.2f}%")
KL散度的數學定義
對于兩個概率分布 PPP(教師模型的輸出分布)和 QQQ(學生模型的輸出分布),KL散度的公式為:
KL(P∥Q)=EP[log?P?log?Q]=∑xP(x)?(log?P(x)?log?Q(x))\text{KL}(P \parallel Q) = \mathbb{E}_P \left[ \log P - \log Q \right] = \sum_x P(x) \cdot \left( \log P(x) - \log Q(x) \right) KL(P∥Q)=EP?[logP?logQ]=x∑?P(x)?(logP(x)?logQ(x))
其中:
- P(x)P(x)P(x) 是教師模型輸出的概率分布(經溫度軟化后);
- Q(x)Q(x)Q(x) 是學生模型輸出的概率分布(經溫度軟化后);
- EP\mathbb{E}_PEP? 表示對分布 PPP 取期望(即對所有樣本平均)。
代碼中KL散度公式的體現
在train_distillation
函數中,kl_loss
的計算完全對應上述公式,具體如下:
- 定義分布 PPP 和 QQQ
教師模型的輸出(logits)經溫度 TTT 軟化后,通過softmax
得到概率分布 PPP:
teacher_soft = nn.functional.softmax(teacher_logits / T, dim=-1) # P(x)
學生模型的輸出(logits)經相同溫度 TTT 軟化后,通過log_softmax
得到 log?Q(x)\log Q(x)logQ(x)(因為log_softmax
等價于對softmax
的結果取對數):
student_soft = nn.functional.log_softmax(student_logits / T, dim=-1) # log Q(x)
- 計算KL散度的核心部分 (重點看這里)
代碼中通過以下一行實現KL散度的求和與期望計算:
kl_loss = torch.sum(teacher_soft * (teacher_soft.log() - student_soft)) / inputs.size(0)
teacher_soft.log()
對應 log?P(x)\log P(x)logP(x);
student_soft
對應 log?Q(x)\log Q(x)logQ(x);
teacher_soft * (teacher_soft.log() - student_soft)
對應 P(x)?(log?P(x)?log?Q(x))P(x) \cdot (\log P(x) - \log Q(x))P(x)?(logP(x)?logQ(x));
torch.sum(...)
對應公式中的求和 ∑x\sum_x∑x?;
除以 inputs.size(0)
(批量大小)對應取期望 EP\mathbb{E}_PEP?(對批量內樣本平均)。
- 溫度補償
最后乘以 T2T^2T2 是為了抵消溫度對梯度的影響(Hinton原始論文中的標準處理),但不改變KL散度的核心公式:
kl_loss *= T **2 # 溫度縮放補償
重點說完了,說總體,讓學生同時學習“教師的軟標簽知識”和“真實硬標簽任務”
1. 函數參數:控制蒸餾過程的關鍵變量
def train_distillation(teacher, student, train_loader, epochs, learning_rate, T, kl_weight, ce_weight, device):
teacher
:已訓練好的教師模型(提供知識的“強模型”);student
:需要訓練的學生模型(需要學習知識的“弱模型”);train_loader
:訓練數據加載器(提供輸入和真實標簽);epochs
:訓練輪次(遍歷數據集的次數);learning_rate
:學生模型的學習率;T
:溫度參數(控制教師/學生輸出分布的“平滑度”,值越大分布越平緩);kl_weight
:KL散度損失(蒸餾損失)的權重;ce_weight
:交叉熵損失(硬標簽損失)的權重;device
:訓練設備(GPU/CPU)。
2. 初始化:損失函數與優化器
ce_criterion = nn.CrossEntropyLoss() # 硬標簽損失函數(真實標簽任務)
optimizer = optim.Adam(student.parameters(), lr=learning_rate) # 只優化學生模型參數
- 用
CrossEntropyLoss
計算學生模型對“真實硬標簽”的損失(保證學生不偏離基礎任務); - 優化器僅針對
student.parameters()
,因為教師模型參數固定,不需要更新。
3. 模型模式設置:固定教師,訓練學生
teacher.eval() # 教師設為評估模式(關閉dropout等訓練特有的層,輸出穩定)
student.train() # 學生設為訓練模式(啟用參數更新)
- 教師模型處于
eval
模式:確保其輸出穩定(不受訓練時隨機因素影響),且不計算梯度(節省資源); - 學生模型處于
train
模式:允許其參數通過反向傳播更新。
4. 核心訓練循環:每輪迭代訓練
for epoch in range(epochs): # 按輪次遍歷running_loss = 0.0 # 累計每輪總損失for inputs, labels in train_loader: # 按批量遍歷數據# 數據轉移到設備inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad() # 清零梯度(避免上一批次梯度殘留)
- 外層循環控制訓練輪次(
epochs
),確保模型多次學習數據; - 內層循環按批量處理數據,每次處理一個
batch
的輸入(圖像)和標簽(真實類別); optimizer.zero_grad()
:清除上一批次計算的梯度,避免干擾當前批次。
5. 教師模型輸出:提供“軟標簽知識”
# 教師輸出(不計算梯度,固定參數)
with torch.no_grad():teacher_logits = teacher(inputs)
with torch.no_grad()
:禁用教師模型的梯度計算(教師參數不更新,節省內存和計算資源);teacher_logits
:教師模型的原始輸出(未經過softmax的“對數幾率”),后續將用于生成“軟標簽”。
6. 學生模型輸出:學習并生成待優化的預測
# 學生輸出(需要計算梯度,更新參數)
student_logits = student(inputs)
student_logits
:學生模型的原始輸出(對數幾率),后續用于計算與教師的差異(蒸餾損失)和與真實標簽的差異(硬標簽損失)。
7. 計算KL散度損失:蒸餾的核心(學習教師知識)
# 教師分布:softmax(teacher_logits / T) → 軟化的概率分布(軟標簽)
teacher_soft = nn.functional.softmax(teacher_logits / T, dim=-1)
# 學生分布:log_softmax(student_logits / T) → 軟化的對數概率分布
student_soft = nn.functional.log_softmax(student_logits / T, dim=-1)# 計算KL散度:衡量學生分布與教師分布的差異
kl_loss = torch.sum(teacher_soft * (teacher_soft.log() - student_soft)) / inputs.size(0)
kl_loss *= T ** 2 # 溫度縮放補償(Hinton論文標準操作)
- 為什么用溫度T?
溫度越高(如T=10),softmax輸出越平滑(小概率值被放大),教師的“知識細節”(如類別間的相似性)更明顯;溫度為1時,等價于普通softmax(硬標簽)。 - KL散度公式對應:
嚴格遵循KL散度定義 KL(P∥Q)=∑P?(log?P?log?Q)\text{KL}(P \parallel Q) = \sum P \cdot (\log P - \log Q)KL(P∥Q)=∑P?(logP?logQ),其中 PPP 是教師軟分布(teacher_soft
),QQQ 是學生軟分布(student_soft
對應的概率)。 - 除以
inputs.size(0)
:對批量內樣本取平均,得到單樣本的KL損失; - 乘以
T2
:補償溫度對梯度的影響(溫度升高會使梯度變小,乘以T2可抵消這一效應)。
8. 計算硬標簽損失:保證學生不偏離真實任務
# 學生對真實標簽的交叉熵損失
ce_loss = ce_criterion(student_logits, labels)
- 用學生的原始輸出(
student_logits
)與真實標簽(labels
)計算交叉熵損失,確保學生在學習教師知識的同時,不忘記“正確答案”。
9. 總損失:平衡教師知識與真實任務
# 加權求和:KL損失(教師知識) + 交叉熵損失(真實標簽)
total_loss = kl_weight * kl_loss + ce_weight * ce_loss
- 通過
kl_weight
和ce_weight
控制兩者的重要性(例如代碼中kl_weight=0.25
、ce_weight=0.75
,表示更側重真實標簽)。
10. 反向傳播與參數更新:只優化學生
total_loss.backward() # 計算總損失對學生參數的梯度
optimizer.step() # 根據梯度更新學生參數
- 梯度僅從學生模型計算(教師模型無梯度),最終只有學生模型的參數被優化。
11. 損失監控:跟蹤訓練過程
running_loss += total_loss.item() # 累計批量損失
# 每輪結束后打印平均損失
print(f"蒸餾輪次 {epoch+1}/{epochs}, 總損失: {running_loss/len(train_loader):.4f}")
- 實時監控損失變化,判斷模型是否在有效學習(通常損失應逐漸下降并趨于穩定)。
類型 | 本質 | 代碼對應 | 作用 |
---|---|---|---|
教師的軟標簽知識 | 教師輸出的平滑概率分布 | teacher_soft (經T軟化的softmax) | 傳遞類別間的相似性知識,幫助學生學習更魯棒的特征 |
真實硬標簽任務 | 數據集中的真實類別標簽 | labels 和ce_loss | 保證學生不偏離基礎分類任務,避免完全“模仿教師錯誤” |
還可以繼續的學習
“瞎折騰”參數和模型,看看知識蒸餾到底在什么情況下最管用,以及它的“能力邊界”在哪兒
1. 調“溫度”試試 溫度參數 TTT 對蒸餾效果的影響
做法:把代碼里的 T
改成1、5、10這些不同的數,其他別動,看看學生模型準確率變不變。 例如固定其他參數(如kl_weight=0.25
),測試不同溫度值(如 T=1,2,5,10,20T=1, 2, 5, 10, 20T=1,2,5,10,20)下學生模型的準確率。
目的:溫度就像“老師說話的委婉程度”——溫度低,老師只說“這個一定對”(硬邦邦);溫度高,老師會說“這個可能對,那個也有點像”(更細膩)。試試哪種“說話方式”能讓學生學得更好。 理解溫度如何調節教師“軟標簽”的平滑度。
2. 調“聽老師”和“聽標準答案”的比例 損失權重(kl_weightkl\_weightkl_weight 和 ce_weightce\_weightce_weight)的平衡實驗
做法:把 kl_weight
(聽老師的權重)和 ce_weight
(聽標準答案的權重)換成不同組合,看學生成績。 例如固定 T=2T=2T=2,測試不同權重組合(如 (0.1, 0.9), (0.5, 0.5), (0.9, 0.1)
)對學生性能的影響
目的:學生既要學老師的“經驗”,又不能完全忘了“標準答案”。試試多聽老師點好,還是多信標準答案點好。 分析“教師知識”與“真實標簽”的權重如何影響學生學習。
3. 老師自己學得好不好,影響學生嗎? 教師模型性能對蒸餾的影響
做法:先讓老師少學幾輪(比如只學5輪,學得差點),或者多學幾輪(比如學20輪,學得好點),再讓它教學生,看學生成績差多少。
目的:想知道“老師越厲害,學生是不是一定越厲害”。比如老師自己考80分,能不能教出考75分的學生?老師考90分,學生能到85分嗎?
4. 學生太“笨”了,老師還能教好嗎? 學生模型復雜度的極限測試
做法:把學生模型改得更簡單(比如少一層卷積,少點參數),再用蒸餾訓練,看它還能不能比自己學(不蒸餾)強。
目的:測試蒸餾的“底線”——如果學生太簡單(比如只有一層神經網絡),老師再厲害,是不是也教不會?
5. 換種“衡量學生和老師差異的方式”行不行? 與其他蒸餾損失函數的對比
做法:不用現在的KL散度,換成“均方誤差”(MSE)來算學生和老師的差異,其他不變,看學生成績變不變。
目的:現在用的KL散度就像“比較兩個概率分布像不像”,換種方式(比如直接比數值差),會不會更好用?
6. 給數據“加戲”,學生學得更好嗎? 數據增強對蒸餾的影響
做法:訓練時給圖片加些變化(比如隨機裁剪一塊、左右翻轉),再做蒸餾,看學生在測試集上表現會不會更好。
目的:就像學生平時做難題練習,考試時更從容。給訓練數據加變化,是不是能讓學生和老師都學更扎實?
7. 學生學幾輪效果最好? 蒸餾輪次的影響
做法:把蒸餾的 epochs
改成5、20這些數,看看學5輪、10輪、20輪,學生成績是不是一直漲,還是學太久反而變差。
目的:避免“學過頭”——就像人做題,做10套題可能進步快,做100套可能記住答案了,但換題就不會了。
8. 學生是不是真的“又小又能打”? 模型輕量化指標的驗證
做法:算一算老師、普通學生、蒸餾學生的“參數數量”(模型大小)和“做題速度”(每秒處理多少張圖)。
目的:蒸餾的核心是“讓小模型有大模型的本事”。得驗證一下:蒸餾后的學生是不是確實比老師小很多,但成績接近;同時跑起來比老師快。
9. 學生能“舉一反三”嗎? 跨數據集泛化測試
做法:用CIFAR-10訓練好的蒸餾學生,去做類似的題(比如CIFAR-100里的部分類別),看它比普通學生表現好多少。
目的:好的學生不僅會做學過的題,還能應付新題。蒸餾是不是能讓學生學到更通用的“解題思路”?