知識蒸餾實戰:用PyTorch和預訓練模型提升小模型性能

在深度學習的浪潮中,我們常常追求更大、更深、更復雜的模型以達到最先進的性能。然而,這些“龐然大物”般的模型往往伴隨著高昂的計算成本和緩慢的推理速度,使得它們難以部署在資源受限的環境中,如移動設備或邊緣計算平臺。知識蒸餾(Knowledge Distillation)技術為此提供了一個優雅的解決方案:將一個大型、高性能的“教師模型”所學習到的“知識”遷移到一個小巧、高效的“學生模型”中。

本篇將一步步使用 PyTorch 實現一個知識蒸餾的案例,其中教師模型將采用預訓練模型。

什么是知識蒸餾?

知識蒸餾的核心思想是,訓練一個小型學生模型 (Student Model) 來模仿一個大型教師模型 (Teacher Model) 的行為。這種模仿不僅僅是學習教師模型對“硬標簽”(即真實標簽)的預測,更重要的是學習教師模型輸出的“軟標簽”(Soft Targets)。

  • 教師模型 (Teacher Model): 通常是一個已經訓練好的、性能優越的大型模型。例如,在計算機視覺領域,可以是 ImageNet 上預訓練的 ResNet、VGG 等。
  • 學生模型 (Student Model): 一個參數量較小、計算更高效的輕量級模型,我們希望它能達到接近教師模型的性能。
  • 軟標簽 (Soft Targets): 教師模型在輸出層(softmax之前,即logits)經過一個較高的“溫度”(Temperature, T)調整后的概率分布。高溫會使概率分布更平滑,從而揭示類別間的相似性信息,這些被稱為“暗知識”(Dark Knowledge)。
  • 硬標簽 (Hard Targets): 數據集的真實標簽。
  • 蒸餾損失 (Distillation Loss): 通常由兩部分組成:
    1. 學生模型在真實標簽上的損失(例如交叉熵損失)。
    2. 學生模型與教師模型軟標簽之間的損失(例如KL散度或均方誤差)。
      這兩部分損失通過一個超參數 a l p h a \\alpha alpha 來加權平衡。

PyTorch 實現步驟

接下來,我們將通過一個圖像分類的例子來演示如何實現知識蒸餾。假設我們的任務是對一個包含10個類別的圖像數據集進行分類。

1. 準備工作:導入庫和設置設備
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms # 用于數據預處理# 檢查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")
2. 定義教師模型 (Pre-trained ResNet18)

我們將使用 torchvision.models 中預訓練的 ResNet18 作為教師模型。為了適應我們自定義的分類任務(例如10分類),我們需要替換其原始的1000類全連接層。

class PretrainedTeacherModel(nn.Module):def __init__(self, num_classes, pretrained=True):super(PretrainedTeacherModel, self).__init__()# 加載預訓練的 ResNet18 模型# PyTorch 1.9+ 推薦使用 weights 參數if pretrained:self.resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)else:self.resnet = models.resnet18(weights=None) # 或者 models.resnet18(pretrained=False) for older versions# 獲取 ResNet18 原本的輸出特征數num_ftrs = self.resnet.fc.in_features# 替換最后的全連接層以適應我們的任務類別數self.resnet.fc = nn.Linear(num_ftrs, num_classes)def forward(self, x):return self.resnet(x)

在蒸餾過程中,教師模型的參數通常是固定的,不參與訓練。

3. 定義學生模型

學生模型應該是一個比教師模型更小、更輕量的網絡。這里我們定義一個簡單的卷積神經網絡 (CNN)。

class StudentCNNModel(nn.Module):def __init__(self, num_classes):super(StudentCNNModel, self).__init__()# 輸入通道數為3 (RGB圖像), 假設輸入圖像大小為 32x32self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 32x32 -> 16x16self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 16x16 -> 8x8# 展平后的特征數: 32 channels * 8 * 8self.fc = nn.Linear(32 * 8 * 8, num_classes)def forward(self, x):out = self.pool1(self.relu1(self.conv1(x)))out = self.pool2(self.relu2(self.conv2(x)))out = out.view(out.size(0), -1) # 展平out = self.fc(out)return out
4. 定義蒸餾損失函數

這是知識蒸餾的核心。損失函數結合了學生模型在硬標簽上的性能和與教師模型軟標簽的匹配程度。

  • L _ h a r d L\_{hard} L_hard: 學生模型輸出與真實標簽之間的交叉熵損失。
  • L _ s o f t L\_{soft} L_soft: 學生模型的軟化輸出與教師模型的軟化輸出之間的KL散度。
  • 總損失 L = a l p h a c d o t L _ h a r d + ( 1 ? a l p h a ) c d o t L _ s o f t c d o t T 2 L = \\alpha \\cdot L\_{hard} + (1 - \\alpha) \\cdot L\_{soft} \\cdot T^2 L=alphacdotL_hard+(1?alpha)cdotL_softcdotT2
    • T T T 是溫度參數。較高的 T T T 會使概率分布更平滑。
    • a l p h a \\alpha alpha 是平衡兩個損失項的權重。
    • L _ s o f t L\_{soft} L_soft 乘以 T 2 T^2 T2 是為了確保軟標簽損失的梯度與硬標簽損失的梯度在量級上大致相當。
