Pytorch損失函數losses簡介

一般來說,監督學習的目標函數由損失函數和正則化項組成。(Objective = Loss + Regularization)

Pytorch中的損失函數一般在訓練模型時候指定。

注意Pytorch中內置的損失函數的參數和tensorflow不同,是y_pred在前,y_true在后,而Tensorflow是y_true在前,y_pred在后。

對于回歸模型,通常使用的內置損失函數是均方損失函數nn.MSELoss 。

對于二分類模型,通常使用的是二元交叉熵損失函數nn.BCELoss (輸入已經是sigmoid激活函數之后的結果) 或者 nn.BCEWithLogitsLoss (輸入尚未經過nn.Sigmoid激活函數) 。

對于多分類模型,一般推薦使用交叉熵損失函數 nn.CrossEntropyLoss。 (y_true需要是一維的,是類別編碼。y_pred未經過nn.Softmax激活。)

此外,如果多分類的y_pred經過了nn.LogSoftmax激活,可以使用nn.NLLLoss損失函數(The negative log likelihood loss)。 這種方法和直接使用nn.CrossEntropyLoss等價。

如果有需要,也可以自定義損失函數,自定義損失函數需要接收兩個張量y_pred,y_true作為輸入參數,并輸出一個標量作為損失函數值。

Pytorch中的正則化項一般通過自定義的方式和損失函數一起添加作為目標函數。

一,內置損失函數

內置的損失函數一般有類的實現和函數的實現兩種形式。

如:nn.BCE 和 F.binary_cross_entropy 都是二元交叉熵損失函數,前者是類的實現形式,后者是函數的實現形式。

實際上類的實現形式通常是調用函數的實現形式并用nn.Module封裝后得到的。

一般我們常用的是類的實現形式。它們封裝在torch.nn模塊下,并且類名以Loss結尾。

常用的一些內置損失函數說明如下。

nn.MSELoss(均方誤差損失,也叫做L2損失,用于回歸)

nn.L1Loss (L1損失,也叫做絕對值誤差損失,用于回歸)

nn.SmoothL1Loss (平滑L1損失,當輸入在-1到1之間時,平滑為L2損失,用于回歸)

nn.BCELoss (二元交叉熵,用于二分類,輸入已經過nn.Sigmoid激活,對不平衡數據集可以用weigths參數調整類別權重)

nn.BCEWithLogitsLoss (二元交叉熵,用于二分類,輸入未經過nn.Sigmoid激活)

nn.CrossEntropyLoss (交叉熵,用于多分類,要求label為稀疏編碼,輸入未經過nn.Softmax激活,對不平衡數據集可以用weigths參數調整類別權重)

nn.NLLLoss (負對數似然損失,用于多分類,要求label為稀疏編碼,輸入經過nn.LogSoftmax激活)

nn.CosineSimilarity(余弦相似度,可用于多分類)

nn.AdaptiveLogSoftmaxWithLoss (一種適合非常多類別且類別分布很不均衡的損失函數,會自適應地將多個小類別合成一個cluster)

更多損失函數的介紹參考如下知乎文章:

《PyTorch的十八個損失函數》

二,自定義L1和L2正則化項

通常認為L1 正則化可以產生稀疏權值矩陣,即產生一個稀疏模型,可以用于特征選擇。

而L2 正則化可以防止模型過擬合(overfitting)。一定程度上,L1也可以防止過擬合。

# L2正則化
def L2Loss(model,alpha):l2_loss = torch.tensor(0.0, requires_grad=True)for name, param in model.named_parameters():if 'bias' not in name: #一般不對偏置項使用正則l2_loss = l2_loss + (0.5 * alpha * torch.sum(torch.pow(param, 2)))return l2_loss# L1正則化
def L1Loss(model,beta):l1_loss = torch.tensor(0.0, requires_grad=True)for name, param in model.named_parameters():if 'bias' not in name:l1_loss = l1_loss +  beta * torch.sum(torch.abs(param))return l1_loss# 將L2正則和L1正則添加到FocalLoss損失,一起作為目標函數
def focal_loss_with_regularization(y_pred,y_true):focal = FocalLoss()(y_pred,y_true) l2_loss = L2Loss(model,0.001) #注意設置正則化項系數l1_loss = L1Loss(model,0.001)total_loss = focal + l2_loss + l1_lossreturn total_lossmodel.compile(loss_func =focal_loss_with_regularization,optimizer= torch.optim.Adam(model.parameters(),lr = 0.01),metrics_dict={"accuracy":accuracy})

