2025-04-20 李沐深度學習4 —— 自動求導

文章目錄

  • 1 導數拓展
    • 1.1 標量導數
    • 1.2 梯度:向量的導數
    • 1.3 擴展到矩陣
    • 1.4 鏈式法則
  • 2 自動求導
    • 2.1 計算圖
    • 2.2 正向模式
    • 2.3 反向模式
  • 3 實戰:自動求導
    • 3.1 簡單示例
    • 3.2 非標量的反向傳播
    • 3.3 分離計算
    • 3.4 Python 控制流

硬件配置:

  • Windows 11
  • Intel?Core?i7-12700H
  • NVIDIA GeForce RTX 3070 Ti Laptop GPU

軟件環境:

  • Pycharm 2025.1
  • Python 3.12.9
  • Pytorch 2.6.0+cu124

1 導數拓展

1.1 標量導數

基本公式

  • 常數: d ( a ) / d x = 0 d(a)/dx = 0 d(a)/dx=0
  • 冪函數: d ( x n ) / d x = n ? x n ? 1 d(x^n)/dx = n·x^{n-1} d(xn)/dx=n?xn?1
  • 指數/對數:
    • d ( e x ) / d x = e x d(e^x)/dx = e^x d(ex)/dx=ex
    • d ( ln ? x ) / d x = 1 / x d(\ln x)/dx = 1/x d(lnx)/dx=1/x
  • 三角函數:
    • d ( sin ? x ) / d x = cos ? x d(\sin x)/dx = \cos x d(sinx)/dx=cosx
    • d ( cos ? x ) / d x = ? sin ? x d(\cos x)/dx = -\sin x d(cosx)/dx=?sinx
image-20250419111921569

求導法則
d ( u + v ) d x = d u d x + d v d x d ( u v ) d x = u d v d x + v d u d x d f ( g ( x ) ) d x = f ′ ( g ( x ) ) ? g ′ ( x ) \begin{aligned}&\frac{d(u+v)}{dx}=\frac{du}{dx}+\frac{dv}{dx}\\&\frac{d(uv)}{dx}=u\frac{dv}{dx}+v\frac{du}{dx}\\&\frac{df(g(x))}{dx}=f^{\prime}(g(x))\cdotp g^{\prime}(x)\end{aligned} ?dxd(u+v)?=dxdu?+dxdv?dxd(uv)?=udxdv?+vdxdu?dxdf(g(x))?=f(g(x))?g(x)?
不可微函數的導數:亞導數

  • ∣ x ∣ |x| x x = 0 x=0 x=0 時的亞導數: [ ? 1 , 1 ] [-1,1] [?1,1] 區間任意值。
  • ReLU 函數:max(0,x) x = 0 x=0 x=0 時導數可取 [ 0 , 1 ] [0,1] [0,1]
image-20250419112225706

1.2 梯度:向量的導數

形狀匹配規則

函數類型自變量類型導數形狀示例
標量 y y y標量 x x x標量 d y / d x = 2 x dy/dx = 2x dy/dx=2x
標量 y y y向量 x \mathbf{x} x行向量 d y / d x = [ 2 x 1 , 4 x 2 ] dy/d\mathbf{x} = [2x_1,4x_2] dy/dx=[2x1?,4x2?]
向量 y \mathbf{y} y標量 x x x列向量 d y / d x = [ cos ? x , ? sin ? x ] T d\mathbf{y}/dx = [\cos x, -\sin x]^T dy/dx=[cosx,?sinx]T
向量 y \mathbf{y} y向量 x \mathbf{x} x雅可比矩陣 d y / d x = [ [ 1 , 0 ] , [ 0 , 1 ] ] d\mathbf{y}/d\mathbf{x} = [[1,0],[0,1]] dy/dx=[[1,0],[0,1]]
image-20250419112819630

案例 1

  • y y y x 1 2 + 2 x 2 2 x_1^2 + 2x_2^2 x12?+2x22?(第一個元素的平方與第二個元素平方的 2 倍之和)
  • x \mathbf{x} x:向量。

