在多分類任務中,模型輸出一個概率分布,常用的損失函數是 Categorical Cross Entropy(多類交叉熵)。本文將帶你理解其數學本質、應用場景、數值穩定性及完整 Python 實現。
📘 一、什么是 Categorical Cross Entropy?
多類交叉熵損失函數 衡量的是預測的概率分布 與真實類別分布
之間的距離。通常用于 Softmax 輸出層 + 多分類問題。
🧮 二、數學公式
設:
:真實標簽(獨熱編碼 One-Hot)
:模型預測概率(Softmax 輸出)
則多類交叉熵定義為:?
?
含義:
若
且其他為 0(獨熱編碼),則只考慮正確類別對應的概率;
預測越接近真實標簽,對應損失越小。
🧑?💻 三、Python 實現(含數值穩定)
函數實現如下:
import mathdef categorical_cross_entropy(y_true, y_pred):"""計算多類交叉熵損失(適用于獨熱編碼標簽)參數:y_true (List[float]):真實標簽(One-Hot)y_pred (List[float]):預測概率(Softmax 輸出)返回:float:交叉熵損失值"""epsilon = 1e-15 # 防止 log(0)y_pred = [min(max(p, epsilon), 1 - epsilon) for p in y_pred]return -sum(y * math.log(p) for y, p in zip(y_true, y_pred))# 示例:3類分類問題
y_true = [0, 1, 0]
y_pred = [0.2, 0.7, 0.1]loss = categorical_cross_entropy(y_true, y_pred)
print("Categorical Cross Entropy:", loss)
? 輸出示例:?
Categorical Cross Entropy: 0.35667494393873245
?? 四、為什么需要 Epsilon 防止 log(0)?
在預測中,某些類概率可能非常接近 0(例如 1e-20),直接對其取對數會:
產生
math domain error
;導致梯度爆炸或模型不穩定。
因此我們設置:
epsilon = 1e-15
y_pred = max(min(p, 1 - epsilon), epsilon)
🔄 五、與 Binary Cross Entropy 的區別?
項目 | Binary Cross Entropy | Categorical Cross Entropy |
---|---|---|
應用場景 | 二分類或多標簽 | 多分類(單標簽) |
標簽格式 | 0 或 1 | 獨熱編碼 |
輸出層 | Sigmoid | Softmax |
🧠 六、實際應用場景
圖像分類(如 CIFAR-10、ImageNet)
文本分類(如新聞分類、情感分析)
多類別實體識別(NER)
📌 七、總結
Categorical Cross Entropy 是多分類任務的首選損失函數;
與 Softmax 輸出層配合使用;
一定要做 數值穩定性處理(加 epsilon);
真實標簽應為 One-Hot 向量;
預測越準,損失越小。
?