Pytorch為什么 nn.CrossEntropyLoss = LogSoftmax + nn.NLLLoss?

為什么 nn.CrossEntropyLoss = LogSoftmax + nn.NLLLoss

在使用 PyTorch 時,我們經常聽說 nn.CrossEntropyLossLogSoftmaxnn.NLLLoss 的組合。這句話聽起來簡單,但背后到底是怎么回事?為什么這兩個分開的功能加起來就等于一個完整的交叉熵損失?今天我們就從數學公式到代碼實現,徹底搞清楚它們的聯系。

1. 先認識三個主角

要理解這個等式,先得知道每個部分的定義和作用:

  • nn.CrossEntropyLoss:交叉熵損失,直接接受未歸一化的 logits,計算模型預測與真實標簽的差距,適用于多分類任務。
  • LogSoftmax:將 logits 轉為對數概率(log probabilities),輸出范圍是負值。
  • nn.NLLLoss:負對數似然損失,接受對數概率,計算正確類別的負對數值。

表面上看,nn.CrossEntropyLoss 是一個獨立的損失函數,而 LogSoftmaxnn.NLLLoss 是兩步操作。為什么說它們本質上是一回事呢?答案藏在數學公式和計算邏輯里。

2. 數學上的拆解

讓我們從交叉熵的定義開始,逐步推導。

(1) 交叉熵的數學形式

交叉熵(Cross-Entropy)衡量兩個概率分布的差異。在多分類任務中:

  • ( p p p ):真實分布,通常是 one-hot 編碼(比如 [0, 1, 0] 表示第 1 類)。
  • ( q q q ):預測分布,是模型輸出的概率(比如 [0.2, 0.5, 0.3])。

交叉熵公式為:

H ( p , q ) = ? ∑ c = 1 C p c log ? ( q c ) H(p, q) = -\sum_{c=1}^{C} p_c \log(q_c) H(p,q)=?c=1C?pc?log(qc?)

對于 one-hot 編碼,( p c p_c pc? ) 在正確類別上為 1,其他為 0,所以簡化為:

H ( p , q ) = ? log ? ( q correct ) H(p, q) = -\log(q_{\text{correct}}) H(p,q)=?log(qcorrect?)

其中 ( q correct q_{\text{correct}} qcorrect? ) 是正確類別對應的預測概率。對 ( N N N ) 個樣本取平均,損失為:

Loss = ? 1 N ∑ i = 1 N log ? ( q i , y i ) \text{Loss} = -\frac{1}{N} \sum_{i=1}^{N} \log(q_{i, y_i}) Loss=?N1?i=1N?log(qi,yi??)

這正是交叉熵損失的核心。

(2) 從 logits 到概率

神經網絡輸出的是原始分數(logits),比如 ( z = [ z 1 , z 2 , z 3 ] z = [z_1, z_2, z_3] z=[z1?,z2?,z3?] )。要得到概率 ( q q q ),需要用 Softmax:

q j = e z j ∑ k = 1 C e z k q_j = \frac{e^{z_j}}{\sum_{k=1}^{C} e^{z_k}} qj?=k=1C?ezk?ezj??

交叉熵損失變成:

Loss = ? 1 N ∑ i = 1 N log ? ( e z i , y i ∑ k = 1 C e z i , k ) \text{Loss} = -\frac{1}{N} \sum_{i=1}^{N} \log\left(\frac{e^{z_{i, y_i}}}{\sum_{k=1}^{C} e^{z_{i,k}}}\right) Loss=?N1?i=1N?log(k=1C?ezi,k?ezi,yi???)

這就是 nn.CrossEntropyLoss 的數學形式。

(3) 分解為兩步

現在我們把這個公式拆開:

  • 第一步:LogSoftmax
    計算對數概率:
    log ? ( q j ) = log ? ( e z j ∑ k = 1 C e z k ) = z j ? log ? ( ∑ k = 1 C e z k ) \log(q_j) = \log\left(\frac{e^{z_j}}{\sum_{k=1}^{C} e^{z_k}}\right) = z_j - \log\left(\sum_{k=1}^{C} e^{z_k}\right) log(qj?)=log(k=1C?ezk?ezj??)=zj??log(k=1C?ezk?)
    這正是 LogSoftmax 的定義。它把 logits ( z z z ) 轉為對數概率 ( log ? ( q ) \log(q) log(q) )。

  • 第二步:NLLLoss
    有了對數概率 ( log ? ( q ) \log(q) log(q) ),取出正確類別的值,取負號并平均:
    NLL = ? 1 N ∑ i = 1 N log ? ( q i , y i ) \text{NLL} = -\frac{1}{N} \sum_{i=1}^{N} \log(q_{i, y_i}) NLL=?N1?i=1N?log(qi,yi??)
    這就是 nn.NLLLoss 的公式。