d y d x = [ 2 x 1 , 4 x 2 ] \frac{dy}{d\mathbf{x}}=\begin{bmatrix}2x_1,&4x_2\end{bmatrix} dxdy?=[2x1?,?4x2??]

  • 幾何解釋:梯度向量 [2, 4] 指向函數值增長最快方向。
image-20250419113359974
  • 其他情況

    image-20250419113541593

案例 2

  • y \mathbf{y} y:向量。
  • x x x:標量。
image-20250419113648965

案例 3

  • y \mathbf{y} y:向量。
  • x \mathbf{x} x:向量。
image-20250419113904878
  • 其他情況
image-20250419113933471

1.3 擴展到矩陣

image-20250419114117161

1.4 鏈式法則

標量鏈式法則的向量化

? 當 y = f ( u ) , u = g ( x ) y = f(u), u = g(x) y=f(u),u=g(x) 時:
d y d x = d y d u ? d u d x \frac{dy}{d\mathbf{x}}=\frac{dy}{du}\cdot\frac{du}{d\mathbf{x}} dxdy?=dudy??dxdu?

  • d y / d u dy/du dy/du:標量 → 形狀不變
  • d u / d x du/dx du/dx:若 u u u 是向量, x x x 是向量 → 雅可比矩陣(形狀 [ d i m ( u ) , d i m ( x ) ] [dim(u), dim(x)] [dim(u),dim(x)]
image-20250419114750475

多變量鏈式法則
d z d w = d z d b ? d b d a ? d a d w = 2 b ? 1 ? x T \frac{dz}{d\mathbf{w}}=\frac{dz}{db}\cdot\frac{db}{da}\cdot\frac{da}{d\mathbf{w}}=2b\cdot1\cdot\mathbf{x}^T dwdz?=dbdz??dadb??dwda?=2b?1?xT

  • 示例:線性回歸 z = ( x T w ? y ) 2 z = (x^Tw - y)^2 z=(xTw?y)2 的梯度計算
image-20250419114839290

2 自動求導

? 自動求導計算一個函數在指定值上的導數,它有別于

  • 符號求導
  • 數值求導
image-20250420134058142

2.1 計算圖

構建原理

  • 將代碼分解成操作子

  • 將計算表示成一個無換圖

    • 節點:輸入變量(如 x , w , y x,w,y x,w,y)或基本操作(如 + , ? , × +,-,× +,?,×

    • 邊:數據流向

image-20250420134206599

顯式 vs 隱式構造

類型代表框架特點
顯式TensorFlow,Mxnet,Theano先定義計算圖,后喂入數據
隱式PyTorch,Mxnet動態構建圖,操作即記錄
image-20250420134557212

2.2 正向模式

? 從輸入到輸出逐層計算梯度,每次計算一個輸入變量對輸出的梯度,通過鏈式法則逐層傳遞梯度。

? 以 z = ( x ? w ? y ) 2 z = (x \cdot w - y)^2 z=(x?w?y)2 為例(線性回歸損失函數):

# 正向計算過程
a = x * w    # a對x的梯度:?a/?x = w
b = a - y    # b對a的梯度:?b/?a = 1
z = b ** 2   # z對b的梯度:?z/?b = 2b
  • 特點:每次只能計算一個輸入變量(如 xw)的梯度,需多次計算。

  • 計算復雜度:O(n)n 為輸入維度)

  • 內存復雜度:O(1)(不需要存儲中間結果)

  • 適用場景:輸入維度低(如參數少)、輸出維度高的函數。

2.3 反向模式

? 從輸出到輸入反向傳播梯度,一次性計算所有輸入變量對輸出的梯度。

數學原理

  • 前向計算

    計算所有中間值(a,b,z)并存儲。

  • 反向傳播(Back Propagation,也稱反向傳遞)

    從輸出z開始,按鏈式法則逐層回傳梯度。

    先計算 ?z/?b = 2b,再計算 ?b/?a = 1,最后計算 ?a/?x = w

image-20250420134829609

? 同樣以 z = ( x ? w ? y ) 2 z = (x \cdot w - y)^2 z=(x?w?y)2 為例:

  1. 前向計算

    a = x * w    # 存儲 a
    b = a - y    # 存儲 b
    z = b ** 2
    
  2. 反向傳播

    dz_db = 2 * b          # ?z/?b
    db_da = 1              # ?b/?a
    da_dx = w              # ?a/?x
    dz_dx = dz_db * db_da * da_dx  # 最終梯度
    
image-20250420135917081
  • 計算復雜度:O(n)(與正向模式相同)
  • 內存復雜度:O(n)(需存儲所有中間變量)
  • 適用場景:深度學習(輸入維度高,輸出為標量損失函數)。

3 實戰:自動求導

3.1 簡單示例

? 以函數 y = 2 x ? x y=2\mathbf{x}^{\top}\mathbf{x} y=2x?x 為例,關于列向量 x \mathbf{x} x 求導。

  1. 首先,創建變量x并為其分配一個初始值。

    import torchx = torch.arange(4.0)
    x
    
    image-20250420214507130
  2. 在計算 y y y 關于 x \mathbf{x} x 的梯度之前,需要一個地方來存儲梯度。

    我們不會在每次對一個參數求導時都分配新的內存。

    因為我們經常會成千上萬次地更新相同的參數,每次都分配新的內存可能很快就會將內存耗盡。

    注意,一個標量函數關于向量 x \mathbf{x} x 的梯度是向量,并且與 x \mathbf{x} x 具有相同的形狀。

    x.requires_grad_(True)  # 等價于x=torch.arange(4.0,requires_grad=True)
    x.grad  # 默認值是None
    
  3. 現在計算 y y y

    y = 2 * torch.dot(x, x)
    y
    
    image-20250420214911500
  4. x是一個長度為 4 的向量,計算xx的點積,得到了我們賦值給y的標量輸出。
    接下來,通過調用反向傳播函數來自動計算y關于x每個分量的梯度,并打印這些梯度。

    y.backward()
    x.grad
    
    image-20250420215014257
  5. 函數 y = 2 x ? x y=2\mathbf{x}^{\top}\mathbf{x} y=2x?x 關于 x \mathbf{x} x 的梯度應為 4 x 4\mathbf{x} 4x。讓我們快速驗證這個梯度是否計算正確。

    x.grad == 4 * x
    
    image-20250420215153504
  6. 探究x的另一個函數。

    # 在默認情況下,PyTorch會累積梯度,我們需要清除之前的值
    x.grad.zero_()
    y = x.sum()
    y.backward()
    x.grad
    
    image-20250420215427676

3.2 非標量的反向傳播

? 當y不是標量時,向量y關于向量x的導數的最自然解釋是一個矩陣。

? 對于高階和高維的yx,求導的結果可以是一個高階張量。

? 雖然這些更奇特的對象確實出現在高級機器學習中(包括[深度學習中]),但當調用向量的反向計算時,我們通常會試圖計算一批訓練樣本中每個組成部分的損失函數的導數。
? 這里,我們的目的不是計算微分矩陣,而是單獨計算批量中每個樣本的偏導數之和。

# 對非標量調用backward需要傳入一個gradient參數,該參數指定微分函數關于self的梯度。
# 本例只想求偏導數的和,所以傳遞一個1的梯度是合適的
x.grad.zero_()
y = x * x
# 等價于y.backward(torch.ones(len(x)))
y.sum().backward()
x.grad
image-20250420220408467

理解

  1. x.grad.zero_()

    • 作用:清空 x 的梯度(grad)緩存。

    • 為什么需要清零?

      • PyTorch 會累積梯度(grad),如果之前已經計算過 x 的梯度(比如在循環中多次 backward()),新的梯度會加到舊的梯度上。
      • 調用 zero_() 可以避免梯度累積,確保每次計算都是新的梯度。

  1. y = x \* x

    • 計算 y = x2(逐元素相乘)。

    • 例如:

      • 如果 x = [1, 2, 3],那么 y = [1, 4, 9]

  1. y.backward(torch.ones(len(x)))

    • backward() 的作用:計算 yx 的梯度(即 dy/dx)。
    • 為什么需要 gradient 參數?
      • 如果 y 是 標量(單個值),可以直接調用 y.backward(),PyTorch 會自動計算 dy/dx
      • 但如果 y 是 非標量(向量/矩陣),PyTorch 不知道如何計算梯度,必須傳入一個 gradient 參數(形狀和 y 相同),表示 y 的梯度權重。
    • gradient=torch.ones(len(x)) 的含義:
      • 這里 gradient 是一個全 1 的張量,表示我們希望計算 y 的所有分量對 x 的 梯度之和(相當于 sum(y)x 的梯度)。
      • 數學上:
        • y = [y?, y?, y?] = [x?2, x?2, x?2]
        • sum(y) = x?2 + x?2 + x?2
        • d(sum(y))/dx = [2x?, 2x?, 2x?](這就是 x.grad 的結果)

  1. 結果 x.grad

    • 由于 y = x2dy/dx = 2x

    • 由于 gradient=torch.ones(len(x)),PyTorch 計算的是 sum(y) 的梯度:

      • x.grad = [2x?, 2x?, 2x?](即 2 * x)。
    • 例如:

      • 如果 x = [1, 2, 3],那么 x.grad = [2, 4, 6]

3.3 分離計算

? 有時,我們希望將某些計算移動到記錄的計算圖之外。例如,假設y是作為x的函數計算的,而z則是作為yx的函數計算的。

? 想象一下,我們想計算z關于x的梯度,但由于某種原因,希望將y視為一個常數,并且只考慮到xy被計算后發揮的作用。這里可以分離y來返回一個新變量u,該變量與y具有相同的值,但丟棄計算圖中如何計算y的任何信息?

? 換句話說,梯度不會向后流經ux。因此,下面的反向傳播函數計算z = u * x關于x的偏導數,同時將u作為常數處理,而不是z = x * x * x關于x的偏導數。

x.grad.zero_()
y = x * x
u = y.detach()
z = u * xz.sum().backward()
x.grad == u
image-20250420222006669

? 由于記錄了y的計算結果,我們可以隨后在y上調用反向傳播,得到y = x * x關于的x的導數,即2 * x

x.grad.zero_()
y.sum().backward()
x.grad == 2 * x
image-20250420225127299

3.4 Python 控制流

? 使用自動微分的一個好處是:即使構建函數的計算圖需要通過 Python 控制流(例如,條件、循環或任意函數調用),我們仍然可以計算得到的變量的梯度。
? 在下面的代碼中,while循環的迭代次數和if語句的結果都取決于輸入a的值。

def f(a):# type: (torch.Tensor)->torch.Tensorb = a * 2while b.norm() < 1000:b = b * 2if b.sum() > 0:c = belse:c = 100 * breturn c

? 讓我們計算梯度。

a = torch.randn(size=(), requires_grad=True)
d = f(a)
d.backward()

? 我們現在可以分析上面定義的f函數。請注意,它在其輸入a中是分段線性的。換言之,對于任何a,存在某個常量標量k,使得f(a)=k*a,其中k的值取決于輸入a,因此可以用d/a驗證梯度是否正確。

a.grad, d / a, a.grad == d / a
image-20250420225434814

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/76932.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/76932.shtml
英文地址,請注明出處:http://en.pswp.cn/web/76932.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Redis的使用總結

Redis 核心使用場景 緩存加速 高頻訪問數據緩存&#xff08;如商品信息、用戶信息&#xff09; 緩解數據庫壓力&#xff0c;提升響應速度 會話存儲 分布式系統共享 Session&#xff08;替代 Tomcat Session&#xff09; 支持 TTL 自動過期 排行榜/計數器 實時排序&#x…

富文本編輯器實現

&#x1f3a8; 富文本編輯器實現原理全解析 &#x1f4dd; 基本實現路徑圖 #mermaid-svg-MO1B8a6kAOmD8B6Y {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-MO1B8a6kAOmD8B6Y .error-icon{fill:#552222;}#mermaid-s…

LeetCode熱題100——283. 移動零

給定一個數組 nums&#xff0c;編寫一個函數將所有 0 移動到數組的末尾&#xff0c;同時保持非零元素的相對順序。 請注意 &#xff0c;必須在不復制數組的情況下原地對數組進行操作。 示例 1: 輸入: nums [0,1,0,3,12] 輸出: [1,3,12,0,0] 示例 2: 輸入: nums [0] 輸出:…

與Ubuntu相關命令

windows將文件傳輸到Ubuntu 傳輸文件夾或文件 scp -r 本地文件夾或文件 ubuntu用戶名IP地址:要傳輸到的文件夾路徑 例如&#xff1a; scp -r .\04.py gao192.168.248.129:/home/gao 如果傳輸文件也可以去掉-r 安裝軟件 sudo apt-get update 更新軟件包列表 sudo apt insta…

Kafka 在小流量和大流量場景下的順序消費問題

一、低流量系統 特點 消息量較少&#xff0c;吞吐量要求低。系統資源&#xff08;如 CPU、內存、網絡&#xff09;相對充足。對延遲容忍度較高。 保證順序消費的方案 單分區 單消費者 將消息發送到單個分區&#xff08;例如固定 Partition 0&#xff09;&#xff0c;由單個…

小程序 GET 接口兩種傳值方式

前言 一般 GET 接口只有兩種URL 參數和路徑參數 一&#xff1a;URL 參數&#xff08;推薦方式&#xff09; 你希望請求&#xff1a; https://serve.zimeinew.com/wx/products/info?id5124接口應該寫成這樣&#xff0c;用 req.query.id 取 ?id5124&#xff1a; app.get(&…

小白學習java第14天(中):數據庫

1.DML data manage language數據庫管理語言 外鍵:外鍵是什么&#xff1f;就是對其進行表與表之間的聯系&#xff0c;就是使用的鍵進行關聯&#xff01; 方法一&#xff1a;我們在數據庫里面就對其進行表與表之間的連接【這種是不建議的&#xff0c;我不太喜歡就是將數據里面弄…

NO.95十六屆藍橋杯備戰|圖論基礎-單源最短路|負環|BF判斷負環|SPFA判斷負環|郵遞員送信|采購特價產品|拉近距離|最短路計數(C++)

P3385 【模板】負環 - 洛谷 如果圖中存在負環&#xff0c;那么有可能不存在最短路。 BF算法判斷負環 執?n輪松弛操作&#xff0c;如果第n輪還存在松弛操作&#xff0c;那么就有負環。 #include <bits/stdc.h> using namespace std;const int N 2e3 10, M 3e3 1…

K8s pod 應用

/** 個人學習筆記&#xff0c;如有問題歡迎交流&#xff0c;文章編排和格式等問題見諒&#xff01; */ &#xff08;1&#xff09;編寫 pod.yaml 文件 pod 是 kubernetes 中最小的編排單位&#xff0c;一個 pod 里包含一個或多個容器。 apiVersion: v1 # 指定api版本 kind…

Oracle創建觸發器實例

一 創建DML 觸發器 DML觸發器基本要點&#xff1a; 觸發時機&#xff1a;指定觸發器的觸發時間。如果指定為BEFORE&#xff0c;則表示在執行DML操作之前觸發&#xff0c;以便防止某些錯誤操作發生或實現某些業務規則&#xff1b;如果指定為AFTER&#xff0c;則表示在執行DML操作…

Filename too long 錯誤

Filename too long 錯誤表明文件名超出了文件系統或版本控制系統允許的最大長度。 可能的原因 文件系統限制 不同的文件系統對文件名長度有不同的限制。例如&#xff0c;FAT32 文件名最長為 255 個字符&#xff0c;而 NTFS 雖然支持較長的文件名&#xff0c;但在某些情況下也…

網絡不可達network unreachable問題解決過程

問題&#xff1a;訪問一個環境中的路由器172.16.1.1&#xff0c;發現ssh無法訪問&#xff0c;ping發現回網絡不可達 C:\Windows\System32>ping 172.16.1.1 正在 Ping 172.16.1.1 具有 32 字節的數據: 來自 172.16.81.1 的回復: 無法訪問目標網。 來自 172.16.81.1 的回復:…

Python設計模式:備忘錄模式

1. 什么是備忘錄模式&#xff1f; 備忘錄模式是一種行為設計模式&#xff0c;它允許在不暴露對象內部狀態的情況下&#xff0c;保存和恢復對象的狀態。備忘錄模式的核心思想是將對象的狀態保存到一個備忘錄對象中&#xff0c;以便在需要時可以恢復到之前的狀態。這種模式通常用…

Python基礎語法3

目錄 1、函數 1.1、語法格式 1.2、函數返回值 1.3、變量作用域 1.4、執行過程 1.5、鏈式調用 1.6、嵌套調用 1.7、函數遞歸 1.8、參數默認值 1.9、關鍵字參數 2、列表 2.1、創建列表 2.2、下標訪問 2.3、切片操作 2.4、遍歷列表元素 2.5、新增元素 2.6、查找元…

JavaEE學習筆記(第二課)

1、好用的AI代碼工具cursor 2、Java框架&#xff1a;Spring(高級框架)、Servelt、Struts、EJB 3、Spring有兩層含義&#xff1a; ①Spring Framework&#xff08;原始框架&#xff09; ②Spring家族 4、Spring Boot(為了使Spring簡化) 5、創建Spring Boot 項目 ① ② ③…

基于Flask與Ngrok實現Pycharm本地項目公網訪問:從零部署

目錄 概要 1. 環境與前置條件 2. 安裝與配置 Flask 2.1 創建虛擬環境 2.2 安裝 Flask 3. 安裝與配置 Ngrok 3.1 下載 Ngrok 3.2 注冊并獲取 Authtoken 4. 在 PyCharm 中創建 Flask 項目 5. 運行本地 Flask 服務 6. 啟動 Ngrok 隧道并獲取公網地址 7. 完整示例代碼匯…

Ragflow、Dify、FastGPT、COZE核心差異對比與Ragflow的深度文檔理解能力??和??全流程優化設計

一、Ragflow、Dify、FastGPT、COZE核心差異對比 以下從核心功能、目標用戶、技術特性等維度對比四款工具的核心差異&#xff1a; 核心功能定位 ? Ragflow&#xff1a;專注于深度文檔理解的RAG引擎&#xff0c;擅長處理復雜格式&#xff08;PDF、掃描件、表格等&#xff09;的…

LeetCode[232]用棧實現隊列

思路&#xff1a; 一道很簡單的題&#xff0c;就是棧是先進后出&#xff0c;隊列是先進先出&#xff0c;用兩個棧底相互對著&#xff0c;這樣一個隊列就產生了&#xff0c;右棧為空的情況&#xff0c;左棧棧底就是隊首元素&#xff0c;所以我們需要將左棧全部壓入右棧&#xff…

postman 刪除注銷賬號

一、刪除賬號 1.右上角找到 頭像&#xff0c;view profile https://123456-6586950.postman.co/settings/me/account 二、找回賬號 1.查看日志所在位置 三、postman更新后只剩下history 在 Postman 中&#xff0c;如果你發現更新后只剩下 History&#xff08;歷史記錄&…

微服務相比傳統服務的優勢

這是一道面試題&#xff0c;咱們先來分析這道題考察的是什么。 如果分析面試官主要考察以下幾個方面&#xff1a; 技術理解深度 你是否清楚微服務架構&#xff08;Microservices&#xff09;和傳統單體架構&#xff08;Monolithic&#xff09;的本質區別。能否從設計理念、技術…