只寫了部分,具體的參考《20天吃透Pytorch》

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/389348.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/389348.shtml
英文地址,請注明出處:http://en.pswp.cn/news/389348.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

讀取Mc1000的 唯一 ID 機器號

先引用Symbol.ResourceCoordination 然后引用命名空間 using System;using System.Security.Cryptography;using System.IO; 以下為類程序 /// <summary> /// 獲取設備id /// </summary> /// <returns></returns> public static string GetDevi…

樣本均值的抽樣分布_抽樣分布樣本均值

樣本均值的抽樣分布One of the most important concepts discussed in the context of inferential data analysis is the idea of sampling distributions. Understanding sampling distributions helps us better comprehend and interpret results from our descriptive as …

玩轉ceph性能測試---對象存儲(一)

筆者最近在工作中需要測試ceph的rgw&#xff0c;于是邊測試邊學習。首先工具采用的intel的一個開源工具cosbench&#xff0c;這也是業界主流的對象存儲測試工具。 1、cosbench的安裝&#xff0c;啟動下載最新的cosbench包wget https://github.com/intel-cloud/cosbench/release…

[BZOJ 4300]絕世好題

Description 題庫鏈接 給定一個長度為 \(n\) 的數列 \(a_i\) &#xff0c;求 \(a_i\) 的子序列 \(b_i\) 的最長長度&#xff0c;滿足 \(b_i\wedge b_{i-1}\neq 0\) &#xff08; \(\wedge\) 表示按位與&#xff09; \(1\leq n\leq 100000\) Solution 令 \(f_i\) 為二進制第 \(i…

因果關系和相關關系 大數據_數據科學中的相關性與因果關系

因果關系和相關關系 大數據Let’s jump into it right away.讓我們馬上進入。 相關性 (Correlation) Correlation means relationship and association to another variable. For example, a movement in one variable associates with the movement in another variable. For…

Pytorch構建模型的3種方法

這個地方一直是我思考的地方&#xff01;因為學的代碼太多了&#xff0c;構建的模型各有不同&#xff0c;這里記錄一下&#xff01; 可以使用以下3種方式構建模型&#xff1a; 1&#xff0c;繼承nn.Module基類構建自定義模型。 2&#xff0c;使用nn.Sequential按層順序構建模…

vue取數據第一個數據_我作為數據科學家的第一個月

vue取數據第一個數據A lot.很多。 I landed my first job as a Data Scientist at the beginning of August, and like any new job, there’s a lot of information to take in at once.我于8月初找到了數據科學家的第一份工作&#xff0c;并且像任何新工作一樣&#xff0c;一…

Flask-SocketIO 簡單使用指南

Flask-SocketIO 使 Flask 應用程序能夠訪問客戶端和服務器之間的低延遲雙向通信。客戶端應用程序可以使用 Javascript&#xff0c;C &#xff0c;Java 和 Swift 中的任何 SocketIO 官方客戶端庫或任何兼容的客戶端來建立與服務器的永久連接。 安裝 直接使用 pip 來安裝&#xf…

STL-開篇

基本概念 STL&#xff1a; Standard Template Library&#xff0c;標準模板庫 定義&#xff1a; c引入的一個標準類庫 特點&#xff1a;1&#xff09;數據結構和算法的 c實現&#xff08; 采用模板類和模板函數&#xff09;2&#xff09;數據的存儲和算法的分離3&#xff09;高…

Symbol Mc1000 聲音的設置以及播放

