nn.CrossEntropyLoss
在 PyTorch 中是處理多分類問題的常用損失函數,它是兩個函數 nn.LogSoftmax
和 nn.NLLLoss
(Negative Log Likelihood Loss)的組合。使用這個損失函數可以直接從模型得到原始的輸出分數(logits),而不需要單獨對輸出進行 Softmax
處理。下面詳細介紹這個損失函數的關鍵特點、工作原理和使用方式。
工作原理
nn.CrossEntropyLoss
首先對網絡的輸出應用 LogSoftmax
。這意味著網絡輸出的 logits(原始預測值)被轉換成概率的對數形式。然后,它使用這些對數概率和真實標簽計算 NLLLoss。
具體來說,公式可以表示為:
[ \text{Loss}(x, \text{class}) = -\log\left(\frac{\exp(x[\text{class}])}{\sum_j \exp(x[j])}\right) ]
[ \text{Loss}(x, \text{class}) = -x[\text{class}] + \log\left(\sum_j \exp(x[j])\right) ]
其中:
- ( x ) 是模型輸出的 logits。
- ( \text{class} ) 是真實的類別標簽(非 one-hot 編碼)。
參數詳解
- weight (Tensor, optional): 手動指定每個類的權重。如果給定,必須是一個長度為
C
的 Te