【DL學習筆記】計算圖與自動求導

計算圖

  • 計算圖(Computation Graph)是一種用于描述計算過程的圖形化表示方法。

  • 在深度學習中,計算圖通常用于描述 網絡結構、運算過程 和數據流向

  • 計算圖是一種有向無環圖,用圖形方式來表示算子與變量之間的關系,直觀高效。

  • 它由節點(Node)和邊(Edge)組成,如下圖Netron庫可視化的例子,其中節點表示操作或函數,邊表示數據流向

在這里插入圖片描述

前向傳播 與 反向傳播

  • 在Pytorch中,計算圖的構建是通過神經網絡的 前向傳播 (forward) 過程完成的。

  • 反向傳播 根據計算圖來計算梯度,從而進行參數更新。它為自動微分(automatic differentiation)提供了基礎,使得深度學習框架能夠自動計算梯度并進行反向傳播。

靜態計算圖、動態計算圖

計算圖可以分為兩種類型:靜態計算圖 和 動態計算圖

  • 靜態計算圖: 在靜態計算圖中,計算圖在模型定義階段就被固定下來,不會發生變化。典型的例子是 TensorFlow 1.x 中的計算圖。在這種情況下,首先定義計算圖,然后運行會話(session)來執行圖中的操作。

  • 動態計算圖: 在動態計算圖中,計算圖在運行時根據輸入數據的形狀和大小動態構建。PyTorch 和 TensorFlow 2.x 采用了動態計算圖的方式。在這種情況下,每次前向傳播都會重新構建計算圖,使得模型更加靈活。

在整個前向計算過程中,PyTorch采用 動態計算圖 的形式進行組織,且在每次 前向傳播時重新構建。
其他深度學習架構,如TensorFlow、Keras 一般為靜態圖。

葉子節點、非葉子節點、根節點

在這里插入圖片描述

  • 上面的計算圖中,圓形表示變量矩形表示算子,這些變量和算子構成了一個完整的前向傳播過程
  • 葉子節點 : x、w、bx、w、bxwb 為葉子節點,它們是用戶創建的變量,不依賴于其他變量
  • 非葉子節點 : y、zy、zyz為非葉子節點,它們是通過計算得到的變量
  • 根節點 : zzz 為根節點,它之后不會再有后續的運算,我們一般讓根節點來執行 反向傳播方法 z.backward()

torch.tensor()的requires_grad參數

  • 對于葉子節點(Leaf Node)的 張量Tensor,需要用 requires_grad 指明是否記錄對其的操作運算,以便之后通過 反向傳播求梯度。

  • 一般僅對 葉子節點 設置 requires_grad, 這些葉子節點,一般就是網絡中層的參數,他們一般都是 torch.nn.Parameter 對象,requires_grad 屬性 默認為 True

  • 葉子結點如果需要求導,requires_grad 需設置為 True,那么由這些葉子節點計算得出的非葉子節點,requires_grad 會自動置為True

import torchx = torch.tensor([2.0], requires_grad=True)   # 葉子節點
w = torch.tensor([3.0], requires_grad=True)   # 葉子節點
b = torch.tensor([1.0], requires_grad=True)   # 葉子節點y = w * x  # 非葉子節點
z = y + b  # 非葉子節點# 查看葉子節點和非葉子節點的 requires_grad 屬性
print('x 的 requires_grad 屬性:', x.requires_grad)
print('w 的 requires_grad 屬性:', w.requires_grad)
print('b 的 requires_grad 屬性:', b.requires_grad)
print('y 的 requires_grad 屬性:', y.requires_grad)
print('z 的 requires_grad 屬性:', z.requires_grad) 

grad_fn屬性

  • 通過運算創建的非葉子節點 tensor,會自動被賦予 grad_fn 屬性,用于表明生成這個張量的操作
  • 葉子節點的 grad_fn 為 None,因為它們不是通過其他操作計算得來的,而是網絡的參數或輸入數據

一些常見的 grad_fn 類型包括:

  • <CatBackward>:表示這個張量是通過 torch.cat 操作得到的。
  • <MatMulBackward>:表示這個張量是通過矩陣乘法操作得到的。
  • <AddBackward>:表示這個張量是通過加法操作得到的。
  • <AddmmBackward>:表示一個張量是通過 torch.addmm 操作得到的。
  • <DivBackward>:表示這個張量是通過除法操作得到的。
  • <ReLUBackward>:表示這個張量是通過 ReLU 激活函數得到的。
