Pytorch的自動求導模塊

文章目錄

  • torch.autograd.backward()
    • 基本用法
    • 非標量張量的反向傳播
    • 保留計算圖
    • 指定輸入張量
    • 高階梯度計算
  • 與 y.backward() 的區別
  • torch.autograd.grad()
    • 基本用法
    • 非標量張量的梯度
    • 高階梯度計算
    • 多輸入、多輸出的梯度計算
    • 未使用的輸入張量
    • 保留計算圖
  • 與 backward() 的區別

torch.autograd.backward()

該函數實現自動求導梯度,函數如下:

torch.autograd.backward(tensors, grad_tensors=None, retain_graph=False, create_graph=False, inputs=None)

參數介紹:

  • tensors: 需要對其進行反向傳播的目標張量(或張量列表),例如:loss。
    這些張量通常是計算圖的最終輸出。
  • grad_tensors:與 tensors 對應的梯度權重(或權重列表)。
    如果 tensors 是標量張量(單個值),可以省略此參數。
    如果 tensors 是非標量張量(如向量或矩陣),則必須提供 grad_tensors,表示每個張量的梯度權重。例如:當有多個loss需要計算梯度時,需要設置每個loss的權值。
  • retain_graph:是否保留計算圖。
    默認值為 False,即反向傳播后會釋放計算圖。如果需要多次反向傳播,需設置為 True。
  • create_graph: 是否創建一個新的計算圖,用于高階梯度計算
    默認值為 False,如果需要計算二階或更高階梯度,需設置為 True。
  • inputs: 指定需要計算梯度的輸入張量(或張量列表)。
    如果指定了此參數,只有這些張量的 .grad 屬性會被更新,而不是整個計算圖中的所有張量。

基本用法

import torch  # 定義張量并啟用梯度計算  
x = torch.tensor(2.0, requires_grad=True)  
y = x ** 2  # y = x^2  # 使用 torch.autograd.backward() 觸發反向傳播  
torch.autograd.backward(y)  # 查看梯度  
print(x.grad)  # 輸出:4.0 (dy/dx = 2x, 當 x=2 時,dy/dx=4)

非標量張量的反向傳播

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)  
y = x ** 2  # y = [x1^2, x2^2, x3^2]  # 指定 grad_tensors 權重  
grad_tensors = torch.tensor([1.0, 1.0, 1.0])  # 權重  
torch.autograd.backward(y, grad_tensors=grad_tensors)  # 查看梯度  
print(x.grad)  # 輸出:[2.0, 4.0, 6.0] (dy/dx = 2x)

保留計算圖

如果需要多次調用反向傳播,可以設置 retain_graph=True。

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次反向傳播  
torch.autograd.backward(y, retain_graph=True)  
print(x.grad)  # 輸出:12.0 (dy/dx = 3x^2, 當 x=2 時,dy/dx=12)  # 第二次反向傳播  
torch.autograd.backward(y, retain_graph=True)  
print(x.grad)  # 輸出:24.0 (梯度累積,12.0 + 12.0)

指定輸入張量

通過 inputs 參數,可以只計算指定張量的梯度,而忽略其他張量。

x = torch.tensor(2.0, requires_grad=True)  
z = torch.tensor(3.0, requires_grad=True)  
y = x ** 2 + z ** 3  # y = x^2 + z^3  # 只計算 x 的梯度  
torch.autograd.backward(y, inputs=[x])  
print(x.grad)  # 輸出:4.0 (dy/dx = 2x)  
print(z.grad)  # 輸出:None (未計算 z 的梯度)

高階梯度計算

通過設置 create_graph=True,可以構建新的計算圖,用于計算高階梯度。

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次反向傳播,創建新的計算圖  
torch.autograd.backward(y, create_graph=True)  
print(x.grad)  # 輸出:12.0 (dy/dx = 3x^2)  # 計算二階梯度  
x_grad = x.grad  
x_grad.backward()  
print(x.grad)  # 輸出:18.0 (d^2y/dx^2 = 6x)

與 y.backward() 的區別

  • 靈活性:

    • torch.autograd.backward() 更靈活,可以對多個張量同時進行反向傳播,并指定梯度權重。
    • y.backward() 是對單個張量的簡單封裝,適合常見場景。對多個loss求導時,需要指定gradient和grad_outputs相同作用。
  • 梯度權重:

    • torch.autograd.backward() 需要顯式提供 grad_tensors 參數(如果目標張量是非標量)。
    • y.backward() 會自動處理標量張量,非標量張量需要手動傳入權重。
  • 輸入控制:

    • torch.autograd.backward() 可以通過 inputs 參數指定只計算某些張量的梯度。
    • y.backward() 無法直接控制,只會更新計算圖中所有相關張量的 .grad。

