理解知識蒸餾中的散度損失函數(KLDivergence/kldivloss )-以DeepSeek為例

1. 知識蒸餾簡介

什么是知識蒸餾?

知識蒸餾(Knowledge Distillation)是一種模型壓縮技術,目標是讓一個較小的模型(學生模型,Student Model)學習一個較大、性能更優的模型(教師模型,Teacher Model)的知識。這樣,我們可以在保持較高準確率的同時,大幅減少計算和存儲成本。

為什么需要知識蒸餾?

  • 降低計算成本:大模型(如 DeepSeek、GPT-4)通常計算量巨大,不適合部署到移動設備或邊緣設備上。
  • 加速推理:較小的模型可以更快地推理,減少延遲。
  • 減少內存占用:適用于資源受限的環境,如嵌入式設備或低功耗服務器。

知識蒸餾的核心思想是:學生模型不僅僅學習教師模型的硬標簽(one-hot labels),更重要的是學習教師模型輸出的概率分布,從而獲得更豐富的表示能力。

2. KL 散度的數學原理

2.1 KL 散度公式

在知識蒸餾過程中,我們通常使用Kullback-Leibler 散度(KL Divergence) 來衡量兩個概率分布(教師模型和學生模型)之間的差異。

2.2 直觀理解

KL 散度可以理解為如果用分布 Q 來近似分布 P,會損失多少信息

  • 當 KL 散度為 0,表示兩個分布完全相同。
  • KL 散度不是對稱的,即 D_{KL}(P || Q) \neq D_{KL}(Q || P)

3. DeepSeek 中的 KL 散度應用

DeepSeek 作為一個強大的開源大語言模型(LLM),在模型蒸餾時廣泛使用了 KL 散度。例如,在訓練較小版本的 DeepSeek 時,研究人員采用了溫度標度(Temperature Scaling) 來調整教師模型的輸出,使其更適合學生模型學習。

教師模型的 softmax 輸出使用溫度參數 TT 進行調整:

當 T 增大時,softmax 輸出的概率分布變得更平滑,從而讓學生模型更容易學習教師模型的知識。

在 DeepSeek 的蒸餾過程中,常見的損失函數是加權組合:

其中:

  • 第一項是 KL 散度損失,使得學生模型的輸出接近教師模型。
  • 第二項是交叉熵損失,確保學生模型仍然學習真實標簽。
  • λ是一個超參數,控制兩者的平衡。

4. 代碼示例:用 Keras 進行知識蒸餾

下面我們用 TensorFlow/Keras 訓練一個簡單的學生模型,讓它學習一個教師模型的知識。

4.1 定義教師模型

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers# 構建一個簡單的教師模型
teacher_model = keras.Sequential([layers.Dense(128, activation="relu", input_shape=(784,)),layers.Dense(10, activation="softmax")
])

4.2 訓練教師模型

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = x_train.reshape(-1, 784) / 255.0, x_test.reshape(-1, 784) / 255.0
y_train, y_test = keras.utils.to_categorical(y_train, 10), keras.utils.to_categorical(y_test, 10)teacher_model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
teacher_model.fit(x_train, y_train, epochs=5, batch_size=32, validation_data=(x_test, y_test))

4.3 讓教師模型生成 soft labels

temperature = 5.0
def soft_targets(logits):return tf.nn.softmax(logits / temperature)y_teacher = soft_targets(teacher_model.predict(x_train))

4.4 訓練學生模型

student_model = keras.Sequential([layers.Dense(64, activation="relu", input_shape=(784,)),layers.Dense(10, activation="softmax")
])student_model.compile(optimizer="adam",loss=tf.keras.losses.KLDivergence(),  # 使用 KL 散度metrics=["accuracy"]
)student_model.fit(x_train, y_teacher, epochs=5, batch_size=32, validation_data=(x_test, y_test))

5. 真實應用場景

5.1 輕量級大模型

  • DistilBERT:使用 BERT 作為教師模型進行蒸餾,訓練更小的 Transformer。
  • TinyBERT:針對任務優化蒸餾,提高學生模型的表現。
  • DeepSeek-Chat 小模型:使用 KL 散度訓練高效版本,提高推理速度。

5.2 知識蒸餾的優勢

  • 可以訓練更小的模型,適用于移動端、嵌入式設備。
  • 學生模型比直接訓練的模型泛化性更強,能更好地模仿教師模型。
  • 結合 KL 散度 + 交叉熵 可以提升訓練效果。

結論

