【深度學習所有損失函數】在 NumPy、TensorFlow 和 PyTorch 中實現(1/2)

一、說明

在本文中,討論了深度學習中使用的所有常見損失函數,并在NumPy,PyTorch和TensorFlow中實現了它們。

二、內容提要?

我們本文所談的代價函數如下所列:

  1. 均方誤差 (MSE) 損失
  2. 二進制交叉熵損失
  3. 加權二進制交叉熵損失
  4. 分類交叉熵損失
  5. 稀疏分類交叉熵損失
  6. 骰子損失
  7. 吉隆坡背離損失
  8. 平均絕對誤差 (MAE) / L1 損耗
  9. 胡貝爾損失

????????在下文,我們將逐一演示其不同實現辦法。?

三、均方誤差 (MSE) 損失

????????均方誤差 (MSE) 損失是回歸問題中常用的損失函數,其目標是預測連續變量。損失計算為預測值和真實值之間的平方差的平均值。MSE 損失的公式為:

MSE loss = (1/n) * sum((y_pred — y_true)2)

????????這里:

  • n 是數據集中的樣本數
  • 目標變量的預測值y_pred
  • y_true是目標變量的真實值

????????MSE損失對異常值很敏感,并且會嚴重懲罰大誤差,這在某些情況下可能是不可取的。在這種情況下,可以使用其他損失函數,如平均絕對誤差(MAE)或Huber損失。

????????在 NumPy 中的實現

import numpy as npdef mse_loss(y_pred, y_true):"""Calculates the mean squared error (MSE) loss between predicted and true values.Args:- y_pred: predicted values- y_true: true valuesReturns:- mse_loss: mean squared error loss"""n = len(y_true)mse_loss = np.sum((y_pred - y_true) ** 2) / nreturn mse_loss

????????在此實現中,和 是分別包含預測值和真值的 NumPy 數組。該函數首先計算 和 之間的平方差,然后取這些值的平均值來獲得 MSE 損失。該變量表示數據集中的樣本數,用于規范化損失。y_predy_truey_predy_truen

TensorFlow 中的實現

import tensorflow as tfdef mse_loss(y_pred, y_true):"""Calculates the mean squared error (MSE) loss between predicted and true values.Args:- y_pred: predicted values- y_true: true valuesReturns:- mse_loss: mean squared error loss"""mse = tf.keras.losses.MeanSquaredError()mse_loss = mse(y_true, y_pred)return mse_loss

在此實現中,和是分別包含預測值和真值的 TensorFlow 張量。該函數計算 和 之間的 MSE 損耗。該變量包含計算出的損失。y_predy_truetf.keras.losses.MeanSquaredError()y_predy_truemse_loss

在 PyTorch 中的實現

import torchdef mse_loss(y_pred, y_true):"""Calculates the mean squared error (MSE) loss between predicted and true values.Args:- y_pred: predicted values- y_true: true valuesReturns:- mse_loss: mean squared error loss"""mse = torch.nn.MSELoss()mse_loss = mse(y_pred, y_true)return mse_loss

在此實現中,和 是分別包含預測值和真值的 PyTorch 張量。該函數計算 和 之間的 MSE 損耗。該變量包含計算出的損失。y_predy_truetorch.nn.MSELoss()y_predy_truemse_loss

四、二進制交叉熵損失

????????二進制交叉熵損失,也稱為對數損失,是二元分類問題中使用的常見損失函數。它測量預測概率分布與實際二進制標簽分布之間的差異。

????????二進制交叉熵損失的公式如下:

????????L(y, ?) = -[y * log(?) + (1 — y) * log(1 — ?)]

????????其中 y 是真正的二進制標簽(0 或 1),? 是預測概率(范圍從 0 到 1),log 是自然對數。

????????等式的第一項計算真實標簽為 1 時的損失,第二項計算真實標簽為 0 時的損失。總損失是兩個項的總和。

????????當預測概率接近真實標簽時,損失較低,當預測概率遠離真實標簽時,損失較高。此損失函數通常用于在輸出層中使用 sigmoid 激活函數來預測二進制標簽的神經網絡模型。

