pytorch基礎4-自動微分

專題鏈接:https://blog.csdn.net/qq_33345365/category_12591348.html

本教程翻譯自微軟教程:https://learn.microsoft.com/en-us/training/paths/pytorch-fundamentals/

初次編輯:2024/3/2;最后編輯:2024/3/3


本教程第一篇:介紹pytorch基礎和張量操作

本教程第二篇:介紹了數據集與歸一化

本教程第三篇:介紹構建模型層的基本操作。

本教程第四篇:介紹自動微分相關知識,即本博客內容。

另外本人還有pytorch CV相關的教程,見專題:

https://blog.csdn.net/qq_33345365/category_12578430.html


自動微分


使用torch.autograd自動微分 Automaic differentiation

在訓練神經網絡時,最常用的算法是反向傳播(back propagation)。在這個算法中,參數(模型權重)根據損失函數相對于給定參數的梯度進行調整。損失函數(loss function)計算神經網絡產生的預期輸出和實際輸出之間的差異。目標是使損失函數的結果盡可能接近零。該算法通過神經網絡向后遍歷以調整權重和偏差來重新訓練模型。這就是為什么它被稱為反向傳播。隨著時間的推移,通過反復進行這種回傳和前向過程來將損失(loss)減少到0的過程稱為梯度下降。

為了計算這些梯度,PyTorch具有一個內置的微分引擎,稱為torch.autograd。它支持對任何計算圖進行梯度的自動計算。

考慮最簡單的單層神經網絡,具有輸入x,參數wb,以及某些損失函數。可以在PyTorch中如下定義:

import torchx = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b  # z = x*w +b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

張量、函數與計算圖(computational graphs)

在這個網絡中,wb參數,他們會被損失函數優化。因此,需要能夠計算損失函數相對于這些變量的梯度。為此,我們將這些張量的requires_grad屬性設置為True。

**注意:**您可以在創建張量時設置requires_grad的值,也可以稍后使用x.requires_grad_(True)方法來設置。

我們將應用于張量的函數(function)用于構建計算圖,這些函數是Function類的對象。這個對象知道如何在前向方向上計算函數,還知道在反向傳播步驟中如何計算其導數。反向傳播函數的引用存儲在張量的grad_fn屬性中。

print('Gradient function for z =',z.grad_fn)
print('Gradient function for loss =', loss.grad_fn)

輸出是:

Gradient function for z = <AddBackward0 object at 0x00000280CC630CA0>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward object at 0x00000280CC630310>

計算梯度

為了優化神經網絡中參數的權重,需要計算損失函數相對于參數的導數,即我們需要在某些固定的xy值下計算 ? l o s s ? w \frac{\partial loss}{\partial w} ?w?loss? ? l o s s ? b \frac{\partial loss}{\partial b} ?b?loss?。為了計算這些導數,我們調用loss.backward(),然后從w.gradb.grad中獲取值。

loss.backward()
print(w.grad)
print(b.grad)

輸出是:

tensor([[0.2739, 0.0490, 0.3279],[0.2739, 0.0490, 0.3279],[0.2739, 0.0490, 0.3279],[0.2739, 0.0490, 0.3279],[0.2739, 0.0490, 0.3279]])
tensor([0.2739, 0.0490, 0.3279])

注意: 只能獲取計算圖中設置了requires_grad屬性為True的葉節點的grad屬性。對于計算圖中的所有其他節點,梯度將不可用。此外,出于性能原因,我們只能對給定圖執行一次backward調用以進行梯度計算。如果我們需要在同一圖上進行多次backward調用,我們需要在backward調用中傳遞retain_graph=True

禁用梯度追蹤 Disabling gradient tracking

默認情況下,所有requires_grad=True的張量都在跟蹤其計算歷史并支持梯度計算。然而,在某些情況下,我們并不需要這樣做,例如,當我們已經訓練好模型并且只想將其應用于一些輸入數據時,也就是說,我們只想通過網絡進行前向計算。我們可以通過將我們的計算代碼放在一個torch.no_grad()塊中來停止跟蹤計算:

z = torch.matmul(x, w)+b
print(z.requires_grad)with torch.no_grad():z = torch.matmul(x, w)+b
print(z.requires_grad)

輸出是:

True
False

另外一種產生相同結果的方法是在張量上使用detach方法:

z = torch.matmul(x, w)+b
z_det = z.detach()
print(z_det.requires_grad)

有一些理由你可能想要禁用梯度跟蹤:

  • 將神經網絡中的某些參數標記為凍結參數(frozen parameters)。這在微調預訓練網絡的情況下非常常見。
  • 當你只進行前向傳播時,為了加速計算,因為不跟蹤梯度的張量上的計算更有效率。

計算圖的更多知識

概念上,autograd 在一個有向無環圖 (DAG) 中保留了數據(張量)和所有執行的操作(以及生成的新張量),這些操作由 Function 對象組成。在這個 DAG 中,葉子節點是輸入張量,根節點是輸出張量。通過從根節點到葉子節點追蹤這個圖,你可以使用鏈式法則(chain rule)自動計算梯度。

