PyTorch 損失函數詳解:從理論到實踐

目錄

一、損失函數的基本概念

二、常用損失函數及實現

1. 均方誤差損失(MSELoss)

2. 平均絕對誤差損失(L1Loss/MAELoss)

3. 交叉熵損失(CrossEntropyLoss)

4. 二元交叉熵損失(BCELoss)

三、損失函數選擇指南

四、損失函數在訓練中的應用

五、總結


損失函數是深度學習模型訓練的核心組件,它量化了模型預測值與真實值之間的差異,指導模型參數的更新方向。本文將結合 PyTorch 代碼實例,詳細講解常用損失函數的原理、適用場景及實現方法。

一、損失函數的基本概念

損失函數(Loss Function)又稱代價函數(Cost Function),是衡量模型預測結果與真實標簽之間差異的指標。在模型訓練過程中,通過優化算法(如梯度下降)最小化損失函數,使模型逐漸逼近最優解。

損失函數的選擇取決于具體任務類型:

  • 回歸任務:預測連續值(如房價、溫度)
  • 分類任務:預測離散類別(如圖片分類、垃圾郵件識別)
  • 其他任務:如生成任務、序列標注等

二、常用損失函數及實現

1. 均方誤差損失(MSELoss)

均方誤差損失是回歸任務中最常用的損失函數,計算預測值與真實值之間平方差的平均值。

數學公式MSE = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 其中,y_i為真實值,\hat{y}_i為預測值,n為樣本數量。

代碼實現

import torch
import torch.nn as nn# 初始化MSE損失函數
mse_loss = nn.MSELoss()# 示例數據
y_true = torch.tensor([3.0, 5.0, 2.5])  # 真實值
y_pred = torch.tensor([2.5, 5.0, 3.0])  # 預測值# 計算損失
loss = mse_loss(y_pred, y_true)
print(f'MSE Loss: {loss.item()}')  # 輸出:MSE Loss: 0.0833333358168602

特點

  • 對異常值敏感,因為會對誤差進行平方
  • 是凸函數,存在唯一全局最小值
  • 適用于大多數回歸任務

2. 平均絕對誤差損失(L1Loss/MAELoss)

平均絕對誤差計算預測值與真實值之間絕對差的平均值,對異常值的敏感性低于 MSE。

數學公式MAE = \frac{1}{n} \sum_{i=1}^{n} |y_i - \hat{y}_i|

代碼實現

# 初始化L1損失函數
l1_loss = nn.L1Loss()# 計算損失
loss = l1_loss(y_pred, y_true)
print(f'L1 Loss: {loss.item()}')  # 輸出:L1 Loss: 0.25

特點

  • 對異常值更穩健
  • 梯度在零點處不連續,可能影響收斂速度
  • 適用于存在異常值的回歸場景

3. 交叉熵損失(CrossEntropyLoss)

交叉熵損失是多分類任務的標準損失函數,在 PyTorch 中內置了 Softmax 操作,直接作用于模型輸出的 logits。

數學公式CrossEntropyLoss = -\sum_{i=1}^{C} y_i \log(\hat{y}_i) 其中,C為類別數,y_i為真實標簽的 one-hot 編碼,\hat{y}_i為經過 Softmax 處理的預測概率。

代碼實現

def test_cross_entropy():# 模型輸出的logits(未經過softmax)logits = torch.tensor([[1.5, 2.0, 0.5], [0.5, 1.0, 1.5]])# 真實標簽(類別索引)labels = torch.tensor([1, 2])  # 第一個樣本屬于類別1,第二個樣本屬于類別2# 初始化交叉熵損失函數criterion = nn.CrossEntropyLoss()loss = criterion(logits, labels)print(f'Cross Entropy Loss: {loss.item()}')  # 輸出:Cross Entropy Loss: 0.6422222256660461test_cross_entropy()

計算過程解析

  1. 對 logits 應用 Softmax 得到概率分布
  2. 計算真實類別對應的負對數概率
  3. 取平均值作為最終損失

特點

  • 自動包含 Softmax 操作,無需手動添加
  • 適用于多分類任務(類別互斥)
  • 標簽格式為類別索引(非 one-hot 編碼)

4. 二元交叉熵損失(BCELoss)

二元交叉熵損失用于二分類任務,需要配合 Sigmoid 激活函數使用,確保輸入值在 (0,1) 范圍內。

數學公式BCELoss = -\frac{1}{n} \sum_{i=1}^{n} [y_i \log(\hat{y}_i) + (1-y_i) \log(1-\hat{y}_i)]

代碼實現

def test_bce_loss():# 模型輸出(已通過sigmoid處理)y_pred = torch.tensor([[0.7], [0.2], [0.9], [0.7]])# 真實標簽(0或1)y_true = torch.tensor([[1], [0], [1], [0]], dtype=torch.float)# 方法1:使用BCELossbce_loss = nn.BCELoss()loss1 = bce_loss(y_pred, y_true)# 方法2:使用functional接口loss2 = nn.functional.binary_cross_entropy(y_pred, y_true)print(f'BCELoss: {loss1.item()}')  # 輸出:BCELoss: 0.47234177589416504print(f'Functional BCELoss: {loss2.item()}')  # 輸出:Functional BCELoss: 0.47234177589416504test_bce_loss()

