交叉熵損失函數通常用于多類分類損失函數計算。計算公式如下:
P為真實值,Q為預測值。
使用tensorflow計算
import tensorflow as tf
import keras# 創建一個示例數據集
# 假設有3個樣本,每個樣本有4個特征,共2個類別
# 目標標簽為稀疏表示(每個樣本的類別用類別索引表示)
y_true = tf.constant([0, 1, 1]) # 真實標簽
y_pred = tf.constant([[0.9, 0.1], [0.2, 0.8], [0.3, 0.7]]) # 模型預測# 使用SparseCategoricalCrossentropy計算損失
loss_fn = keras.losses.SparseCategoricalCrossentropy()
loss = loss_fn(y_true, y_pred)
print("損失值:", loss.numpy())
# 損失值: 0.22839303
手動計算
真實標簽:[0,1,1]寫成one-hot形式 [[1,0],[0,1],[0,1]]
預測標簽(經過softmax):[[0.9, 0.1], [0.2, 0.8], [0.3, 0.7]]
import mathh1 = (1 * math.log(0.9) + 0 * math.log(0.1))
h2 = (0 * math.log(0.2) + 1 * math.log(0.8))
h3 = (0 * math.log(0.3) + 1 * math.log(0.7))
h = -(h1 + h2 + h3) / 3
print("損失值:", h)
# 損失值: 0.22839300363692283
注意多標簽分類和多類分類區別與聯系。
參考
- 交叉熵函數Cross_EntropyLoss()的詳細計算過程
- 多標簽分類任務中的損失函數