知識蒸餾 - 基于KL散度的知識蒸餾 HelloWorld 示例

知識蒸餾 - 基于KL散度的知識蒸餾 HelloWorld 示例

flyfish

知識蒸餾 - 蒸的什么

知識蒸餾 - 通過引入溫度參數T調整 Softmax 的輸出

知識蒸餾 - 對數函數的單調性

知識蒸餾 - 信息量的公式為什么是對數

知識蒸餾 - 根據真實事件的真實概率分布對其進行編碼

知識蒸餾 - 信息熵中的平均為什么是按概率加權的平均

知識蒸餾 - 自信息量是單個事件的信息量,而平均自信息量(即信息熵)是所有事件自信息量以其概率為權重的加權平均值

知識蒸餾 - 最小化KL散度與最小化交叉熵是完全等價的

知識蒸餾 - 基于KL散度的知識蒸餾 HelloWorld 示例

知識蒸餾的步驟如下:

  1. 訓練教師模型:使用常規交叉熵損失在訓練集上訓練深層教師模型,使其在任務上達到較好性能(作為知識的"來源")。

  2. 固定教師模型:將訓練好的教師模型設為評估模式(不更新參數),僅用于提供"軟標簽"知識。

  3. 初始化學生模型:創建輕量級學生模型(與基線學生模型初始化相同,保證公平對比)。

  4. 學生模型蒸餾訓練:
    輸入數據同時傳入教師模型和學生模型;
    教師模型輸出logits(不計算梯度),經溫度T軟化后得到軟概率分布(教師的"軟標簽");
    學生模型輸出logits,經相同溫度T軟化后得到對數概率分布;
    計算學生與教師軟分布的KL散度損失(衡量兩者差異,即蒸餾損失);
    計算學生與原始硬標簽(真實類別)的交叉熵損失;
    總損失為KL散度損失與交叉熵損失的加權和;
    基于總損失更新學生模型參數,教師模型參數保持不變。

  5. 重復訓練:迭代多輪,直至學生模型收斂,最終得到通過蒸餾學習了教師知識的輕量級模型。

使用的數據集

CIFAR-10 數據集:

  1. 介紹
    CIFAR-10(Canadian Institute for Advanced Research 10)是由加拿大高級研究所發布的小型圖像數據集,廣泛用于計算機視覺領域的入門級模型訓練和測試。

  2. 數據組成
    包含 10個類別 的彩色圖像,類別分別為:飛機(airplane)、汽車(automobile)、鳥類(bird)、貓(cat)、鹿(deer)、狗(dog)、青蛙(frog)、馬(horse)、船(ship)、卡車(truck)。
    圖像尺寸統一為 32×32像素,且為 RGB三通道(每個圖像包含 3×32×32=3072 個像素值,范圍 0-255)。
    數據集分為兩部分:

    • 訓練集:50,000 張圖像(每個類別 5,000 張);
    • 測試集:10,000 張圖像(每個類別 1,000 張)。
  3. 用途
    主要用于圖像分類任務的基準測試,適合驗證輕量級模型(如簡單CNN)的性能,這里用于驗證知識蒸餾對輕量級學生模型的性能提升)。

  4. 使用
    通過 datasets.CIFAR10 加載數據集時,指定 root(數據保存路徑)、train(是否為訓練集)、download(是否自動下載)和 transform(數據預處理方式),即可便捷地獲取預處理后的圖像數據和對應標簽。

模型設計

維度TeacherNet(教師)StudentNet(學生)設計目的
卷積層數量4層(更深)2層(更淺)教師通過更多層提取更復雜的特征,學生通過精簡結構降低部署成本。
卷積核數量初始128,后續64/32(更多)固定16(更少)教師用更多卷積核捕捉細節,學生用少量卷積核減少參數和計算量。
參數規模教師擬合能力強(作為知識源),學生輕量(適合邊緣設備部署)。
輸出維度10類(與學生一致)10類(與教師一致)確保兩者輸出分布可通過KL散度比較,實現知識遷移。

知識蒸餾的核心前提——用“強模型”指導“弱模型”學習

教師網絡(TeacherNet)