變種:BCEWithLogitsLoss
對于未經過 Sigmoid 處理的 logits,推薦使用BCEWithLogitsLoss,它內部會自動應用 Sigmoid,數值穩定性更好:

# 對于logits輸入(未經過sigmoid)
logits = torch.tensor([[0.8], [-0.5], [1.2], [0.6]])
bce_with_logits_loss = nn.BCEWithLogitsLoss()
loss = bce_with_logits_loss(logits, y_true)

三、損失函數選擇指南

任務類型推薦損失函數特點
回歸任務MSELoss對異常值敏感,適用于大多數回歸場景
回歸任務(含異常值)L1Loss對異常值穩健,梯度不連續
多分類任務CrossEntropyLoss內置 Softmax,處理互斥類別
二分類任務BCELoss/BCEWithLogitsLoss配合 Sigmoid 使用,輸出概率值
多標簽分類BCEWithLogitsLoss每個類別獨立判斷,可同時屬于多個類別

四、損失函數在訓練中的應用

以圖像分類任務為例,展示損失函數在完整訓練流程中的使用:

import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 數據預處理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加載MNIST數據集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 定義簡單的全連接網絡
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(28*28, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28*28)x = torch.relu(self.fc1(x))x = self.fc2(x)  # 輸出logits,不使用softmaxreturn x# 初始化模型、損失函數和優化器
model = SimpleNet()
criterion = nn.CrossEntropyLoss()  # 多分類任務
optimizer = optim.Adam(model.parameters(), lr=0.001)# 訓練循環
def train(epochs=5):model.train()for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:# 前向傳播outputs = model(images)loss = criterion(outputs, labels)# 反向傳播和優化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()# 打印每輪的平均損失avg_loss = running_loss / len(train_loader)print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}')train()

五、總結

損失函數的選擇直接影響模型的訓練效果和收斂速度,關鍵要點:

  1. 回歸任務優先選擇 MSELoss,存在異常值時考慮 L1Loss
  2. 多分類任務使用 CrossEntropyLoss,無需手動添加 Softmax
  3. 二分類任務推薦使用 BCEWithLogitsLoss,數值穩定性更好
  4. 訓練過程中需監控損失變化,判斷模型是否收斂或過擬合

合理選擇損失函數并配合適當的優化器,才能充分發揮模型的學習能力。在實際應用中,可根據具體任務特點和數據分布嘗試不同的損失函數,選擇表現最佳的方案。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/89559.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/89559.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/89559.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

MinIO深度解析:從核心特性到Spring Boot實戰集成

在當今數據爆炸的時代,海量非結構化數據的存儲與管理成為企業級應用的關鍵挑戰。傳統文件系統在TB級數據面前捉襟見肘,而昂貴的云存儲服務又讓中小企業望而卻步。MinIO作為一款開源高性能對象存儲解決方案,正以其獨特的技術優勢成為開發者的首…

騰訊云服務上下載docker以及使用Rabbitmq的流程

執行以下命令,添加 Docker 軟件源并配置為騰訊云源。sudo yum-config-manager --add-repohttps://mirrors.cloud.tencent.com/docker-ce/linux/centos/docker-ce.repo sudo sed -i "s/download.docker.com/mirrors.tencentyun.com\/docker-ce/g" /etc/yu…

UE5 一些關于過場動畫sequencer,軌道track的一些Python操作

刪除多余的軌道 import unreal def execute():movie_scene_actors []sequence_assets []data 0.0# 獲取編輯器實用工具庫lib unreal.EditorUtilityLibrary()selected_assets lib.get_selected_assets()for asset in selected_assets:if asset.get_class() unreal.LevelS…

前端性能優化“核武器”:新一代圖片格式(AVIF/WebP)與自動化優化流程實戰

前端性能優化“核武器”:新一代圖片格式(AVIF/WebP)與自動化優化流程實戰 當你的頁面加載時間超過3秒時,用戶的跳出率會飆升到40%以上。而在所有的前端性能優化手段中,圖片優化無疑是投入產出比最高的一環。一張未經優化的巨大圖片&#xff0…

單元測試學習+AI輔助單測

標題單元測試衡量指標具體測試1、Resource2、MockBean3、Test4、Test模板5、單測示例H2數據庫JSON1、使用方式AI輔助單測使用方法單元測試 單元測試一般指程序員在寫好代碼后,提交測試前,需要驗證自己的代碼是否可以正常工作,同時將自己的代…

Spring Cloud Gateway與Envoy Sidecar在微服務請求路由中的架構設計分享

Spring Cloud Gateway與Envoy Sidecar在微服務請求路由中的架構設計分享 在現代微服務架構中,請求路由層承擔著流量分發、安全鑒權、流量控制等多重職責。傳統的單一網關方案往往面臨可擴展性和可維護性挑戰。本文將從真實生產環境出發,分享如何結合Spri…

GitHub Pages+Jekyll 靜態網站搭建(二)

