原理
交叉熵損失函數是深度學習中分類問題常用的損失函數,特別適用于多分類問題。它通過度量預測分布與真實分布之間的差異,來衡量模型輸出的準確性。
交叉熵的數學公式
交叉熵的定義如下:
C r o s s E n t r o y L o s s = ? ∑ i = 1 N y i ? l o g ( y ^ i ) \begin{equation} CrossEntroyLoss = -\sum_{i=1}^{N}y_i \cdot log(\hat{y}_i) \end{equation} CrossEntroyLoss=?i=1∑N?yi??log(y^?i?)??
- N N N:類別數
- y i y_i yi?:真實的標簽(用 one-hot 編碼表示,只有目標類別對應的位置為 1,其他位置為 0)。
- y ^ i \hat{y}_i y^?i??:模型的預測概率,即
softmax
的輸出值。
對于單個樣本:
L o s s = ? l o g ( y ^ c ) \begin{equation} Loss = -log(\hat{y}_c) \end{equation} Loss=?log(y^?c?)??
其中 c c c是真實類別的索引。
解釋:
- 如果模型的預測概率 y ^ c \hat{y}_c y^?c?越接近1,則 ? l o g ( y ^ c ) -log(\hat{y}_c) ?log(y^?c?)越小,損失越大。
- 如果 y ^ c \hat{y}_c y^?c?越接近0,則 ? l o g ( y ^ c ) -log(\hat{y}_c) ?log(y^?c?)?越大,損失越大。
交叉熵損失和softmax函數的關系
-
模型通常輸出logits(未歸一化的分數),例如 [ z 1 , z 2 , ? , z N ] [z_1,z_2,\cdots,z_N] [z1?,z2?,?,zN?]。
-
softmax函數將logits轉化為概率分布:
y ^ i = z z i ∑ j = 1 N e z j \begin{equation} \hat{y}_i = \dfrac{z^{z_i}}{\sum_{j=1}^N e^{z_j}} \end{equation} y^?i?=∑j=1N?ezj?zzi???? -
交叉熵損失結合 softmax,用來計算預測分布與真實分布之間的差異。
在 PyTorch 的 CrossEntropyLoss
中,softmax 和交叉熵是結合在一起實現的,因此你不需要手動調用 softmax。
特性
應用場景:
- 多分類任務,例如圖像分類、文本分類等。
- 真實標簽通常以整數形式存儲(如
0, 1, 2
)。
數值穩定性:
- 由于 softmax 和交叉熵結合在一起,可以避免單獨計算 softmax 導致的數值不穩定問題。
Pytorch中的實現
構造函數
PyTorch 提供了 torch.nn.CrossEntropyLoss
:
torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduction='mean')
參數說明:
weight
:用于對不同類別賦予不同的權重。ignore_index
:指定忽略某些類別的損失(通常用于處理 padding)。reduction
:決定損失的輸出形式:'mean'
(默認):返回損失的均值。'sum'
:返回損失的總和。'none'
:返回每個樣本的損失值。
使用示例
1、單樣本交叉熵損失
import torch
import torch.nn as nn# 模型的輸出 logits 和真實標簽
logits = torch.tensor([[2.0, 1.0, 0.1]]) # 未經過 softmax 的輸出
labels = torch.tensor([0]) # 真實標簽(類別索引)# 定義交叉熵損失函數
criterion = nn.CrossEntropyLoss()# 計算損失
loss = criterion(logits, labels)
print("CrossEntropyLoss:", loss.item())
解釋:
logits
是未歸一化的分數。labels
是類別索引(如類別 0)。- 內部會先對 logits 應用 softmax,再計算交叉熵損失。
計算細節:
a)、給定的數據
- logits: [ 2.0 1.0 0.1 ] \begin{bmatrix} 2.0 & 1.0 & 0.1 \end{bmatrix} [2.0?1.0?0.1?]
- 這是模型輸出的未歸一化分數(logits)。
- labels: [ 0 ] \begin{bmatrix} 0 \end{bmatrix} [0?]
- 真實標簽,表示類別索引(0 表示第一類)。
b)、CrossEntropyLoss 的計算公式,交叉熵損失公式如下:
L o s s = ? 1 N ∑ i = 1 N l o g ( e x p ( l o g i t y i ) ∑ j e x p ( l o g i t j ) ) \begin{equation} Loss = -\dfrac{1}{N} \sum_{i=1}^{N} log \left( \dfrac{exp(logit_{y_i})}{\sum_j exp(logit_j)} \right) \end{equation} Loss=?N1?i=1∑N?log(∑j?exp(logitj?)exp(logityi??)?)??
其中:
- N N N: 樣本數量(在這里為 1)。
- l o g i t j logit_j logitj?: 第 j j j類的 logit 值。
- y i y_i yi?: 樣本 i i i? 的真實類別索引。
c)、具體的步驟
step 1:softmax計算概率分布
softmax函數將logits轉換為概率分布:
s o f t m a x ( z i ) = e x p ( z i ) ∑ j e x p ( z j ) \begin{equation} softmax(z_i) = \dfrac{exp(z_i)}{\sum_j exp(z_j)} \end{equation} softmax(zi?)=∑j?exp(zj?)exp(zi?)???
對于logits: [ 2.0 1.0 0.1 ] \begin{bmatrix} 2.0 & 1.0 & 0.1 \end{bmatrix} [2.0?1.0?0.1?],計算如下:
- 計算每個元素的指數:
e x p ( 2.0 ) = e 2 ≈ 7.389 , e x p ( 1.0 ) = e 1 ≈ 2.718 , e x p ( 0.1 ) = e 0.1 ≈ 1.105 \begin{equation} exp(2.0)=e^2 \approx 7.389, \quad exp(1.0)=e^1 \approx 2.718, \quad exp(0.1)=e^{0.1} \approx 1.105 \end{equation} exp(2.0)=e2≈7.389,exp(1.0)=e1≈2.718,exp(0.1)=e0.1≈1.105??
- 求和:
s u m = 7.389 + 2.718 + 1.105 ≈ 11.212 \begin{equation} sum = 7.389 + 2.718 + 1.105 \approx 11.212 \end{equation} sum=7.389+2.718+1.105≈11.212??
- 計算每個類別的概率:
s o f t m a x ( 2.0 ) = 7.389 11.212 ≈ 0.659 , s o f t m a x ( 1.0 ) = 2.718 11.212 ≈ 0.242 , s o f t m a x ( 0.1 ) = 1.105 11.212 ≈ 0.099 \begin{equation} softmax(2.0)=\dfrac{7.389}{11.212} \approx 0.659,\quad softmax(1.0)=\dfrac{2.718}{11.212} \approx 0.242,\quad softmax(0.1)=\dfrac{1.105}{11.212} \approx 0.099 \end{equation} softmax(2.0)=11.2127.389?≈0.659,softmax(1.0)=11.2122.718?≈0.242,softmax(0.1)=11.2121.105?≈0.099??
概率分布為:
[ 0.659 0.242 0.099 ] \begin{bmatrix} 0.659 & 0.242 & 0.099 \end{bmatrix} [0.659?0.242?0.099?]
step 2:取真實標簽對應的概率
真實標簽 y = 0 y=0 y=0,對應的概率為第一個類別的softmax輸出:
P ( y = 0 ) = 0.659 \begin{equation} P(y=0)=0.659 \end{equation} P(y=0)=0.659??
step 3:計算交叉熵損失
根據交叉熵公式,損失為:
L o s s = ? l o g ( P ( y = 0 ) ) = ? l o g ( 0.659 ) \begin{equation} Loss = -log(P(y=0)) = -log(0.659) \end{equation} Loss=?log(P(y=0))=?log(0.659)??
計算對數值:
l o g ( 0.659 ) ≈ ? 0.416 \begin{equation} log(0.659) \approx -0.416 \end{equation} log(0.659)≈?0.416??
因此,損失為:
L o s s = 0.416 \begin{equation} Loss = 0.416 \end{equation} Loss=0.416??
2、多樣本交叉熵損失
logits = torch.tensor([[1.5, 0.3, 2.1], [2.0, 1.0, 0.1], [0.1, 2.2, 1.0]]) # Batch size = 3labels = torch.tensor([2, 0, 1]) # Batch size = 3# 定義交叉熵損失函數
criterion = nn.CrossEntropyLoss()# 計算損失
loss = criterion(logits, labels)
print("CrossEntropyLoss:", loss.item())
a)、給定的數據
logits(未歸一化的分數):
l o g i t s = [ 1.5 0.3 2.1 2.0 1.0 0.1 0.1 2.2 1.0 ] logits = \begin{bmatrix} 1.5 & 0.3 & 2.1 \\ 2.0 & 1.0 & 0.1 \\ 0.1 & 2.2 & 1.0 \end{bmatrix} logits= ?1.52.00.1?0.31.02.2?2.10.11.0? ?
labels(真實標簽的索引):
l a b e l s = [ 2 0 1 ] labels = \begin{bmatrix} 2 & 0 & 1 \end{bmatrix} labels=[2?0?1?]
- 第一行對應的類別2
- 第二行對應的類別0
- 第三行對應的類別1
b)、交叉熵損失函數
L o s s = ? 1 N ∑ i = 1 N l o g ( e x p ( l o g i t i , y i ) ∑ j e x p ( l o g i t i , j ) ) \begin{equation} Loss = -\dfrac{1}{N} \sum_{i=1}^{N} log \left( \dfrac{exp(logit_{i,y_i})}{\sum_j exp(logit_{i,j})} \right) \end{equation} Loss=?N1?i=1∑N?log(∑j?exp(logiti,j?)exp(logiti,yi??)?)??
其中:
- N = 3 N=3 N=3: 是批量大小。
- l o g i t i , j logit_{i,j} logiti,j?: 是樣本 i i i對類別 j j j的預測分數。
- y i y_i yi?: 樣本 i i i?? 的真實類別索引。
c)、逐行計算softmax概率和交叉熵損失
step 1:第一行 l o g i t s = [ 1.5 0.3 2.1 ] logits = \begin{bmatrix} 1.5 & 0.3 & 2.1 \end{bmatrix} logits=[1.5?0.3?2.1?]?,真實標簽 = 2
-
計算softmax:
- 計算每個分數的指數值:
e x p ( 1.5 ) ≈ 4.481 , e x p ( 0.3 ) ≈ 1.350 , e x p ( 2.1 ) ≈ 8.165 \begin{equation} exp(1.5) \approx 4.481, \quad exp(0.3) \approx 1.350, \quad exp(2.1) \approx 8.165 \end{equation} exp(1.5)≈4.481,exp(0.3)≈1.350,exp(2.1)≈8.165??
- 求和
s u m = 4.481 + 1.350 + 8.165 ≈ 13.996 \begin{equation} sum = 4.481 + 1.350 + 8.165 \approx 13.996 \end{equation} sum=4.481+1.350+8.165≈13.996??
- 計算每個類別的概率
P ( 0 ) = 4.481 13.996 ≈ 0.32 , P ( 1 ) = 1.350 13.996 ≈ 0.096 , P ( 2 ) = 8.165 13.996 ≈ 0.583 \begin{equation} P(0) = \dfrac{4.481}{13.996} \approx 0.32,\quad P(1) = \dfrac{1.350}{13.996} \approx 0.096,\quad P(2) = \dfrac{8.165}{13.996} \approx 0.583 \end{equation} P(0)=13.9964.481?≈0.32,P(1)=13.9961.350?≈0.096,P(2)=13.9968.165?≈0.583??
-
取真實類別2的概率:
P ( y = 2 ) = 0.583 \begin{equation} P(y=2) = 0.583 \end{equation} P(y=2)=0.583??
- 計算損失:
L o s s 1 = ? l o g ( 0.583 ) ≈ 0.540 \begin{equation} Loss_1 = -log(0.583) \approx 0.540 \end{equation} Loss1?=?log(0.583)≈0.540??
step 2:第二行 l o g i t s = [ 2.0 1.0 0.1 ] logits = \begin{bmatrix} 2.0 & 1.0 & 0.1 \end{bmatrix} logits=[2.0?1.0?0.1?]?,真實標簽 = 0
-
計算softmax:
- 計算每個分數的指數值:
e x p ( 2.0 ) ≈ 7.389 , e x p ( 1.0 ) ≈ 2.718 , e x p ( 0.1 ) ≈ 1.105 \begin{equation} exp(2.0) \approx 7.389, \quad exp(1.0) \approx 2.718, \quad exp(0.1) \approx 1.105 \end{equation} exp(2.0)≈7.389,exp(1.0)≈2.718,exp(0.1)≈1.105??
- 求和
s u m = 7.389 + 2.718 + 1.105 ≈ 11.212 \begin{equation} sum = 7.389 + 2.718 + 1.105 \approx 11.212 \end{equation} sum=7.389+2.718+1.105≈11.212??
- 計算每個類別的概率
P ( 0 ) = 7.389 11.212 ≈ 0.659 , P ( 1 ) = 2.718 11.212 ≈ 0.242 , P ( 2 ) = 1.105 11.212 ≈ 0.099 \begin{equation} P(0) = \dfrac{7.389}{11.212} \approx 0.659,\quad P(1) = \dfrac{2.718}{11.212} \approx 0.242,\quad P(2) = \dfrac{1.105}{11.212} \approx 0.099 \end{equation} P(0)=11.2127.389?≈0.659,P(1)=11.2122.718?≈0.242,P(2)=11.2121.105?≈0.099??
-
取真實類別0的概率:
P ( y = 0 ) = 0.659 \begin{equation} P(y=0) = 0.659 \end{equation} P(y=0)=0.659??
- 計算損失:
L o s s 2 = ? l o g ( 0.659 ) ≈ 0.417 \begin{equation} Loss_2 = -log(0.659) \approx 0.417 \end{equation} Loss2?=?log(0.659)≈0.417??
step 3:第二行 l o g i t s = [ 0.1 2.2 1.0 ] logits = \begin{bmatrix} 0.1 & 2.2 & 1.0 \end{bmatrix} logits=[0.1?2.2?1.0?]?,真實標簽 = 1
-
計算softmax:
- 計算每個分數的指數值:
e x p ( 0.1 ) ≈ 1.105 , e x p ( 2.2 ) ≈ 9.025 , e x p ( 1.0 ) ≈ 2.718 \begin{equation} exp(0.1) \approx 1.105, \quad exp(2.2) \approx 9.025, \quad exp(1.0) \approx 2.718 \end{equation} exp(0.1)≈1.105,exp(2.2)≈9.025,exp(1.0)≈2.718??
- 求和
s u m = 1.105 + 9.025 + 2.718 ≈ 12.848 \begin{equation} sum = 1.105 + 9.025 + 2.718 \approx 12.848 \end{equation} sum=1.105+9.025+2.718≈12.848??
- 計算每個類別的概率
P ( 0 ) = 1.105 12.848 ≈ 0.086 , P ( 1 ) = 9.025 12.848 ≈ 0.703 , P ( 2 ) = 2.718 12.848 ≈ 0.211 \begin{equation} P(0) = \dfrac{1.105}{12.848} \approx 0.086,\quad P(1) = \dfrac{9.025}{12.848} \approx 0.703,\quad P(2) = \dfrac{2.718}{12.848} \approx 0.211 \end{equation} P(0)=12.8481.105?≈0.086,P(1)=12.8489.025?≈0.703,P(2)=12.8482.718?≈0.211??
-
取真實類別1的概率:
P ( y = 1 ) = 0.703 \begin{equation} P(y=1) = 0.703 \end{equation} P(y=1)=0.703??
- 計算損失:
L o s s 3 = ? l o g ( 0.703 ) ≈ 0.353 \begin{equation} Loss_3 = -log(0.703) \approx 0.353 \end{equation} Loss3?=?log(0.703)≈0.353??
d)、批量損失
將每個樣本的損失平均:
L o s s = L o s s 1 + L o s s 2 + L o s s 3 3 = 0.540 + 0.417 + 0.353 3 ≈ 0.437 \begin{equation} Loss = \dfrac{Loss_1 + Loss_2 + Loss_3}{3} = \dfrac{0.540 + 0.417 + 0.353}{3} \approx 0.437 \end{equation} Loss=3Loss1?+Loss2?+Loss3??=30.540+0.417+0.353?≈0.437??
3、帶權重的交叉熵
在某些情況下,類別分布不平衡,可以為不同類別設置權重:
weights = torch.tensor([1.0, 2.0, 3.0]) # 類別權重
criterion = nn.CrossEntropyLoss(weight=weights)loss = criterion(logits, labels)
print("Weighted CrossEntropyLoss:", loss.item())
4、示例:在神經網絡中的應用
import torch
import torch.nn as nn
import torch.optim as optim# 定義一個簡單的神經網絡
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc = nn.Linear(4, 3) # 輸入 4 維特征,輸出 3 類def forward(self, x):return self.fc(x)# 模型、損失函數和優化器
model = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)# 輸入數據和標簽
inputs = torch.tensor([[0.5, 1.2, -1.3, 0.8], [0.3, -0.7, 1.0, 1.5]]) # Batch size = 2
labels = torch.tensor([0, 2]) # 兩個樣本對應的真實類別# 前向傳播
outputs = model(inputs)# 計算損失
loss = criterion(outputs, labels)
print("Loss:", loss.item())# 反向傳播和優化
loss.backward()
optimizer.step()
5、總結
-
交叉熵損失函數用于度量預測分布與真實分布之間的差異,是分類問題中的核心工具。
-
在 PyTorch 中,
torch.nn.CrossEntropyLoss
結合了 softmax 和交叉熵計算,使用簡單且高效。 -
可以通過參數調整(如權重)來適應不平衡數據集。
整合的代碼
import torch
import torch.nn as nn
import torch.optim as optimdef single_instance_CrossEntropyLoss():# 模型的輸出 logits 和真實標簽logits = torch.tensor([[2.0, 1.0, 0.1]]) # 未經過 softmax 的輸出labels = torch.tensor([0]) # 真實標簽(類別索引)# 定義交叉熵損失函數criterion = nn.CrossEntropyLoss()# 計算損失loss = criterion(logits, labels)print("CrossEntropyLoss:", loss.item())def multi_instance_CrossEntropyLoss():logits = torch.tensor([[1.5, 0.3, 2.1], [2.0, 1.0, 0.1], [0.1, 2.2, 1.0]]) # Batch size = 3labels = torch.tensor([2, 0, 1]) # Batch size = 3# 定義交叉熵損失函數criterion = nn.CrossEntropyLoss()# 計算損失loss = criterion(logits, labels)print("CrossEntropyLoss:", loss.item())# 定義一個簡單的神經網絡
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc = nn.Linear(4, 3) # 輸入 4 維特征,輸出 3 類def forward(self, x):return self.fc(x)def apply_deepLearning_CrossEntropyLoss():# 模型、損失函數和優化器model = SimpleNet()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.01)# 輸入數據和標簽inputs = torch.tensor([[0.5, 1.2, -1.3, 0.8], [0.3, -0.7, 1.0, 1.5]]) # Batch size = 2labels = torch.tensor([0, 2]) # 兩個樣本對應的真實類別# 前向傳播outputs = model(inputs)# 計算損失loss = criterion(outputs, labels)print("Loss:", loss.item())# 反向傳播和優化loss.backward()optimizer.step()if __name__ == "__main__":print("*" * 30)single_instance_CrossEntropyLoss()multi_instance_CrossEntropyLoss()apply_deepLearning_CrossEntropyLoss()