4.1 在 NumPy 中的實現

????????在numpy中,二進制交叉熵損失可以使用我們前面描述的公式來實現。下面是如何計算它的示例:

# define true labels and predicted probabilities
y_true = np.array([0, 1, 1, 0])
y_pred = np.array([0.1, 0.9, 0.8, 0.3])# calculate the binary cross-entropy loss
loss = -(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred)).mean()# print the loss
print(loss)

4.2 TensorFlow 中的實現

????????在 TensorFlow 中,二進制交叉熵損失可以使用 tf.keras.loss.BinaryCrossentropy() 函數實現。下面是如何使用它的示例:

import tensorflow as tf# define true labels and predicted probabilities
y_true = tf.constant([0, 1, 1, 0])
y_pred = tf.constant([0.1, 0.9, 0.8, 0.3])# define the loss function
bce_loss = tf.keras.losses.BinaryCrossentropy()# calculate the loss
loss = bce_loss(y_true, y_pred)# print the loss
print(loss)

4.3 在 PyTorch 中的實現

????????在 PyTorch 中,二進制交叉熵損失可以使用該函數實現。下面是如何使用它的示例:torch.nn.BCELoss()

import torch# define true labels and predicted probabilities
y_true = torch.tensor([0, 1, 1, 0], dtype=torch.float32)
y_pred = torch.tensor([0.1, 0.9, 0.8, 0.3], dtype=torch.float32)# define the loss function
bce_loss = torch.nn.BCELoss()# calculate the loss
loss = bce_loss(y_pred, y_true)# print the loss
print(loss)

4.4 加權二進制交叉熵損失

????????加權二元交叉熵損失是二元交叉熵損失的一種變體,允許為正熵和負示例分配不同的權重。這在處理不平衡的數據集時非常有用,其中一類與另一類相比明顯不足。

????????加權二元交叉熵損失的公式如下:

L(y, ?) = -[w_pos * y * log(?) + w_neg * (1 — y) * log(1 — ?)]

????????其中 y 是真正的二進制標簽(0 或 1),? 是預測概率(范圍從 0 到 1),log 是自然對數,w_pos 和 w_neg 分別是正權重和負權重。

????????等式的第一項計算真實標簽為 1 時的損失,第二項計算真實標簽為 0 時的損失。總損失是兩個項的總和,每個項按相應的權重加權。

????????可以根據每個類的相對重要性選擇正權重和負權重。例如,如果正類更重要,則可以為其分配更高的權重。同樣,如果負類更重要,則可以為其分配更高的權重。

????????當預測概率接近真實標簽時,損失較低,當預測概率遠離真實標簽時,損失較高。此損失函數通常用于在輸出層中使用 sigmoid 激活函數來預測二進制標簽的神經網絡模型。

五、分類交叉熵損失

????????分類交叉熵損失是多類分類問題中使用的一種常用損失函數。它衡量每個類的真實標簽和預測概率之間的差異。

????????分類交叉熵損失的公式為:

L = -1/N * sum(sum(Y * log(Y_hat)))

????????其中 是單熱編碼格式的真實標簽矩陣,是每個類的預測概率矩陣,是樣本數,表示自然對數。YY_hatNlog

????????在此公式中,形狀為 ,其中是樣本數,是類數。每行 表示單個樣本的真實標簽分布,列中的值 1 對應于真實標簽,0 對應于所有其他列。Y(N, C)NCY

????????類似地,具有 的形狀,其中每行表示單個樣本的預測概率分布,每個類都有一個概率值。Y_hat(N, C)

????????該函數逐個應用于預測的概率矩陣。該函數使用兩次來求和矩陣的兩個維度。logY_hatsumY

????????結果值表示數據集中所有樣本的平均交叉熵損失。訓練神經網絡的目標是最小化這種損失函數。LN

????????損失函數對模型的懲罰更大,因為在預測低概率的類時犯了大錯誤。目標是最小化損失函數,這意味著使預測概率盡可能接近真實標簽。