KL 散度損失是知識蒸餾的核心,它讓學生模型學習教師模型的概率分布,從而獲得更好的表現。DeepSeek 這樣的 LLM 在蒸餾過程中廣泛使用 KL 散度,使得較小模型也能高效推理。希望本文能幫助你理解 KL 散度在知識蒸餾中的應用!

其它

代碼示例一,

假設我們有兩個概率分布 p(真實分布)和 q(預測分布),我們使用 KLDivergence 計算它們之間的 KL 散度損失。

import tensorflow as tf
import numpy as np# 定義 KLDivergence 損失函數
kl_loss = tf.keras.losses.KLDivergence()# 真實分布 p (標簽)
p = np.array([0.1, 0.4, 0.5], dtype=np.float32)# 預測分布 q
q = np.array([0.2, 0.3, 0.5], dtype=np.float32)# 計算 KL 散度損失
loss_value = kl_loss(p, q)print(f'KL Divergence Loss: {loss_value.numpy()}')

代碼示例二,

一個完整的 Keras 代碼示例,展示了如何在分類任務中使用 KLDivLoss 作為損失函數。這個示例使用一個簡單的神經網絡對 手寫數字 MNIST 數據集 進行分類,并使用 KLDivLoss 計算真實分布和模型預測分布之間的散度。

import tensorflow as tf
from tensorflow import keras
from keras import layers
import numpy as np# 加載 MNIST 數據集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()# 歸一化數據到 [0,1] 之間
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0# 將標簽轉換為概率分布 (one-hot 編碼)
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)# 構建一個簡單的神經網絡模型
model = keras.Sequential([layers.Flatten(input_shape=(28, 28)),layers.Dense(128, activation="relu"),layers.Dense(10, activation="softmax")  # 輸出層用 softmax 歸一化
])# 編譯模型,使用 KLDivLoss 作為損失函數
model.compile(optimizer="adam",loss=tf.keras.losses.KLDivergence(),metrics=["accuracy"])# 訓練模型
model.fit(x_train, y_train, epochs=5, batch_size=32, validation_data=(x_test, y_test))# 評估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test Accuracy: {test_acc:.4f}")

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/67738.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/67738.shtml
英文地址,請注明出處:http://en.pswp.cn/web/67738.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Electron使用WebAassembly實現CRC-8 MAXIM校驗

Electron使用WebAssembly實現CRC-8 MAXIM校驗 將C/C語言代碼,經由WebAssembly編譯為庫函數,可以在JS語言環境進行調用。這里介紹在Electron工具環境使用WebAssembly調用CRC-8 MAXIM格式校驗的方式。 CRC-8 MAXIM校驗函數WebAssebly源文件 C語言實現CR…

Vue3.0實戰:大數據平臺可視化

文章目錄 創建vue3.0項目項目初始化項目分辨率響應式設置項目頂部信息條創建頁面主體創建全局引入echarts和axios后臺接口創建express銷售總量圖實現完整項目下載項目任何問題都可在評論區,或者直接私信即可。 創建vue3.0項目 創建項目: vue create vueecharts選擇第三項:…

vector容器(詳解)

本文最后是模擬實現全部講解,文章穿插有彩色字體,是我總結的技巧和關鍵 1.vector的介紹及使用 1.1 vector的介紹 https://cplusplus.com/reference/vector/vector/(vector的介紹) 了解 1. vector是表示可變大小數組的序列容器。…

Ubuntu 下 nginx-1.24.0 源碼分析 ngx_debug_init();

目錄 ngx_debug_init() 函數: NGX_LINUX 的定義: ngx_debug_init() 函數: ngx_debug_init() 函數定義在 src\os\unix 目錄下的 ngx_linux_config.h 中 #define ngx_debug_init() 也就是說這個環境下的 main 函數中的 ngx_debug_init() 這…

Airflow:深入理解Apache Airflow Task

Apache Airflow是一個開源工作流管理平臺,支持以編程方式編寫、調度和監控工作流。由于其靈活性、可擴展性和強大的社區支持,它已迅速成為編排復雜數據管道的首選工具。在這篇博文中,我們將深入研究Apache Airflow 中的任務概念,探…

開發環境搭建-4:WSL 配置 docker 運行環境

在 WSL 環境中構建:WSL2 (2.3.26.0) Oracle Linux 8.7 官方鏡像 基本概念說明 容器技術 利用 Linux 系統的 文件系統(UnionFS)、命名空間(namespace)、權限管理(cgroup),虛擬出一…

JavaScript 基礎 - 7