class DistillationLoss(nn.Module):def __init__(self, alpha, temperature):super(DistillationLoss, self).__init__()self.alpha = alphaself.temperature = temperatureself.criterion_hard = nn.CrossEntropyLoss() # 硬標簽損失# reduction='batchmean' 會將KL散度在batch維度上取平均,這在很多實現中是常見的self.criterion_soft = nn.KLDivLoss(reduction='batchmean') # 軟標簽損失def forward(self, student_logits, teacher_logits, labels):# 硬標簽損失loss_hard = self.criterion_hard(student_logits, labels)# 軟標簽損失# 使用 softmax 和 temperature 來計算軟標簽和軟預測# 注意:KLDivLoss期望的輸入是 (log_probs, probs)soft_teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)soft_student_log_probs = F.log_softmax(student_logits / self.temperature, dim=1)# 計算KL散度損失loss_soft = self.criterion_soft(soft_student_log_probs, soft_teacher_probs) * (self.temperature ** 2)# 總損失loss = self.alpha * loss_hard + (1 - self.alpha) * loss_softreturn loss
5. 訓練流程

現在我們將所有部分組合起來進行訓練。

# --- 示例參數 ---
num_classes = 10  # 假設我們的任務是10分類
img_channels = 3
img_height = 32
img_width = 32learning_rate = 0.001
num_epochs = 20 # 實際應用中需要更多 epochs 和真實數據
batch_size = 32
temperature = 4.0 # 蒸餾溫度
alpha = 0.3       # 硬標簽損失的權重# --- 實例化模型 ---
teacher_model = PretrainedTeacherModel(num_classes=num_classes, pretrained=True).to(device)
teacher_model.eval() # 教師模型設為評估模式,不更新其權重student_model = StudentCNNModel(num_classes=num_classes).to(device)# --- 準備優化器和損失函數 ---
optimizer = optim.Adam(student_model.parameters(), lr=learning_rate) # 只優化學生模型的參數
distillation_criterion = DistillationLoss(alpha=alpha, temperature=temperature).to(device)# --- 生成一些虛擬圖像數據進行演示 ---
# !!! 警告: 實際應用中必須使用真實數據加載器 (DataLoader) 和正確的預處理 !!!
# 預訓練模型通常對輸入有特定的歸一化要求。
# 例如,ImageNet預訓練模型通常使用:
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# 并且輸入尺寸也需要匹配,或進行適當調整。
# 本例中學生模型接收 32x32 輸入,教師模型(ResNet)通常處理更大圖像如 224x224。
# 為簡化,我們假設教師模型能處理學生模型的輸入尺寸,或者在教師模型前對輸入進行適配。
dummy_inputs = torch.randn(batch_size, img_channels, img_height, img_width).to(device)
dummy_labels = torch.randint(0, num_classes, (batch_size,)).to(device)print("開始訓練學生模型...")
# --- 訓練學生模型 ---
for epoch in range(num_epochs):student_model.train() # 學生模型設為訓練模式# 獲取教師模型的輸出 (logits)with torch.no_grad(): # 教師模型的權重不更新# 如果教師模型和學生模型期望的輸入尺寸不同,需要適配# teacher_input_adjusted = F.interpolate(dummy_inputs, size=(224, 224), mode='bilinear', align_corners=False) # 示例調整# teacher_logits = teacher_model(teacher_input_adjusted)teacher_logits = teacher_model(dummy_inputs) # 假設教師模型可以處理此尺寸或已適配# 前向傳播 - 學生模型student_logits = student_model(dummy_inputs)# 計算蒸餾損失loss = distillation_criterion(student_logits, teacher_logits, dummy_labels)# 反向傳播和優化optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 5 == 0 or epoch == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')print("學生模型訓練完成!")# (可選) 保存學生模型
# torch.save(student_model.state_dict(), 'student_cnn_distilled.pth')
# print("蒸餾后的學生CNN模型已保存。")