5.1 在 NumPy 中的實現

????????在numpy中,分類交叉熵損失可以使用我們前面描述的公式來實現。下面是如何計算它的示例:

import numpy as np# define true labels and predicted probabilities as NumPy arrays
y_true = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]])
y_pred = np.array([[0.8, 0.1, 0.1], [0.2, 0.3, 0.5], [0.1, 0.6, 0.3]])# calculate the loss
loss = -1/len(y_true) * np.sum(np.sum(y_true * np.log(y_pred)))# print the loss
print(loss)In this example, y_true represents the true labels (in integer format), and y_pred represents the predicted probabilities for each class (in a 2D array). The eye() function is used to convert the true labels to one-hot encoding, which is required for the loss calculation. The categorical cross-entropy loss is calculated using the formula we provided earlier, and the mean() function is used to average the loss over the entire dataset. Finally, the calculated loss is printed to the console.

????????在此示例中, 以獨熱編碼格式表示真實標簽,并表示每個類的預測概率,兩者都為 NumPy 數組。使用上述公式計算損失,然后使用該函數打印到控制臺。請注意,該函數使用兩次來對矩陣的兩個維度求和。y_truey_predprintnp.sumY

5.2 TensorFlow 中的實現

????????在TensorFlow中,分類交叉熵損失可以使用該類輕松計算。下面是如何使用它的示例:tf.keras.losses.CategoricalCrossentropy

import tensorflow as tf# define true labels and predicted probabilities as TensorFlow Tensors
y_true = tf.constant([[0, 1, 0], [0, 0, 1], [1, 0, 0]])
y_pred = tf.constant([[0.8, 0.1, 0.1], [0.2, 0.3, 0.5], [0.1, 0.6, 0.3]])# create the loss object
cce_loss = tf.keras.losses.CategoricalCrossentropy()# calculate the loss
loss = cce_loss(y_true, y_pred)# print the loss
print(loss.numpy())

????????在此示例中,以獨熱編碼格式表示真實標簽,并表示每個類的預測概率,兩者都作為 TensorFlow 張量。該類用于創建損失函數的實例,然后通過將真實標簽和預測概率作為參數傳遞來計算損失。最后,使用該方法將計算出的損失打印到控制臺。y_truey_predCategoricalCrossentropy.numpy()

請注意,該類在內部處理將真實標簽轉換為獨熱編碼,因此無需顯式執行此操作。如果你的真實標簽已經是獨熱編碼格式,你可以將它們直接傳遞給損失函數,沒有任何問題。CategoricalCrossentropy

5.3 在 PyTorch 中的實現

????????在 PyTorch 中,分類交叉熵損失可以使用該類輕松計算。下面是如何使用它的示例:torch.nn.CrossEntropyLoss

import torch# define true labels and predicted logits as PyTorch Tensors
y_true = torch.LongTensor([1, 2, 0])
y_logits = torch.Tensor([[0.8, 0.1, 0.1], [0.2, 0.3, 0.5], [0.1, 0.6, 0.3]])# create the loss object
ce_loss = torch.nn.CrossEntropyLoss()# calculate the loss
loss = ce_loss(y_logits, y_true)# print the loss
print(loss.item())

????????在此示例中, 以整數格式表示真實標簽,并表示每個類的預測對數,兩者都作為 PyTorch 張量。該類用于創建損失函數的實例,然后通過將預測的對數和 true 標簽作為參數傳遞來計算損失。最后,使用該方法將計算出的損失打印到控制臺。y_truey_logitsCrossEntropyLoss.item()

????????請注意,該類將 softmax 激活函數和分類交叉熵損失組合到一個操作中,因此您無需單獨應用 softmax。另請注意,真正的標簽應采用整數格式,而不是獨熱編碼格式。CrossEntropyLoss

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/38596.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/38596.shtml
英文地址,請注明出處:http://en.pswp.cn/news/38596.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