torch.autograd.grad()

torch.autograd.grad() 是 PyTorch 中用于計算張量梯度的函數,與 backward() 不同的是,它不會更新張量的 .grad 屬性,而是直接返回計算的梯度值。它適用于需要手動獲取梯度值而不修改計算圖中張量的 .grad 屬性的場景。

torch.autograd.grad(  outputs,   inputs,   grad_outputs=None,   retain_graph=False,   create_graph=False,   only_inputs=True,   allow_unused=False  
)

參數介紹:

  • outputs:
    目標張量(或張量列表),即需要對其進行求導的輸出張量。
  • inputs:
    需要計算梯度的輸入張量(或張量列表)。
    這些張量必須啟用了 requires_grad=True。
  • grad_outputs:
    與 outputs 對應的梯度權重(或權重列表)。
    如果 outputs 是標量張量,可以省略此參數;如果是非標量張量,則需要提供權重,表示每個輸出張量的梯度權重。
  • retain_graph:
    是否保留計算圖。
    默認值為 False,即反向傳播后會釋放計算圖。如果需要多次計算梯度,需設置為 True。
  • create_graph:
    是否創建一個新的計算圖,用于高階梯度計算。
    默認值為 False,如果需要計算二階或更高階梯度,需設置為 True。
  • only_inputs:
    是否只對 inputs 中的張量計算梯度。
    默認值為 True,表示只計算 inputs 的梯度。
  • allow_unused:
    是否允許 inputs 中的某些張量未被 outputs 使用。
    默認值為 False,如果某些 inputs 未被 outputs 使用,會拋出錯誤。如果設置為 True,未使用的張量的梯度會返回 None。

返回值:

  • 返回一個元組,包含 inputs 中每個張量的梯度值。
  • 如果某個輸入張量未被 outputs 使用,且 allow_unused=True,則對應的梯度為 None。

基本用法

import torch  # 定義張量并啟用梯度計算  
x = torch.tensor(2.0, requires_grad=True)  
y = x ** 2  # y = x^2  # 使用 torch.autograd.grad() 計算梯度  
grad = torch.autograd.grad(y, x)  
print(grad)  # 輸出:(4.0,) (dy/dx = 2x, 當 x=2 時,dy/dx=4)

非標量張量的梯度

當目標張量是非標量時,需要提供 grad_outputs 參數:

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)  
y = x ** 2  # y = [x1^2, x2^2, x3^2]  # 指定 grad_outputs 權重  
grad_outputs = torch.tensor([1.0, 1.0, 1.0])  # 權重  
grad = torch.autograd.grad(y, x, grad_outputs=grad_outputs)  
print(grad)  # 輸出:(tensor([2.0, 4.0, 6.0]),) (dy/dx = 2x)

高階梯度計算

通過設置 create_graph=True,可以計算高階梯度:

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次計算梯度  
grad = torch.autograd.grad(y, x, create_graph=True)  
print(grad)  # 輸出:(12.0,) (dy/dx = 3x^2)  # 計算二階梯度  
grad2 = torch.autograd.grad(grad[0], x)  
print(grad2)  # 輸出:(6.0,) (d^2y/dx^2 = 6x)

多輸入、多輸出的梯度計算

可以對多個輸入和輸出同時計算梯度:

x = torch.tensor(2.0, requires_grad=True)  
z = torch.tensor(3.0, requires_grad=True)  
y1 = x ** 2 + z ** 3  # y1 = x^2 + z^3  
y2 = x * z  # y2 = x * z  # 對多個輸入計算梯度  
grads = torch.autograd.grad([y1, y2], [x, z], grad_outputs=[torch.tensor(1.0), torch.tensor(1.0)])  
print(grads)  # 輸出:(7.0, 11.0) (dy1/dx + dy2/dx, dy1/dz + dy2/dz)

未使用的輸入張量

如果某些輸入張量未被目標張量使用,需設置 allow_unused=True:

x = torch.tensor(2.0, requires_grad=True)  
z = torch.tensor(3.0, requires_grad=True)  
y = x ** 2  # y = x^2  # z 未被 y 使用  
grad = torch.autograd.grad(y, [x, z], allow_unused=True)  
print(grad)  # 輸出:(4.0, None) (dy/dx = 4, z 未被使用,梯度為 None)

保留計算圖

