人工智能-python-深度學習-自動微分

自動微分:基礎概念與應用

自動微分(Autograd)是現代深度學習框架(如PyTorch、TensorFlow)中的一個核心功能。它通過構建計算圖并在計算圖上自動計算梯度,簡化了反向傳播算法的實現。以下是自動微分的基本概念及其操作。


1. 基礎概念

自動微分指的是通過跟蹤計算圖中的每一步計算,自動計算目標函數相對于模型參數的梯度。這些計算圖是在每次前向傳播時動態構建的。基于這個圖,系統可以在反向傳播時自動計算梯度,而不需要手動推導每個梯度。

1.1 張量

torch中一切皆為張量,屬性requires_grad決定是否對其進行梯度計算。默認是False,如需計算梯度則設置為True

1.2 計算圖

torch.autograd通過創建一個動態計算圖來跟蹤張良的操作,每個張量是計算圖中的一個節點,節點之間的操作構成圖的邊。

在Pytorch中,當張量的requiers_grad=Ture時,Pytorch會自動跟蹤與該張量相關的所有操作,并構建計算圖。每個操作都會生成一個新的張量,并記錄其依賴關系。當設置為True時,表示該張量在計算圖中需要參與梯度計算,即在反向傳播(Backpropagation)過程中惠子dog計算其梯度;當設置為False時,不會計算梯度。
例如
z=x?yloss=z.sum()z = x * y\\loss = z.sum()z=x?yloss=z.sum()
在上述代碼中,x 和 y 是輸入張量,即葉子節點,z 是中間結果,loss 是最終輸出。每一步操作都會記錄依賴關系:

z = x * y:z 依賴于 x 和 y。

loss = z.sum():loss 依賴于 z。

這些依賴關系形成了一個動態計算圖,如下所示:

	  x       y\     /\   /\ /z||vloss

葉子節點

在 PyTorch 的自動微分機制中,葉子節點(leaf node) 是計算圖中:

  • 由用戶直接創建的張量,并且它的 requires_grad=True。
  • 這些張量是計算圖的起始點,通常作為模型參數或輸入變量。

特征:

  • 沒有由其他張量通過操作生成。
  • 如果參與了計算,其梯度會存儲在 leaf_tensor.grad 中。
  • 默認情況下,葉子節點的梯度不會自動清零,需要顯式調用 optimizer.zero_grad() 或 x.grad.zero_() 清除。

如何判斷一個張量是否是葉子節點?

通過 tensor.is_leaf 屬性,可以判斷一個張量是否是葉子節點。

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)  # 葉子節點
y = x ** 2  # 非葉子節點(通過計算生成)
z = y.sum()print(x.is_leaf)  # True
print(y.is_leaf)  # False
print(z.is_leaf)  # False

葉子節點與非葉子節點的區別

特性葉子節點非葉子節點
創建方式用戶直接創建的張量通過其他張量的運算生成
is_leaf 屬性TrueFalse
梯度存儲梯度存儲在 .grad 屬性中梯度不會存儲在 .grad,只能通過反向傳播傳遞
是否參與計算圖是計算圖的起點是計算圖的中間或終點
刪除條件默認不會被刪除在反向傳播后,默認被釋放(除非 retain_graph=True)

detach():張量 x 從計算圖中分離出來,返回一個新的張量,與 x 共享數據,但不包含計算圖(即不會追蹤梯度)。

特點

  • 返回的張量是一個新的張量,與原始張量共享數據。
  • 對 x.detach() 的操作不會影響原始張量的梯度計算。
  • 推薦使用 detach(),因為它更安全,且在未來版本的 PyTorch 中可能會取代 data。
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x.detach()  # y 是一個新張量,不追蹤梯度y += 1  # 修改 y 不會影響 x 的梯度計算
print(x)  # tensor([1., 2., 3.], requires_grad=True)
print(y)  # tensor([2., 3., 4.])

反向傳播

使用tensor.backward()方法執行反向傳播,從而計算張量的梯度。這個過程會自動計算每個張量對損失函數的梯度。例如:調用 loss.backward() 從輸出節點 loss 開始,沿著計算圖反向傳播,計算每個節點的梯度。

梯度

計算得到的梯度通過tensor.grad訪問,這些梯度用于優化模型參數,以最小化損失函數。

2. 計算梯度