**import torch# 創建葉子節點張量
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([3.0], requires_grad=True)# 創建非葉子節點張量,通過運算生成
z = x * y# 查看葉子節點和非葉子節點的 grad_fn 屬性
print('x 的 grad_fn:', x.grad_fn) 
print('y 的 grad_fn:', y.grad_fn)
print('z 的 grad_fn:', z.grad_fn) **

反向傳播

在反向傳播中,以 loss(根節點 tensor )為核心,步驟為:

  1. optimizer.zero_grad() 清空葉子節點梯度,避免多次 optimizer.step() 時梯度累加。
  2. 調用 loss.backward() 反向傳播,計算葉子節點梯度并存入 .grad 屬性。
  3. 執行 optimizer.step() ,依優化器算法和學習率,用 .grad 梯度更新葉子節點(即模型參數 ) 。
for epoch in range(epochs):model.train()for imgs, labels in train_loader        :# trainoptimizer.zero_grad()loss.backward()optimizer.step()

完整舉例:

import torch# 輸入張量 x, require_grad 屬性默認為 False
x = torch.Tensor([2])# 初始化 權重參數w, 偏移量b,并設置 require_grad 屬性為 True, 為自動求導
w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)# 實現向前傳播
y = torch.mul(w, x)
z = torch.add(y, b)# 分別查看葉子節點 x, w, b 和 非葉子節點 y、z 的require_grad屬性
print(x.requires_grad, w.requires_grad, b.requires_grad)  # False True True
print(y.requires_grad, z.requires_grad )  # True True# 查看各節點是否為葉子節點
print(x.is_leaf, w.is_leaf, b.is_leaf, y.is_leaf, z.is_leaf)  # True True True False False# 分別查看 葉子節點 和 非葉子節點 的 grad_fn 屬性
print(x.grad_fn, w.grad_fn, b.grad_fn)   # None None None
print(y.grad_fn, z.grad_fn)   # <MulBackward0 object at 0x7f8ac1303910> <AddBackward0 object at 0x7f8ac1303070># 反向傳播計算梯度
z.backward()  # 查看葉子節點的梯度,x是葉子節點但它無須求導,故其梯度為None 
print(w.grad,b.grad,x.grad)  # tensor([2.]) tensor([1.]) None# 非葉子節點的梯度,執行backward之后,會自動清空 
print(y.grad,z.grad)  # None None

在這里插入圖片描述

自動求導 Autograd

  • 在神經網絡中,一個重要內容就是進行參數學習,而參數學習的反向傳播離不開求導。
  • 現在大部分深度學習架構都有自動求導的功能,torch.autograd包 就是用來自動求導的。
  • torch.autograd 包為張量上所有的操作提供了自動求導功能

實驗:backward()反向傳播自動求導

以下代碼實現 : 機器學習 回歸問題舉例,使用 backward() 反向傳播自動求導,并手動更新參數

  1. 先來造一批數據,作為樣本數據 x 和 標簽值y
import torch
import matplotlib.pyplot as plttorch.manual_seed(100)# 生成 x坐標數據,形狀為 100 x 1
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)# 生成 y坐標數據,,形狀為 100 x 1,加上一些噪聲
y = 3 * x.pow(2) + 2 + 0.2 * torch.rand(x.size())# 把tensor數據轉換為numpy數據,并可視化
plt.scatter(x.numpy(), y.numpy())
plt.show()

在這里插入圖片描述

  1. 定義一個模型 y = wx +b, 我們要學習出 w 和 b 的值,用來擬合 x 和 y
# 初始化權重參數,參數 w、b 為需要學習的,故需要設置參數 requires_grad=True
w = torch.randn(1, 1, dtype=torch.float, requires_grad=True)
b = torch.zeros(1, 1, dtype=torch.float, requires_grad=True)
print(w)  # tensor([[1.1046]], requires_grad=True)
print(b)  # tensor([[0.]], requires_grad=True)lr = 0.001 # 學習率for i in range(800):# 向前傳播,得到預測的y值,記為 y_predy_pred = w * x.pow(2) + b# 定義損失函數loss = (y - y_pred) ** 2loss = loss.sum()# 反向傳播,自動計算梯度,存放在 grad 屬性中loss.backward()# 手動更新參數,需要用torch.no_grad(), 使上下文環境中切斷自動求導的計算with torch.no_grad():# 更新參數w -= lr * w.gradb -= lr * b.grad# 梯度清零w.grad.zero_()b.grad.zero_()print(w)  # tensor([[2.9668]], requires_grad=True)
print(b)  # tensor([[2.1138]], requires_grad=True)