首先引用Symbol.Audio 加一命名空間using Symbol.Audio; /聲音設備的設置 //Select Device from device list Symbol.Audio.Device MyDevice (Symbol.Audio.Device)Symbol.StandardForms.SelectDevice.Select( Symbol.Audio.Controller.Title, Symbol.Audio.Devic…

/bin/bash^M: 壞的解釋器: 沒有那個文件或目錄

在win下編輯的時候&#xff0c;換行結尾是\n\r &#xff0c; 而在linux下 是\n&#xff0c;所以會多出來一個\r&#xff0c;這樣會出現錯誤 此時執行 sed -i s/\r$// file.sh 將file.sh中的\r都替換為空白&#xff0c;問題解決轉載于:https://www.cnblogs.com/zzdbullet/p/9890…

rcp rapido_為什么氣流非常適合Rapido

rcp rapidoBack in 2019, when we were building our data platform, we started building the data platform with Hadoop 2.8 and Apache Hive, managing our own HDFS. The need for managing workflows whether it’s data pipelines, i.e. ETL’s, machine learning predi…

pandas處理丟失數據與數據導入導出

3.4pandas處理丟失數據 頭文件&#xff1a; import numpy as np import pandas as pd丟棄數據部分&#xff1a; dates pd.date_range(20130101,periods6) df pd.DataFrame(np.random.randn(6,4),indexdates,columns[A,B,C,D]) df.iloc[0,1] np.nan df.iloc[1,2] np.nanp…

Mysql5.7開啟遠程

2019獨角獸企業重金招聘Python工程師標準>>> 1.注掉bind-address #bind-address 127.0.0.1 2.開啟遠程訪問權限 grant all privileges on *.* to root"xxx.xxx.xxx.xxx" identified by "密碼"; 或 grant all privileges on *.* to root"%…

分類結果可視化python_可視化分類結果的另一種方法

分類結果可視化pythonI love good data visualizations. Back in the days when I did my PhD in particle physics, I was stunned by the histograms my colleagues built and how much information was accumulated in one single plot.我喜歡出色的數據可視化。 早在我獲得…

算法組合 優化算法_算法交易簡化了風險價值和投資組合優化

算法組合 優化算法Photo by Markus Spiske (left) and Jamie Street (right) on UnsplashMarkus Spiske (左)和Jamie Street(右)在Unsplash上的照片 In the last post, we saw how actual algorithms are developed and tested. In this post, we will figure out the level of…

Symbol Mc1000 快捷鍵 的 設置 事件 開發

switch (e.KeyCode) { ///數據 case Keys.F1://清除數據 if(File.Exists("Storage Card/CG.sdf")) { Mc.gConn.Close(); Mc.gConn.Dispose(); File.Delete("Storage Card/CG.sdf"); } MessageBox.S…

pandas合并concatmerge和plot畫圖

3.6&#xff0c;3.7pandas合并concat&merge 頭文件&#xff1a; import pandas as pd import numpy as npconcat基礎合并用法 df1 pd.DataFrame(np.ones((3,4))*0,columns [a,b,c,d]) df2 pd.DataFrame(np.ones((3,4))*1,columns [a,b,c,d]) df3 pd.DataFrame(np.ones…

Android跳轉WIFI界面的四種方式

第一種 Intent intent new Intent(); intent.setAction("android.net.wifi.PICK_WIFI_NETWORK"); startActivity(intent); 第二種 startActivity(new Intent(android.provider.Settings.ACTION_WIFI_SETTINGS)); 第三種 Intent i new Intent(); if(android.os.Buil…

PS摳發絲技巧 「選擇并遮住…」

PS摳發絲技巧 「選擇并遮住…」 現在的海報設計&#xff0c;大多數都有模特MM&#xff0c;然而MM的頭發實用太多了&#xff0c;有的還飄起來…… 對于設計師(特別是淘寶美工)沒有一個強大、快速、實用的摳發絲技巧真的混不去哦。而PS CC 2017版本開始&#xff0c;就有了一個強大…