TeacherNet被設計為“性能更強”的網絡,通過更多的卷積層和卷積核提取更豐富的圖像特征,作為知識的“來源”。其結構可分為特征提取部分(self.features)分類部分(self.classifier)

1. 特征提取部分(self.features)

由4個卷積層(Conv2d)、2個池化層(MaxPool2d)和ReLU激活函數組成,逐步提取圖像的層級特征:

  • 第一層nn.Conv2d(3, 128, kernel_size=3, padding=1)
    輸入:3通道(RGB圖像),輸出:128個特征圖(卷積核),卷積核3x3,邊緣填充1(保證輸出尺寸與輸入一致,32x32)。
    作用:提取最基礎的圖像特征(如邊緣、紋理)。

  • 第二層nn.Conv2d(128, 64, kernel_size=3, padding=1)
    輸入128個特征圖,輸出64個,進一步壓縮特征并提煉細節。

  • 第一次池化nn.MaxPool2d(kernel_size=2, stride=2)
    將特征圖尺寸從32x32壓縮到16x16(減少計算量,保留關鍵特征)。

  • 第三、四層卷積nn.Conv2d(64, 64, ...)nn.Conv2d(64, 32, ...)
    繼續深化特征提取,最終輸出32個特征圖。

  • 第二次池化nn.MaxPool2d(...)
    特征圖尺寸從16x16壓縮到8x8。

2. 分類部分(self.classifier)

將特征提取得到的特征圖轉換為類別概率:

  • 先通過torch.flatten(x, 1)將8x8的32個特征圖展平為一維向量:32 * 8 * 8 = 2048(維度)。
  • 再通過全連接層處理:2048 → 512 → 10(10為CIFAR-10的類別數),中間用ReLU激活和Dropout(0.1)防止過擬合。

學生網絡(StudentNet)

StudentNet被設計為“輕量級”網絡,參數更少、結構更簡單(便于部署),但其輸出維度與教師網絡一致(均為10類),確保能通過KL散度學習教師的知識。結構同樣分為特征提取和分類兩部分:

1. 特征提取部分(self.features)

僅含2個卷積層和2個池化層,參數遠少于教師網絡:

  • 第一層nn.Conv2d(3, 16, kernel_size=3, padding=1)

    • 輸入3通道,輸出僅16個特征圖(遠少于教師的128),同樣保留32x32尺寸。
  • 第一次池化MaxPool2d將尺寸壓縮到16x16。

  • 第二層卷積nn.Conv2d(16, 16, ...)

    • 保持16個特征圖,進一步提取特征。
  • 第二次池化:尺寸壓縮到8x8。

2. 分類部分(self.classifier)
  • 展平后特征維度:16 * 8 * 8 = 1024(遠小于教師的2048)。
  • 全連接層:1024 → 256 → 10(隱藏層維度也遠小于教師的512)。

完整的 HelloWorld 示例


