深度學習框架PyTorch——從入門到精通(6.1)自動微分

使用torch.autograd自動微分

  • 張量、函數和計算圖
  • 計算梯度
  • 禁用梯度追蹤
  • 關于計算圖的更多信息
  • 張量梯度和雅可比乘積

在訓練神經網絡時,最常用的算法是反向傳播。在該算法中,參數(模型權重)根據損失函數的梯度相對于給定參數進行調整。

為了計算這些梯度,PyTorch有一個內置的微分引擎,名為torch.autograd。它支持為任何計算圖自動計算梯度。

考慮最簡單的一層神經網絡,具有輸入x、參數w和b以及一些損失函數。它可以通過以下方式在PyTorch中定義:

import torchx = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

張量、函數和計算圖

剛才的代碼定義了以下計算圖:

在這里插入圖片描述
在這個網絡中,w和b是參數,我們需要優化。因此,我們需要能夠計算關于這些變量的損失函數的梯度。

為了做到這一點,我們設置了這些張量的requires_grad屬性。

注:您可以在創建張量時設置`requires_grad`的值,或者稍后使用`x.requires_grad_(True)`方法。

在 PyTorch 中,用于構建計算圖的張量操作函數實際上是Function類的對象。該對象不僅負責處理正向傳播時的函數計算,還能在反向傳播過程中計算導數。反向傳播函數的引用會存儲在張量的grad_fn屬性中。你可以在官方文檔中找到關于Function類的更多詳細信息。

print(f"Gradient function for z = {z.grad_fn}")
print(f"Gradient function for loss = {loss.grad_fn}")
# 輸出
Gradient function for z = <AddBackward0 object at 0x7fdf8cca1c30>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x7fdf8cca2c20>

計算梯度

為了優化神經網絡中參數的權重,我們需要計算損失函數關于參數的導數。也就是說,在輸入值 x 和目標值 y 的固定取值下,我們需要計算 ? l o s s ? w \frac{\partial loss}{\partial w} ?w?loss? ? l o s s ? b \frac{\partial loss}{\partial b} ?b?loss?。要計算這些導數,我們只需調用 loss.backward(),然后從參數 w.gradb.grad 中獲取對應的導數值。

loss.backward()
print(w.grad)
print(b.grad)
# 輸出
tensor([[0.3313, 0.0626, 0.2530],[0.3313, 0.0626, 0.2530],[0.3313, 0.0626, 0.2530],[0.3313, 0.0626, 0.2530],[0.3313, 0.0626, 0.2530]])
tensor([0.3313, 0.0626, 0.2530])
注:我們只能獲得計算圖葉節點的grad屬性,因為這些葉節點的requires_grad屬性是True。對于圖中其他的節點,獲得梯度屬性的方法將不可用。
另外出于性能原因,我們只能在給定圖上使用backward執行一次梯度計算。如果我們需要在同一張圖上執行幾個backward調用,我們需要將retain_graph=True傳遞給backward調用。

禁用梯度追蹤

默認情況下,所有具有requires_grad=True的張量都在跟蹤它們的計算歷史并支持梯度計算。
但是,有些情況下我們不需要這樣做。

例如,當我們已經訓練了模型并只想將其應用于一些輸入數據時,即我們只想通過網絡進行轉發計算。我們可以通過用torch.no_grad()塊:

z = torch.matmul(x, w)+b
print(z.requires_grad)with torch.no_grad():z = torch.matmul(x, w)+b
print(z.requires_grad)# 輸出
True
False

實現相同結果的另一種方法是使用對張量detach()方法:

z = torch.matmul(x, w)+b
z_det = z.detach()
print(z_det.requires_grad)
# 輸出
False

可能想要禁用漸變跟蹤的原因如下:

  • 將神經網絡中的某些參數標記為凍結參數
  • 當您只進行前向傳遞時,要加快計算速度,因為不跟蹤梯度的張量上的計算會更有效。

關于計算圖的更多信息

從概念上講,自動微分在由Function對象組成的有向無環圖(DAG)中記錄數據(張量)和所有執行的操作(以及生成的新張量)。
在這個有向無環圖(DAG)中,葉是輸入張量,根是輸出張量。通過從根到葉跟蹤這個圖,您可以使用鏈式規則自動計算梯度。