在前向傳播中,autograd 同時執行兩件事情:

  • 運行所請求的操作以計算結果張量,并且
  • 在 DAG 中維護操作的 梯度函數(gradient function)

當在 DAG 根節點上調用 .backward() 時,反向傳播開始。autograd 然后:

  • 從每個 .grad_fn 計算梯度,
  • 將它們累積在相應張量的 .grad 屬性中,并且
  • 使用鏈式法則一直傳播到葉子張量。

PyTorch 中的 DAG 是動態的

一個重要的事情要注意的是,圖是從頭開始重新創建的;在每次 .backward() 調用之后,autograd 開始填充一個新的圖。這正是允許您在模型中使用控制流語句的原因;如果需要,您可以在每次迭代中更改形狀、大小和操作。

代碼匯總:

import torchx = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w) + b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)print('Gradient function for z =', z.grad_fn)
print('Gradient function for loss =', loss.grad_fn)loss.backward()
print(w.grad)
print(b.grad)z = torch.matmul(x, w) + b
print(z.requires_grad)with torch.no_grad():z = torch.matmul(x, w) + b
print(z.requires_grad)z = torch.matmul(x, w) + b
z_det = z.detach()
print(z_det.requires_grad)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/716826.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/716826.shtml
英文地址,請注明出處:http://en.pswp.cn/news/716826.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【Java EE】JUC(java.util.concurrent) 的常見類

目錄 &#x1f334;Callable 接口&#x1f38d;ReentrantLock&#x1f340;原子類&#x1f333;線程池&#x1f332;信號量 Semaphore??CountDownLatch、?相關面試題 &#x1f334;Callable 接口 Callable 是?個 interface . 相當于把線程封裝了?個 “返回值”. ?便程序…

什么是灰色預測

灰色預測是一種基于灰色系統理論的預測方法&#xff0c;用于處理數據不完全、信息不充分或未知的情況下的預測問題。它適用于樣本數據較少、無法建立精確的數學模型的情況。 灰色預測的基本思想是利用已知數據的特點和規律來推斷未知數據的發展趨勢。它的核心是灰色關聯度的概念…

(學習日記)2024.03.01:UCOSIII第三節 + 函數指針 (持續更新文件結構)

寫在前面&#xff1a; 由于時間的不足與學習的碎片化&#xff0c;寫博客變得有些奢侈。 但是對于記錄學習&#xff08;忘了以后能快速復習&#xff09;的渴望一天天變得強烈。 既然如此 不如以天為單位&#xff0c;以時間為順序&#xff0c;僅僅將博客當做一個知識學習的目錄&a…

Kubernetes: 本地部署dashboard

本篇文章主要是介紹如何在本地部署kubernetes dashboard, 部署環境是mac m2 下載dashboard.yaml 官網release地址: kubernetes/dashboard/releases 本篇文章下載的是kubernetes-dashboard-v2.7.0的版本&#xff0c;通過wget命令下載到本地: wget https://raw.githubusercont…

【Python】進階學習:pandas--isin()用法詳解

【Python】進階學習&#xff1a;pandas–isin()用法詳解 &#x1f308; 個人主頁&#xff1a;高斯小哥 &#x1f525; 高質量專欄&#xff1a;Matplotlib之旅&#xff1a;零基礎精通數據可視化、Python基礎【高質量合集】、PyTorch零基礎入門教程&#x1f448; 希望得到您的訂閱…

【NDK系列】Android tombstone文件分析

文件位置 data/tombstone/tombstone_xx.txt 獲取tombstone文件命令&#xff1a; adb shell cp /data/tombstones ./tombstones 觸發時機 NDK程序在發生崩潰時&#xff0c;它會在路徑/data/tombstones/下產生導致程序crash的文件tombstone_xx&#xff0c;記錄了死亡了進程的…

單細胞Seurat - 細胞聚類(3)

本系列持續更新Seurat單細胞分析教程&#xff0c;歡迎關注&#xff01; 維度確定 為了克服 scRNA-seq 數據的任何單個特征中廣泛的技術噪音&#xff0c;Seurat 根據 PCA 分數對細胞進行聚類&#xff0c;每個 PC 本質上代表一個“元特征”&#xff0c;它結合了相關特征集的信息。…

深入測探:用Python玩轉分支結構與循環操作——技巧、場景及面試寶典

在編程的世界里&#xff0c;分支結構和循環操作是構建算法邏輯的基礎磚石。它們如同編程的“鹽”&#xff0c;賦予代碼生命&#xff0c;讓靜態的數據跳躍起來。本文將帶你深入探索Python中的分支結構和循環操作&#xff0c;通過精心挑選的示例和練習題&#xff0c;不僅幫助你掌…

mysql5*-mysql8 區別