"""知識蒸餾(KL散度實現)
"""import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets# 設置設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用{device}設備")######################################################################
# 數據加載
######################################################################
# 數據預處理
transforms_cifar = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加載CIFAR-10數據集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar
)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar
)# 數據加載器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2
)######################################################################
# 模型定義
######################################################################
# 教師模型(較深網絡)
class TeacherNet(nn.Module):def __init__(self, num_classes=10):super(TeacherNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 128, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(128, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(64, 32, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(2048, 512),nn.ReLU(),nn.Dropout(0.1),nn.Linear(512, num_classes))def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)x = self.classifier(x)return x# 學生模型(輕量級網絡)
class StudentNet(nn.Module):def __init__(self, num_classes=10):super(StudentNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(16, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(1024, 256),nn.ReLU(),nn.Dropout(0.1),nn.Linear(256, num_classes))def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)x = self.classifier(x)return x######################################################################
# 核心函數(KL散度蒸餾)
######################################################################
def train_baseline(model, train_loader, epochs, learning_rate, device):"""常規交叉熵訓練(用于訓練教師模型)"""criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=learning_rate)model.train()for epoch in range(epochs):running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"輪次 {epoch+1}/{epochs}, 損失: {running_loss/len(train_loader):.4f}")def train_distillation(teacher, student, train_loader, epochs, learning_rate, T, kl_weight, ce_weight, device):"""基于KL散度的知識蒸餾訓練"""ce_criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(student.parameters(), lr=learning_rate)teacher.eval()  # 教師模型固定student.train() # 學生模型訓練for epoch in range(epochs):running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# 教師輸出(不計算梯度)with torch.no_grad():teacher_logits = teacher(inputs)# 學生輸出student_logits = student(inputs)# 計算KL散度損失(知識蒸餾核心)# 教師分布:softmax(teacher_logits / T)# 學生分布:log_softmax(student_logits / T)teacher_soft = nn.functional.softmax(teacher_logits / T, dim=-1)student_soft = nn.functional.log_softmax(student_logits / T, dim=-1)kl_loss = torch.sum(teacher_soft * (teacher_soft.log() - student_soft)) / inputs.size(0)kl_loss *= T ** 2  # 溫度縮放補償# 計算硬標簽損失ce_loss = ce_criterion(student_logits, labels)# 總損失:KL散度損失 + 交叉熵損失total_loss = kl_weight * kl_loss + ce_weight * ce_losstotal_loss.backward()optimizer.step()running_loss += total_loss.item()print(f"蒸餾輪次 {epoch+1}/{epochs}, 總損失: {running_loss/len(train_loader):.4f}")def test(model, test_loader, device):"""測試模型準確率"""model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f"測試準確率: {accuracy:.2f}%")return accuracy######################################################################
# 執行流程
######################################################################
if __name__ == "__main__":# 1. 訓練教師模型torch.manual_seed(42)teacher = TeacherNet().to(device)print("===== 訓練教師模型 =====")train_baseline(teacher, train_loader, epochs=10, learning_rate=0.001, device=device)teacher_acc = test(teacher, test_loader, device)# 2. 初始化學生模型(與基線對比)torch.manual_seed(42)student_baseline = StudentNet().to(device)print("\n===== 訓練基線學生模型(無蒸餾) =====")train_baseline(student_baseline, train_loader, epochs=10, learning_rate=0.001, device=device)student_baseline_acc = test(student_baseline, test_loader, device)# 3. 用KL散度蒸餾訓練學生模型torch.manual_seed(42)student_distill = StudentNet().to(device)print("\n===== 用KL散度蒸餾訓練學生模型 =====")train_distillation(teacher=teacher,student=student_distill,train_loader=train_loader,epochs=10,learning_rate=0.001,T=2,  # 溫度參數kl_weight=0.25,  # KL散度損失權重ce_weight=0.75,  # 交叉熵損失權重device=device)student_distill_acc = test(student_distill, test_loader, device)# 4. 結果對比print("\n===== 最終結果 =====")print(f"教師模型準確率: {teacher_acc:.2f}%")print(f"基線學生模型準確率: {student_baseline_acc:.2f}%")print(f"KL蒸餾學生模型準確率: {student_distill_acc:.2f}%")

KL散度的數學定義

對于兩個概率分布 PPP(教師模型的輸出分布)和 QQQ(學生模型的輸出分布),KL散度的公式為:
KL(P∥Q)=EP[log?P?log?Q]=∑xP(x)?(log?P(x)?log?Q(x))\text{KL}(P \parallel Q) = \mathbb{E}_P \left[ \log P - \log Q \right] = \sum_x P(x) \cdot \left( \log P(x) - \log Q(x) \right) KL(PQ)=EP?[logP?logQ]=x?P(x)?(logP(x)?logQ(x))
其中:

  • P(x)P(x)P(x) 是教師模型輸出的概率分布(經溫度軟化后);
  • Q(x)Q(x)Q(x) 是學生模型輸出的概率分布(經溫度軟化后);
  • EP\mathbb{E}_PEP? 表示對分布 PPP 取期望(即對所有樣本平均)。

代碼中KL散度公式的體現

train_distillation函數中,kl_loss的計算完全對應上述公式,具體如下:

  1. 定義分布 PPPQQQ

