2025-05-27 Python深度學習7——損失函數和反向傳播

文章目錄

  • 1 損失函數
    • 1.1 L1Loss
    • 1.2 MSELoss
    • 1.3 CrossEntropyLoss
  • 2 反向傳播

本文環境:

  • Pycharm 2025.1
  • Python 3.12.9
  • Pytorch 2.6.0+cu124

1 損失函數

? 損失函數 (loss function) 是將隨機事件或其有關隨機變量的取值映射為非負實數以表示該隨機事件的"風險"或"損失"的函數。在機器學習中,損失函數通常作為學習準則與優化問題相聯系,通過最小化損失函數來求解和評估模型。

? 損失函數主要分為兩類:

  • 回歸問題:常用 L1 損失函數 (MAE) 和 L2 損失函數 (MSE)。
  • 分類問題:常用 0-1 損失函數及其代理損失 (如交叉熵損失、鉸鏈損失等)。

1.1 L1Loss

image-20250527202547146

? L1Loss 計算輸入 ( x x x) 和目標 ( y y y) 之間的平均絕對誤差 (MAE)。數學公式如下:
l n = ∣ x n ? y n ∣ l_n=|x_n-y_n| ln?=xn??yn?

參數類型說明
size_average(bool, 可選)已棄用(請使用reduction)。默認情況下,損失會對批中每個損失元素求平均。注意對于某些損失,每個樣本可能有多個元素。如果設為False,則對每個minibatch的損失求和。當reduce為False時被忽略。默認: True
reduce(bool, 可選)已棄用(請使用reduction)。默認情況下,根據size_average對每個minibatch的觀測值求平均或求和。當reduce為False時,返回每個批元素的損失并忽略size_average。默認: True
reduction(str, 可選)指定應用于輸出的縮減方式: ‘none’|‘mean’|‘sum’。
- ‘none’: 不應用縮減,
- ‘mean’: 輸出總和除以元素數量,
- ‘sum’: 輸出求和。
注意: size_average和reduce正在被棄用,目前指定這兩個參數中的任何一個都會覆蓋reduction。默認: ‘mean’