關于JS函數部分的學習和一個案例的練習 1 函數封裝 抽取相同部分代碼封裝 優點 提高代碼復用性:封裝好的函數可以在多個地方被重復調用,避免了重復編寫相同的代碼。例如,編寫一個計算兩個數之和的函數,在多個不同的計算場景中都…

詳解u3d之AssetBundle

一.AssetBundle的概念 “AssetBundle”可以指兩種不同但相關的東西。 1.1 AssetBundle指的是u3d在磁盤上生成的存放資源的目錄 目錄包含兩種類型文件(下文簡稱AB包): 一個序列化文件,其中包含分解為各個對象并寫入此單個文件的資源。資源文件&#x…

微信登錄模塊封裝

文章目錄 1.資質申請2.combinations-wx-login-starter1.目錄結構2.pom.xml 引入okhttp依賴3.WxLoginProperties.java 屬性配置4.WxLoginUtil.java 后端通過 code 獲取 access_token的工具類5.WxLoginAutoConfiguration.java 自動配置類6.spring.factories 激活自動配置類 3.com…

DeepSeek 介紹及對外國的影響

DeepSeek 簡介 DeepSeek(深度求索)是一家專注實現 AGI(人工通用智能)的中國科技公司,2023 年成立,總部位于杭州,在北京設有研發中心。與多數聚焦具體應用(如人臉識別、語音助手&…

MySQL數據庫(二)- SQL

目錄 ?編輯 一 DDL (一 數據庫操作 1 查詢-數據庫(所有/當前) 2 創建-數據庫 3 刪除-數據庫 4 使用-數據庫 (二 表操作 1 創建-表結構 2 查詢-所有表結構名稱 3 查詢-表結構內容 4 查詢-建表語句 5 添加-字段名數據類型 6 修改-字段數據類…

ARM嵌入式學習--第十天(UART)

--UART介紹 UART(Universal Asynchonous Receiver and Transmitter)通用異步接收器,是一種通用串行數據總線,用于異步通信。該總線雙向通信,可以實現全雙工傳輸和接收。在嵌入式設計中,UART用來與PC進行通信,包括與監控…

面試題-消失的數字-異或

消失的數字 數組nums包含從0到n的所有整數,但其中缺了一個。請編寫代碼找出那個缺失的整數。你有辦法在 O(n) 時間內完成嗎? 示例: 輸入:[3,0,1] 輸出:2 int missingNumber(int* nums, int numsSize) {}分析 本題對…

數據結構與算法之棧: LeetCode 1685. 有序數組中差絕對值之和 (Ts版)

有序數組中差絕對值之和 https://leetcode.cn/problems/sum-of-absolute-differences-in-a-sorted-array/description/ 描述 給你一個 非遞減 有序整數數組 nums 請你建立并返回一個整數數組 result,它跟 nums 長度相同,且result[i] 等于 nums[i] 與數…

筆試-排列組合

應用 一個長度為[1, 50]、元素都是字符串的非空數組,每個字符串的長度為[1, 30],代表非負整數,元素可以以“0”開頭。例如:[“13”, “045”,“09”,“56”]。 將所有字符串排列組合,拼起來組成…

Python3 OS模塊中的文件/目錄方法說明十七

一. 簡介 前面文章簡單學習了 Python3 中 OS模塊中的文件/目錄的部分函數。 本文繼續來學習 OS 模塊中文件、目錄的操作方法:os.walk() 方法、os.write()方法 二. Python3 OS模塊中的文件/目錄方法 1. os.walk() 方法 os.walk() 方法用于生成目錄樹中的文件名&a…

[Java]抽象類

1. 什么是抽象類? 1.1 定義: 抽象類是一個不能實例化的類,它是用來作為其他類的基類的。抽象類可以包含抽象方法和非抽象方法。抽象方法沒有方法體,子類必須重寫這些方法并提供具體的實現。抽象類可以有構造方法、成員變量、靜態…

css三角圖標

案例三角&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title><s…

跨越通信障礙:深入了解ZeroMQ的魅力

在復雜的分布式系統開發中&#xff0c;進程間通信就像一座橋梁&#xff0c;連接著各個獨立運行的進程&#xff0c;讓它們能夠協同工作。然而&#xff0c;傳統的通信方式往往伴隨著復雜的設置、高昂的性能開銷以及有限的靈活性&#xff0c;成為了開發者們前進道路上的 “絆腳石”…

深入解析 COUNT(DISTINCT) OVER(ORDER BY):原理、問題與高效替代方案

目錄 一、累計去重需求場景 二、COUNT(DISTINCT) OVER(ORDER BY) 語法解析 2.1 基礎語法 2.2 執行原理 三、三大核心問題分析