關鍵點與最佳實踐

  1. 數據預處理: 對于預訓練的教師模型,其輸入數據必須經過與預訓練時相同的預處理(如歸一化、尺寸調整)。這是確保教師模型發揮其最佳性能并傳遞有效知識的關鍵。
  2. 輸入兼容性: 確保教師模型和學生模型接收的輸入在語義上是一致的。如果它們的網絡結構原生接受不同尺寸的輸入,你可能需要調整輸入數據(例如,通過插值 F.interpolate)以適應教師模型,或者確保兩個模型都能處理相同的輸入。
  3. 超參數調優: alpha, temperature, learning_rate 等超參數對蒸餾效果至關重要。通常需要通過實驗來找到最佳組合。較高的 temperature 可以讓學生學習到更多類別間的細微差別,但過高可能會導致信息模糊。
  4. 教師模型的選擇: 教師模型越強大,通常能傳遞的知識越多。但也要考慮其推理成本(即使只在訓練時)。
  5. 學生模型的設計: 學生模型不應過于簡單,以至于無法吸收教師的知識;也不應過于復雜,從而失去蒸餾的意義。
  6. 訓練時長: 知識蒸餾通常需要足夠的訓練輪次才能讓學生模型充分學習。
  7. 不僅僅是 Logits: 本文介紹的是最常見的基于 Logits 的蒸餾。還有其他蒸餾方法,例如匹配教師模型和學生模型中間層的特征表示(Feature Distillation),這有時能帶來更好的效果。

在這里插入圖片描述


本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/81815.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/81815.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/81815.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

python:mysql全局大覽(保姆級教程)