教師模型的輸出(logits)經溫度 TTT 軟化后,通過softmax得到概率分布 PPP

teacher_soft = nn.functional.softmax(teacher_logits / T, dim=-1)  # P(x)

學生模型的輸出(logits)經相同溫度 TTT 軟化后,通過log_softmax得到 log?Q(x)\log Q(x)logQ(x)(因為log_softmax等價于對softmax的結果取對數):

student_soft = nn.functional.log_softmax(student_logits / T, dim=-1)  # log Q(x)
  1. 計算KL散度的核心部分 (重點看這里)

代碼中通過以下一行實現KL散度的求和與期望計算:

   kl_loss = torch.sum(teacher_soft * (teacher_soft.log() - student_soft)) / inputs.size(0)

teacher_soft.log() 對應 log?P(x)\log P(x)logP(x)
student_soft 對應 log?Q(x)\log Q(x)logQ(x)
teacher_soft * (teacher_soft.log() - student_soft) 對應 P(x)?(log?P(x)?log?Q(x))P(x) \cdot (\log P(x) - \log Q(x))P(x)?(logP(x)?logQ(x))
torch.sum(...) 對應公式中的求和 ∑x\sum_xx?
除以 inputs.size(0)(批量大小)對應取期望 EP\mathbb{E}_PEP?(對批量內樣本平均)。

  1. 溫度補償

最后乘以 T2T^2T2 是為了抵消溫度對梯度的影響(Hinton原始論文中的標準處理),但不改變KL散度的核心公式:

   kl_loss *= T **2  # 溫度縮放補償

重點說完了,說總體,讓學生同時學習“教師的軟標簽知識”和“真實硬標簽任務”

1. 函數參數:控制蒸餾過程的關鍵變量

def train_distillation(teacher, student, train_loader, epochs, learning_rate, T, kl_weight, ce_weight, device):
  • teacher:已訓練好的教師模型(提供知識的“強模型”);
  • student:需要訓練的學生模型(需要學習知識的“弱模型”);
  • train_loader:訓練數據加載器(提供輸入和真實標簽);
  • epochs:訓練輪次(遍歷數據集的次數);
  • learning_rate:學生模型的學習率;
  • T:溫度參數(控制教師/學生輸出分布的“平滑度”,值越大分布越平緩);
  • kl_weight:KL散度損失(蒸餾損失)的權重;
  • ce_weight:交叉熵損失(硬標簽損失)的權重;
  • device:訓練設備(GPU/CPU)。

2. 初始化:損失函數與優化器

ce_criterion = nn.CrossEntropyLoss()  # 硬標簽損失函數(真實標簽任務)
optimizer = optim.Adam(student.parameters(), lr=learning_rate)  # 只優化學生模型參數
  • CrossEntropyLoss計算學生模型對“真實硬標簽”的損失(保證學生不偏離基礎任務);
  • 優化器僅針對student.parameters(),因為教師模型參數固定,不需要更新。

3. 模型模式設置:固定教師,訓練學生

teacher.eval()  # 教師設為評估模式(關閉dropout等訓練特有的層,輸出穩定)
student.train()  # 學生設為訓練模式(啟用參數更新)
  • 教師模型處于eval模式:確保其輸出穩定(不受訓練時隨機因素影響),且不計算梯度(節省資源);
  • 學生模型處于train模式:允許其參數通過反向傳播更新。

4. 核心訓練循環:每輪迭代訓練

for epoch in range(epochs):  # 按輪次遍歷running_loss = 0.0  # 累計每輪總損失for inputs, labels in train_loader:  # 按批量遍歷數據# 數據轉移到設備inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()  # 清零梯度(避免上一批次梯度殘留)
  • 外層循環控制訓練輪次(epochs),確保模型多次學習數據;
  • 內層循環按批量處理數據,每次處理一個batch的輸入(圖像)和標簽(真實類別);
  • optimizer.zero_grad():清除上一批次計算的梯度,避免干擾當前批次。

5. 教師模型輸出:提供“軟標簽知識”