在這里插入圖片描述

  1. 可視化一下結果,紅色曲線是預測結果 ,藍色點是真實標簽值
    在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/94003.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/94003.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/94003.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

大型地面光伏電站開發建設流程

?地面電站特特點&#xff1a;規模大&#xff0c;通常占用土地、水面等&#xff0c;地面式選址選項多&#xff0c;且不斷拓展出新的用地模式&#xff0c;地面式選址集中在山體、灘涂、沼澤、戈壁、沙漠、受污染土地等閑置或廢棄土地上。

除數博弈(動態規劃)

愛麗絲和鮑勃一起玩游戲&#xff0c;他們輪流行動。愛麗絲先手開局。最初&#xff0c;黑板上有一個數字 n 。在每個玩家的回合&#xff0c;玩家需要執行以下操作&#xff1a;選出任一 x&#xff0c;滿足 0 < x < n 且 n % x 0 。用 n - x 替換黑板上的數字 n 。如果玩家…

一起學springAI系列一:初體驗

Spring AI是干嘛的官網最權威&#xff0c;直接粘貼&#xff1a;“Spring AI”項目旨在簡化那些包含人工智能功能的應用程序的開發過程&#xff0c;同時避免不必要的復雜性。AI相關領域的功能對python的支持是最好的&#xff0c;相關供應商在出了啥功能的時候&#xff0c;都會優…

Ext JS極速項目之 Coworkee

ExtJS Coworkee 是什么? Ext JS 的 Coworkee 是一個由 Sencha 官方提供的完整員工管理應用示例,旨在展示 Ext JS 框架在企業級應用開發中的能力。 在線試用的地址是: https://examples.sencha.com/coworkee/#home 頁面效果與布局 登錄頁面: 主頁效果 左右分區結構:左…

飛算科技:原創技術重塑 Java 開發,引領行業數智化新浪潮

在科技革新的浪潮中&#xff0c;飛算科技作為一家堅持自主創新的數字科技企業&#xff0c;同時也是國家級高新技術企業&#xff0c;正深耕互聯網科技、大數據、人工智能等前沿領域&#xff0c;為眾多企業的數字化與智能化轉型提供強勁動力。?飛算科技的成長軌跡&#xff0c;是…

cesium FBO(一)渲染到紋理(RTT)

一聽到三維的RTT&#xff08;Render To Texture&#xff09;&#xff0c;似乎很神秘&#xff0c;但從底層實現一看&#xff0c;其實也就那樣&#xff0c;設計API的哪些頂級家伙已經幫你安排的明明白白了&#xff0c;咱們只需要學會怎么用就可以了。我認為得從WebGL入手&#xf…

PNP機器人機器人學術年會展示靈巧手動作捕捉方案。

2025年8月1-3日&#xff0c;第六屆中國機器人學術年會&#xff08;CCRS2025&#xff09;在長沙國際會議中心舉行&#xff0c;主題“人機共融&#xff0c;智向未來”。PNP機器人與靈巧智能聯合展出最新靈巧手模仿學習方案&#xff1a;基于少量示教數據即可快速復現復雜抓取動作&…

【45】C#入門到精通——C#調用C/C++生成動態庫.dll及C++ 生成動態庫.dll ,DllImport()方式導入 C++動態庫.dll方法總結

文章目錄1 C 生成動態庫.dll2 C#調用C/C生成動態庫.dll2.1 [DllImport()] 方式導入 C動態庫.dll2.2 調用測試3 C/C 生成通用dll,改進3.1改進后.h3.2 .cpp3.2 C# 調用4 [DllImport()] 方式導入C生成的 .dll 總結4.1 指定路徑導入4.2 .dll放在 執行目錄下&#xff08;一定要放對&…

從協議棧到ath12k_mac_op_tx的完整調用路徑

文章目錄 從協議棧到ath12k_mac_op_tx的完整調用路徑 1. 整體架構概覽 2. 詳細調用路徑分析 2.1 應用層到Socket層 2.2 協議層處理 2.3 網絡設備層到mac80211 2.4 mac80211發送入口 2.5 mac80211核心發送處理 2.6 mac80211發送核心處理 2.7 mac80211發送調度 2.8 最終驅動調用 …