1.Mysql5.7-Mysql8.0 sysbench https://github.com/geekgogie/mysql57_vs_8-benchmark_scripts 1.讀、寫、刪除更新 速度 512 個線程以后才會出現如下的。 2.刪除速度 2.事務處理性能 3.CPU利用率 mysql8 利用率高。 4.排序 5.7 只能ASC&#xff0c;不能降序 數據越來越大

牢記于心單獨說出來的知識點(后續會加)

第一個 非十進制&#xff08;八進制&#xff0c;十六進制&#xff09;寫在文件中它本身就是補碼&#xff0c;計算機是不用進行內存轉換&#xff0c;它直接存入內存。&#xff08;因為十六進制本身是補碼&#xff0c;所以計算機里面我們看到的都是十六進制去存儲&#xff09; …

Qt 簡約美觀的加載動畫 文本風格 第八季

今天和大家分享一個文本風格的加載動畫, 有兩類,其中一個可以設置文本內容和文本顏色,演示了兩份. 共三個動畫, 效果如下: 一共三個文件,可以直接編譯 , 如果對您有所幫助的話 , 不要忘了點贊呢. //main.cpp #include "LoadingAnimWidget.h" #include <QApplic…

MySQL:開始深入其數據(一)DML

在上一章初識MySQL了解了如何定義數據庫和數據表&#xff08;DDL&#xff09;&#xff0c;接下來我們開始開始深入其數據,對其數據進行訪問&#xff08;DAL&#xff09;、查詢DQL&#xff08;&#xff09;和操作(DML)等。 通過DML語句操作管理數據庫數據 DML (數據操作語言) …

一文搞定 FastAPI 路徑參數

路徑參數定義 路徑操作裝飾器中對應的值就是路徑參數,比如: from fastapi import FastAPI app = FastAPI()@app.get("/hello/{name}") def say_hello(name: str):return {

突破編程_C++_STL教程( list 的基礎知識)

1 std::list 概述 std::list 是 C 標準庫中的一個雙向鏈表容器。它支持在容器的任何位置進行常數時間的插入和刪除操作&#xff0c;但不支持快速隨機訪問。與 std::vector 或 std::deque 這樣的連續存儲容器相比&#xff0c;std::list 在插入和刪除元素時不需要移動其他元素&a…

計算機網絡之傳輸層 + 應用層

.1 UDP與TCP IP中的檢驗和只檢驗IP數據報的首部, 但UDP的檢驗和檢驗 偽首部 首部 數據TCP的交互單位是數據塊, 但仍說TCP是面向字節流的, 因為TCP僅把應用層傳下來的數據看成無結構的字節流, 根據當時的網絡環境組裝成大小不一的報文段.10秒內有1秒用于發送端發送數據, 信道…

【Python】進階學習:pandas--groupby()用法詳解

&#x1f4ca;【Python】進階學習&#xff1a;pandas–groupby()用法詳解 &#x1f308; 個人主頁&#xff1a;高斯小哥 &#x1f525; 高質量專欄&#xff1a;Matplotlib之旅&#xff1a;零基礎精通數據可視化、Python基礎【高質量合集】、PyTorch零基礎入門教程&#x1f448;…

Python算法100例-3.5 親密數

1.問題描述2.問題分析3.算法設計4.確定程序框架5.完整的程序6.問題拓展 1&#xff0e;問題描述 如果整數A的全部因子&#xff08;包括1&#xff0c;不包括A本身&#xff09;之和等于B&#xff0c;且整數B的全部因子&#xff08;包括1&#xff0c;不包括B本身&#xff09;之和…

中國電子學會2020年6月份青少年軟件編程Sc ratch圖形化等級考試試卷四級真題。

第 1 題 【 單選題 】 1.執行下面程序&#xff0c;輸入4和7后&#xff0c;角色說出的內容是&#xff1f; A&#xff1a;4&#xff0c;7 B&#xff1a;7&#xff0c;7 C&#xff1a;7&#xff0c;4 D&#xff1a;4&#xff0c;4 2.執行下面程序&#xff0c;輸出是&#xff…

Oracle自帶的網絡工具(計算傳輸redo需要的帶寬,使用STATSPACK,計算redo壓縮率,db_ultra_safe)

--根據primary database redo產生的速率,計算傳輸redo需要的帶寬. 除去tcp/ip網絡其余30%的開銷,計算需要的帶寬公式: 需求帶寬((每秒產生redo的速率峰值/0.75)*8)/1,000,000帶寬(Mbps) --可以通過去多次業務高峰期的Statspack/AWR獲取每秒產生redo的速率峰值,也可以通過查詢視…

post請求體內容無法重復獲取

post請求體內容無法重復獲取 為什么會無法重復讀取呢&#xff1f; 以tomcat為例&#xff0c;在進行請求體讀取時實際底層調用的是org.apache.catalina.connector.Request的getInputStream()方法&#xff0c;而該方法返回的是CoyoteInputStream輸入流 public ServletInputStream…