# 教師輸出(不計算梯度,固定參數)
with torch.no_grad():teacher_logits = teacher(inputs)
  • with torch.no_grad():禁用教師模型的梯度計算(教師參數不更新,節省內存和計算資源);
  • teacher_logits:教師模型的原始輸出(未經過softmax的“對數幾率”),后續將用于生成“軟標簽”。

6. 學生模型輸出:學習并生成待優化的預測

# 學生輸出(需要計算梯度,更新參數)
student_logits = student(inputs)
  • student_logits:學生模型的原始輸出(對數幾率),后續用于計算與教師的差異(蒸餾損失)和與真實標簽的差異(硬標簽損失)。

7. 計算KL散度損失:蒸餾的核心(學習教師知識)

# 教師分布:softmax(teacher_logits / T) → 軟化的概率分布(軟標簽)
teacher_soft = nn.functional.softmax(teacher_logits / T, dim=-1)
# 學生分布:log_softmax(student_logits / T) → 軟化的對數概率分布
student_soft = nn.functional.log_softmax(student_logits / T, dim=-1)# 計算KL散度:衡量學生分布與教師分布的差異
kl_loss = torch.sum(teacher_soft * (teacher_soft.log() - student_soft)) / inputs.size(0)
kl_loss *= T ** 2  # 溫度縮放補償(Hinton論文標準操作)
  • 為什么用溫度T?
    溫度越高(如T=10),softmax輸出越平滑(小概率值被放大),教師的“知識細節”(如類別間的相似性)更明顯;溫度為1時,等價于普通softmax(硬標簽)。
  • KL散度公式對應
    嚴格遵循KL散度定義 KL(P∥Q)=∑P?(log?P?log?Q)\text{KL}(P \parallel Q) = \sum P \cdot (\log P - \log Q)KL(PQ)=P?(logP?logQ),其中 PPP 是教師軟分布(teacher_soft),QQQ 是學生軟分布(student_soft對應的概率)。
  • 除以inputs.size(0):對批量內樣本取平均,得到單樣本的KL損失;
  • 乘以T2:補償溫度對梯度的影響(溫度升高會使梯度變小,乘以T2可抵消這一效應)。

8. 計算硬標簽損失:保證學生不偏離真實任務

# 學生對真實標簽的交叉熵損失
ce_loss = ce_criterion(student_logits, labels)
  • 用學生的原始輸出(student_logits)與真實標簽(labels)計算交叉熵損失,確保學生在學習教師知識的同時,不忘記“正確答案”。

9. 總損失:平衡教師知識與真實任務

# 加權求和:KL損失(教師知識) + 交叉熵損失(真實標簽)
total_loss = kl_weight * kl_loss + ce_weight * ce_loss
  • 通過kl_weightce_weight控制兩者的重要性(例如代碼中kl_weight=0.25ce_weight=0.75,表示更側重真實標簽)。

10. 反向傳播與參數更新:只優化學生

total_loss.backward()  # 計算總損失對學生參數的梯度
optimizer.step()  # 根據梯度更新學生參數
  • 梯度僅從學生模型計算(教師模型無梯度),最終只有學生模型的參數被優化。

11. 損失監控:跟蹤訓練過程

running_loss += total_loss.item()  # 累計批量損失
# 每輪結束后打印平均損失
print(f"蒸餾輪次 {epoch+1}/{epochs}, 總損失: {running_loss/len(train_loader):.4f}")
  • 實時監控損失變化,判斷模型是否在有效學習(通常損失應逐漸下降并趨于穩定)。
類型本質代碼對應作用
教師的軟標簽知識教師輸出的平滑概率分布teacher_soft(經T軟化的softmax)傳遞類別間的相似性知識,幫助學生學習更魯棒的特征
真實硬標簽任務數據集中的真實類別標簽labelsce_loss保證學生不偏離基礎分類任務,避免完全“模仿教師錯誤”

還可以繼續的學習

“瞎折騰”參數和模型,看看知識蒸餾到底在什么情況下最管用,以及它的“能力邊界”在哪兒

1. 調“溫度”試試 溫度參數 TTT 對蒸餾效果的影響