“深入解析JVM內部機制:探索Java虛擬機的奧秘“

標題:深入解析JVM內部機制:探索Java虛擬機的奧秘 JVM(Java虛擬機)是Java程序的核心執行環境,它負責將Java字節碼轉換為機器碼并執行。了解JVM的內部機制對于理解Java程序的執行過程和性能優化至關重要。本文將深入解析…

開啟想象翅膀:輕松實現文本生成模型的創作應用,支持LLaMA、ChatGLM、UDA、GPT2、Seq2Seq、BART、T5、SongNet等模型,開箱即用

開啟想象翅膀:輕松實現文本生成模型的創作應用,支持LLaMA、ChatGLM、UDA、GPT2、Seq2Seq、BART、T5、SongNet等模型,開箱即用 TextGen: Implementation of Text Generation models 1.介紹 TextGen實現了多種文本生成模型,包括&a…

c++——::作用域、命名空間、using(聲明和編譯指令)

c 作用域和名字控制 一、::(雙冒號) 作用域 <::>運算符是一個作用域如果<::>前面什么都沒有加 代表是全局作用域 二、命名空間&#xff08;namespace) 1、namespace 本質是作用域,可以更好的控制標識符的作用域命名空間 就可以存放 變量 函數 類 結構體 … 2…

【kubernetes】在k8s集群環境上,部署kubesphere

部署kubesphere 學習于尚硅谷kubesphere課程 前置環境配置-部署默認存儲類型 這里使用nfs #所有節點安裝 yum install -y nfs-utils# 在master節點執行以下命令 echo "/nfs/data/ *(insecure,rw,sync,no_root_squash)" > /etc/exports # 執行以下命令&#xff…

QML與C++交互

目錄 1 QML獲取C的變量值 2 QML獲取C創建的自定義對象 3 QML發送信號綁定C端的槽 4 C端發送信號綁定qml端槽 5 C調用QML端函數 1 QML獲取C的變量值 QQmlApplicationEngine engine; 全局對象 上下文屬性 QQmlApplicationEngine engine; QQmlContext *context1 engine.…

flowable流程移植新項目前端問題匯總

flowable流程移植到新項目時&#xff0c;出現一些前端問題&#xff0c;匯總如下&#xff1a; PS F:\khxm\NBCIO_VUE> yarn run serve yarn run v1.21.1 $ vue-cli-service serve INFO Starting development server... ERROR Error: Vue packages version mismatch: -…

25 | 葡萄酒質量數據分析

基于kaggle提供的公開數據集,對全球葡萄酒分布情況和質量情況進行數據探索和分析 from kaggle: https://www.kaggle.com/zynicide/wine-reviews 分析思路: 0、數據準備 1、葡萄酒的種類 2、葡萄酒質量 3、葡萄酒價格 4、葡萄酒描述詞庫 5、品鑒師信息 6、總結 0、數據準備 …

學習Vue:組件的概念和優勢

在現代的前端開發中&#xff0c;組件化開發是一種重要的方法&#xff0c;它可以將復雜的應用程序拆分成多個獨立的、可復用的組件。Vue.js 是一個流行的前端框架&#xff0c;它支持組件化開發&#xff0c;讓開發者能夠更輕松地構建和維護復雜的用戶界面。在本文中&#xff0c;我…

計算機組成部分

計算機的五大部件是什么&#xff1f;答案&#xff1a;計算機的五大部件是運算器&#xff0c;控制器&#xff0c;存儲器&#xff0c;輸入設備和輸出設備。 其中運算器和控制器合稱中央處理器&#xff0c;是計算機的核心部件&#xff1b; 存儲器是用來存儲程序指令和數據用的&am…

修改第三方組件默認樣式

深度選擇器 修改el-input的樣式&#xff1a; <el-input class"input-area"></el-input>查看DOM結構&#xff1a; 原本使用 /deep/ 但是可能不兼容 使用 :deep .input-area {:deep(.el-input__inner){background-color: blue;} }將 input 框背景色改為…

【Kubernetes】Kubernetes的Pod進階

