深度學習-損失函數

目錄

1. 線性回歸損失函數

1.1 MAE損失

1.2 MSE損失

2. CrossEntropyLoss

2.1 信息量

2.2 信息熵

2.3 KL散度

2.4 交叉熵

3. BCELoss

4. 總結


1. 線性回歸損失函數

1.1 MAE損失

MAE(Mean Absolute Error,平均絕對誤差)通常也被稱為 L1-Loss,通過對預測值和真實值之間的絕對差取平均值來衡量他們之間的差異。。

MAE的公式如下:


\text{MAE} = \frac{1}{n} \sum_{i=1}^{n} \left| y_i - \hat{y}_i \right|

其中:

  • n 是樣本的總數。

  • y_i 是第 i 個樣本的真實值。

  • \hat{y}_i是第 i 個樣本的預測值。

  • \left| y_i - \hat{y}_i \right|是真實值和預測值之間的絕對誤差。

特點

  1. 魯棒性:與均方誤差(MSE)相比,MAE對異常值(outliers)更為魯棒,因為它不會像MSE那樣對較大誤差平方敏感。

  2. 物理意義直觀:MAE以與原始數據相同的單位度量誤差,使其易于解釋。

應用場景: MAE通常用于需要對誤差進行線性度量的情況,尤其是當數據中可能存在異常值時,MAE可以避免對異常值的過度懲罰。

使用torch.nn.L1Loss即可計算MAE:

import torch
import torch.nn as nn
?
# 初始化MAE損失函數
mae_loss = nn.L1Loss()
?
# 假設 y_true 是真實值, y_pred 是預測值
y_true = torch.tensor([3.0, 5.0, 2.5])
y_pred = torch.tensor([2.5, 5.0, 3.0])
?
# 計算MAE
loss = mae_loss(y_pred, y_true)
print(f'MAE Loss: {loss.item()}')

1.2 MSE損失

均方差損失,也叫L2Loss。

MSE(Mean Squared Error,均方誤差)通過對預測值和真實值之間的誤差平方取平均值,來衡量預測值與真實值之間的差異。

MSE的公式如下:


\text{MSE} = \frac{1}{n} \sum_{i=1}^{n} \left( y_i - \hat{y}_i \right)^2

其中:

  • n 是樣本的總數。

  • y_i 是第 i 個樣本的真實值。

  • \hat{y}_i 是第 i 個樣本的預測值。

  • \left( y_i - \hat{y}_i \right)^2 是真實值和預測值之間的誤差平方。

特點

  1. 平方懲罰:因為誤差平方,MSE 對較大誤差施加更大懲罰,所以 MSE 對異常值更為敏感。

  2. 凸性:MSE 是一個凸函數(國際的叫法,國內叫凹函數),這意味著它具有一個唯一的全局最小值,有助于優化問題的求解。

應用場景

MSE被廣泛應用在神經網絡中。

使用 torch.nn.MSELoss 可以實現:

import torch
import torch.nn as nn
?
# 初始化MSE損失函數
mse_loss = nn.MSELoss()
?
# 假設 y_true 是真實值, y_pred 是預測值
y_true = torch.tensor([3.0, 5.0, 2.5])
y_pred = torch.tensor([2.5, 5.0, 3.0])
?
# 計算MSE
loss = mse_loss(y_pred, y_true)
print(f'MSE Loss: {loss.item()}')

2. CrossEntropyLoss

2.1 信息量

信息量用于衡量一個事件所包含的信息的多少。信息量的定義基于事件發生的概率:事件發生的概率越低,其信息量越大。其量化公式:

對于一個事件x,其發生的概率為 P(x),信息量I(x) 定義為:

性質

  1. 非負性:I(x)≥0。

  2. 單調性:P(x)越小,I(x)越大。

2.2 信息熵

信息熵是信息量的期望值。熵越高,表示隨機變量的不確定性越大;熵越低,表示隨機變量的不確定性越小。

公式由數學中的期望推導而來:

