一、核心思想:通過概率分布懲罰錯誤
交叉熵損失的本質是:
比較模型預測的概率分布 vs 真實標簽的概率分布,懲罰兩者之間的差異。
例如:
- 真實標簽:圖像 0 → 文本 0(獨熱編碼 [1, 0, 0, ...])
- 模型預測:[0.1, 0.2, 0.3, 0.4, ...](預測文本 0 的概率僅 0.1)
此時損失會很大,因為預測分布與真實分布差異大。
二、分步解析交叉熵懲罰機制
1. 相似度矩陣 → 概率分布
假設 sim_i2t 是一個 [3, 6] 的矩陣(3 個圖像 × 6 個文本):
# 示例相似度矩陣(簡化版,僅展示對角線高相似度)
sim_i2t = torch.tensor([[5.0, 1.0, 1.0, 1.0, 1.0, 1.0], # 圖像0 → 文本0是正樣本[1.0, 5.0, 1.0, 1.0, 1.0, 1.0], # 圖像1 → 文本1是正樣本[1.0, 1.0, 5.0, 1.0, 1.0, 1.0] # 圖像2 → 文本2是正樣本
])
通過 softmax 將相似度轉換為概率分布:
probs = F.softmax(sim_i2t, dim=1) # 對每行做softmax
print(probs)
輸出結果:
tensor([[0.94, 0.02, 0.02, 0.02, 0.02, 0.02], # 預測文本0概率最高(正確)[0.02, 0.94, 0.02, 0.02, 0.02, 0.02], # 預測文本1概率最高(正確)[0.02, 0.02, 0.94, 0.02, 0.02, 0.02] # 預測文本2概率最高(正確)
])
2. 真實標簽的概率分布
假設 targets = [0, 1, 2],轉換為獨熱編碼:
# 獨熱編碼(簡化版,僅展示核心邏輯)
one_hot = torch.zeros_like(probs)
for i, t in enumerate(targets):one_hot[i, t] = 1.0print(one_hot)
輸出結果:
tensor([[1.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 圖像0的正樣本是文本0[0.0, 1.0, 0.0, 0.0, 0.0, 0.0], # 圖像1的正樣本是文本1[0.0, 0.0, 1.0, 0.0, 0.0, 0.0] # 圖像2的正樣本是文本2
])
3. 計算交叉熵損失
交叉熵損失公式:
對于上述例子:
- 圖像 0 的損失:-log(0.94) ≈ 0.06
- 圖像 1 的損失:-log(0.94) ≈ 0.06
- 圖像 2 的損失:-log(0.94) ≈ 0.06
平均損失:(0.06 + 0.06 + 0.06) / 3 ≈ 0.06
實際函數內部:
# 1. 對預測值應用softmax,轉換為概率分布
probs = F.softmax(sim_i2t, dim=1)# 2. 對每個樣本,取出目標類別對應的概率
# 例如:
# - 第0個樣本的目標類別是0,取出probs[0, 0]
# - 第1個樣本的目標類別是1,取出probs[1, 1]
# - 第2個樣本的目標類別是2,取出probs[2, 2]
target_probs = probs[torch.arange(len(targets)), targets]# 3. 計算負對數似然
nll = -torch.log(target_probs)# 4. 求平均值得到最終損失
loss = nll.mean()
三、標簽平滑如何調整懲罰
標簽平滑(label_smoothing=0.1)會將:
- 正樣本的概率從 1.0 調整為 0.9
- 負樣本的概率從 0.0 調整為 0.1 / (類別數-1)
例如,對于圖像 0(正樣本是文本 0):
- 原始標簽:[1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
- 平滑后標簽:[0.9, 0.02, 0.02, 0.02, 0.02, 0.02]
此時損失計算變為:
實際函數內部:當使用label_smoothing=0.1時,函數內部會將目標概率分布從嚴格的獨熱編碼調整為平滑分布:
def cross_entropy_with_label_smoothing(logits, targets, smoothing=0.1):num_classes = logits.size(1)# 計算平滑后的目標分布# - 正樣本概率: 1.0 - smoothing + (smoothing / num_classes)# - 負樣本概率: smoothing / num_classessmooth_targets = torch.full_like(logits, smoothing / (num_classes - 1))smooth_targets[torch.arange(len(targets)), targets] = 1.0 - smoothing + (smoothing / num_classes)# 對預測值應用log_softmaxlog_probs = F.log_softmax(logits, dim=1)# 計算交叉熵(等價于F.kl_div(log_probs, smooth_targets))loss = (-smooth_targets * log_probs).sum(dim=1).mean()return loss
四、懲罰機制可視化
假設模型預測錯誤(圖像 0 預測文本 1 的概率最高):
# 錯誤預測的情況
bad_probs = torch.tensor([[0.1, 0.8, 0.05, 0.05, 0.0, 0.0], # 錯誤:預測文本1概率最高[0.02, 0.94, 0.02, 0.02, 0.02, 0.0], # 正確[0.02, 0.02, 0.94, 0.02, 0.02, 0.0] # 正確
])# 計算交叉熵損失(無標簽平滑)
loss = -torch.log(bad_probs[0, 0]) # 圖像0的損失:-log(0.1) ≈ 2.3
print(f"錯誤預測的損失: {loss.item():.4f}") # 損失遠大于正確預測的0.06
輸出結果:
錯誤預測的損失: 2.3026
五、總結
交叉熵損失的懲罰機制是:
- 對正樣本:預測概率越低,懲罰越大(損失呈對數增長)
- 對負樣本:預測概率越高,懲罰越大
- 標簽平滑:減輕對極端預測的懲罰,防止過擬合
通過這種方式,模型被強制學習到:
- 正樣本對的相似度要盡可能高
- 負樣本對的相似度要盡可能低
這就是對比學習中 “拉近正樣本、推遠負樣本” 的核心實現方式!