交叉熵損失F.cross_entropy在分類模型中的應用

一、核心思想:通過概率分布懲罰錯誤


交叉熵損失的本質是:
比較模型預測的概率分布 vs 真實標簽的概率分布,懲罰兩者之間的差異。

例如:

  • 真實標簽:圖像 0 → 文本 0(獨熱編碼 [1, 0, 0, ...])
  • 模型預測:[0.1, 0.2, 0.3, 0.4, ...](預測文本 0 的概率僅 0.1)

此時損失會很大,因為預測分布與真實分布差異大。

二、分步解析交叉熵懲罰機制


1. 相似度矩陣 → 概率分布


假設 sim_i2t 是一個 [3, 6] 的矩陣(3 個圖像 × 6 個文本):

# 示例相似度矩陣(簡化版,僅展示對角線高相似度)
sim_i2t = torch.tensor([[5.0, 1.0, 1.0, 1.0, 1.0, 1.0],  # 圖像0 → 文本0是正樣本[1.0, 5.0, 1.0, 1.0, 1.0, 1.0],  # 圖像1 → 文本1是正樣本[1.0, 1.0, 5.0, 1.0, 1.0, 1.0]   # 圖像2 → 文本2是正樣本
])

通過 softmax 將相似度轉換為概率分布:

probs = F.softmax(sim_i2t, dim=1)  # 對每行做softmax
print(probs)

輸出結果:

tensor([[0.94, 0.02, 0.02, 0.02, 0.02, 0.02],  # 預測文本0概率最高(正確)[0.02, 0.94, 0.02, 0.02, 0.02, 0.02],  # 預測文本1概率最高(正確)[0.02, 0.02, 0.94, 0.02, 0.02, 0.02]   # 預測文本2概率最高(正確)
])

2. 真實標簽的概率分布


假設 targets = [0, 1, 2],轉換為獨熱編碼:

# 獨熱編碼(簡化版,僅展示核心邏輯)
one_hot = torch.zeros_like(probs)
for i, t in enumerate(targets):one_hot[i, t] = 1.0print(one_hot)

輸出結果

tensor([[1.0, 0.0, 0.0, 0.0, 0.0, 0.0],  # 圖像0的正樣本是文本0[0.0, 1.0, 0.0, 0.0, 0.0, 0.0],  # 圖像1的正樣本是文本1[0.0, 0.0, 1.0, 0.0, 0.0, 0.0]   # 圖像2的正樣本是文本2
])

3. 計算交叉熵損失

交叉熵損失公式:

對于上述例子:

  • 圖像 0 的損失:-log(0.94) ≈ 0.06
  • 圖像 1 的損失:-log(0.94) ≈ 0.06
  • 圖像 2 的損失:-log(0.94) ≈ 0.06

平均損失:(0.06 + 0.06 + 0.06) / 3 ≈ 0.06

實際函數內部:

# 1. 對預測值應用softmax,轉換為概率分布
probs = F.softmax(sim_i2t, dim=1)# 2. 對每個樣本,取出目標類別對應的概率
# 例如:
# - 第0個樣本的目標類別是0,取出probs[0, 0]
# - 第1個樣本的目標類別是1,取出probs[1, 1]
# - 第2個樣本的目標類別是2,取出probs[2, 2]
target_probs = probs[torch.arange(len(targets)), targets]# 3. 計算負對數似然
nll = -torch.log(target_probs)# 4. 求平均值得到最終損失
loss = nll.mean()

三、標簽平滑如何調整懲罰


標簽平滑(label_smoothing=0.1)會將:

  • 正樣本的概率從 1.0 調整為 0.9
  • 負樣本的概率從 0.0 調整為 0.1 / (類別數-1)