? 依據reduction的不同,輸出結果也不同:
? ( x , y ) = { 1 N ∑ n = 1 N l n , if?reduction = ’mean’ . ∑ n = 1 N l n , if?reduction = ’sum’ . \ell(x,y)= \begin{cases} \displaystyle\frac{1}{N}\sum_{n=1}^N l_n,&\text{if reduction}=\text{'mean'}.\\\\ \displaystyle\sum_{n=1}^N l_n,&\text{ if reduction}=\text{'sum'}. \end{cases} ?(x,y)=? ? ??N1?n=1N?ln?,n=1N?ln?,?if?reduction=’mean’.?if?reduction=’sum’.?
? 其中, N N N 為每個批次的數量。

  • 輸入: (?), 其中 ? 表示任意維數。
  • 目標: (?), 與輸入形狀相同。
  • 輸出: 標量。如果reduction為’none’,則形狀與輸入相同(?)。
import torch
from torch import nninputs = torch.tensor([1., 2, 3])
targets = torch.tensor([1, 2, 5])loss = nn.L1Loss()
result = loss(inputs, targets)  # 計算平均絕對誤差
print(result)  # tensor(0.6667) 計算:(0 + 0 + 2)/3 = 0.6667

特點:

  • 對異常值不敏感,具有較好的魯棒性。
  • 梯度恒定(±1),在接近最優解時可能導致震蕩。
  • 適用于對異常值敏感的場景。

1.2 MSELoss

image-20250527203806901

? MSELoss 計算輸入 ( x x x) 和目標 ( y y y) 之間的均方誤差 (MSE)。
l n = ( x n ? y n ) 2 l_n=(x_n-y_n)^2 ln?=(xn??yn?)2

參數類型說明
size_average(bool, 可選)已棄用(請使用reduction)。默認對批中每個損失元素求平均。設為False則對每個minibatch的損失求和。當reduce為False時被忽略。默認: True
reduce(bool, 可選)已棄用(請使用reduction)。默認根據size_average對觀測值求平均或求和。當reduce為False時,返回每個批元素的損失。默認: True
reduction(str, 可選)指定輸出縮減方式: ‘none’|‘mean’|‘sum’。
- ‘none’: 不縮減,
- ‘mean’: 輸出總和除以元素數量,
- ‘sum’: 輸出求和。
注意: size_average和reduce將被棄用。默認: ‘mean’

? 依據reduction的不同,輸出結果也不同:
? ( x , y ) = { 1 N ∑ n = 1 N l n , if?reduction = ’mean’ . ∑ n = 1 N l n , if?reduction = ’sum’ . \ell(x,y)= \begin{cases} \displaystyle\frac{1}{N}\sum_{n=1}^N l_n,&\text{if reduction}=\text{'mean'}.\\\\ \displaystyle\sum_{n=1}^N l_n,&\text{ if reduction}=\text{'sum'}. \end{cases} ?(x,y)=? ? ??N1?n=1N?ln?,n=1N?ln?,?if?reduction=’mean’.?if?reduction=’sum’.?
? 其中, N N N 為每個批次的數量。

  • 輸入: (?), 其中 ? 表示任意維數
  • 目標: (?), 與輸入形狀相同
  • 輸出: 標量。如果reduction為’none’,則形狀與輸入相同(?)
解釋import torch
from torch import nninputs = torch.tensor([1., 2, 3])
targets = torch.tensor([1, 2, 5])loss_mse = nn.MSELoss()
result_mse = loss_mse(inputs, targets)
print(result_mse)  # tensor(1.3333) 計算:(0 + 0 + 4)/3 = 1.3333

特點:

  • 對較大誤差懲罰更重(平方放大效應)。
  • 對異常值敏感。
  • 梯度隨誤差減小而減小,收斂速度較快。
  • 適用于數據質量較好的場景。

1.3 CrossEntropyLoss

image-20250527214211838 ? CrossEntropyLoss 計算輸入 ($x$) 和目標 ($y$) 之間的交叉熵。該損失函數結合 LogSoftmax 和 NLLLoss (負對數似然損失) 的操作,適用于多類分類任務。

l n = ? ∑ c = 1 C w c log ? exp ? ( x n , c ) ∑ i = 1 C exp ? ( x n , i ) y n , c l_n=-\sum_{c=1}^Cw_c\log\frac{\exp(x_{n,c})}{\sum_{i=1}^C\exp(x_{n,i})}y_{n,c} ln?=?c=1C?wc?logi=1C?exp(xn,i?)exp(xn,c?)?yn,c?
? 其中 w w w 為權重, C C C 為類別數。

參數類型說明
weight(Tensor, 可選)為每個類別分配權重的一維張量,用于處理類別不平衡問題
ignore_index(int, 可選)指定要忽略的目標值,不參與梯度計算
reduction(str, 可選)指定輸出縮減方式: ‘none’|‘mean’|‘sum’。默認: ‘mean’
label_smoothing(float, 可選)標簽平滑系數,范圍[0.0,1.0]。0.0表示無平滑

? 依據reduction的不同,輸出結果也不同:
? ( x , y ) = { 1 N ∑ n = 1 N l n , if?reduction = ’mean’ . ∑ n = 1 N l n , if?reduction = ’sum’ . \ell(x,y)= \begin{cases} \displaystyle\frac{1}{N}\sum_{n=1}^N l_n,&\text{if reduction}=\text{'mean'}.\\\\ \displaystyle\sum_{n=1}^N l_n,&\text{ if reduction}=\text{'sum'}. \end{cases} ?(x,y)=? ? ??N1?n=1N?ln?,n=1N?ln?,?if?reduction=’mean’.?if?reduction=’sum’.?
? 其中, N N N 為每個批次的數量。

輸入形狀

  • 無批處理: ( C ) (C) (C)
  • 批處理: ( N , C ) (N, C) (N,C) ( N , C , d 1 , d 2 , . . . , d K ) , K ≥ 1 (N, C, d?, d?,...,d_K), K≥1 (N,C,d1?,d2?,...,dK?),K1

目標形狀

  • 類別索引: ( ) , ( N ) (), (N) (),(N) ( N , d 1 , d 2 , . . . , d K ) (N, d?, d?,...,d_K) (N,d1?,d2?,...,dK?)
  • 類別概率: 必須與輸入形狀相同。
from torch import nnx = torch.tensor([0.1, 0.2, 0.3])  # 預測值(未歸一化)
y = torch.tensor([1])  # 真實類別索引
x = x.reshape(1, -1)  # 調整為(batch_size, num_classes)loss_cross = nn.CrossEntropyLoss()
result_cross = loss_cross(x, y)
print(result_cross)  # tensor(1.1019)

計算過程:

  1. 對 x 應用 softmax 得到概率分布:[0.3006,0.3322,0.3672]。
  2. 取真實類別 (1) 的概率:0.3322。
  3. 計算負對數: ? l o g ( 0.3322 ) ≈ 1.1019 -log(0.3322)\approx1.1019 ?log(0.3322)1.1019

特點:

  • 結合了 Softmax 和負對數似然。
  • 梯度計算高效,適合多分類問題。
  • 對預測概率與真實標簽的差異敏感。

2 反向傳播

? 反向傳播(Backpropagation)是神經網絡訓練的核心算法,通過鏈式法則計算損失函數對網絡參數的梯度。關鍵步驟:

  1. 前向傳播:計算網絡輸出和損失值。
  2. 反向傳播
    • 計算損失函數對輸出的梯度。
    • 逐層傳播梯度到各參數。
    • 應用鏈式法則計算參數梯度。
  3. 參數更新:使用優化器根據梯度更新參數。

? 以 CIFAR10 網絡為例:

from collections import OrderedDictimport torch
import torchvision
from torch import nnfrom torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms# 加載CIFAR10數據集
dataset = torchvision.datasets.CIFAR10(root='./dataset',  # 數據集存放路徑train=False,  # 是否為訓練集download=True,  # 是否下載數據集transform=transforms.ToTensor()  # 數據預處理
)# 加載數據集
dataloader = DataLoader(dataset, batch_size=1)class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.model1 = nn.Sequential(OrderedDict([('conv1', nn.Conv2d(3, 32, 5, padding=2)),('maxpool1', nn.MaxPool2d(2)),('conv2', nn.Conv2d(32, 32, 5, padding=2)),('maxpool2', nn.MaxPool2d(2)),('conv3', nn.Conv2d(32, 64, 5, padding=2)),('maxpool3', nn.MaxPool2d(2)),('flatten', nn.Flatten()),('linear1', nn.Linear(64 * 4 * 4, 64)),('linear2', nn.Linear(64, 10))]))def forward(self, x):x = self.model1(x)return xloss = nn.CrossEntropyLoss()model = MyModel()
for data in dataloader:imgs, targets = dataoutputs = model(imgs)result_loss = loss(outputs, targets)result_loss.backward()  # 使用反向傳播print(result_loss)

? 在 Pycharm 中,將第 48 行注釋,點擊調試。

image-20250527211328521

? 依次在變量窗口中展開“model”-》“model1”-》“conv1”,可看到 conv1 層中的權重參數 weight。

image-20250527211446267

? 展開“weight”,其 grad 屬性此時為 None。

image-20250527211608982

? 點擊“步過”按鈕,運行 48 行,“weight”的 grad 屬性被賦值。此值即為本次迭代的梯度數據。

image-20250527211725204

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/82631.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/82631.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/82631.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

python+tkinter實現GUI界面調用即夢AI文生圖片API接口

背景 目前字節跳動公司提供了即夢AI的接口免費試用,但是并發量只有1,不過足夠我們使用了。我這里想做個使用pythontkinter實現的GUI可視化界面客戶端,這樣就不用每次都登錄官方網站去進行文生圖片,當然文生視頻,或者圖…

#git 儲藏庫意外被清空 Error: bad index – Fatal: index file corrupt

問題:通常是由于 Git 的索引文件損壞導致 原因:系統崩潰或斷電、硬盤故障、Git 操作錯誤等 方案:重建索引文件:將當前的索引文件重命名為其他名稱或刪除,比如 index.m,然后命令行重建索引,git…

GitLab 18.0 正式發布,15.0 將不再受技術支持,須升級【二】

GitLab 是一個全球知名的一體化 DevOps 平臺,很多人都通過私有化部署 GitLab 來進行源代碼托管。極狐GitLab 是 GitLab 在中國的發行版,專門為中國程序員服務。可以一鍵式部署極狐GitLab。 學習極狐GitLab 的相關資料: 極狐GitLab 官網極狐…

車載網關策略 --- 車載網關通信故障處理機制深度解析

我是穿拖鞋的漢子,魔都中堅持長期主義的汽車電子工程師。 老規矩,分享一段喜歡的文字,避免自己成為高知識低文化的工程師: 鈍感力的“鈍”,不是木訥、遲鈍,而是直面困境的韌勁和耐力,是面對外界噪音的通透淡然。 生活中有兩種人,一種人格外在意別人的眼光;另一種人無論…

Unity數字人開發筆記

開源工程地址:https://github.com/zhangliwei7758/unity-AI-Chat-Toolkit 先致敬zhangliwei7758,開放這個源碼 一、建立工程 建立Unity工程(UnityAiChat)拖入Unity-AI-Chat-Toolkit.unitypackage打開chatSample工程,可…

Cherry Studio連接配置MCP服務器

之前寫了一篇關于Cherry Studio的文章,不了解的可以先看一下 AI工具——Cherry Studio,搭建滿血DeepSeek R1的AI對話客戶端【硅基流動DeepSeek API】-CSDN博客 最近Cherry Studio更新了一個新功能:MCP服務器 在 v1.2.9 版本中,…

OpenSSH 服務配置與會話保活完全指南

一、/etc/ssh/sshd_config 配置機制 1. 配置文件基礎 文件作用 OpenSSH 服務器 (sshd) 的主配置文件,控制連接、認證、端口轉發等行為。 加載與生效 修改后需重啟服務:sudo systemctl restart sshd # Systemd 系統 sudo service ssh restart # S…

阿里云國際版注冊郵箱格式詳解

“為什么我的阿里云國際版注冊總提示郵箱無效?” 這是許多初次接觸阿里云國際版(Alibaba Cloud International)的用戶常遇到的困惑。隨著全球化進程加速,越來越多的企業選擇阿里云國際版部署海外業務,而注冊環節中郵箱…

【IDEA問題】springboot本地啟動應用報錯:程序包不存在;找不到符號

問題: springboot本地啟動應用報錯: 程序包xxx不存在;找不到符號 解決方案: 1.確保用maven重新導入依賴 2.刪除.idea文件夾 3.invalidate caches里,把能選擇的都勾選上,然后清除緩存重啟 4.再在上方工具欄…

FFmpeg 時間戳回繞處理:保障流媒體時間連續性的核心機制

FFmpeg 時間戳回繞處理:保障流媒體時間連續性的核心機制 一、回繞處理函數 /** * Wrap a given time stamp, if there is an indication for an overflow * * param st stream // 傳入一個指向AVStream結構體的指針,代表流信息 * pa…

【b站計算機拓荒者】【2025】微信小程序開發教程 - chapter3 項目實踐 -1 項目功能描述

1 項目功能描述 # 智慧社區-小程序-1 歡迎頁-加載后端:動態變化-2 首頁-輪播圖:動態-公共欄:動態-信息采集,社區活動,人臉檢測,語音識別,心率檢測,積分商城-3 信息采集頁面-采集人數…

5.27 day 30

知識點回顧: 導入官方庫的三種手段導入自定義庫/模塊的方式導入庫/模塊的核心邏輯:找到根目錄(python解釋器的目錄和終端的目錄不一致) 作業:自己新建幾個不同路徑文件嘗試下如何導入 一、導入官方庫 我們復盤下學習py…

【GitHub Pages】部署指南

vue項目 編輯你的 vite.config.ts 文件,加上 base 路徑,設置為你的 GitHub 倉庫名 import { defineConfig } from vite import vue from vitejs/plugin-vue// 假設你的倉庫是 https://github.com/your-username/my-vue-app export default defineConfi…

遠程控制技術全面解析:找到適合你的最佳方案

背景:遠程控制為何成為企業核心需求? 隨著企業數字化轉型的推進,遠程控制技術已成為異地辦公和運維的關鍵工具。無論是跨國企業需要高效管理全球設備,還是中小型企業追求經濟高效的解決方案,選擇合適的遠程控制技術&a…

觸覺智能RK3506星閃開發板規格書 型號IDO-EVB3506-V1

產品概述 觸覺智能RK3506星閃開發板,型號IDO-EVB3506-V1采用 Rockchip RK3506(三核 Cortex-A7單核Cortex-M0, 主頻最高1.5GHz)設計的評估開發板,專為家電顯控、顯示HMI、手持終端、工業IOT網關、工業控制、PLC等領域而設計。內置…

九級融智臺階與五大要素協同的量子化解析

九級融智臺階與五大要素協同的量子化解析 摘要:本文構建了一個量子力學框架下的九級融智模型,將企業創新過程映射為量子能級躍遷。研究發現五大要素協同態決定系統躍遷概率(P∣?Ψ_m∣H_協同∣Ψ_n?∣^2),當要素協同…

Kotlin學習34-data數據類1

定義如下:與普通類對比學習 //普通類 class NormalClass(val name: String, val age: Int, val sex: Char) //數據類 data class DataClass(val name: String, val age: Int, val sex: Char)對應找到java反編譯的代碼路徑:Tool-->Kotlin-->Show K…

博圖SCL基礎知識-表達式及賦值運算

S7-1200 從 V2.2 版本開始支持 SCL 語言。 語言元素 SCL 除了包含 PLC 的典型元素(例如,輸入、輸出、定時器或存儲器位)外,還包含高級編程語言表達式、賦值運算和運算符。 程序控制語句 SCL 提供了簡便的指令進行程序控制。例…

海思3519V200ARM Linux 下移植 Qt5.8.0

一、移植背景及意義 海思3519V200是一款基于ARM架構的嵌入式芯片,廣泛應用于智能安防、工業控制等領域。在這些應用場景中,對設備的圖形用戶界面(GUI)有著越來越高的要求。Qt5.8.0作為一個功能強大、跨平臺的GUI開發框架,能夠幫助開發者快速開發出美觀、高效的用戶界面。…

msql的樂觀鎖和冪等性問題解決方案

目錄 1、介紹 2、樂觀鎖 2.1、核心思想 2.2、實現方式 1. 使用 version 字段(推薦) 2. 使用 timestamp 字段 2.3、如何處理沖突 2.4、樂觀鎖局限性 3、冪等性 3.1、什么是冪等性 3.2、樂觀鎖與冪等性的關系 1. 樂觀鎖如何輔助冪等性&#xf…