如果需要多次計算梯度,可以設置 retain_graph=True:

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次計算梯度  
grad1 = torch.autograd.grad(y, x, retain_graph=True)  
print(grad1)  # 輸出:(12.0,)  # 第二次計算梯度  
grad2 = torch.autograd.grad(y, x)  
print(grad2)  # 輸出:(12.0,)

與 backward() 的區別

  • 梯度存儲
    • torch.autograd.grad() 不會修改張量的 .grad 屬性,而是直接返回梯度值。
    • backward() 會將計算的梯度累積到 .grad 屬性中。
  • 靈活性:
    • torch.autograd.grad() 可以對多個輸入和輸出同時計算梯度,并支持未使用的輸入張量。
    • backward() 只能對單個輸出張量進行反向傳播。
  • 高階梯度:
    • torch.autograd.grad() 支持通過 create_graph=True 計算高階梯度。
    • backward() 也支持高階梯度,但需要手動設置 create_graph=True。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/64903.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/64903.shtml
英文地址,請注明出處:http://en.pswp.cn/web/64903.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Mac OS

本文來自智譜清言 ------ Mac OS(現稱為macOS)是蘋果公司開發和銷售的操作系統,自1984年推出以來,它已經經歷了多次重大的演變和發展。 起源:Mac OS 1.0的誕生 - 1984年,蘋果發布了Macintosh計算機&#…

spring中使用@Validated,什么是JSR 303數據校驗,spring boot中怎么使用數據校驗

文章目錄 一、JSR 303后臺數據校驗1.1 什么是 JSR303?1.2 為什么使用 JSR 303? 二、Spring Boot 中使用數據校驗2.1 基本注解校驗2.1.1 使用步驟2.1.2 舉例Valid注解全局統一異常處理 2.2 分組校驗2.2.1 使用步驟2.2.2 舉例Validated注解Validated和Vali…

ubuntu常用快捷鍵和變量記錄

alias b‘cd …/’ alias bb‘cd …/…/’ alias bbb‘cd …/…/…/’ alias bbbb‘cd …/…/…/…/’ alias bbbbb‘cd …/…/…/…/…/’ alias bbbbbb‘cd …/…/…/…/…/…/’ alias apkinfo‘aapt dump badging’ alias npp‘notepad-plus-plus’ export ANDROID_HOME/h…

AWS S3文件存儲工具類

pom依賴 <!--aws-s3--> <dependency><groupId>com.amazonaws</groupId><artifactId>aws-java-sdk-s3</artifactId><version>1.12.95</version></dependency>S3Utils import cn.hutool.core.util.ZipUtil; import com.a…

【SOC 芯片設計 DFT 學習專欄 -- 測試向量生成 ATPG (Automatic Test Pattern Generation) 】

文章目錄 OverviewATPG 的基本功能ATPG 的工作流程ATPG 應用場景示例示例 1&#xff1a;檢測單個信號的 Stuck-at Fault示例 2&#xff1a;針對 Transition Fault 的 ATPG ATPG 工具與常用工具鏈ATPG 優化與挑戰 Overview 本文主要介紹 DFT scan 中的 ATPG 功能。在 DFT (Desi…

2024 高通邊緣智能創新應用大賽智能邊緣計算賽道冠軍方案解讀

2024 高通邊緣智能創新應用大賽聚焦不同細分領域的邊緣智能創新應用落地&#xff0c;共設立三大熱門領域賽道——工業智能質檢賽道、智能邊緣計算賽道和智能機器人賽道。本文為智能邊緣計算賽道冠軍項目《端側大模型智能翻譯機》的開發思路與成果分享。 賽題要求 聚焦邊緣智能…

【Python運維】用Python和Ansible實現高效的自動化服務器配置管理

《Python OpenCV從菜鳥到高手》帶你進入圖像處理與計算機視覺的大門! 解鎖Python編程的無限可能:《奇妙的Python》帶你漫游代碼世界 隨著云計算和大規模數據中心的興起,自動化配置管理已經成為現代IT運維中不可或缺的一部分。通過自動化,企業可以大幅提高效率,降低人為錯…

微信小程序獲取后端數據

在小程序中獲取后端接口數據 通常可以使用 wx.request 方法&#xff0c;以下是一個基本示例&#xff1a; // pages/index/index.js Page({data: {// 用于存儲后端返回的數據resultData: [] },onLoad() {this.fetchData();},fetchData() {wx.request({url: https://your-backe…

應用架構模式-總體思路

