為什么 nn.CrossEntropyLoss
= LogSoftmax
+ nn.NLLLoss
?
在使用 PyTorch 時,我們經常聽說 nn.CrossEntropyLoss
是 LogSoftmax
和 nn.NLLLoss
的組合。這句話聽起來簡單,但背后到底是怎么回事?為什么這兩個分開的功能加起來就等于一個完整的交叉熵損失?今天我們就從數學公式到代碼實現,徹底搞清楚它們的聯系。
1. 先認識三個主角
要理解這個等式,先得知道每個部分的定義和作用:
nn.CrossEntropyLoss
:交叉熵損失,直接接受未歸一化的 logits,計算模型預測與真實標簽的差距,適用于多分類任務。LogSoftmax
:將 logits 轉為對數概率(log probabilities),輸出范圍是負值。nn.NLLLoss
:負對數似然損失,接受對數概率,計算正確類別的負對數值。
表面上看,nn.CrossEntropyLoss
是一個獨立的損失函數,而 LogSoftmax
和 nn.NLLLoss
是兩步操作。為什么說它們本質上是一回事呢?答案藏在數學公式和計算邏輯里。
2. 數學上的拆解
讓我們從交叉熵的定義開始,逐步推導。
(1) 交叉熵的數學形式
交叉熵(Cross-Entropy)衡量兩個概率分布的差異。在多分類任務中:
- ( p p p ):真實分布,通常是 one-hot 編碼(比如
[0, 1, 0]
表示第 1 類)。 - ( q q q ):預測分布,是模型輸出的概率(比如
[0.2, 0.5, 0.3]
)。
交叉熵公式為:
H ( p , q ) = ? ∑ c = 1 C p c log ? ( q c ) H(p, q) = -\sum_{c=1}^{C} p_c \log(q_c) H(p,q)=?c=1∑C?pc?log(qc?)
對于 one-hot 編碼,( p c p_c pc? ) 在正確類別上為 1,其他為 0,所以簡化為:
H ( p , q ) = ? log ? ( q correct ) H(p, q) = -\log(q_{\text{correct}}) H(p,q)=?log(qcorrect?)
其中 ( q correct q_{\text{correct}} qcorrect? ) 是正確類別對應的預測概率。對 ( N N N ) 個樣本取平均,損失為:
Loss = ? 1 N ∑ i = 1 N log ? ( q i , y i ) \text{Loss} = -\frac{1}{N} \sum_{i=1}^{N} \log(q_{i, y_i}) Loss=?N1?i=1∑N?log(qi,yi??)
這正是交叉熵損失的核心。
(2) 從 logits 到概率
神經網絡輸出的是原始分數(logits),比如 ( z = [ z 1 , z 2 , z 3 ] z = [z_1, z_2, z_3] z=[z1?,z2?,z3?] )。要得到概率 ( q q q ),需要用 Softmax:
q j = e z j ∑ k = 1 C e z k q_j = \frac{e^{z_j}}{\sum_{k=1}^{C} e^{z_k}} qj?=∑k=1C?ezk?ezj??
交叉熵損失變成:
Loss = ? 1 N ∑ i = 1 N log ? ( e z i , y i ∑ k = 1 C e z i , k ) \text{Loss} = -\frac{1}{N} \sum_{i=1}^{N} \log\left(\frac{e^{z_{i, y_i}}}{\sum_{k=1}^{C} e^{z_{i,k}}}\right) Loss=?N1?i=1∑N?log(∑k=1C?ezi,k?ezi,yi???)
這就是 nn.CrossEntropyLoss
的數學形式。
(3) 分解為兩步
現在我們把這個公式拆開:
-
第一步:LogSoftmax
計算對數概率:
log ? ( q j ) = log ? ( e z j ∑ k = 1 C e z k ) = z j ? log ? ( ∑ k = 1 C e z k ) \log(q_j) = \log\left(\frac{e^{z_j}}{\sum_{k=1}^{C} e^{z_k}}\right) = z_j - \log\left(\sum_{k=1}^{C} e^{z_k}\right) log(qj?)=log(∑k=1C?ezk?ezj??)=zj??log(k=1∑C?ezk?)
這正是LogSoftmax
的定義。它把 logits ( z z z ) 轉為對數概率 ( log ? ( q ) \log(q) log(q) )。 -
第二步:NLLLoss
有了對數概率 ( log ? ( q ) \log(q) log(q) ),取出正確類別的值,取負號并平均:
NLL = ? 1 N ∑ i = 1 N log ? ( q i , y i ) \text{NLL} = -\frac{1}{N} \sum_{i=1}^{N} \log(q_{i, y_i}) NLL=?N1?i=1∑N?log(qi,yi??)
這就是nn.NLLLoss
的公式。
組合起來:
LogSoftmax
把 logits 轉為 ( log ? ( q ) \log(q) log(q) )。nn.NLLLoss
對 ( log ? ( q ) \log(q) log(q) ) 取負號,計算損失。- 兩步合起來正好是 ( ? log ? ( q correct ) -\log(q_{\text{correct}}) ?log(qcorrect?) ),與交叉熵一致。
3. PyTorch 中的實現驗證
從數學上看,nn.CrossEntropyLoss
的確可以分解為 LogSoftmax
和 nn.NLLLoss
。我們用代碼驗證一下:
import torch
import torch.nn as nn# 輸入數據
logits = torch.tensor([[1.0, 2.0, 0.5], [0.1, 0.5, 2.0]]) # [batch_size, num_classes]
target = torch.tensor([1, 2]) # 真實類別索引# 方法 1:直接用 nn.CrossEntropyLoss
ce_loss_fn = nn.CrossEntropyLoss()
ce_loss = ce_loss_fn(logits, target)
print("CrossEntropyLoss:", ce_loss.item())# 方法 2:LogSoftmax + nn.NLLLoss
log_softmax = nn.LogSoftmax(dim=1)
nll_loss_fn = nn.NLLLoss()
log_probs = log_softmax(logits) # 計算對數概率
nll_loss = nll_loss_fn(log_probs, target)
print("LogSoftmax + NLLLoss:", nll_loss.item())
運行結果:兩個輸出的值完全相同(比如 0.75)。這證明 nn.CrossEntropyLoss
在內部就是先做 LogSoftmax
,再做 nn.NLLLoss
。
4. 為什么 PyTorch 這么設計?
既然 nn.CrossEntropyLoss
等價于 LogSoftmax
+ nn.NLLLoss
,為什么 PyTorch 提供了兩種方式?
-
便利性:
nn.CrossEntropyLoss
是一個“一體式”工具,直接輸入 logits 就能用,適合大多數場景,省去手動搭配的麻煩。 -
模塊化:
LogSoftmax
和nn.NLLLoss
分開設計,給開發者更多靈活性:- 你可以在模型里加
LogSoftmax
,只用nn.NLLLoss
計算損失。 - 可以單獨調試對數概率(比如打印
log_probs
)。 - 在某些自定義損失中,可能需要用到獨立的
LogSoftmax
。
- 你可以在模型里加
-
數值穩定性:
nn.CrossEntropyLoss
內部優化了計算,避免了分開操作時可能出現的溢出問題(比如 logits 很大時,Softmax 的分母溢出)。
5. 為什么不直接用 Softmax?
你可能好奇:為什么不用 Softmax
+ 對數 + 取負,而是用 LogSoftmax
?
答案是數值穩定性:
- 單獨計算
Softmax
(指數運算)可能導致溢出(比如 ( e 1000 e^{1000} e1000 ))。 LogSoftmax
把指數和對數合并為 ( z j ? log ? ( ∑ e z k ) z_j - \log(\sum e^{z_k}) zj??log(∑ezk?) ),計算更穩定。
6. 使用場景對比
-
nn.CrossEntropyLoss
:- 輸入:logits。
- 場景:標準多分類任務(圖像分類、文本分類)。
- 優點:簡單直接。
-
LogSoftmax
+nn.NLLLoss
:- 輸入:logits 需手動轉為對數概率。
- 場景:需要顯式控制 Softmax,或者模型已輸出對數概率。
- 優點:靈活性高。
7. 小結:為什么等價?
- 數學上:交叉熵 ( ? log ? ( q correct ) -\log(q_{\text{correct}}) ?log(qcorrect?) ) 可以拆成兩步:
LogSoftmax
:從 logits 到 ( log ? ( q ) \log(q) log(q) )。nn.NLLLoss
:從 ( log ? ( q ) \log(q) log(q) ) 到 ( ? log ? ( q correct ) -\log(q_{\text{correct}}) ?log(qcorrect?) )。
- 實現上:
nn.CrossEntropyLoss
把這兩步封裝成一個函數,結果一致。 - 設計上:PyTorch 提供兩種方式,滿足不同需求。
所以,nn.CrossEntropyLoss
= LogSoftmax
+ nn.NLLLoss
不是巧合,而是交叉熵計算的自然分解。理解這一點,能幫助你更靈活地使用 PyTorch 的損失函數。
8. 彩蛋:手動推導
想自己驗證?試試手動計算:
- logits
[1.0, 2.0, 0.5]
,目標是 1。 - Softmax:
[0.23, 0.63, 0.14]
。 - LogSoftmax:
[-1.47, -0.47, -1.97]
。 - NLL:
-(-0.47) = 0.47
。 - 直接用
nn.CrossEntropyLoss
,結果一樣!
希望這篇博客解開了你的疑惑!
后記
2025年2月28日18點51分于上海,在grok3 大模型輔助下完成。