本文目錄: 一、關于數據庫**二、sql語言分類**三、數據庫增刪改查操作**四、庫中表增刪改查操作**五、表中記錄插入**六、表約束**七、單表查詢**八、多表查詢**(一)外鍵約束**(二)連結查詢**1.交叉連接(笛…

Android framework 問題記錄

一、休眠喚醒,很快熄屏 1.1 問題描述 機器休眠喚醒后,沒有按照約定的熄屏timeout 進行熄屏,很快就熄屏(約2s~3s左右) 1.2 原因分析: 抓取相關log,打印休眠背光 相關調用棧 //具體打印調用棧…

怎么利用JS根據坐標判斷構成單個多邊形是否合法

怎么利用JS根據坐標判斷構成單個多邊形是否合法 引言 在GIS(地理信息系統)、游戲開發、計算機圖形學等領域,判斷一組坐標點能否構成合法的簡單多邊形(Simple Polygon)是一個常見需求。合法多邊形需要滿足幾何學上的基本規則,本文將詳細介紹如何使用JavaScript實現這一判…

sqlite的拼接字段的方法(sqlite沒有convert函數)

我在sqlserver 操作方式&#xff1a; /// <summary>///獲取當前門店工資列表/// </summary>/// <param name"wheres">其他條件</param>/// <param name"ThisMendian">當前門店</param>/// <param name"IsNotU…

構建高效移動端網頁調試流程:以 WebDebugX 為核心的工具、技巧與實戰經驗

現代前端開發早已不僅僅局限于桌面瀏覽器。隨著 Hybrid 應用、小程序、移動 Web 的廣泛應用&#xff0c;開發者日常面臨的一個關鍵挑戰是&#xff1a;如何在移動設備上快速定位并解決問題&#xff1f; 這不再是“打開 DevTools 查查 Console”的問題&#xff0c;而是一個關于設…

新興技術與安全挑戰

7.1 云原生安全(K8s安全、Serverless防護) 核心風險與攻擊面 Kubernetes配置錯誤: 風險:默認開放Dashboard未授權訪問(如kubectl proxy未鑒權)。防御:啟用RBAC,限制ServiceAccount權限。Serverless函數注入: 漏洞代碼(AWS Lambda):def lambda_handler(event, cont…

《算法筆記》11.7小節——動態規劃專題->背包問題 問題 C: 貨幣系統

題目描述 母牛們不但創建了他們自己的政府而且選擇了建立了自己的貨幣系統。 [In their own rebellious way],&#xff0c;他們對貨幣的數值感到好奇。 傳統地&#xff0c;一個貨幣系統是由1,5,10,20 或 25,50, 和 100的單位面值組成的。 母牛想知道有多少種不同的方法來用貨幣…

SN生成流水號并且打亂

目前公司的產品會通過sn綁定賬號&#xff0c;但是會出現一個問題&#xff0c;流水號會容易被人猜出來導致被他人在未授權的情況下使用&#xff0c;所以開發了一個生成流水號后打亂的python程序&#xff0c;比如輸入sn的前11位后&#xff0c;后面的字符所有的排列組合有26^4方種…

msq基礎

一、檢索數據 SELECT語句 1.檢索單個列 SELECT prod_name FROM products 上述語句用SELECT語句從products表中檢索一個名prod_name的列&#xff0c;所需列名在SELECT關鍵字之后給出&#xff0c;FROM關鍵字指出從其中檢索數據的表名 &#xff08;返回數據的順序可能是數據…

【回溯 剪支 狀態壓縮】# P10419 [藍橋杯 2023 國 A] 01 游戲|普及+

本文涉及知識點 C回溯 位運算、狀態壓縮、枚舉子集匯總 P10419 [藍橋杯 2023 國 A] 01 游戲 題目描述 小藍最近玩上了 01 01 01 游戲&#xff0c;這是一款帶有二進制思想的棋子游戲&#xff0c;具體來說游戲在一個大小為 N N N\times N NN 的棋盤上進行&#xff0c;棋盤…

2025華為OD機試真題+全流程解析+備考攻略+經驗分享+Java/python/JavaScript/C++/C/GO六種語言最佳實現

華為OD全流程解析&#xff0c;備考攻略 快捷目錄 華為OD全流程解析&#xff0c;備考攻略一、什么是華為OD&#xff1f;二、什么是華為OD機試&#xff1f;三、華為OD面試流程四、華為OD薪資待遇及職級體系五、ABCDE卷類型及特點六、題型與考點七、機試備考策略八、薪資與轉正九、…

深入解析DICOM標準:文件結構、元數據、影像數據與應用

&#x1f9d1; 博主簡介&#xff1a;CSDN博客專家、CSDN平臺優質創作者&#xff0c;高級開發工程師&#xff0c;數學專業&#xff0c;10年以上C/C, C#, Java等多種編程語言開發經驗&#xff0c;擁有高級工程師證書&#xff1b;擅長C/C、C#等開發語言&#xff0c;熟悉Java常用開…

Visual Studio 2022 插件推薦

Visual Studio 2022 插件推薦 Visual Studio 2022 (簡稱 VS2022) 是一款強大的 IDE&#xff0c;適合各類系統組件、框架和應用的開發。插件是接入 VS2022 最重要的擴展方式之一&#xff0c;它們可以大幅提升開發效率、優化代碼質量&#xff0c;并提供強大的調試和分析功能。 …

OBS Studio:windows免費開源的直播與錄屏軟件

OBS Studio是一款免費、開源且跨平臺的直播與錄屏軟件。其支持 Windows、macOS 和 Linux。OBS適用于&#xff0c;有直播需求的人群或錄屏需求的人群。 Stars 數64,323Forks 數8413 主要特點 推流&#xff1a;OBS Studio 支持將視頻實時推流至多個平臺&#xff0c;如 YouTube、…

SCAU--平衡樹

3 平衡樹 Time Limit:1000MS Memory Limit:65535K 題型: 編程題 語言: G;GCC;VC;JAVA;PYTHON 描述 平衡樹并不是平衡二叉排序樹。 這里的平衡指的是左右子樹的權值和差距盡可能的小。 給出n個結點二叉樹的中序序列w[1],w[2],…,w[n]&#xff0c;請構造平衡樹&#xff0c…

Docker容器鏡像與容器常用操作指南

一、鏡像基礎操作 搜索鏡像 docker search <鏡像名>在Docker Hub中查找公開鏡像&#xff0c;例如&#xff1a; docker search nginx拉取鏡像 docker pull <鏡像名>:<標簽>從倉庫拉取鏡像到本地&#xff0c;標簽默認為latest&#xff1a; docker pull nginx:a…

TDengine 更多安全策略

簡介 上一節我們介紹了 TDengine 安全部署配置建議&#xff0c;除了傳統的這些配置外&#xff0c;TDengine 還有其他的安全策略&#xff0c;例如 IP 白名單、審計日志、數據加密等&#xff0c;這些都是 TDengine Enterprise 特有功能&#xff0c;其中白名單功能在 3.2.0.0 版本…

小白入門:GitHub 遠程倉庫使用全攻略

一、Git 核心概念 1. 三個工作區域 工作區&#xff08;Working Directory&#xff09;&#xff1a;實際編輯文件的地方。 暫存區&#xff08;Staging Area&#xff09;&#xff1a;準備提交的文件集合&#xff08;使用git add操作&#xff09;。 本地倉庫&#xff08;Local…

[創業之路-370]:企業戰略管理案例分析-10-戰略制定-差距分析的案例之小米

戰略制定-差距分析的案例之小米 在戰略制定過程中&#xff0c;小米通過差距分析明確自身與市場機會之間的差距&#xff0c;并制定針對性戰略&#xff0c;實現快速發展。以下以小米在智能手機市場的機會差距分析為例&#xff0c;說明其戰略制定過程。 一、市場機會識別與差距分…

Index-AniSora模型論文速讀:基于人工反饋的動漫視頻生成

Aligning Anime Video Generation with Human Feedback 一、引言 論文開頭指出&#xff0c;盡管視頻生成模型不斷涌現&#xff0c;但動漫視頻生成面臨動漫數據稀缺和運動模式異常的挑戰&#xff0c;導致生成視頻存在運動失真和閃爍偽影等問題&#xff0c;難以滿足人類偏好。現…