GitHub PagesJekyll 靜態網站搭建(二)GitHub PagesJekyll 靜態網站搭建(二內容簡介搭建模板網站部署工作流程GitHub PagesJekyll 靜態網站搭建(二 內容簡介 🚩 Tech Contents 該文主要涉及Jekyll主題的下載與使用。Gi…

Django 實戰:I18N 國際化與本地化配置、翻譯與切換一步到位

文章目錄一、國際化與本地化介紹定義相關概念二、安裝配置安裝 gettext配置 settings.py三、使用國際化視圖中使用序列化器和模型中使用四、本地化操作創建或更新消息文件消息文件說明編譯消息文件五、項目實戰一、國際化與本地化介紹 定義 國際化和本地化的目標,…

通過國內扣子(Coze)搭建智能體并接入discord機器人

國內的扣子是無法直接授權給discord的,但是用國外的coze的話,大模型調用太貴,如果想要接入國外的平臺,那就需要通過調用API來實現。 1.搭建智能體(以工作流模式為例) 首先,我們需要在扣子平臺…

【辦公類-107-02】20250719視頻MP4轉gif(削減MB)

背景需求 最近在寫第五屆智慧項目結題(一共3篇)寫的昏天黑地,日以繼夜。 我自己《基于“AI技術”的幼兒園教學資源開發和運用》提到了AI繪畫、AI視頻和AI編程。 為了更好的展示AI編程的狀態,我在WORD里面插入了MP4轉gif的動圖。 【教學類-75-04】20241023世界名畫-《蒙…

一文講清楚React的render優化,包括shouldComponentUpdate、PureComponent和memo

文章目錄一文講清楚React的render優化,包括shouldComponentUpdate、PureComponent和memo1. React的渲染render機制2. shouldComponentUpdate2.1 先上單組件渲染,驗證state變化2.2 上父子組件,驗證props2. PureComponent2.1 單組件驗證state2.…

物聯網iot、mqtt協議與華為云平臺的綜合實踐(萬字0基礎保姆級教程)

本學期的物聯網技術與應用課程,其結課設計內容包含:mqtt、華為云、PyQT5和MySQL等結合使用,完成了從華為云配置產品信息以及轉發規則,到mqtt命令轉發,再到python編寫邏輯代碼實現相關功能,最后用PyQT5實現面…

使用IntelliJ IDEA和Maven搭建SpringBoot集成Fastjson項目

使用IntelliJ IDEA和Maven搭建SpringBoot集成Fastjson項目 下面我將詳細介紹如何在IntelliJ IDEA中使用Maven搭建一個集成Fastjson的SpringBoot項目,包含完整的環境配置和代碼實現。 一、環境準備 軟件要求 IntelliJ IDEA 2021.x或更高版本JDK 1.8或更高版本&#x…

Java從入門到精通!第九天, 重點!(集合(一))

十一、集合1. 為什么要使用集合(1) 數組存在的弊端1) 數組在初始化之后,長度就不能改變,不方便擴展。2) 數組中提供的屬性和方法比較少,不便于進行添加、刪除、修改等操作,并且效率不高,同時無法直接存儲元素的個數。3…

為什么使用時序數據庫

為什么使用時序數據庫? 時序數據庫(Time-Series Database, TSDB)是專為時間序列數據優化的數據庫,相比傳統關系型數據庫(如MySQL)或NoSQL數據庫(如MongoDB),它在以下方面…

計算機網絡:(十一)多協議標記交換 MPLS

計算機網絡:(十一)多協議標記交換 MPLS前言一、傳統網絡的問題二、MPLS:給數據包貼個“標簽”三、MPLS的工作流程1. 入站2. 中間3. 出站四、MPLS的能力前言 前面我們講解了計算機網絡中網絡層的相關知識,包括網絡層轉發…

docker run elasticsearch 報錯

谷粒商城 p103 前提條件: 下載鏡像文件 #存儲和檢索數據 docker pull elasticsearch:7.4.2 #可視化檢索數據 docker pull kibana:7.4.2 創建掛載的文件和配置 mkdir -p /mydata/elasticsearch/config mkdir -p /mydata/elasticsearch/data echo "http.h…

巧用Callbre RVE生成DRC HTML report及CTO的使用方法

對于后端版圖人員,在芯片TO前的LV signoff階段,猶如一段漫長而有期待的朝圣之旅,需要耐心,毅力和信心,在龐雜的DRC中找到一條收斂之路。為了讓此路更為清晰收斂,Calibre提供了一套可追溯對比的富文本方式-H…

產品需求文檔(PRD)格式全解析:從 RP 到 Word 的選擇與實踐

產品需求文檔(PRD)的形式多種多樣,但核心目標始終一致:清晰傳遞產品需求,讓團隊高效協作。不同公司對 PRD 的格式要求可能不同,有的偏愛直接在原型工具中撰寫,有的則習慣用 Word 整理歸檔。本文…

【C++】入門階段

一、初始化C中的初始化指為變量賦予初始值的過程。初始化方式多樣,適用于不同場景。char cha0; char chb{0}; char chc(\0); char chdcha; char che{};注意事項優先使用列表初始化({}),避免窄化轉換風險。在c11中{ }在變量&#x…