2.1 標量梯度計算

標量梯度計算指的是計算標量(通常是損失函數)相對于模型參數的梯度。在深度學習中,常見的損失函數(如均方誤差、交叉熵等)都是標量值。

import torch# 定義張量
x = torch.tensor(2.0, requires_grad=True)
y = x**2 + 3*x + 1  # 定義標量函數# 計算梯度
y.backward()  # 反向傳播
print(x.grad)  # 輸出x的梯度

為何需要標量梯度?

  • 在訓練過程中,我們需要計算損失函數相對于各個參數的梯度,從而調整模型參數。標量梯度的計算是整個訓練過程中優化模型的基礎。
2.2 向量梯度計算

向量梯度計算用于計算多維向量函數相對于輸入向量的梯度。例如,輸出是一個向量時,我們希望計算每個分量的梯度。

# 定義張量
x = torch.tensor([2.0, 3.0], requires_grad=True)
y = x**2  # 計算每個元素的平方# 計算梯度
y.backward(torch.tensor([1.0, 1.0]))  # 向量梯度計算
print(x.grad)  # 輸出x的梯度

為何需要向量梯度?

  • 在多輸入多輸出的情況下,向量梯度計算能有效地描述每個輸入對于輸出的影響。
2.3 多標量梯度計算

在一些復雜的場景中,損失函數可能有多個標量輸出。我們需要計算每個標量輸出對參數的梯度。

x = torch.tensor([2.0, 3.0], requires_grad=True)
y1 = x[0]**2 + 3*x[0] + 1
y2 = x[1]**3 + 2*x[1] - 5
y = y1 + y2  # 多標量函數y.backward()  # 計算梯度
print(x.grad)

為何需要多標量梯度?

  • 多標量梯度有助于處理多任務學習中的梯度計算,特別是當每個任務有不同的損失函數時。
2.4 多向量梯度計算

當輸出是多個向量時,我們通常需要計算每個向量對每個輸入的梯度。比如在生成對抗網絡(GAN)或多任務學習中,常見這種情況。

x = torch.tensor([1.0, 2.0], requires_grad=True)
y1 = x[0]**2 + 3*x[0]
y2 = x[1]**3 + 2*x[1]grad_outputs = torch.tensor([1.0, 1.0])  # 指定多個輸出梯度
y1.backward(grad_outputs)  # 分別計算y1和y2的梯度

為何需要多向量梯度?

  • 計算多個輸出的梯度可以幫助我們進行多維度的優化,尤其是在復雜的網絡結構中,多個輸出有助于提高模型的多樣性和魯棒性。

3. 梯度上下文控制

在深度學習中,常常需要控制梯度計算的上下文,以節省內存或者針對性地優化某些參數。

3.1 控制梯度計算

我們可以通過torch.no_grad()with torch.set_grad_enabled(False)來臨時停止梯度計算,這對于不需要計算梯度的操作(例如推理階段)非常有用。

with torch.no_grad():y = x * 2  # 在此塊中,不會計算梯度

為何控制梯度計算?

  • 在推理階段,我們不需要梯度,這樣可以節省計算資源和內存。
3.2 累計梯度

在某些情況下,梯度計算需要分多個小批次進行累計。例如,使用小批次訓練時,梯度會在每個小批次上累加。

optimizer.zero_grad()  # 清空之前的梯度
y.backward()  # 累計梯度
optimizer.step()  # 更新參數

為何累計梯度?

  • 累計梯度可以使得模型在小批次上進行優化,而不丟失總體梯度信息,適用于大規模數據的訓練。
3.3 梯度清零

在每次更新前,我們需要清空之前計算的梯度,否則它會在下一步的計算中累加。

optimizer.zero_grad()  # 清除上次計算的梯度

為何清零梯度?

  • 防止梯度計算的累積影響下一次計算,確保每次計算梯度時的準確性。

4. 案例分析

4.1 求函數最小值

通過計算梯度并使用優化算法(如梯度下降),我們可以找到函數的最小值。

x = torch.tensor(2.0, requires_grad=True)
for _ in range(100):y = x**2 + 3*x + 1y.backward()with torch.no_grad():x -= 0.1 * x.grad  # 使用梯度更新xx.grad.zero_()  # 清空梯度

為何使用梯度下降求解最小值?

  • 通過不斷調整參數,沿著梯度方向前進,直到收斂到函數的最小值。