在向前傳播中,自動微分同時做兩件事:

  • 運行請求的操作以計算生成的張量
  • 在有向無環圖中維護操作的梯度函數。

當在有向無環圖(DAG)的根節點上調用 .backward() 方法時,反向傳播過程就啟動了。隨后,自動求導系統會進行以下操作:

  • 計算每個.grad_fn的梯度。
  • 將它們累加到相應張量的.grad屬性中
  • 使用鏈式規則,一直傳播到圖的葉張量(葉節點)。

在 PyTorch 中,有向無環圖(DAG)是動態的。需要重點理解的是:圖會 從頭重新構建;每次調用 .backward() 后,自動求導機制(autograd)都會開始生成一幅新圖。而這正是模型中能使用控制流語句(如循環、條件判斷)的關鍵——如果有需求,你完全可以在每次迭代時調整圖的形狀、規模以及具體操作。

張量梯度和雅可比乘積

在許多場景中,我們會用到標量損失函數,此時需要計算損失函數關于某些參數的梯度。但也存在輸出函數是任意張量的情況。這時,PyTorch 支持計算所謂的 雅可比積,而非直接計算實際的梯度。

對于向量函數 y ? = f ( x ? ) \vec{y} = f(\vec{x}) y ?=f(x )(其中 x ? = ? x 1 , . . . , x n ? \vec{x} = \langle x_1, ..., x_n \rangle x =?x1?,...,xn?? y ? = ? y 1 , . . . , y m ? \vec{y} = \langle y_1, ..., y_m \rangle y ?=?y1?,...,ym??), y ? \vec{y} y ? 關于 x ? \vec{x} x 的梯度由雅可比矩陣表示:
J = ( ? y 1 ? x 1 ? ? y 1 ? x n ? ? ? ? y m ? x 1 ? ? y m ? x n ) J = \begin{pmatrix} \frac{\partial y_1}{\partial x_1} & \cdots & \frac{\partial y_1}{\partial x_n} \\ \vdots & \ddots & \vdots \\ \frac{\partial y_m}{\partial x_1} & \cdots & \frac{\partial y_m}{\partial x_n} \end{pmatrix} J= ??x1??y1????x1??ym????????xn??y1????xn??ym??? ?

PyTorch 并不直接計算雅可比矩陣本身,而是允許針對給定的輸入向量 v = ( v 1 , . . . , v m ) v = (v_1, ..., v_m) v=(v1?,...,vm?) 計算雅可比積 v T ? J v^T \cdot J vT?J。這一過程通過將 v v v 作為參數調用 backward 實現。需要注意的是, v v v 的維度必須與我們希望計算雅可比積的原始張量的維度一致。

inp = torch.eye(4, 5, requires_grad=True)
out = (inp+1).pow(2).t()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"First call\n{inp.grad}")
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")
inp.grad.zero_()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")
# 輸出
First call
tensor([[4., 2., 2., 2., 2.],[2., 4., 2., 2., 2.],[2., 2., 4., 2., 2.],[2., 2., 2., 4., 2.]])Second call
tensor([[8., 4., 4., 4., 4.],[4., 8., 4., 4., 4.],[4., 4., 8., 4., 4.],[4., 4., 4., 8., 4.]])Call after zeroing gradients
tensor([[4., 2., 2., 2., 2.],[2., 4., 2., 2., 2.],[2., 2., 4., 2., 2.],[2., 2., 2., 4., 2.]])

請注意,當我們使用sameargument第二次調用backward時,梯度的值是不同的。發生這種情況是因為在進行backward傳播時,PyTorch累積梯度,即計算梯度的值被添加到計算圖所有葉節點的grad屬性中。
如果你想計算正確的梯度,你需要先將grad屬性歸零。在實際訓練中,優化器能幫助我們做到這一點。

注:之前我們調用 backward() 函數時沒有傳入參數。實際上,這等同于調用 backward(torch.tensor(1.0))。在處理標量值函數時,這樣做是一種很實用的計算梯度的方法,比如在神經網絡訓練過程中計算損失函數的梯度就可以用這種方式。