采用引導式設計方法&#xff1a;以企業級架構為指導&#xff0c;形成較為齊全的規范指引。在實踐中總結重要設計形成決策要點&#xff0c;一個決策要點對應一個設計模式。自底向上總結采用該設計模式的必備條件&#xff0c;將之轉化通過簡單需求分析就能得到的業務特點&#xf…

【數據結構】雙向循環鏈表的使用

雙向循環鏈表的使用 1.雙向循環鏈表節點設計2.初始化雙向循環鏈表-->定義結構體變量 創建頭節點&#xff08;1&#xff09;示例代碼&#xff1a;&#xff08;2&#xff09;圖示 3.雙向循環鏈表節點頭插&#xff08;1&#xff09;示例代碼&#xff1a;&#xff08;2&#xff…

【Java設計模式-3】門面模式——簡化復雜系統的魔法

在軟件開發的世界里&#xff0c;我們常常會遇到復雜的系統&#xff0c;這些系統由多個子系統或模塊組成&#xff0c;各個部分之間的交互錯綜復雜。如果直接讓外部系統與這些復雜的子系統進行交互&#xff0c;不僅會讓外部系統的代碼變得復雜難懂&#xff0c;還會增加系統之間的…

Linux一些問題

修改YUM源 Centos7將yum源更換為國內源保姆級教程_centos使用中科大源-CSDN博客 直接安裝包&#xff0c;走鏈接也行 Index of /7.9.2009/os/x86_64/Packages 直接復制里面的安裝包鏈接&#xff0c;在命令行直接 yum install https://vault.centos.org/7.9.2009/os/x86_64/Pa…

微信小程序 覆蓋組件cover-view

wxml 覆蓋組件 <video src"../image/1.mp4" controls"{{false}}" event-model"bubble"> <cover-view class"controls"> <cover-view class"play" bind:tap"play"> <cover-image class"…

HTML——57. type和name屬性

<!DOCTYPE html> <html><head><meta charset"UTF-8"><title>type和name屬性</title></head><body><!--1.input元素是最常用的表單控件--><!--2.input元素不僅可以在form標簽內使用也可以在form標簽外使用-…

uniapp本地加載騰訊X5瀏覽器內核插件

概述 TbsX5webviewUTS插件封裝騰訊x5webview離線內核加載模塊&#xff0c;可以把uniapp的瀏覽器內核直接替換成Android X5 Webview(騰訊TBS)最新內核&#xff0c;提高交互體驗和流暢度。 功能說明 下載SDK插件 1.集成x5內核后哪些頁面會由x5內核渲染&#xff1f; 所有plus…

力扣hot100——二叉樹

94. 二叉樹的中序遍歷 class Solution { public:vector<int> inorderTraversal(TreeNode* root) {vector<int> ans;stack<TreeNode*> stk;while (root || stk.size()) {while (root) {stk.push(root);root root->left;}auto cur stk.top();stk.pop();a…

設計模式 創建型 單例模式(Singleton Pattern)與 常見技術框架應用 解析

單例模式&#xff08;Singleton Pattern&#xff09;是一種創建型設計模式&#xff0c;旨在確保某個類在應用程序的生命周期內只有一個實例&#xff0c;并提供一個全局訪問點來獲取該實例。這種設計模式在需要控制資源訪問、避免頻繁創建和銷毀對象的場景中尤為有用。 一、核心…

您的公司需要小型語言模型

當專用模型超越通用模型時 “越大越好”——這個原則在人工智能領域根深蒂固。每個月都有更大的模型誕生&#xff0c;參數越來越多。各家公司甚至為此建設價值100億美元的AI數據中心。但這是唯一的方向嗎&#xff1f; 在NeurIPS 2024大會上&#xff0c;OpenAI聯合創始人伊利亞…

uniapp-vue3(下)

關聯鏈接&#xff1a;uniapp-vue3&#xff08;上&#xff09; 文章目錄 七、咸蝦米壁紙項目實戰7.1.咸蝦米壁紙項目概述7.2.項目初始化公共目錄和設計稿尺寸測量工具7.3.banner海報swiper輪播器7.4.使用swiper的縱向輪播做公告區域7.5.每日推薦滑動scroll-view布局7.6.組件具名…

使用 Python 實現隨機中點位移法生成逼真的裂隙面

使用 Python 實現隨機中點位移法生成逼真的裂隙面 一、隨機中點位移法簡介 1. 什么是隨機中點位移法&#xff1f;2. 應用領域 二、 Python 代碼實現 1. 導入必要的庫2. 函數定義&#xff1a;隨機中點位移法核心邏輯3. 設置隨機數種子4. 初始化二維裂隙面5. 初始化網格的四個頂點…