WPFC#超市管理系統(4)入庫管理

入庫管理7. 商品入庫管理7.2 入庫實現顯示名稱、圖片、單位7.3 界面設計7.3 功能實現7. 商品入庫管理 數據庫中StockRecord表需要增加商品出入庫Type類型為nvarchar(50)。C#中的數據庫重新同步StockRecord表在Entity→Model中新建枚舉類型StockType namespace 超市管理系統.E…

CSS 打字特效

效果圖.wxml <view class"tips"><text>{{ tipsText }}</text><text class"tips-line">|</text> </view>.wxss .tips{padding: 50rpx 100rpx;font-size: 28rpx; } .tips-line{color: #ccc;animation: tips-line .5s al…

直播小程序 app 系統架構分析

一、引言 直播行業近年來發展迅猛&#xff0c;直播小程序和 APP 成為眾多用戶獲取直播內容以及主播進行內容輸出的重要平臺。一個完善且高效的系統架構是支撐直播業務穩定運行、提供優質用戶體驗的關鍵。本文將詳細剖析直播小程序 / APP 的系統架構&#xff0c;包括整體架構設計…

Vue常見題目

1. 什么是 Vue.js&#xff1f;它的核心特點是什么&#xff1f; Vue.js 是一個漸進式 JavaScript 框架&#xff0c;用于構建用戶界面。它的核心特點包括&#xff1a; - 響應式數據綁定 - 組件化開發 - 虛擬 DOM - 指令系統 - 輕量級且易于集成 - 豐富的生態系統&#xff08;Vue…

ipynb文件直接發布csdn

第一步&#xff0c;下載markdown文件 file --> save and export notebook as --> markdown第二步&#xff0c;導入markdown文件 進入csdn發布文章界面&#xff0c;點擊導入&#xff0c;選擇第一步下載的markdown文件即可

廣東省省考備考(第六十四天8.2)——判斷推理(重點回顧)

判斷推理&#xff1a;數量規律 錯題解析解析解析解析解析解析解析標記題解析解析解析解析解析解析解析今日題目正確率&#xff1a;53% 判斷推理&#xff1a;屬性規律 錯題解析解析解析解析解析解析標記題解析解析今日題目正確率&#xff1a;60%

【C++/STL】vector的OJ,深度剖析和模擬實現

vector在OJ中的使用 1.只出現一次的數字 class Solution { public:int singleNumber(vector<int>& nums) {int value 0;for(auto e : v) {value ^ e; }return value;} };2.楊輝三角 class Solution { public:vector<vector<int>> generate(int numRow…

衡石湖倉一體架構深度解構:統一元數據層如何破除數據孤島?

一、數據融合的世紀難題典型困境二、衡石統一元數據層設計架構核心關鍵技術實現智能元數據發現自動構建跨源血緣關系動態查詢重寫 將標準SQL翻譯為最優執行計劃text Original: SELECT SUM(sales) FROM virtual_view Rewritten: [S3] SELECT SUM(amount) FROM crm_sales [My…

Windows 下 fping 指令使用指南

fping 作為一款強大的網絡工具&#xff0c;能夠同時向多個主機發送 ICMP 回聲請求&#xff0c;相較于傳統的 ping 命令&#xff0c;在處理大量主機時具有顯著優勢。 一、fping 簡介? fping 是 “fast pinger” 的縮寫&#xff0c;它可以向一系列 IP 地址發送 ICMP 回聲請求。…

代碼隨想錄day52圖論3

文章目錄101. 孤島的總面積102. 沉沒孤島103. 水流問題104.建造最大島嶼101. 孤島的總面積 題目鏈接 文章講解 #include<bits/stdc.h> using namespace std;int ans 0; // 記錄不與邊界相連的孤島數量 int sum 0; // 當前孤島的面積 bool flag false; /…

linux pip/conda 修改默認cache位置

1 pip pip cache默認在/home/{username}目錄下&#xff0c;容易導致系統盤寫滿報錯。查看pip cache位置pip cache dir假設移動pip cache目錄到 /data/.cache/pip/cache&#xff0c;命令如下pip config set global.cache-dir /data/.cache/pip/cache2 conda 查看conda緩存位置c…