組合起來

  • LogSoftmax 把 logits 轉為 ( log ? ( q ) \log(q) log(q) )。
  • nn.NLLLoss 對 ( log ? ( q ) \log(q) log(q) ) 取負號,計算損失。
  • 兩步合起來正好是 ( ? log ? ( q correct ) -\log(q_{\text{correct}}) ?log(qcorrect?) ),與交叉熵一致。
3. PyTorch 中的實現驗證

從數學上看,nn.CrossEntropyLoss 的確可以分解為 LogSoftmaxnn.NLLLoss。我們用代碼驗證一下:

import torch
import torch.nn as nn# 輸入數據
logits = torch.tensor([[1.0, 2.0, 0.5], [0.1, 0.5, 2.0]])  # [batch_size, num_classes]
target = torch.tensor([1, 2])  # 真實類別索引# 方法 1:直接用 nn.CrossEntropyLoss
ce_loss_fn = nn.CrossEntropyLoss()
ce_loss = ce_loss_fn(logits, target)
print("CrossEntropyLoss:", ce_loss.item())# 方法 2:LogSoftmax + nn.NLLLoss
log_softmax = nn.LogSoftmax(dim=1)
nll_loss_fn = nn.NLLLoss()
log_probs = log_softmax(logits)  # 計算對數概率
nll_loss = nll_loss_fn(log_probs, target)
print("LogSoftmax + NLLLoss:", nll_loss.item())

運行結果:兩個輸出的值完全相同(比如 0.75)。這證明 nn.CrossEntropyLoss 在內部就是先做 LogSoftmax,再做 nn.NLLLoss

4. 為什么 PyTorch 這么設計?

既然 nn.CrossEntropyLoss 等價于 LogSoftmax + nn.NLLLoss,為什么 PyTorch 提供了兩種方式?

  • 便利性
    nn.CrossEntropyLoss 是一個“一體式”工具,直接輸入 logits 就能用,適合大多數場景,省去手動搭配的麻煩。

  • 模塊化
    LogSoftmaxnn.NLLLoss 分開設計,給開發者更多靈活性:

    • 你可以在模型里加 LogSoftmax,只用 nn.NLLLoss 計算損失。
    • 可以單獨調試對數概率(比如打印 log_probs)。
    • 在某些自定義損失中,可能需要用到獨立的 LogSoftmax
  • 數值穩定性
    nn.CrossEntropyLoss 內部優化了計算,避免了分開操作時可能出現的溢出問題(比如 logits 很大時,Softmax 的分母溢出)。

5. 為什么不直接用 Softmax?

你可能好奇:為什么不用 Softmax + 對數 + 取負,而是用 LogSoftmax
答案是數值穩定性:

  • 單獨計算 Softmax(指數運算)可能導致溢出(比如 ( e 1000 e^{1000} e1000 ))。
  • LogSoftmax 把指數和對數合并為 ( z j ? log ? ( ∑ e z k ) z_j - \log(\sum e^{z_k}) zj??log(ezk?) ),計算更穩定。
6. 使用場景對比
  • nn.CrossEntropyLoss

    • 輸入:logits。
    • 場景:標準多分類任務(圖像分類、文本分類)。
    • 優點:簡單直接。
  • LogSoftmax + nn.NLLLoss

    • 輸入:logits 需手動轉為對數概率。
    • 場景:需要顯式控制 Softmax,或者模型已輸出對數概率。
    • 優點:靈活性高。
7. 小結:為什么等價?
  • 數學上:交叉熵 ( ? log ? ( q correct ) -\log(q_{\text{correct}}) ?log(qcorrect?) ) 可以拆成兩步:
    1. LogSoftmax:從 logits 到 ( log ? ( q ) \log(q) log(q) )。
    2. nn.NLLLoss:從 ( log ? ( q ) \log(q) log(q) ) 到 ( ? log ? ( q correct ) -\log(q_{\text{correct}}) ?log(qcorrect?) )。
  • 實現上nn.CrossEntropyLoss 把這兩步封裝成一個函數,結果一致。
  • 設計上:PyTorch 提供兩種方式,滿足不同需求。

所以,nn.CrossEntropyLoss = LogSoftmax + nn.NLLLoss 不是巧合,而是交叉熵計算的自然分解。理解這一點,能幫助你更靈活地使用 PyTorch 的損失函數。

8. 彩蛋:手動推導

