文章目錄
- 背景介紹
- F.nll_loss
- 什么是負對數似然損失?
- 應用場景
- nn.CrossEntropyLoss
- 簡化工作流程
- 內部機制
- 區別與聯系
背景介紹
無論是圖像分類、文本分類還是其他類型的分類任務,交叉熵損失(Cross Entropy Loss)都是最常用的一種損失函數。它衡量的是模型預測的概率分布與真實標簽之間的差異。在 PyTorch 中,有兩個特別值得注意的實現:F.nll_loss
和 nn.CrossEntropyLoss
。
F.nll_loss
什么是負對數似然損失?
F.nll_loss
是負對數似然損失(Negative Log Likelihood Loss),主要用于多類分類問題。它的輸入是對數概率(log-probabilities),這意味著在使用 F.nll_loss
之前,我們需要先對模型的輸出應用 log_softmax
函數,將原始輸出轉換為對數概率形式。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset# 創建一些虛擬數據
features = torch.randn(100, 20) # 假設有100個樣本,每個樣本有20個特征
labels = torch.randint(0, 3, (100,)) # 假設有3個類別# 創建數據加載器
dataset = TensorDataset(features, labels)
data_loader = DataLoader(dataset, batch_size=10, shuffle=True)class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(20, 3) # 輸入維度為20,輸出維度為3(對應3個類別)def forward(self, x):return self.fc(x)model_nll = SimpleModel()
optimizer = torch.optim.SGD(model_nll.parameters(), lr=0.01)for inputs, targets in data_loader:optimizer.zero_grad() # 清除梯度outputs = model_nll(inputs) # 模型前向傳播log_softmax_outputs = F.log_softmax(outputs, dim=1) # 應用 log_softmaxloss = F.nll_loss(log_softmax_outputs, targets) # 計算 nll_lossloss.backward() # 反向傳播optimizer.step() # 更新權重print(f"Batch Loss with F.nll_loss: {loss.item():.4f}")
應用場景
由于 F.nll_loss
需要預先計算 log_softmax
,這為用戶提供了一定程度的靈活性,尤其是在需要復用 log_softmax
結果的情況下。
nn.CrossEntropyLoss
簡化工作流程
相比之下,nn.CrossEntropyLoss
更加直接和易用。它結合了 log_softmax
和 nll_loss
的功能,因此可以直接接受未經歸一化的原始輸出作為輸入,內部自動完成這兩個步驟。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset# 創建一些虛擬數據
features = torch.randn(100, 20) # 假設有100個樣本,每個樣本有20個特征
labels = torch.randint(0, 3, (100,)) # 假設有3個類別# 創建數據加載器
dataset = TensorDataset(features, labels)
data_loader = DataLoader(dataset, batch_size=10, shuffle=True)class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(20, 3) # 輸入維度為20,輸出維度為3(對應3個類別)def forward(self, x):return self.fc(x)model_ce = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model_ce.parameters(), lr=0.01)for inputs, targets in data_loader:optimizer.zero_grad() # 清除梯度outputs = model_ce(inputs) # 模型前向傳播loss = criterion(outputs, targets) # 直接計算交叉熵損失,內部包含 log_softmaxloss.backward() # 反向傳播optimizer.step() # 更新權重print(f"Batch Loss with nn.CrossEntropyLoss: {loss.item():.4f}")
內部機制
實際上,nn.CrossEntropyLoss
= log_softmax
+ nll_loss
。這種設計簡化了用戶的代碼編寫過程,特別是當不需要對中間結果進行額外操作時。
區別與聯系
-
輸入要求:
F.nll_loss
要求輸入為log_softmax
后的結果;而nn.CrossEntropyLoss
可以直接接受未經softmax
處理的原始輸出。 -
靈活性:如果需要對
log_softmax
結果進行進一步處理或調試,那么F.nll_loss
提供了更大的靈活性。 -
便捷性:對于大多數用戶而言,
nn.CrossEntropyLoss
因其簡潔性和內置的log_softmax
步驟,是更方便的選擇。