其中:

-logP(x_i)是信息量,P(x_i)是信息量對應的概率

2.3 KL散度

KL散度用于衡量兩個概率分布之間的差異。它描述的是用一個分布 Q來近似另一個分布 P時,所損失的信息量。KL散度越小,表示兩個分布越接近。

對于兩個離散概率分布 P和 Q,KL散度定義為:

其中:P 是真實分布,Q是近似分布。

2.4 交叉熵

對KL散度公式展開:

由上述公式可知,P是真實分布,H(P)是常數,所以KL散度可以用H(P,Q)來表示;H(P,Q)叫做交叉熵。

如果將P換成y,Q換成\hat{y},則交叉熵公式為:

其中:

  • C 是類別的總數。

  • y 是真實標簽的one-hot編碼向量,表示真實類別。

  • \hat{y} 是模型的輸出(經過 softmax 后的概率分布)。

  • y_i 是真實類別的第 i 個元素(0 或 1)。

  • \hat{y}_i 是預測的類別概率分布中對應類別 i 的概率。

函數曲線圖:

特點:

  1. 概率輸出:CrossEntropyLoss 通常與 softmax 函數一起使用,使得模型的輸出表示為一個概率分布(即所有類別的概率和為 1)。PyTorch 的 nn.CrossEntropyLoss 已經內置了 Softmax 操作。如果我們在輸出層顯式地添加 Softmax,會導致重復應用 Softmax,從而影響模型的訓練效果。

  2. 懲罰錯誤分類:該損失函數在真實類別的預測概率較低時,會施加較大的懲罰,這樣模型在訓練時更注重提升正確類別的預測概率。

  3. 多分類問題中的標準選擇:在大多數多分類問題中,CrossEntropyLoss 是首選的損失函數。

應用場景:

CrossEntropyLoss 廣泛應用于各種分類任務,包括圖像分類、文本分類等,尤其是在神經網絡模型中。

nn.CrossEntropyLoss基本原理:

由交叉熵公式可知:


\text{Loss}(y, \hat{y}) = - \sum_{i=1}^{C} y_i \log(\hat{y}_i)

因為y_i是one-hot編碼,其值不是1便是0,又是乘法,所以只要知道1對應的index就可以了,展開后:


\text{Loss}(y, \hat{y}) = - \log(\hat{y}_m)

其中,m表示真實類別。

因為神經網絡最后一層分類總是接softmax,所以可以把\hat{y}_m直接看為是softmax后的結果。


\text{Loss}(i) = - \log(softmax(x_i))
?

所以,CrossEntropyLoss 實質上是兩步的組合:Cross Entropy = Log-Softmax + NLLLoss

  • Log-Softmax:對輸入 logits 先計算對數 softmax:log(softmax(x))

  • NLLLoss(Negative Log-Likelihood):對 log-softmax 的結果計算負對數似然損失。簡單理解就是求負數。原因是概率值通常在 0 到 1 之間,取對數后會變成負數。為了使損失值為正數,需要取負數。

對于softmax(x_i),在softmax介紹中了解到,需要減去最大值以確保數值穩定。


\mathrm{Softmax}(x_i)=\frac{e^{x_i-\max(x)}}{\sum_{j=1}^ne^{x_j-\max(x)}}

則:


LogSoftmax(x_i) =log(\frac{e^{x_i-\max(x)}}{\sum_{j=1}^ne^{x_j-\max(x)}})\\ =x_i-\max(x)-log(\sum_{j=1}^ne^{x_j-\max(x)})

所以:


\text{Loss}(i) = - (x_i-\max(x)-log(\sum_{j=1}^ne^{x_j-\max(x)}))
?

總的交叉熵損失函數是所有樣本的平均值:


