【NLP 50、損失函數 KL散度】

目錄

一、定義與公式

1.核心定義

2.數學公式

3.KL散度與交叉熵的關系

二、使用場景

1.生成模型與變分推斷

2.知識蒸餾

3.模型評估與優化

4.信息論與編碼優化

三、原理與特性

1.信息論視角

?2.優化目標

3.?局限性

四、代碼示例

代碼運行流程

核心代碼解析


抵達夢想靠的不是狂熱的想象,而是謙卑的務實,甚至你自己都看不起的可憐的隱忍

????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????—— 25.3.27

一、定義與公式

1.核心定義

????????KL散度(相對熵)是衡量兩個概率分布?P?和?Q?之間差異的非對稱性指標。它量化了當用分布?Q?近似真實分布?P?時的信息損失

非對稱性:,即P和Q的順序不能交換

非負性:,當且僅當P = Q時取等號

2.數學公式

離散形式:

連續形式:

其中,P是真實分布,Q是近似分布

3.KL散度與交叉熵的關系

KL散度可以分解為交叉熵H(P,Q與P的熵H(P):

交叉熵常用于分類任務,而KL散度更關注分布間的信息差異


二、使用場景

1.生成模型與變分推斷

變分自編碼器(VAE)?:通過最小化,使編碼器輸出的隱變量分布Q(z|x)逼近先驗分布P(z)

生成對抗網絡(GAN)?:輔助衡量生成分布與真實分布的差異

2.知識蒸餾

????????將復雜教師模型的輸出概率(軟標簽)作為監督信號,指導學生模型學習,損失函數中常包含KL散度項

3.模型評估與優化

?多模態分布對齊:在推薦系統中對齊用戶行為分布與模型預測分布?

異常檢測:通過KL散度衡量測試數據分布與正常數據分布的偏離程度

4.信息論與編碼優化

最小化編碼長度:KL散度表示用?Q?編碼?P?時所需的額外比特數


三、原理與特性

1.信息論視角

?信息增益:KL散度表示從?Q?中獲取?P?的信息時需要增加的“驚訝度”(Surprisal)。

?凸性:KL散度是凸函數,可通過梯度下降法優化。

?2.優化目標

?前向KL散度?DKL?(P∥Q):要求?Q?覆蓋?P?的主要模式,避免?Q?的“零概率陷阱”(即?Q(x)=0?但?P(x)>0?會導致無窮大)?反向KL散度?DKL?(Q∥P):鼓勵?Q?聚焦于?P?的單一主峰,適用于稀疏分布近似。

3.?局限性

?非對稱性:需根據任務選擇方向(如VAE使用前向KL,部分GAN變體使用反向KL)

數值穩定性:需避免?Q(x)=0?或極端概率值,可通過平滑或溫度參數(Temperature Scaling)調整。


四、代碼示例

代碼運行流程

KL散度計算流程
├── 1. 輸入預處理
│   ├── a. 獲取學生/教師模型原始輸出
│   │   ├─ student_logits: 形狀(batch=32, classes=10)
│   │   └─ teacher_logits: 同左[1,3](@ref)
│   └── b. 溫度參數初始化
│       └─ temperature=5.0 (默認值)
├── 2. 概率變換
│   ├── a. 溫度縮放
│   │   ├─ student_logits → student_logits / 5.0
│   │   └─ teacher_logits → teacher_logits / 5.0
│   ├── b. 概率歸一化
│   │   ├─ student_probs = log_softmax(...)  # 對數空間
│   │   └─ teacher_probs = softmax(...)      # 線性空間
├── 3. 損失計算
│   ├── a. 初始化KLDivLoss
│   │   └─ reduction='batchmean' (符合數學期望)
│   ├── b. 執行KL散度計算
│   │   └─ KL(student_probs || teacher_probs)
│   └── c. 梯度補償
│       └─ 乘以temperature2=25 恢復梯度幅值
└── 4. 結果輸出└── 打印損失值 (標量Tensor轉float)

student_logits:學生模型的原始輸出(未歸一化),形狀為?(batch_size, num_classes),表示每個樣本的預測得分

teacher_logits:教師模型的原始輸出(未歸一化),作為知識蒸餾的監督信號,形狀同student_logits

temperature:溫度縮放參數,軟化概率分布(值越大分布越平滑,值越小越接近原始分布)

student_probs:學生模型經溫度縮放后的對數概率

teacher_probs:教師模型經溫度縮放后的概率

loss:?KL散度損失的計算結果,表示學生模型輸出分布與教師模型輸出分布之間的差異程度。該值是一個標量(Scalar),用于指導反向傳播優化學生模型的參數

batch_size:表示 ?單次輸入模型的樣本數量,即一次前向傳播和反向傳播處理32個樣本。

nums_classes:?表示 ?分類任務的類別總數,即模型需區分的不同標簽種類數。

F.log_softmax():將輸入張量通過Softmax函數歸一化為概率分布后,再對每個元素取自然對數,常用于分類任務的損失計算(如交叉熵損失)。

參數名類型說明默認值
?**input**Tensor輸入張量必填
?**dim**int指定歸一化的維度(如dim=1表示按行計算)必填

F.softmax():將輸入張量通過指數函數歸一化為概率分布,輸出值范圍為(0,1)且和為1。

參數名類型說明默認值
?**input**Tensor輸入張量必填
?**dim**int歸一化維度(如dim=0按列歸一化)必填

nn.KLDivLoss():計算兩個概率分布之間的Kullback-Leibler散度(KL散度),用于衡量分布差異。

參數名類型說明可選值默認值
?**reduction**str損失聚合方式'none',?'mean',?'sum',?'batchmean''mean'

torch.randn():生成服從標準正態分布(均值為0,標準差為1)的隨機數張量,常用于初始化權重或生成噪聲數據。

參數名類型說明默認值
?***size**int或tuple張量形狀(如(3,4)生成3行4列矩陣)必填
?**dtype**torch.dtype數據類型(如torch.float32None(自動推斷)
?**device**torch.device設備(如'cuda'CPU
?**requires_grad**bool是否需要梯度跟蹤False

item():PyTorch中torch.Tensor類的方法,用于從單元素張量中提取Python標量值(如intfloat等)

核心代碼解析

loss = nn.KLDivLoss(reduction='batchmean')(student_probs, teacher_probs) * (temperature ?** 2)

nn.KLDivLoss(reduction='batchmean')計算學生模型輸出 (student_probs) 與教師模型輸出 (teacher_probs) 之間的 ?KL散度,衡量兩者的概率分布差異

?????????參數?reduction='batchmean'將每個樣本的KL散度求和后除以批量大小 (batch_size),確保損失值符合KL散度的數學定義

???mean對所有元素取平均(總和除以元素總數)。

???sum直接求和。

???none保留每個樣本的獨立損失值。

(student_probs, teacher_probs):輸入參數student_probs 和 teacher_probs

* (temperature ?** 2):溫度縮放與梯度補償

? ? ? ? 溫度的作用:軟化概率分布:高溫值會使教師模型的概率分布更平滑,避免過度關注高置信度類別

?????????為何乘以?temperature2①?梯度補償:溫度縮放會縮小梯度的幅值,乘以?temperature2?可恢復原始梯度量級,確保優化方向正確? ②?數學推導:KL散度計算中,溫度參數會引入縮放因子?T1?,反向傳播時梯度需乘以?T2?以抵消縮放效應。

import torch
import torch.nn as nn
import torch.nn.functional as F# 定義KL散度損失函數(帶溫度參數)
def kl_div_loss_with_temperature(student_logits, teacher_logits, temperature=5.0):# 對logits應用溫度縮放student_probs = F.log_softmax(student_logits / temperature, dim=-1)teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)# 計算KL散度loss = nn.KLDivLoss(reduction='batchmean')(student_probs, teacher_probs) * (temperature ?** 2)return loss# 模擬輸入數據
batch_size, num_classes = 32, 10
student_logits = torch.randn(batch_size, num_classes)  # 學生模型輸出(未歸一化)
teacher_logits = torch.randn(batch_size, num_classes)  # 教師模型輸出(未歸一化)# 計算損失
loss = kl_div_loss_with_temperature(student_logits, teacher_logits)
print(f"KL散度損失: {loss.item()}")

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/73742.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/73742.shtml
英文地址,請注明出處:http://en.pswp.cn/web/73742.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

使用QT畫帶有透明效果的圖

分辨率&#xff1a;24X24 最大圓 代碼: #include <QApplication> #include <QImage> #include <QPainter>int main(int argc, char *argv[]) {QImage image(QSize(24,24),QImage::Format_ARGB32);image.fill(QColor(0,0,0,0));QPainter paint(&image);…

【Unity網絡編程知識】使用Socket實現簡單TCP通訊

1、Socket的常用屬性和方法 創建Socket TCP流套接字 Socket socketTcp new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); 1.1 常用屬性 1&#xff09;套接字的連接狀態 socketTcp.Connected 2&#xff09;獲取套接字的類型 socketTcp.So…

青少年編程與數學 02-013 初中數學知識點 02課題、概要

青少年編程與數學 02-013 初中數學知識點 02課題、概要 一、數與代數二、圖形與幾何三、統計與概率四、綜合與實踐五、課程理念與目標 根據2022年版義務教育數學課程標準&#xff0c;初中數學知識點可以總結為以下四大領域。 一、數與代數 數與式 有理數與實數&#xff1a;理解…

深入探索 libarchive

深入探索 libarchive&#xff1a;跨平臺歸檔處理的終極解決方案 一、背景與歷史沿革 1.1 歸檔處理的演進之路 從1979年tar格式的誕生到現代云存儲時代&#xff0c;歸檔技術經歷了四個關鍵階段&#xff1a; Unix時代&#xff1a;tar/cpio主導系統備份互聯網黎明期&#xff1…

2025最新“科研創新與智能化轉型“暨AI智能體開發與大語言模型的本地化部署、優化技術實踐

第一章、智能體(Agent)入門 1、智能體&#xff08;Agent&#xff09;概述&#xff08;什么是智能體&#xff1f;智能體的類型和應用場景、典型的智能體應用&#xff0c;如&#xff1a;Google Data Science Agent等&#xff09; 2、智能體&#xff08;Agent&#xff09;與大語…

Yolo_v8的安裝測試

前言 如何安裝Python版本的Yolo&#xff0c;有一段時間不用了&#xff0c;Yolo的版本也在不斷地發展&#xff0c;所以重新安裝了運行了一下&#xff0c;記錄了下來&#xff0c;供參考。 一、搭建環境 1.1、創建Pycharm工程 首先創建好一個空白的工程&#xff0c;如下圖&…

時尚界正在試圖用AI,創造更多沖擊力

數字藝術正以深度融合的方式&#xff0c;在時尚、游戲、影視等行業實現跨界合作&#xff0c;催生了多樣化的商業模式&#xff0c;為創作者和品牌帶來更多機會&#xff0c;數字藝術更是突破了傳統藝術的限制&#xff0c;以趣味觸達用戶&#xff0c;尤其吸引了年輕一代的消費群體…

藍橋杯省模擬賽 01串個數

問題描述 請問有多少個長度為 24 的 01 串&#xff0c;滿足任意 5 個連續的位置中不超過 3 個位置的值為 1。 所有長度為24的01串組合有2*24種 思路&#xff1a;遍歷所有長度為24的01串組合&#xff0c;選擇出符合題意的 #include<iostream> #include<cmath> us…

【軟考備考】系統架構設計論文完整范文示例

本文由AI輔助創造 題目:基于微服務與云原生的智慧政務平臺架構設計與實踐 摘要(約300字) 本文以某省級智慧政務平臺建設項目為背景,針對傳統政務系統存在的"信息孤島"、擴展性差、維護成本高等問題,提出了一套基于微服務與云原生技術的解決方案。通過領域驅動…

數據庫原理及應用mysql版陳業斌實驗二

&#x1f3dd;?專欄&#xff1a;Mysql_貓咪-9527的博客-CSDN博客 &#x1f305;主頁&#xff1a;貓咪-9527-CSDN博客 “欲窮千里目&#xff0c;更上一層樓。會當凌絕頂&#xff0c;一覽眾山小。” 目錄 實驗二單表查詢 1.實驗數據如下 student 表&#xff08;學生表&#…

SDL —— 將sdl渲染畫面嵌入Qt窗口顯示(附:源碼)

?? SDL/SDL2 相關技術、疑難雜癥文章合集(掌握后可自封大俠 ?_?)(記得收藏,持續更新中…) 效果 使用QWidget加載了SDL的窗口,渲染器使用硬件加速跑GPU的。支持Qt窗口縮放或顯示隱藏均不影響SDL的圖像刷新。 ? 操作步驟 1、在創建C++空工程時加入SDL,引入頭文件時需…

C語言之鏈表增刪查改

1.知識百科 鏈表&#xff08;Linked List&#xff09;是計算機科學中一種基礎的數據結構&#xff0c;通過節點&#xff08;Node&#xff09;的鏈式連接來存儲數據。每個節點包含兩部分&#xff1a;存儲數據的元素和指向下一個節點的指針&#xff08;單鏈表&#xff09;或前后兩…

Windows環境下AnythingLLM安裝與Ollama+DeepSeek集成指南

前面已經完成了Ollama的安裝并下載了deepseek大模型包&#xff0c;下面介紹如何與anythingLLM 集成 Windows環境下AnythingLLM安裝與OllamaDeepSeek集成指南 一、安裝準備 1. 硬件要求 如上文說明 2. 前置條件 已安裝Ollama并下載DeepSeek模型&#xff08;如deepseek-r1:…

當貝AI知識庫評測 AI如何讓知識檢索快人一步

近日,國內領先的人工智能服務商當貝AI正式推出“個人知識庫”功能,這一創新性工具迅速引發行業關注。在信息爆炸的時代,如何高效管理個人知識資產、快速獲取精準答案成為用戶的核心需求。當貝AI通過將“閉卷考試”變為“開卷考試”的獨特設計,為用戶打造了一個高度個性化的智能…

HarmonyOS NEXT——【鴻蒙原生應用加載Web頁面】

鴻蒙客戶端加載Web頁面&#xff1a; 在鴻蒙原生應用中&#xff0c;我們需要使用前端頁面做混合開發&#xff0c;方法之一是使用Web組件直接加載前端頁面&#xff0c;其中WebView提供了一系列相關的方法適配鴻蒙原生與web之間的使用。 效果 web頁面展示&#xff1a; Column()…

嵌入式開發場景中Shell腳本執行方式的對比

?Shell腳本執行方式對比表? ?執行方式??命令示例??是否需要執行權限??是否啟動子Shell??環境變量影響范圍??適用場景??嵌入式開發中的典型應用??直接執行腳本?./script.sh是是子Shell內有效獨立運行的腳本&#xff0c;需固定環境自動化構建腳本&#xff08;…

MES系統需要采集的數據及如何采集

?數據采集在企業信息化建設中占據著舉足輕重的地位&#xff0c;是實現物料跟蹤、生產計劃制定、產品歷史記錄維護以及其他生產管理活動的基石。數據的準確性和實時性直接關系到企業信息化能否成功落地&#xff0c;是企業邁向高效生產的關鍵因素。 數據收集對于MES制造執行系統…

閉環管理:借助數字化管理平臺實現客戶反饋的價值升級

在競爭激烈的市場環境中&#xff0c;客戶反饋已成為企業優化服務、提升競爭力的核心資源。如何高效處理客戶反饋&#xff0c;將其轉化為企業持續改進的動力&#xff0c;是每個企業面臨的重要課題。作為服務管理數字化轉型服務商&#xff0c;瑞云服務云為大中型企業提供了一套完…

C++Primer學習(13.6 對象移動)

13.6 對象移動 新標準的一個最主要的特性是可以移動而非拷貝對象的能力。如我們在13.1.1節(第440頁)中所見&#xff0c;很多情況下都會發生對象拷貝。在其中某些情況下&#xff0c;對象拷貝后就立即被銷毀了。在這些情況下&#xff0c;移動而非拷貝對象會大幅度提升性能。 如我…

Uni-app頁面信息與元素影響解析

獲取窗口信息uni.getWindowInfo {pixelRatio: 3safeArea:{bottom: 778height: 731left: 0right: 375top: 47width: 375}safeAreaInsets: {top: 47, left: 0, right: 0, bottom: 34},screenHeight: 812,screenTop: 0,screenWidth: 375,statusBarHeight: 47,windowBottom: 0,win…