做法:把代碼里的 T 改成1、5、10這些不同的數,其他別動,看看學生模型準確率變不變。 例如固定其他參數(如kl_weight=0.25),測試不同溫度值(如 T=1,2,5,10,20T=1, 2, 5, 10, 20T=1,2,5,10,20)下學生模型的準確率。
目的:溫度就像“老師說話的委婉程度”——溫度低,老師只說“這個一定對”(硬邦邦);溫度高,老師會說“這個可能對,那個也有點像”(更細膩)。試試哪種“說話方式”能讓學生學得更好。 理解溫度如何調節教師“軟標簽”的平滑度。

2. 調“聽老師”和“聽標準答案”的比例 損失權重(kl_weightkl\_weightkl_weightce_weightce\_weightce_weight)的平衡實驗

做法:把 kl_weight(聽老師的權重)和 ce_weight(聽標準答案的權重)換成不同組合,看學生成績。 例如固定 T=2T=2T=2,測試不同權重組合(如 (0.1, 0.9), (0.5, 0.5), (0.9, 0.1))對學生性能的影響
目的:學生既要學老師的“經驗”,又不能完全忘了“標準答案”。試試多聽老師點好,還是多信標準答案點好。 分析“教師知識”與“真實標簽”的權重如何影響學生學習。

3. 老師自己學得好不好,影響學生嗎? 教師模型性能對蒸餾的影響

做法:先讓老師少學幾輪(比如只學5輪,學得差點),或者多學幾輪(比如學20輪,學得好點),再讓它教學生,看學生成績差多少。
目的:想知道“老師越厲害,學生是不是一定越厲害”。比如老師自己考80分,能不能教出考75分的學生?老師考90分,學生能到85分嗎?

4. 學生太“笨”了,老師還能教好嗎? 學生模型復雜度的極限測試

做法:把學生模型改得更簡單(比如少一層卷積,少點參數),再用蒸餾訓練,看它還能不能比自己學(不蒸餾)強。
目的:測試蒸餾的“底線”——如果學生太簡單(比如只有一層神經網絡),老師再厲害,是不是也教不會?

5. 換種“衡量學生和老師差異的方式”行不行? 與其他蒸餾損失函數的對比

做法:不用現在的KL散度,換成“均方誤差”(MSE)來算學生和老師的差異,其他不變,看學生成績變不變。
目的:現在用的KL散度就像“比較兩個概率分布像不像”,換種方式(比如直接比數值差),會不會更好用?

6. 給數據“加戲”,學生學得更好嗎? 數據增強對蒸餾的影響

做法:訓練時給圖片加些變化(比如隨機裁剪一塊、左右翻轉),再做蒸餾,看學生在測試集上表現會不會更好。
目的:就像學生平時做難題練習,考試時更從容。給訓練數據加變化,是不是能讓學生和老師都學更扎實?

7. 學生學幾輪效果最好? 蒸餾輪次的影響

做法:把蒸餾的 epochs 改成5、20這些數,看看學5輪、10輪、20輪,學生成績是不是一直漲,還是學太久反而變差。
目的:避免“學過頭”——就像人做題,做10套題可能進步快,做100套可能記住答案了,但換題就不會了。

8. 學生是不是真的“又小又能打”? 模型輕量化指標的驗證

做法:算一算老師、普通學生、蒸餾學生的“參數數量”(模型大小)和“做題速度”(每秒處理多少張圖)。
目的:蒸餾的核心是“讓小模型有大模型的本事”。得驗證一下:蒸餾后的學生是不是確實比老師小很多,但成績接近;同時跑起來比老師快。

9. 學生能“舉一反三”嗎? 跨數據集泛化測試

做法:用CIFAR-10訓練好的蒸餾學生,去做類似的題(比如CIFAR-100里的部分類別),看它比普通學生表現好多少。
目的:好的學生不僅會做學過的題,還能應付新題。蒸餾是不是能讓學生學到更通用的“解題思路”?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/94238.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/94238.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/94238.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

從結構到交互:HTML5進階開發全解析——語義化標簽、Canvas繪圖與表單設計實戰