想自己驗證?試試手動計算:

  • logits [1.0, 2.0, 0.5],目標是 1。
  • Softmax:[0.23, 0.63, 0.14]
  • LogSoftmax:[-1.47, -0.47, -1.97]
  • NLL:-(-0.47) = 0.47
  • 直接用 nn.CrossEntropyLoss,結果一樣!

希望這篇博客解開了你的疑惑!

后記

2025年2月28日18點51分于上海,在grok3 大模型輔助下完成。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/70968.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/70968.shtml
英文地址,請注明出處:http://en.pswp.cn/web/70968.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

rabbitmq 延時隊列

要使用 RabbitMQ Delayed Message Plugin 實現延時隊列,首先需要確保插件已安裝并啟用。以下是實現延時隊列的步驟和代碼示例。 1. 安裝 RabbitMQ Delayed Message Plugin 首先,確保你的 RabbitMQ 安裝了 rabbitmq-delayed-message-exchange 插件。你可…

在 Vue 單文件組件(SFC)中,標簽的顯式關閉與隱式關閉有著重要的區別

一、顯式關閉標簽 1、定義&#xff1a; 所有的 HTML 標簽都必須有一個對應的結束標簽。 自閉合標簽也必須使用 / 來關閉。 <template> <div> <p>這是一個段落</p> <img src"image.png"…

第四屆大數據、區塊鏈與經濟管理國際學術會議

重要信息 官網&#xff1a;www.icbbem.com 時間&#xff1a;2025年3月14-16日 地點&#xff1a;中國-武漢 &#xff08;線上召開&#xff09; 簡介 第四屆大數據、區塊鏈與經濟管理國際學術會議(ICBBEM 2025)&#xff0c;將于2025年3月14-16日在中國湖北省武漢市召開。…

每日十個計算機專有名詞 (7)

Metasploit 詞源&#xff1a;Meta&#xff08;超越&#xff0c;超出&#xff09; exploit&#xff08;漏洞利用&#xff09; Metasploit 是一個安全測試框架&#xff0c;用來幫助安全專家&#xff08;也叫滲透測試人員&#xff09;發現和利用計算機系統中的漏洞。你可以把它想…

使用Docker Compose部署 MySQL8

MySQL 8 是一個功能強大的關系型數據庫管理系統,而 Docker 則是一個流行的容器化平臺。結合使用它們可以極大地簡化 MySQL 8 的部署過程,并且確保開發環境和生產環境的一致性。 安裝 Docker 和 Docker Compose 首先,確保你的機器上已經安裝了 Docker 和 Docker Compose。 …

mamba_ssm和causal-conv1d詳細安裝教程

1.前言 Mamba是近年來在深度學習領域出現的一種新型結構&#xff0c;特別是在處理長序列數據方面表現優異。在本文中&#xff0c;我將介紹如何在 Linux 系統上安裝并配置 mamba_ssm 虛擬環境。由于官方指定mamba_ssm適用于 PyTorch 版本高于 1.12 且 CUDA 版本大于 11.6 的環境…

c++中初始化列表的使用

在 C 中&#xff0c;初始化列表是在構造函數的定義中&#xff0c;用于對類的成員變量進行初始化的一種方式。它緊跟在構造函數的參數列表之后&#xff0c;使用冒號 : 分隔&#xff0c;各成員變量的初始化用逗號 , 分隔。下面詳細介紹初始化列表及其參數的含義。 基本語法 clas…

《Linux系統編程篇》System V信號量實現生產者與消費者問題(Linux 進程間通信(IPC))——基礎篇(拓展思維)

文章目錄 &#x1f4da; **生產者-消費者問題**&#x1f511; **問題分析**&#x1f6e0;? **詳細實現&#xff1a;生產者-消費者****步驟 1&#xff1a;定義信號量和緩沖區****步驟 2&#xff1a;創建信號量****步驟 3&#xff1a;生產者進程****步驟 4&#xff1a;消費者進程…

利用 Python 爬蟲進行跨境電商數據采集

1 引言2 代理IP的優勢3 獲取代理IP賬號4 爬取實戰案例---&#xff08;某電商網站爬取&#xff09;4.1 網站分析4.2 編寫代碼4.3 優化代碼 5 總結 1 引言 在數字化時代&#xff0c;數據作為核心資源蘊含重要價值&#xff0c;網絡爬蟲成為企業洞察市場趨勢、學術研究探索未知領域…

HONOR榮耀MagicBook 15 2021款 獨顯(BOD-WXX9,BDR-WFH9HN)原廠Win10系統

適用型號&#xff1a;【BOD-WXX9】 MagicBook 15 2021款 i7 獨顯 MX450 16GB512GB (BDR-WFE9HN) MagicBook 15 2021款 i5 獨顯 MX450 16GB512GB (BDR-WFH9HN) MagicBook 15 2021款 i5 集顯 16GB512GB (BDR-WFH9HN) 鏈接&#xff1a;https://pan.baidu.com/s/1S6L57ADS18fnJZ1…