例如,對于圖像 0(正樣本是文本 0):

  • 原始標簽:[1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
  • 平滑后標簽:[0.9, 0.02, 0.02, 0.02, 0.02, 0.02]

此時損失計算變為:

實際函數內部:當使用label_smoothing=0.1時,函數內部會將目標概率分布從嚴格的獨熱編碼調整為平滑分布:

def cross_entropy_with_label_smoothing(logits, targets, smoothing=0.1):num_classes = logits.size(1)# 計算平滑后的目標分布# - 正樣本概率: 1.0 - smoothing + (smoothing / num_classes)# - 負樣本概率: smoothing / num_classessmooth_targets = torch.full_like(logits, smoothing / (num_classes - 1))smooth_targets[torch.arange(len(targets)), targets] = 1.0 - smoothing + (smoothing / num_classes)# 對預測值應用log_softmaxlog_probs = F.log_softmax(logits, dim=1)# 計算交叉熵(等價于F.kl_div(log_probs, smooth_targets))loss = (-smooth_targets * log_probs).sum(dim=1).mean()return loss

四、懲罰機制可視化


假設模型預測錯誤(圖像 0 預測文本 1 的概率最高):

# 錯誤預測的情況
bad_probs = torch.tensor([[0.1, 0.8, 0.05, 0.05, 0.0, 0.0],  # 錯誤:預測文本1概率最高[0.02, 0.94, 0.02, 0.02, 0.02, 0.0],  # 正確[0.02, 0.02, 0.94, 0.02, 0.02, 0.0]   # 正確
])# 計算交叉熵損失(無標簽平滑)
loss = -torch.log(bad_probs[0, 0])  # 圖像0的損失:-log(0.1) ≈ 2.3
print(f"錯誤預測的損失: {loss.item():.4f}")  # 損失遠大于正確預測的0.06

輸出結果:

錯誤預測的損失: 2.3026

五、總結


交叉熵損失的懲罰機制是:

  • 對正樣本:預測概率越低,懲罰越大(損失呈對數增長)
  • 對負樣本:預測概率越高,懲罰越大
  • 標簽平滑:減輕對極端預測的懲罰,防止過擬合

通過這種方式,模型被強制學習到:

  • 正樣本對的相似度要盡可能高
  • 負樣本對的相似度要盡可能低

這就是對比學習中 “拉近正樣本、推遠負樣本” 的核心實現方式!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/89511.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/89511.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/89511.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

測試學習之——Pytest Day3

引言Pytest 作為 Python 中最受歡迎的測試框架之一,以其簡潔的語法、強大的功能和豐富的插件生態系統,極大地提升了自動化測試的效率和可維護性。在本文中,我們將深入探討 Pytest 的兩大核心特性:Fixture 和插件管理,幫…

控制Vue對話框顯示隱藏

正確做法 — 使用 Vue 數據驅動控制顯隱你不需要手動設置 display: block&#xff0c;因為 Element Plus 的 <el-dialog> 是基于 v-model 或 :visible.sync 控制的。&#x1f527; 修改模板部分&#xff1a;將原來的&#xff1a;<el-dialog title"報文詳情"…

直播帶貨與開源AI智能名片鏈動2+1模式S2B2C商城小程序:重塑電商營銷新格局

摘要&#xff1a;本文聚焦于直播帶貨對互聯網供需關系的深刻影響&#xff0c;分析其如何改變傳統電商營銷模式&#xff0c;實現從“人找貨”到“貨找人”的轉變。同時&#xff0c;引入開源AI智能名片鏈動21模式S2B2C商城小程序這一創新概念&#xff0c;探討其在直播帶貨背景下的…

Jmeter 性能測試響應時間過長怎么辦?

當 JMeter 性能測試中出現 響應時間過長 的問題時&#xff0c;需要從 測試腳本、服務器、網絡、JMeter配置 等多方面排查和優化。以下是詳細的解決步驟和思路&#xff1a; B站最新性能進階&#xff0c;學會這些jmeter性能測試技能&#xff0c;更助于正確設計、執行和分析性能測…

COZE官方文檔基礎知識解讀第三期 —— prompt(提示詞)

COZE官方文檔基礎知識解讀第三期 —— prompt&#xff08;提示詞&#xff09; 對于初步接觸PE&#xff08;prompt engineering&#xff09; 的小伙伴們&#xff0c;你們可以去火山方舟提供的prompt工具&#xff0c;用工具&#xff08;其余的prompt網站https://www.promptinggu…

代碼隨想錄算法訓練營第三十二天|動態規劃理論基礎、LeetCode 509. 斐波那契數、70. 爬樓梯、746. 使用最小花費爬樓梯

目錄 LeetCode 509. 斐波那契數 70. 爬樓梯 746. 使用最小花費爬樓梯 感想 文檔講解&#xff1a;代碼隨想錄 動態規劃&#xff0c;英文&#xff1a;Dynamic Programming&#xff0c;簡稱DP&#xff0c;如果某一問題有很多重疊子問題&#xff0c;使用動態規劃是最有效的。 …

SpringMVC3

一、JSON 與參數傳遞1.1JSON 是什么- JSON 是字符串&#xff1a;比如 {"name":"zhangsan","password":"123456","age":15} 就是一個 JSON 字符串&#xff0c;它用來在前后端、服務間傳遞數據。- JSON 庫&#xff1a;Fastj…

查看.bin二進制文件的方式(HxD十六進制編輯器的安裝)

文章目錄Windows 系統上安裝 HxD 十六進制編輯器的步驟。**HxD 是一款免費、輕量級的工具&#xff0c;適合查看和編輯 .bin 等二進制文件。****PS:實際安裝過程中會發現找不到Windows11的版本&#xff0c;安裝windows10的即可&#xff0c;并且沒有區別setup版和portable版**安裝…

Linux系統性能優化與監控

系統性能優化與監控是保障 Linux 服務器穩定運行的核心技術&#xff0c;涉及 ??CPU、內存、磁盤 I/O、網絡、進程?? 等多維度的指標分析、問題定位與優化策略。以下從??監控工具與指標??、??常見問題診斷??、??優化方法??三個層面詳細講解&#xff0c;并結合?…

如何在 React + TypeScript 中實現 JSON 格式化功能

如何在 React TypeScript 中實現 JSON 格式化功能 作為前端開發者&#xff0c;我們經常需要處理 JSON 數據。無論是 API 調試、配置文件編輯還是數據轉換&#xff0c;能夠格式化 JSON 是一項基本但非常有用的技能。本文將詳細介紹如何在 React 和 TypeScript 環境中實現 JSON…

Mac連接服務器Docker容器全攻略

蘋果電腦( macOS 系統 )連接服務器、配置容器,整體思路和 Linux 終端操作更貼近,以下結合 macOS 特點,詳細分步說明,以 Docker 容器 + 常見 Linux 服務器( 如 CentOS、Ubuntu )為例: 一、連接服務器(SSH 方式, macOS 終端原生支持 ) 1. 準備信息 找運維或云平臺…

【字節跳動】數據挖掘面試題0019:帶貨直播間推薦:現在有一個帶貨的直播間,怎么把它精準地推送給有需要的用戶

文章大綱 帶貨直播間推薦系統:原理、算法與實踐 一、推薦系統在帶貨直播中的重要性 二、數據收集與處理 1. 用戶數據 2. 直播間數據 3. 用戶行為數據 4. 數據處理與特征工程 三、推薦算法實現 1. 基于內容的推薦 2. 基于協同過濾的推薦 3. 基于知識圖譜的推薦 4. 混合推薦算法…

Windows10筆記本電腦開啟BIOS

文章目錄什么是BIOS一、方案一&#xff1a;快捷鍵進入二、方案二&#xff08;推薦&#xff09;各品牌快捷鍵大全什么是BIOS BIOS 全拼為 BasicInputOutputSystem, 即基本輸入/輸出系統,是計算機中非常基礎而且重要的程序。把這一段程序存放在一個不需要電源的記憶體(芯片)中,就…

NFS、iSCSI 和lnmp部署操作

目錄 &#xff08;一&#xff09;基礎配置 1.NFS服務安裝 2.修改配置文件 3.重載配置文件 4.查看共享目錄 5.客戶端掛載 6.更換共享目錄 7.基礎實驗 &#xff08;二&#xff09;布置lnmp平臺 1.php 安裝軟件 檢測 2.連接MySQL 測試 3.軟件實施 軟件安裝配置 &…

Redis深度解析:從緩存原理到高并發實戰

第一部分&#xff1a;Redis核心概念與架構設計1.1 Redis本質解析Redis&#xff08;Remote Dictionary Server&#xff09;作為開源的內存數據結構存儲系統&#xff0c;其核心價值在于&#xff1a;內存優先架構&#xff1a;數據主要存儲在內存中&#xff0c;讀寫性能達到10萬 QP…

【NLP輿情分析】基于python微博輿情分析可視化系統(flask+pandas+echarts) 視頻教程 - 微博類別信息爬取

大家好&#xff0c;我是java1234_小鋒老師&#xff0c;最近寫了一套【NLP輿情分析】基于python微博輿情分析可視化系統(flaskpandasecharts)視頻教程&#xff0c;持續更新中&#xff0c;計劃月底更新完&#xff0c;感謝支持。今天講解架構搭建 視頻在線地址&#xff1a; 2026…

GD32/STM32嵌入CMSIS-DSP的庫(基于Keil)

當你要用到三角函數、開方、矩陣運算等復雜的數學運算時&#xff0c;可以選擇用C庫的math.h里面的函數&#xff0c;如果要求速度快的話就得用CMSIS-DSP庫里面的函數了&#xff0c;因為CMSIS-DSP庫充分運用了CM4內核的浮點運算單元&#xff08;若有&#xff09;和DSP相關的指令&…

頁面登錄阻止瀏覽器提醒是否保存密碼

一、原因 使用input的type"password"類型&#xff0c;瀏覽器會提醒是否記住密碼。 二、解決 取消type"password" 三、實現輸入密碼*代替 通過input輸入框&#xff0c;監聽輸入值&#xff0c;進行替換成*符號&#xff0c;避免使用input的type"password…

【iOS】dyld加載流程——應用程序的加載

目錄 前言 編譯過程與動靜態庫 編譯過程 動靜態庫 dyld &#x1f4cc; 什么是 dyld&#xff1f; dyld_shared_cache: dyld加載流程 _dyld_start dyldbootstrap::start dyld::main() 配置環境變量 共享緩存 主程序的初始化 插入動態庫 link主程序 link動態庫 弱…

從零開始,手把手教你本地部署Stable Diffusion AI繪畫(Win最新版)

本號之前有發過一篇win平臺的教程&#xff0c;由于是去年10月發布的&#xff0c;而Al繪畫技術發展很快&#xff0c;那篇教程已經有些不適用了&#xff0c;有些同學執行到第二步就出錯了。 應廣大同學的期望&#xff0c;我更新一版新版詳細教程。 一、前言 1.為什么要本地部署…