Pod進階 一、資源限制和重啟策略1. 資源限制2. 資源單位2.1 CPU 資源單位2.2 內存 資源單位 3. 重啟策略&#xff08;restartPolicy&#xff09; 二、健康檢查的概念1. 健康檢查1.1 探針的三種規則1.2 Probe 支持三種檢查方法 2. 示例2.1 exec 方式2.2 httpGet 方式2.3 tcpSock…

臨床試驗三原則-對照、重復、隨機

臨床試驗必須遵循三個基本原則&#xff1a;對照、重復、隨機。 一、對照原則和對照的設置 核心觀點&#xff1a;有比較才有鑒別。 對照組和試驗組同質可比。 三臂試驗 安慰劑&#xff1a;試驗組&#xff1a;陽性對照組1&#xff1a;n&#xff1a;m&#xff08;n≥m&#xff…

FFmpeg常見命令行(五):FFmpeg濾鏡使用

前言 在Android音視頻開發中&#xff0c;網上知識點過于零碎&#xff0c;自學起來難度非常大&#xff0c;不過音視頻大牛Jhuster提出了《Android 音視頻從入門到提高 - 任務列表》&#xff0c;結合我自己的工作學習經歷&#xff0c;我準備寫一個音視頻系列blog。本文是音視頻系…

Nginx反向代理服務流式輸出設置

Nginx反向代理服務流式輸出設置 1.問題場景 提問&#xff1a;為什么我部署的服務沒有流式響應 最近在重構原有的GPT項目時&#xff0c;遇到gpt回答速度很慢的現象。在使用流式輸出的接口時&#xff0c;接口響應速度居然還是達到了30s以上。 2.現象分析 分析現象我發現&…

Leetcode鏈表篇 Day3

.24. 兩兩交換鏈表中的節點 - 力扣&#xff08;LeetCode&#xff09; 1.構建虛擬結點 2.兩兩一組&#xff0c;前繼結點一定在兩兩的前面 3.保存結點1和結點3 19. 刪除鏈表的倒數第 N 個結點 - 力扣&#xff08;LeetCode&#xff09; 1.雙指針&#xff1a;快慢指針 兩個指針的差…

新能源汽車需要檢測哪些項目

截至2022年底&#xff0c;中國新能源車保有量達1310萬輛&#xff0c;其中純電動汽車保有量1045萬輛。為把好新能源汽車安全關&#xff0c;我國新能源汽車除了完善的強制性產品認證型式實驗外&#xff0c;還建立了“車企-地方-國家”逐級上報的三級監管體系實行新能源汽車全生命…

2023.8.14論文閱讀

文章目錄 ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation摘要本文方法實驗結果 DeepFusion: Lidar-Camera Deep Fusion for Multi-Modal 3D Object Detection摘要本文方法實驗結果 ESPNet: Efficient Spatial Pyramid of Dilated Convo…

vue 路由地址把#去掉

在路由對象里邊添加history模式就不顯示# mode:history // 4.通過規則創建對象 const router new VueRouter({routes,// 默認模式為hash 帶# // history 不帶#mode:history })想把端口號8000換成其他的 比如我這樣的3000更換端口號教程

Android Framework 動態更新插拔設備節點執行權限

TF卡設備節點是插上之后動態添加&#xff0c;所以不能通過初始化設備節點權限來解決&#xff0c;需要監聽TF插入事件&#xff0c;在init.rc 監聽插入后動態更新設備節點執行權限 添加插拔TF卡監聽 frameworks/base/services/core/java/com/android/server/StorageManagerServic…

IL匯編ldc指令學習

ldc指令是把值送到棧上&#xff0c; 說明如下&#xff0c; ldc.i4 將所提供的int32類型的值作為int32推送到計算堆棧上&#xff1b; ldc.i4.0 將數值0作為int32推送到計算堆棧上&#xff1b; ... ldc.i4.8 將數值8作為int32推送到計算堆棧上&#xff1b; ldc.i4.m1 將數值-…