更多內容請看:自動求導機制

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/73054.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/73054.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/73054.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

跟我學C++中級篇——std::not_fn

一、std::not_fn定義和說明 std::not_fn這個模板函數非常有意思&#xff0c;在前面我們學習過wrapper&#xff08;包裝器&#xff09;&#xff0c;其實它就是通過封裝一個包裝器來實現返回值的非。它的基本定義如下&#xff1a; template< class F > /* 未指定 */ not_…

階躍星辰開源300億參數視頻模型Step-Video-TI2V:運動可控+102幀長視頻生成

階躍星辰&#xff08;StepFun&#xff09;正式開源其新一代圖生視頻模型 Step-Video-TI2V &#xff0c;該模型基于300億參數的Step-Video-T2V訓練&#xff0c;支持文本與圖像聯合驅動生成長達102幀的高質量視頻&#xff0c;在運動控制與場景適配性上實現突破。 核心亮點 …

java查詢es超過10000條數據

java查詢es超過10000條數據 背景:需要每天零點導出es中日志數據到數據庫中給數據分析人員做清洗&#xff0c;然后展示給業務人員。但在es中默認一次最多只能查詢10000條數據。 在這里我就只貼一下關鍵代碼 SearchRequest searchRequest new SearchRequest("索引名"…

使用 libevent 構建高性能網絡應用

使用 libevent 構建高性能網絡應用 在現代網絡編程中&#xff0c;高性能和可擴展性是開發者追求的核心目標。為了實現這一目標&#xff0c;許多開發者選擇使用事件驅動庫來管理 I/O 操作和事件處理。libevent 是一個輕量級、高性能的事件通知庫&#xff0c;廣泛應用于網絡服務…

HeyGem.ai 全離線數字人生成引擎加入 GitCode:開啟本地化 AIGC 創作新時代

在人工智能技術飛速演進的時代&#xff0c;數據隱私與創作自由正成為全球開發者關注的焦點。硅基智能旗下開源項目 HeyGem.ai 近日正式加入 GitCode&#xff0c;以全球首個全離線數字人生成引擎的顛覆性技術&#xff0c;重新定義人工智能生成內容&#xff08;AIGC&#xff09;的…

【leetcode hot 100 39】組合總和

錯誤解法一&#xff1a;每一次回溯都遍歷提供的數組 class Solution {public List<List<Integer>> combinationSum(int[] candidates, int target) {List<List<Integer>> result new ArrayList<List<Integer>>();List<Integer> te…

VSCODE右下角切換環境沒用

VSCODE惦記右下角python版本&#xff0c;切換別的虛擬環境時&#xff0c;始終切換不了&#xff0c;同時右下角彈出&#xff1a; Client Pylance: connection to server is erroring. 取消繼承環境也改了。https://www.cnblogs.com/coreylin/p/17509610.html 還是不行&#xf…

【sql靶場】第23、25,25a關過濾繞過保姆級教程

目錄 【sql靶場】第23、25-28關過濾繞過保姆級教程 第二十三關 第二十五關 1.爆出數據庫 2.爆出表名 3.爆出字段 4.爆出賬號密碼 【sql靶場】第23、25&#xff0c;25a關過濾繞過保姆級教程 第二十三關 從本關開始又是get傳參&#xff0c;并且還有了對某些字符或字段的過…

python每日十題(5)

保留字&#xff0c;也稱關鍵字&#xff0c;是指被編程語言內部定義并保留使用的標識符。Python 3.x版本中有35個保留字&#xff0c;分別為&#xff1a;and, as,assert,async,await,break,class,continue,def,del,elif,else, except, False, finally,for,from,global, if,import…

Pytorch使用手冊—自定義 C++ 和 CUDA 擴展(專題五十二)

提示 從 PyTorch 2.4 開始,本教程已被廢棄。請參考 PyTorch 自定義操作符,了解關于通過自定義 C++/CUDA 擴展擴展 PyTorch 的最新指南。 PyTorch 提供了大量與神經網絡、任意張量代數、數據處理等相關的操作。然而,您可能仍然會發現自己需要一個更自定義的操作。例如,您可能…

CHM(ConcurrentHashMap)中的 sizeCtl 的作用與值變化詳解

