目錄
一、損失函數的基本概念
二、常用損失函數及實現
1. 均方誤差損失(MSELoss)
2. 平均絕對誤差損失(L1Loss/MAELoss)
3. 交叉熵損失(CrossEntropyLoss)
4. 二元交叉熵損失(BCELoss)
三、損失函數選擇指南
四、損失函數在訓練中的應用
五、總結
損失函數是深度學習模型訓練的核心組件,它量化了模型預測值與真實值之間的差異,指導模型參數的更新方向。本文將結合 PyTorch 代碼實例,詳細講解常用損失函數的原理、適用場景及實現方法。
一、損失函數的基本概念
損失函數(Loss Function)又稱代價函數(Cost Function),是衡量模型預測結果與真實標簽之間差異的指標。在模型訓練過程中,通過優化算法(如梯度下降)最小化損失函數,使模型逐漸逼近最優解。
損失函數的選擇取決于具體任務類型:
- 回歸任務:預測連續值(如房價、溫度)
- 分類任務:預測離散類別(如圖片分類、垃圾郵件識別)
- 其他任務:如生成任務、序列標注等
二、常用損失函數及實現
1. 均方誤差損失(MSELoss)
均方誤差損失是回歸任務中最常用的損失函數,計算預測值與真實值之間平方差的平均值。
數學公式: 其中,
為真實值,
為預測值,n為樣本數量。
代碼實現:
import torch
import torch.nn as nn# 初始化MSE損失函數
mse_loss = nn.MSELoss()# 示例數據
y_true = torch.tensor([3.0, 5.0, 2.5]) # 真實值
y_pred = torch.tensor([2.5, 5.0, 3.0]) # 預測值# 計算損失
loss = mse_loss(y_pred, y_true)
print(f'MSE Loss: {loss.item()}') # 輸出:MSE Loss: 0.0833333358168602
特點:
- 對異常值敏感,因為會對誤差進行平方
- 是凸函數,存在唯一全局最小值
- 適用于大多數回歸任務
2. 平均絕對誤差損失(L1Loss/MAELoss)
平均絕對誤差計算預測值與真實值之間絕對差的平均值,對異常值的敏感性低于 MSE。
數學公式:
代碼實現:
# 初始化L1損失函數
l1_loss = nn.L1Loss()# 計算損失
loss = l1_loss(y_pred, y_true)
print(f'L1 Loss: {loss.item()}') # 輸出:L1 Loss: 0.25
特點:
- 對異常值更穩健
- 梯度在零點處不連續,可能影響收斂速度
- 適用于存在異常值的回歸場景
3. 交叉熵損失(CrossEntropyLoss)
交叉熵損失是多分類任務的標準損失函數,在 PyTorch 中內置了 Softmax 操作,直接作用于模型輸出的 logits。
數學公式: 其中,C為類別數,
為真實標簽的 one-hot 編碼,
為經過 Softmax 處理的預測概率。
代碼實現:
def test_cross_entropy():# 模型輸出的logits(未經過softmax)logits = torch.tensor([[1.5, 2.0, 0.5], [0.5, 1.0, 1.5]])# 真實標簽(類別索引)labels = torch.tensor([1, 2]) # 第一個樣本屬于類別1,第二個樣本屬于類別2# 初始化交叉熵損失函數criterion = nn.CrossEntropyLoss()loss = criterion(logits, labels)print(f'Cross Entropy Loss: {loss.item()}') # 輸出:Cross Entropy Loss: 0.6422222256660461test_cross_entropy()
計算過程解析:
- 對 logits 應用 Softmax 得到概率分布
- 計算真實類別對應的負對數概率
- 取平均值作為最終損失
特點:
- 自動包含 Softmax 操作,無需手動添加
- 適用于多分類任務(類別互斥)
- 標簽格式為類別索引(非 one-hot 編碼)
4. 二元交叉熵損失(BCELoss)
二元交叉熵損失用于二分類任務,需要配合 Sigmoid 激活函數使用,確保輸入值在 (0,1) 范圍內。
數學公式:
代碼實現:
def test_bce_loss():# 模型輸出(已通過sigmoid處理)y_pred = torch.tensor([[0.7], [0.2], [0.9], [0.7]])# 真實標簽(0或1)y_true = torch.tensor([[1], [0], [1], [0]], dtype=torch.float)# 方法1:使用BCELossbce_loss = nn.BCELoss()loss1 = bce_loss(y_pred, y_true)# 方法2:使用functional接口loss2 = nn.functional.binary_cross_entropy(y_pred, y_true)print(f'BCELoss: {loss1.item()}') # 輸出:BCELoss: 0.47234177589416504print(f'Functional BCELoss: {loss2.item()}') # 輸出:Functional BCELoss: 0.47234177589416504test_bce_loss()
變種:BCEWithLogitsLoss
對于未經過 Sigmoid 處理的 logits,推薦使用BCEWithLogitsLoss
,它內部會自動應用 Sigmoid,數值穩定性更好:
# 對于logits輸入(未經過sigmoid)
logits = torch.tensor([[0.8], [-0.5], [1.2], [0.6]])
bce_with_logits_loss = nn.BCEWithLogitsLoss()
loss = bce_with_logits_loss(logits, y_true)
三、損失函數選擇指南
任務類型 | 推薦損失函數 | 特點 |
---|---|---|
回歸任務 | MSELoss | 對異常值敏感,適用于大多數回歸場景 |
回歸任務(含異常值) | L1Loss | 對異常值穩健,梯度不連續 |
多分類任務 | CrossEntropyLoss | 內置 Softmax,處理互斥類別 |
二分類任務 | BCELoss/BCEWithLogitsLoss | 配合 Sigmoid 使用,輸出概率值 |
多標簽分類 | BCEWithLogitsLoss | 每個類別獨立判斷,可同時屬于多個類別 |
四、損失函數在訓練中的應用
以圖像分類任務為例,展示損失函數在完整訓練流程中的使用:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 數據預處理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加載MNIST數據集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 定義簡單的全連接網絡
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(28*28, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28*28)x = torch.relu(self.fc1(x))x = self.fc2(x) # 輸出logits,不使用softmaxreturn x# 初始化模型、損失函數和優化器
model = SimpleNet()
criterion = nn.CrossEntropyLoss() # 多分類任務
optimizer = optim.Adam(model.parameters(), lr=0.001)# 訓練循環
def train(epochs=5):model.train()for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:# 前向傳播outputs = model(images)loss = criterion(outputs, labels)# 反向傳播和優化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()# 打印每輪的平均損失avg_loss = running_loss / len(train_loader)print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}')train()
五、總結
損失函數的選擇直接影響模型的訓練效果和收斂速度,關鍵要點:
- 回歸任務優先選擇 MSELoss,存在異常值時考慮 L1Loss
- 多分類任務使用 CrossEntropyLoss,無需手動添加 Softmax
- 二分類任務推薦使用 BCEWithLogitsLoss,數值穩定性更好
- 訓練過程中需監控損失變化,判斷模型是否收斂或過擬合
合理選擇損失函數并配合適當的優化器,才能充分發揮模型的學習能力。在實際應用中,可根據具體任務特點和數據分布嘗試不同的損失函數,選擇表現最佳的方案。