一、語義化標簽進階&#xff1a;重構頁面結構的「邏輯語言」 在 HTML5 的舞臺上&#xff0c;語義化標簽是熠熠生輝的主角&#xff0c;它們為網頁賦予了清晰的邏輯結構&#xff0c;使其更易被搜索引擎理解和被開發者維護。其中&#xff0c;<section>與<article>標簽…

標準七層網絡協議和TCP/IP四層協議的區別

分別是什么? OSI七層協議是國際標準組織制定的標準協議。其中七層分別是物理層,數據鏈路層,網絡層,傳輸層,會話層,表示層,應用層。 TCP/IP協議是美國軍方在后期網絡技術的發展中提出來的符合目前現狀的協議。其中四層分別是網絡接口層對應七層中的物理層和數據鏈路層,…

前端面試手撕題目全解析

以下是前端面試中常遭遇的“手撕”基礎題目匯總&#xff0c;涵蓋 HTML→JS→Vue→React&#xff0c;每題附經典實現&#xff0f;原理解析&#xff0c;可現場答題或后端總結。 HTML 基礎題 &#x1f4dd; 語義化卡片&#xff08;Semantic Card ARIA&#xff09; <article cl…

道格拉斯-普克算法 - 把一堆復雜的線條變得簡單,同時盡量保持原來的樣子

道格拉斯-普克算法 - 把一堆復雜的線條變得簡單&#xff0c;同時盡量保持原來的樣子 flyfish 道格拉斯-普克算法&#xff08;Douglas-Peucker Algorithm解決的問題其實很日常&#xff1a;把一堆復雜的線條&#xff08;比如地圖上的道路、河流&#xff0c;或者GPS記錄的軌跡&…

團購商城 app 系統架構分析

一、引言 團購商城 APP 作為一種融合了電子商務與團購模式的應用程序&#xff0c;近年來在市場上取得了顯著的發展。它為用戶提供了便捷的購物體驗&#xff0c;同時也為商家創造了更多的銷售機會。一個完善且高效的系統架構是保障團購商城 APP 穩定運行、提供優質服務的基礎。本…

【AI平臺】n8n入門7:本地n8n更新

?0、前言 目標&#xff1a;本地n8n部署后&#xff0c;有新版本&#xff0c;然后進行更新。官方文檔&#xff1a;Docker | n8n Docs特別說明&#xff1a; n8n鏡像更新后&#xff0c;容器重建&#xff0c;所以之前在n8n配置的東西&#xff0c;就莫有了&#xff0c;工作流提前導…

還在使用Milvus向量庫?2025-AI智能體選型架構防坑指南

前言說明&#xff1a;數據來源&#xff1a;主要基于 Milvus&#xff08;v2.3&#xff09;和 Qdrant&#xff08;v1.8&#xff09;的最新穩定版&#xff0c;參考官方文檔、GitHub Issues、CNCF報告、以及第三方評測&#xff08;如DB-Engines、TechEmpower&#xff09;。評估原則…

3-verilog的使用-1

verilog的使用-1 1.判斷上升沿 reg s_d0; reg s_d1; wire signal_up ; //判斷信號的上升沿 assign signal_up (~touch_key_d1) & touch_key_d0; always (posedge clk or negedge rst_n) beginif(rst_n 1b0) begins_d0< 1b0;s_d1< 1b0;endelse begins_d0&…

ESXI虛擬交換機 + H3C S5120交換機 + GR5200路由器組網筆記

文章目錄一、組網拓撲與核心邏輯1. 拓撲結構2. 核心邏輯二、詳細規劃方案1. VLAN 與 IP 地址規劃2. 設備連接規劃三、配置步驟1. H3C S5120 交換機配置&#xff08;VLAN 與端口&#xff09;2. H3C GR5200 路由器配置&#xff08;路由、網關、NAT&#xff09;3. ESXi 虛擬交換機…

python的駕校培訓預約管理系統

前端開發框架:vue.js 數據庫 mysql 版本不限 后端語言框架支持&#xff1a; 1 java(SSM/springboot)-idea/eclipse 2.NodejsVue.js -vscode 3.python(flask/django)–pycharm/vscode 4.php(thinkphp/laravel)-hbuilderx 數據庫工具&#xff1a;Navicat/SQLyog等都可以 該系統通…