學海無涯&#xff0c;志當存遠。燃心礪志&#xff0c;奮進不輟。愿諸君得此雞湯&#xff0c;如沐春風&#xff0c;學業有成。若覺此言甚善&#xff0c;煩請賜贊一枚&#xff0c;共勵學途&#xff0c;同鑄輝煌 ConcurrentHashMap常簡寫為CHM&#xff0c;尤其是在討論并發編程時。…

VLAN綜合實驗報告

一、實驗拓撲 網絡拓撲結構包括三臺交換機&#xff08;LSW1、LSW2、LSW3&#xff09;、一臺路由器&#xff08;AR1&#xff09;以及六臺PC&#xff08;PC1-PC6&#xff09;。交換機之間通過Trunk鏈路相連&#xff0c;交換機與PC、路由器通過Access或Hybrid鏈路連接。 二、實驗…

OpenGL ES ->計算多個幀緩沖對象(Frame Buffer Object)+疊加多個濾鏡作用后的Bitmap

XML文件 <?xml version"1.0" encoding"utf-8"?> <RelativeLayout xmlns:android"http://schemas.android.com/apk/res/android"android:layout_width"match_parent"android:layout_height"match_parent"><…

Java線程池深度解析:從使用到調優

適合人群&#xff1a;Java中級開發者 | 并發編程入門者 | 系統調優實踐者 目錄 一、引言&#xff1a;為什么線程池是Java并發的核心&#xff1f; 二、線程池核心知識點詳解 1. 線程池核心參數與原理 2. 線程池的創建與使用 (1) 基礎用法示例 (2) 內置線程池的隱患 3. 線…

【工具變量】全國地級市地方ZF債務數據集(2014-2023年)

地方ZF債務是地方財政運作的重要組成部分&#xff0c;主要用于基礎設施建設、公共服務及經濟發展&#xff0c;是衡量地方財政健康狀況的重要指標。近年來&#xff0c;我國地級市的地方ZF債務規模不斷變化&#xff0c;涉及一般債務和專項債務等多個方面&#xff0c;對金融市場、…

大模型訓練的調參與算力調度技術分析

大模型訓練的調參與算力調度 雖然從網絡上&#xff0c;還有通過和大模型交流&#xff0c;了解了很多訓練和微調的技術。但沒有實踐&#xff0c;也沒有什么機會實踐。因為大模型訓練門檻還是挺高的&#xff0c;想要有一手資料比較困難。如果需要多機多卡&#xff0c;硬件成本小…

深入理解 lt; 和 gt;:HTML 實體轉義的核心指南!!!

&#x1f6e1;? 深入理解 < 和 >&#xff1a;HTML 實體轉義的核心指南 &#x1f6e1;? 在編程和文檔編寫中&#xff0c;< 和 > 符號無處不在&#xff0c;但它們也是引發語法錯誤、安全漏洞和渲染混亂的頭號元兇&#xff01;&#x1f525; 本文將聚焦 <&#…

GRS認證的注意事項!GRS認證的定義

GRS認證的注意事項&#xff0c;對于企業而言&#xff0c;是通往可持續發展和環保生產道路上的重要里程碑。在追求這一認證的過程中&#xff0c;企業必須細致入微&#xff0c;確保每一個環節都符合嚴格的標準與要求。 首先&#xff0c;企業必須全面理解GRS認證的核心原則&#…

位運算--求二進制中1的個數

位運算–求二進制中1的個數 給定一個長度為 n 的數列&#xff0c;請你求出數列中每個數的二進制表示中 1 的個數。 輸入格式 第一行包含整數 n。 第二行包含 n 個整數&#xff0c;表示整個數列。 輸出格式 共一行&#xff0c;包含 n 個整數&#xff0c;其中的第 i 個數表…

Linux常用指令(3)

大家好,今天我們繼續來介紹一下linux常用指令的語法,加深對linux操作系統的了解,話不多說,來看. 1.rmdir指令 功能&#xff1a;刪除空目錄 基本語法&#xff1a; rmdir 要刪除的空目錄 ??rmdir刪除的是空目錄,如果目錄下有內容是無法刪除 2.mkdir指令 功能&#xff1a;創…