4.2 函數參數求解

如果已知函數并希望通過梯度來求解某些未知參數,可以使用反向傳播來更新這些參數。

def func(x):return x**2 - 4*x + 3x = torch.tensor(3.0, requires_grad=True)
for i in range(100):y = func(x)y.backward()x.data -= 0.1 * x.gradx.grad.zero_()

為何求解函數參數?

  • 在機器學習中,模型的參數通過梯度計算來優化,進而提高模型的性能。

結論

自動微分的引入讓深度學習框架大大簡化了梯度計算過程。通過自動計算標量、向量梯度以及控制梯度的計算上下文,開發者可以專注于模型設計而非手動推導梯度公式。


本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/94461.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/94461.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/94461.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

k8s原理及操作

簡介 kubernetes的本質是一組服務器集群,它可以在集群的每個節點上運行特定的程序,來對節點中的容器 進行管理。目的是實現資源管理的自動化,主要提供了如下的主要功能: 自我修復:一旦某一個容器崩潰,能夠在…

理解音頻響度:LUFS 標準及其計算實現

LUFS 及其重要性 1.1、什么是 LUFS? LUFS(Loudness Units relative to Full Scale)是音頻工程中用于測量感知響度的標準單位。它已成為廣播、流媒體和音樂制作領域的行業標準,用于確保不同音頻內容具有一致的響度水平。 LUFS 是 I…

【在ubuntu下使用vscode打開c++的make項目及編譯調試】

在ubuntu下使用vscode打開c的make項目及編譯調試第一步:安裝必要的軟件第二步:示例項目準備1. 創建C源文件: main.cpp2. 創建頭文件: utils.h3. 創建實現文件: utils.cpp第三步:使用 VS Code 打開項目第四步…

3-2.Python 函數 - None(None 概述、None 應用場景)