webrtc弱網-QualityScaler 源碼分析與算法原理

一. 核心功能QualityScaler 是 WebRTC 中用于動態調整視頻編碼質量的模塊&#xff0c;主要功能包括&#xff1a;QP 監控&#xff1a;持續監測編碼器輸出的量化參數&#xff08;QP&#xff09;丟幀率分析&#xff1a;跟蹤媒體優化和編碼器導致的丟幀情況自適應決策&#xff1a;根…

Maven 快照(SNAPSHOT)

Maven 快照(SNAPSHOT) 引言 Maven 快照(SNAPSHOT)是 Maven 中的一個重要概念,主要用于版本管理。它允許開發者在構建過程中使用尚未發布的版本。本文將詳細介紹 Maven 快照的原理、用途以及如何在項目中配置和使用快照。 Maven 快照原理 Maven 快照是版本號的一部分,…

2025-0803學習記錄20——畢業論文快速整理成小論文

本科畢業論文寫好啦&#xff0c;但是C導要我整理成一篇約8000字的小論文&#xff0c;準備投稿。畢業論文到投稿的小論文&#xff0c;這其實是從“全景展示”到“聚焦精煉”的過程。目前我已經有完整的大論文&#xff08;約6萬字&#xff09;&#xff0c;材料是充足的&#xff0…

VUE2 學習筆記16 插槽、Vuex

插槽在編寫組件時&#xff0c;可能存在這種情況&#xff0c;頁面需要顯示不同的內容&#xff0c;但是頁面結構是類似的&#xff0c;在這種情況下&#xff0c;雖然也可以使用傳參來進行&#xff0c;但傳參時&#xff0c;還需要編寫props等邏輯&#xff0c;略顯重復&#xff0c;而…

IntelliJ IDEA開發編輯器摸魚看股票數據

在IDEA的插件市場中心搜索stock&#xff0c;檢索結果里面的插件&#xff0c;點擊安裝即可安裝后的效果

Linux Deepin深度操作系統應用商店加載失敗,安裝星火應用商店

Linux Deepin國產操作系統優點 Deepin&#xff08;原名Linux Deepin&#xff09;是一款由中國團隊開發的Linux發行版&#xff0c;基于Debian stable分支&#xff0c;以美觀易用的界面和本土化體驗著稱。以下是其核心優點總結&#xff1a; 1. 極致美觀的界面設計 Deepin Deskt…

postgresql創建只讀用戶并授權

postgresql創建只讀用戶并授權 CREATE USER yk WITH ENCRYPTED PASSWORD <your_password>;GRANT USAGE ON SCHEMA public to yk; GRANT SELECT ON ALL TABLES IN SCHEMA public TO yk;根據以上創建的用戶&#xff0c;出現一個問題&#xff0c;對新建的表沒有查詢權限&am…

pytest vs unittest: 區別與優缺點比較

主要區別特性pytestunittest起源第三方庫Python標準庫語法風格更簡潔的Pythonic語法基于Java風格的JUnit測試發現自動發現測試需要繼承TestCase類斷言方式使用Python原生assert使用各種assert方法(assertEqual等)夾具系統強大的fixture系統簡單的setUp/tearDown方法參數化測試內…

Boost.Asio學習(5):c++的協程

協程是什么&#xff1f;協程就是可以“暫停”和“繼續”的函數&#xff0c;像在函數里打個斷點&#xff0c;然后以后可以從斷點繼續運行&#xff0c;而不是重新開始。線程 vs 協程&#xff1a;類比想象你在寫小說&#xff1a;線程&#xff1a;你開了 3 個作者&#xff08;線程&…

Linux 中,命令查看系統版本和內核信息

在 Linux 中&#xff0c;可以通過以下命令查看系統版本和內核信息&#xff1a;1. 查看內核版本uname -a或精簡顯示&#xff1a;uname -r # 只顯示內核版本示例輸出&#xff1a;Linux ubuntu 5.4.0-135-generic #152-Ubuntu SMP Tue Nov 15 08:12:21 UTC 2022 x86_64 x86_64 x8…