\ell(x, y) = \begin{cases} \frac{\sum_{n=1}^N l_n}{N}, & \text{if reduction} = \text{`mean';}\\ \sum_{n=1}^N l_n, & \text{if reduction} = \text{`sum'.} \end{cases}

示例代碼如下:

import torch
import torch.nn as nn
?
# 假設有三個類別,模型輸出是未經softmax的logits
logits = torch.tensor([[1.5, 2.0, 0.5], [0.5, 1.0, 1.5]])
?
# 真實的標簽
labels = torch.tensor([1, 2]) ?# 第一個樣本的真實類別為1,第二個樣本的真實類別為2
?
# 初始化CrossEntropyLoss
# 參數:reduction:mean-平均值,sum-總和
criterion = nn.CrossEntropyLoss()
?
# 計算損失
loss = criterion(logits, labels)
print(f'Cross Entropy Loss: {loss.item()}')

在這個例子中,CrossEntropyLoss 直接作用于未經 softmax 處理的 logits 輸出和真實標簽,PyTorch 內部會自動應用 softmax 激活函數,并計算交叉熵損失。

分析示例中的代碼:

logits = torch.tensor([[1.5, 2.0, 0.5], [0.5, 1.0, 1.5]])

第一個樣本的得分是 [1.5, 2.0, 0.5],分別對應類別 0、1 和 2 的得分。

第二個樣本的得分是 [0.5, 1.0, 1.5],分別對應類別 0、1 和 2 的得分

labels = torch.tensor([1, 2])

第一個樣本的真實類別是 1。

第二個樣本的真實類別是 2。

CrossEntropyLoss 的計算過程可以分為以下幾個步驟:

(1) LogSoftmax 操作

首先,對每個樣本的 logits 應用 LogSoftmax 函數,將 logits 轉換為概率分布。LogSoftmax 函數的公式是: LogSoftmax(x_i) =x_i-\max(x)-log(\sum_{j=1}^ne^{x_j-\max(x)})

對于第一個樣本 [1.5, 2.0, 0.5]

減去最大值:

x_i-\max(x)=[1.5-2.0,2.0-2.0,0.5-2.0]=[-0.5,0,-1.5]

計算e^{x_i-\max(x)}

求和并取對數:

計算 log_softmax

對于第二個樣本 [0.5, 1.0, 1.5]

減去最大值:

x_i-\max(x)=[0.5-1.5,1.0-1.5,1.5-1.5]=[-1.0,-0.5,0]

計算e^{x_i-\max(x)}

求和并取對數:

計算 log_softmax

(2) 計算每個樣本的損失

接下來,根據真實標簽 z_t 計算每個樣本的交叉熵損失。交叉熵損失的公式是:

對于第一個樣本:

  • 真實類別是 1,對應的 softmax 值是 -0.6041。

對于第二個樣本:

  • 真實類別是 2,對應的 softmax 值是 -0.6803。

(3) 計算平均損失

最后,計算所有樣本的平均損失:

3. BCELoss

二分類交叉熵損失函數,使用在輸出層使用sigmoid激活函數進行二分類時。

由交叉熵公式:

\text{CELoss}(y, \hat{y}) = - \sum_{i=1}^{C} y_i \log(\hat{y}_i)

對于二分類問題,真實標簽 y的值為(0 或 1),假設模型預測為正類的概率為 \hat{y},則:

所以:

示例:

import torch
import torch.nn as nn
?
# y 是模型的輸出,已經被sigmoid處理過,確保其值域在(0,1)
y = torch.tensor([[0.7], [0.2], [0.9], [0.7]])
# targets 是真實的標簽,0或1
t = torch.tensor([[1], [0], [1], [0]], dtype=torch.float)
?
# 計算損失方式一:
bceLoss = nn.BCELoss()
loss1 = bceLoss(y, t)
?
#計算損失方式二: 兩種方式結果相同
loss2 = nn.functional.binary_cross_entropy(y, t)
?
print(loss1, loss2)

逐樣本計算

樣本y_it_i計算項 t_i * log(y_i) + (1-t_i) * log(1-y_i)
10.711*log(0.7) + 0*log(0.3) ≈ -0.3567
20.200*log(0.2) + 1*log(0.8) ≈ -0.2231
30.911*log(0.9) + 0*log(0.1) ≈ -0.1054
40.700*log(0.7) + 1*log(0.3) ≈ -1.2040

計算最終損失

4. 總結

  • 當輸出層使用softmax多分類時,使用交叉熵損失函數;

  • 當輸出層使用sigmoid二分類時,使用二分類交叉熵損失函數, 比如在邏輯回歸中使用;

  • 當功能為線性回歸時,使用均方差損失-L2 loss;

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/77988.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/77988.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/77988.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

第六篇:linux之解壓縮、軟件管理

第六篇:linux之解壓縮、軟件管理 文章目錄 第六篇:linux之解壓縮、軟件管理一、解壓和壓縮1、window壓縮包與linux壓縮包能否互通?2、linux下壓縮包的類型3、打包與壓縮 二、軟件管理1、rpm1、什么是rpm?2、rpm包名組成部分3、如何…

Redis 鍵管理

Redis 鍵管理 以下從鍵重命名、隨機返回鍵、鍵過期機制和鍵遷移四個維度展開詳細說明,結合 Redis 核心命令與底層邏輯進行深入分析: 一、鍵重命名 1. ?RENAME?? 與 ?RENAMENX?? **RENAME key newkey?**: 功能:強制重命名…

OpenCV 模板匹配方法詳解

文章目錄 1. 什么是模板匹配?2. 模板匹配的原理2.1數學表達 3. OpenCV 實現模板匹配3.1基本步驟 4. 模板匹配的局限性5. 總結 1. 什么是模板匹配? 模板匹配(Template Matching)是計算機視覺中的一種基礎技術,用于在目…

TextCNN 模型文本分類實戰:深度學習在自然語言處理中的應用

在自然語言處理(NLP)領域,文本分類是研究最多且應用最廣泛的任務之一。從情感分析到主題識別,文本分類技術在眾多場景中都發揮著重要作用。最近,我參與了一次基于 TextCNN 模型的文本分類實驗,從數據準備到…

Qt-創建模塊化.pri文件

文章目錄 一、.pri文件的作用與基本結構作用基本結構 二、創建.pri文件如何添加模塊代碼? 一、.pri文件的作用與基本結構 作用 在Qt開發中,.pri文件(Project Include File)是一種配置包含文件,用于模塊化管理和復用項…

SpringCloud組件——Eureka

一.背景 1.問題提出 我們在一個父項目下寫了兩個子項目,需要兩個子項目之間相互調用。我們可以發送HTTP請求來獲取我們想要的資源,具體實現的方法有很多,可以用HttpURLConnection、HttpClient、Okhttp、 RestTemplate等。 舉個例子&#x…

EAL4+與等保2.0:解讀中國網絡安全雙標準

EAL4與等保2.0:解讀中國網絡安全雙標準 在當今數字化時代,網絡安全已成為各個行業不可忽視的重要議題。特別是在金融、政府、醫療等領域,保護信息的安全性和隱私性顯得尤為關鍵。在中國,EAL4和等級保護2.0(簡稱“等保…

FFmpeg+Nginx+VLC打造M3U8直播

一、視頻直播的技術原理和架構方案 直播模型一般包括三個模塊:主播方、服務器端和播放端 主播放創造視頻,加美顏、水印、特效、采集后推送給直播服務器 播放端: 直播服務器端:收集主播端的視頻推流,將其放大后推送給…

【Redis】緩存三劍客問題實踐(上)

本篇對緩存三劍客問題進行介紹和解決方案說明,下篇將進行實踐,有需要的同學可以跳轉下篇查看實踐篇:(待發布) 緩存三劍客是什么? 緩存三劍客指的是在分布式系統下使用緩存技術最常見的三類典型問題。它們分…

Flink 2.0 編譯

文章目錄 Flink 2.0 編譯第一個問題 java 版本太低maven 版本太低maven 版本太高開始編譯擴展多版本jdk 配置 Flink 2.0 編譯 看到Flink2.0 出來了,想去玩玩,看看怎么樣,當然第一件事,就是編譯代碼,但是沒想到這么多問…

獲取印度股票市場列表、查詢IPO信息以及通過WebSocket實時接收數據

為了對接印度股票市場,獲取市場列表、查詢IPO信息、查看漲跌排行榜以及通過WebSocket實時接收數據等步驟。 1. 獲取市場列表 首先,您需要獲取支持的市場列表,這有助于了解哪些市場可以交易或監控。 請求方法:GETURL&#xff1a…

云原生--CNCF-1-云原生計算基金會介紹(云原生生態的發展目標和未來)

1、CNCF定義與背景 云原生計算基金會(Cloud Native Computing Foundation,CNCF)是由Linux基金會于2015年12月發起成立的非營利組織,旨在推動云原生技術的標準化、開源生態建設和行業協作。其核心目標是通過開源項目和社區協作&am…

【Rust 精進之路之第5篇-數據基石·下】復合類型:元組 (Tuple) 與數組 (Array) 的定長世界

系列: Rust 精進之路:構建可靠、高效軟件的底層邏輯 作者: 碼覺客 發布日期: 2025-04-20 引言:從原子到分子——組合的力量 在上一篇【數據基石上】中,我們仔細研究了 Rust 的四種基本標量類型&#xff1…

MongoDB 集合名稱映射問題

項目場景 在使用 Spring Data MongoDB 進行開發時,定義了一個名為 CompetitionSignUpLog 的實體類,并創建了對應的 Repository 接口。需要明確該實體類在 MongoDB 中實際對應的集合名稱是 CompetitionSignUpLog 還是 competitionSignUpLog。 問題描述 …

物聯網 (IoT) 安全簡介

什么是物聯網安全? 物聯網安全是網絡安全的一個分支領域,專注于保護、監控和修復與物聯網(IoT)相關的威脅。物聯網是指由配備傳感器、軟件或其他技術的互聯設備組成的網絡,這些設備能夠通過互聯網收集、存儲和共享數據…

PCB原理圖解析(炸雞派為例)

晶振 這是外部晶振的原理圖。 32.768kHz 的晶振,常用于實時時鐘(RTC)電路,因為它的頻率恰好是一天的分數(32768 秒),便于實現秒計數。 C25 和 C24:兩個 12pF 的電容,用于…

Jupyter Notebook 中切換/使用 conda 虛擬環境的方式(解決jupyter notebook 環境默認在base下面的問題)

使用 nb_conda_kernels 添加所有環境 一鍵添加所有 conda 環境 conda activate my-conda-env # this is the environment for your project and code conda install ipykernel conda deactivateconda activate base # could be also some other environment conda in…

【JAVA】十三、基礎知識“接口”精細講解!(二)(新手友好版~)

哈嘍大家好呀qvq,這里是乎里陳,接口這一知識點博主分為三篇博客為大家進行講解,今天為大家講解第二篇java中實現多個接口,接口間的繼承,抽象類和接口的區別知識點,更適合新手寶寶們閱讀~更多內容持續更新中…

基于MuJoCo物理引擎的機器人學習仿真框架robosuite

Robosuite 基于 MuJoCo 物理引擎,能支持多種機器人模型,提供豐富多樣的任務場景,像基礎的抓取、推物,精細的開門、擰瓶蓋等操作。它可靈活配置多種傳感器,提供本體、視覺、力 / 觸覺等感知數據。因其對強化學習友好&am…

企業微信自建應用開發回調事件實現方案

目錄 1. 前言 2. 正文 2.1 技術方案 2.2 策略上下文 2.2 添加客戶策略實現類 2.3 修改客戶信息策略實現類 2.4 默認策略實現類 2.5 接收事件的實體類(可以根據事件格式的參數做修改) 2.6 實際接收回調結果的接口 近日在開發企業微信的自建應用時…