變量名解釋
logits:未經過normalize(未經過激活函數處理)的原始分數,例如一個mlp將特征映射到num_target_class維的輸出tensor就是logits。
probs:probabilities的簡寫,logits經過sigmoid函數,就變成了分布在0-1之間的概率值probs。
Binary Cross-Entropy Loss
Binary Cross-Entropy Loss,簡稱為BCE loss,即二元交叉熵損失。
二元交叉熵損失是一種用于二分類問題的損失函數。它衡量的是模型預測的概率分布與真實標簽的概率分布之間的差異。在二分類問題中,每個樣本的標簽只有兩種可能的狀態,通常表示為 0(負類)和 1(正類)。
其公式為:
BCE?Loss? = ? 1 N ∑ i = 1 N [ y i log ? ( p i ) + ( 1 ? y i ) log ? ( 1 ? p i ) ] \text { BCE Loss }=-\frac{1}{N} \sum_{i=1}^N\left[y_i \log \left(p_i\right)+\left(1-y_i\right) \log \left(1-p_i\right)\right] ?BCE?Loss?=?N1?i=1∑N?[yi?log(pi?)+(1?yi?)log(1?pi?)]
其中:
- N N N是數據集的樣本數量。
- y i y_i yi?是第 i {i} i個樣本的真實標簽,取值為 0 或 1,即第 i {i} i個樣本要么屬于類別 0(負類),要么屬于類別 1(正類)。
- p i p_i pi?是第 i {i} i個樣本屬于類別 1(正類)的概率
- log ? \log log是是自然對數。
當真實標簽 y i = 1 y_i=1 yi?=1 時,損失函數的第一部分 y i log ? ( p i ) y_i \log \left(p_i\right) yi?log(pi?) 起作用,第二部分為 0 。此時, 如果預測概率 p i p_i pi? 接近 1 (接近真實標簽 y i = 1 y_i=1 yi?=1), 那么 log ? ( p i ) \log \left(p_i\right) log(pi?) 接近 0 , 損失較小;如果 p i p_i pi? 接近 0 (即模型預測錯誤),那么 log ? ( p i ) \log \left(p_i\right) log(pi?) 會變得成絕對值很大的負數,導致取反后loss很大。
當真實標簽 y i = 0 y_i=0 yi?=0 時,損失函數的第二部分 ( 1 ? y i ) log ? ( 1 ? p i ) \left(1-y_i\right) \log \left(1-p_i\right) (1?yi?)log(1?pi?)起作用,第一部分為 0。此時,預測概率 p i p_i pi?越接近于 0,整體loss越小。
Pytorch手動實現
import torch
import torch.nn.functional as Fdef manual_binary_cross_entropy_with_logits(logits, targets):# 使用 Sigmoid 函數將 logits 轉換為概率probs = torch.sigmoid(logits)# 計算二元交叉熵損失loss = - torch.mean(targets * torch.log(probs) + (1 - targets) * torch.log(1 - probs))return loss# logits和targets可以是任意shape的tensor,只要shape相同即可
logits = torch.tensor([0.2, -0.4, 1.2, 0.8])
targets = torch.tensor([0., 1., 1., 0.])
assert logits.shape == targets.shape# 使用 PyTorch 的 F.binary_cross_entropy_with_logits 函數計算損失
loss_pytorch = F.binary_cross_entropy_with_logits(logits, targets)# 使用手動實現的函數計算損失
loss_manual = manual_binary_cross_entropy_with_logits(logits, targets)print(f'Loss (PyTorch): {loss_pytorch.item()}')
print(f'Loss (Manual): {loss_manual.item()}')
F.binary_cross_entropy 與 F.binary_cross_entropy_with_logits的區別
F.binary_cross_entropy的輸入是probs
F.binary_cross_entropy_with_logits的輸入是logits