c語言實現三子棋小游戲(涉及二維數組、函數、循環、常量、動態取地址等知識點)

使用C語言實現一個三子棋小游戲 涉及知識點&#xff1a;二維數組、自定義函數、自帶函數庫、循環、常量、動態取地址等等 一些細節點&#xff1a; 1、引入自定義頭文件&#xff0c;需要用""雙引號包裹文件名&#xff0c;目的是為了和官方頭文件的<>區分開。…

C語言數據類型及其使用 (帶示例)

目錄 1. 基本數據類型 整型 浮點型 字符型 2. 構造數據類型 數組 結構體 聯合體&#xff08;共用體&#xff09; 枚舉類型 3. 指針類型 4. 空類型 在 C 語言中&#xff0c;數據類型是非常重要的概念&#xff0c;它決定了數據在內存中的存儲方式、占用空間大小以及可…

Web自動化之Selenium添加網站Cookies實現免登錄

在使用Selenium進行Web自動化時&#xff0c;添加網站Cookies是實現免登錄的一種高效方法。通過模擬瀏覽器行為&#xff0c;我們可以將已登錄狀態的Cookies存儲起來&#xff0c;并在下次自動化測試或爬蟲任務中直接加載這些Cookies&#xff0c;從而跳過登錄步驟。 Cookies簡介 …

NAT 技術:網絡中的 “地址魔術師”

目錄 一、性能瓶頸&#xff1a;NAT 的 “阿喀琉斯之踵” &#xff08;一&#xff09;數據包處理延遲 &#xff08;二&#xff09;高并發下的性能損耗 二、應用兼容性&#xff1a;NAT 帶來的 “適配難題” &#xff08;一&#xff09;端到端通信的困境 &#xff08;二&…

php序列化與反序列化

文章目錄 基礎知識魔術方法&#xff1a;在序列化和反序列化過程中自動調用的方法什么是 __destruct() 方法&#xff1f;何時觸發 __destruct() 方法&#xff1f;用途&#xff1a;語法示例&#xff1a; 反序列化漏洞利用前提條件一些繞過策略繞過__wakeup函數繞過正則匹配繞過相…

docker 占用系統空間太大了,整體遷移到掛載的其他磁盤|【當前普通用戶使用docker時,無法指定鏡像、容器安裝位置【無法指定】】

文章目錄 前言【核心步驟皆為 大模型生成的方案】總結步驟應該是&#xff1a;詳細步驟如下1. **停止 Docker 服務**2. **備份原數據&#xff08;防止遷移失敗&#xff09;**3. **遷移數據到新磁盤**4. **修改 Docker 配置文件**5. **重啟 Docker 服務**6. **驗證容器和鏡像**7.…

設計后端返回給前端的返回體

目錄 1、為什么要設計返回體&#xff1f; 2、返回體包含哪些內容&#xff08;如何設計&#xff09;&#xff1f; 舉例 3、總結 1、為什么要設計返回體&#xff1f; 在設計后端返回給前端的返回體時&#xff0c;通常需要遵循一定的規范&#xff0c;以確保前后端交互的清晰性…

Springboot 自動化裝配的原理

Springboot 自動化裝配的原理 SpringBoot 主要作用為&#xff1a;起步依賴、自動裝配。而為了實現這種功能&#xff0c;SpringBoot 底層主要使用了 SpringBootApplication 注解。 首先&#xff0c;SpringBootApplication 是一個復合注解&#xff0c;它結合了 Configuration、…

基于vue框架的游戲博客網站設計iw282(程序+源碼+數據庫+調試部署+開發環境)帶論文文檔1萬字以上,文末可獲取,系統界面在最后面。

系統程序文件列表 項目功能&#xff1a;用戶,博客信息,資源共享,游戲視頻,游戲照片 開題報告內容 基于FlaskVue框架的游戲博客網站設計開題報告 一、項目背景與意義 隨著互聯網技術的飛速發展和游戲產業的不斷壯大&#xff0c;游戲玩家對游戲資訊、攻略、評測等內容的需求日…

算法-二叉樹篇13-路徑總和

路徑總和 力扣題目鏈接 題目描述 給你二叉樹的根節點 root 和一個表示目標和的整數 targetSum 。判斷該樹中是否存在 根節點到葉子節點 的路徑&#xff0c;這條路徑上所有節點值相加等于目標和 targetSum 。如果存在&#xff0c;返回 true &#xff1b;否則&#xff0c;返回…