一、None 概述在 Python 中,None 是一個特殊的常量,用于表示空值或無值None 是 Python 中唯一的一個 NoneType 類型的實例二、None 應用場景 1、定義變量 None 常用于初始化變量,表示該變量暫時不需要有具體值 name Noneprint(name) print(t…

js獲取html元素并設置高度為100vh-鍵盤高度

獲取HTML元素并設置高度為(100vh - 鍵盤高度) 我將設計一個頁面,展示如何獲取HTML元素并動態設置其高度為視口高度減去鍵盤高度,這在移動設備上特別有用,可以避免鍵盤遮擋內容。 設計思路 創建一個帶有輸入框的界面,模擬鍵盤彈…

基于SpringBoot的校園博客管理系統

🔗 目錄 一. 前言 ??二. 前端框架、后端框架以及存儲框架使用情況說明 ??三. 核心技術 ????1. ?Java開發語言 ????2. ?MyBatis ????3. ?Mysql ????4. ?Vue ????5. ?部署項目 ??四. 演示效果 ????1. 管理員功能模塊 ??????…

Nginx + Certbot配置 HTTPS / SSL 證書

前提條件: 1.已有域名 2.Nginx 已安裝并正在運行,且有對應的 Server 配置 3.防火墻開放 80 和 443 端口 安裝 EPEL 倉庫: sudo yum install epel-release -y安裝 Snapd sudo yum install snapd -y啟用并啟動 Snapd Socket sudo systemctl ena…

圖結構使用 Louvain 社區檢測算法進行分組

圖結構使用 Louvain 社區檢測算法進行分組 flyfish Louvain 算法是一種基于模塊度最大化的社區檢測算法,核心目標是在復雜網絡中找到“內部連接緊密、外部連接稀疏”的社區結構。它的優勢在于高效性(可處理百萬級節點的大規模網絡)和近似最優…

layui.formSelects自定義多選組件在layer.open中使用、獲取、復現

layui.formSelects自定義多選組件在layer.open中使用、獲取、復現 引入css和js //<th:block th:include"include :: layui-formSelects-css"/> <link th:href"{/ajax/libs/layui-formSelects/formSelects-v4.css}" rel"stylesheet"/>…

基于SpringBoot的社團管理系統【2026最新】

作者&#xff1a;計算機學姐 開發技術&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源碼”。 專欄推薦&#xff1a;前后端分離項目源碼、SpringBoot項目源碼、Vue項目源碼、SSM項目源碼、微信小程序源碼 精品專欄&#xff1a;…

運行node18報錯

又碰到一個奇葩的問題&#xff0c;報錯如下> tigermes.vue30.1.0 serve > vue-cli-service serveBrowserslist: caniuse-lite is outdated. Please run:npx update-browserslist-dblatestWhy you should do it regularly: https://github.com/browserslist/update-db#rea…

Python第三方庫IPFS-API使用詳解:構建去中心化應用的完整指南

目錄 Python第三方庫IPFS-API使用詳解&#xff1a;構建去中心化應用的完整指南 引言&#xff1a;IPFS與去中心化存儲的革命 星際文件系統&#xff08;IPFS&#xff0c;InterPlanetary File System&#xff09;是一種革命性的點對點超媒體協議&#xff0c;旨在創建持久且分布式的…

ETL與iPaaS的融合方案:加速數據集成流程

在今天的商業世界里&#xff0c;數據幾乎無處不在。企業每天都在產生和接收海量的數據——從CRM到ERP&#xff0c;從云端SaaS應用到本地數據庫&#xff0c;來源越來越分散&#xff0c;集成也越來越復雜。 傳統的ETL工具&#xff08;提取、轉換、加載&#xff09;在處理結構化數…

詳解flink SQL基礎(四)

文章目錄1.Flink SQL介紹2.streaming SQL&watermarks使用3.窗口聚合&#xff08;window aggregations&#xff09;4.over aggregations5.FlinkSQL 流連接&#xff08;Streaming join&#xff09;6.使用MATCH_RECOGNIZE 進行模式識別和復雜事件處理7.變更記錄&#xff08;ch…

有鹿機器人:為城市描繪清潔新圖景的智能使者

一、智慧清潔&#xff1a;科技賦能的環境革新每天清晨&#xff0c;當我沿著小區路徑緩緩行駛&#xff0c;雙激光雷達系統便開始精準測繪環境。我的專業清掃能力源自2cm精度死亡貼邊技術&#xff0c;這項讓同行驚嘆的能力&#xff0c;可以輕松震出嵌了十年的煙頭&#xff0c;徹底…

Tableau Server高危漏洞允許攻擊者上傳任意惡意文件

Tableau Server 存在一個嚴重安全漏洞&#xff0c;可能允許攻擊者上傳并執行惡意文件&#xff0c;最終導致系統完全淪陷。該漏洞編號為 CVE-2025-26496&#xff0c;CVSS 評分為 9.6 分&#xff0c;影響 Windows 和 Linux 平臺上的多個 Tableau Server 和 Tableau Desktop 版本。…

數據結構07(Java)-- (堆,大根堆,堆排序)

前言 本文為本小白&#x1f92f;學習數據結構的筆記&#xff0c;將以算法題為導向&#xff0c;向大家更清晰的介紹數據結構相關知識&#xff08;算法題都出自&#x1f64c;B站馬士兵教育——左老師的課程&#xff0c;講的很好&#xff0c;對于想入門刷題的人很有幫助&#x1f4…

onnx入門教程(七)——如何添加 TensorRT 自定義算子

在前面的模型入門系列文章中&#xff0c;我們介紹了部署一個 PyTorch 模型到推理后端&#xff0c;如 ONNXRuntime&#xff0c;這其中可能遇到很多工程性的問題。有些可以通過創建 ONNX 節點來解決&#xff0c;該節點仍然使用后端原生的實現進行推理。而有些無法導出到后端的算法…

YggJS RButton 按鈕組件 v1.0.0 使用教程

&#x1f4cb; 目錄 簡介核心特性快速開始安裝指南基礎使用主題系統高級功能API 參考最佳實踐性能優化故障排除總結 &#x1f680; 簡介 YggJS RButton 是一個專門為 React 應用程序設計的高性能按鈕組件庫。它提供了兩套完整的設計主題&#xff1a;科技風主題和極簡主題&…

Linux(二十)——SELinux 概述與狀態切換

文章目錄前言一、SELinux 概述1.1 SELinux 簡介1.2 SELinux 特點1.2.1 MAC&#xff08;Mandatory Access Control&#xff09;1.2.2 RBAC&#xff08;Role-Based Access Control&#xff09;1.2.3 TE&#xff08;Type Enforcement&#xff09;1.3 SELinux 的執行模式1.4 SELinu…