pytorch 反向傳播

文章目錄

    • 概念
      • 計算圖
      • 自動求導的兩種模式
    • 自動求導-代碼
      • 標量的反向傳播
      • 非標量變量的反向傳播
      • 將某些計算移動到計算圖之外

概念

核心:鏈式法則

深度學習框架通過自動計算導數(自動微分)來加快求導。

實踐中,根據涉及號的模型,系統會構建一個計算圖,來跟蹤計算是哪些數據通過哪些操作組合起來產生輸出。

自動微分使系統能夠隨后反向傳播梯度。

反向傳播:跟蹤整個計算圖,填充關于每個參數的偏導數。

計算圖

  1. 將代碼分解成操作子,將計算表示成一個無環圖
  2. 將計算表示成一個無環圖、

自動求導的兩種模式

反向傳播

  1. 構造計算圖
  2. 前向:執行圖,存儲中間結果
  3. 反向:從相反方向執行圖 - 不需要的枝可以減去,比如正向里的x和y連接的那個枝

自動求導-代碼

標量的反向傳播

案例:假設對函數 y = 2 x T x y=2x^Tx y=2xTx關于列向量x求導

1.首先初始化一個向量

x = torch.arange(4.0) # 創建變量x并為其分配初始值
print(x) #tensor([0., 1., 2., 3.])

2.計算y關于x的梯度之前,需要一個地方來存儲梯度。

x.requires_grad_()等價于x=torch.arange(4.0,requires_grad=True),這樣PyTorch會跟蹤x的梯度,并生成grad屬性,該屬性里記錄梯度。

通常用于表示某個變量或返回值“有意為空”或"暫時沒有值",已經初始化但是沒有值

x.requires_grad_(True)
print(x.grad)  # 默認值是None,存儲導數。

3.計算y的值,y是一個標量,在python中表示為tensor(28., ),并記錄是通過某種乘法操作生成的。

y = 2 * torch.dot(x, x)
print(y) # tensor(28., grad_fn=<MulBackward0>)

4.調用反向傳播函數來自動計算y關于x每個分量的梯度。

y.backward()
print(x.grad) # tensor([ 0.,  4.,  8., 12.])

我們可以知道根據公式來算, y = 2 x T x y=2x^Tx y=2xTx關于列向量x求導的結果是4x,根據打印結果來看結果是正確的。

5.假設此時我們需要繼續計算x所有分量的和,也就是 y = x . s u m ( ) y=x.sum() y=x.sum()

在默認情況下,PyTorch會累計梯度,我們需要調用grad.zero_清空之前的值。

x.grad.zero_()
y = x.sum() # y = x? + x? + x? + x?
print(y)
y.backward()
print(x.grad) # tensor([1., 1., 1., 1.])

非標量變量的反向傳播

在深度學習中,大部分時候目的是 將批次的損失求和之后(標量)再對分量求導。

y.sum()將 y的所有元素相加,得到一個標量 s u m ( y ) = ∑ i = 1 n x i 2 sum(y)=\sum_{i=1}^n x_i^2 sum(y)=i=1n?xi2?

y.sum().backward()等價于y.backward(torch.ones(len(x))

x.grad.zero_()
y = x * x  # y是一個矩陣
print(y) # tensor([0., 1., 4., 9.], grad_fn=<MulBackward0>)  4*1的矩陣
# 等價于y.backward(torch.ones(len(x)))
y.sum().backward()
print(x.grad)  # [0., 2., 4., 6.]

將某些計算移動到計算圖之外

假設 y = f ( x ) , z = g ( y , x ) y=f(x),z=g(y,x) y=f(x),z=g(y,x),我們需要計算 z z z關于 x x x的梯度,正常反向傳播時,梯度會通過 y y y x x x 兩條路徑傳播到 x x x ? z ? x = ? g ? y ? y ? x + ? g ? x \frac{\partial z}{\partial x} = \frac{\partial g}{\partial y} \frac{\partial y}{\partial x} +\frac{\partial g}{\partial x} ?x?z?=?y?g??x?y?+?x?g?。但由于某種原因,希望將 y y y視為一個常數,忽略 y y y x x x的依賴: ? z ? x ∣ y 常數 = ? g ? x \frac{\partial z}{\partial x} |_{y常數} =\frac{\partial g}{\partial x} ?x?z?y常數?=?x?g?

通過 detach() 方法將 y y y從計算圖中分離,使其不參與梯度計算。

z . s u m ( ) 求導 = ? ∑ z i ? x i = u i z.sum() 求導 = \frac{\partial \sum z_i}{\partial x_i} = u_i z.sum()求導=?xi??zi??=ui?

x.grad.zero_()
y = x * x 
print(y) # tensor([0., 1., 4., 9.], grad_fn=<MulBackward0>)
u = y.detach() # 把y看成一個常數從計算圖中分離,不參與梯度計算,但值還是x*x
print(u) # tensor([0., 1., 4., 9.])
z = u * x # z是一個常數*x
print(z) # tensor([ 0.,  1.,  8., 27.], grad_fn=<MulBackward0>)
z.sum().backward() print(x.grad == u) # tensor([True,True,true,True])

執行y.detach()返回一個計算圖之外,但值同y一樣的tensor,只是將函數z中的y替換成了這個等價變量。

但對于y本身來說還是一個在該計算圖中,就可以在y上調用反向傳播函數,得到 y = x ? x y=x*x y=x?x關于 x x x的導數 2 x 2x 2x

x.grad.zero_()
y.sum().backward()
print(x.grad == 2 * x) # tensor([True,True,true,True])

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/75722.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/75722.shtml
英文地址,請注明出處:http://en.pswp.cn/web/75722.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Kotlin日常使用函數記錄

文章目錄 前言字符串集合1.兩個集合的差集2.集合轉數組2.1.集合轉基本數據類型數組2.2.集合轉對象數組 Map1.合并Map1.1.使用 操作符1.2.使用 操作符1.3.使用 putAll 方法1.4.使用 merge 函數 前言 記錄一些kotlin開發中&#xff0c;日常使用的函數和方式之類的&#xff0c;…

詳解正則表達式中的?:、?= 、 ?! 、?<=、?<!

1、?: - 非捕獲組 語法: (?:pattern) 作用: 創建一個分組但不捕獲匹配結果&#xff0c;不會將匹配的文本存儲到內存中供后續使用。 優勢: 提高性能和效率 不占用編號&#xff08;不會影響后續捕獲組的編號&#xff09; 減少內存使用 // 使用捕獲組 let regex1 /(hell…

【無標題】spark編程

Value類型&#xff1a; 9) distinct ? 函數簽名 def distinct()(implicit ord: Ordering[T] null): RDD[T] def distinct(numPartitions: Int)(implicit ord: Ordering[T] null): RDD[T] ? 函數說明 將數據集中重復的數據去重 val dataRDD sparkContext.makeRDD(Lis…

GPT-2 語言模型 - 模型訓練

本節代碼是一個完整的機器學習工作流程&#xff0c;用于訓練一個基于GPT-2的語言模型。下面是對這段代碼的詳細解釋&#xff1a; 文件目錄如下 1. 初始化和數據準備 設置隨機種子 random.seed(1002) 確保結果的可重復性。 定義參數 test_rate 0.2 context_length 128 tes…

架構師面試(二十九):TCP Socket 編程

問題 今天考察網絡編程的基礎知識。 在基于 TCP 協議的網絡 【socket 編程】中可能會遇到很多異常&#xff0c;在下面的相關描述中說法正確的有哪幾項呢&#xff1f; A. 在建立連接被拒絕時&#xff0c;有可能是因為網絡不通或地址錯誤或 server 端對應端口未被監聽&#x…

HTTP實現心跳模塊

HTTP實現心跳模塊 使用輕量級的cHTTP庫cpp-httplib重現實現HTTP心跳模塊 頭文件HttplibHeartbeat.h #ifndef HTTPLIB_HEARTBEAT_H #define HTTPLIB_HEARTBEAT_H#include <string> #include <thread> #include <atomic> #include <chrono> #include …

openharmony—release—4.1開發環境搭建(踩坑記錄)

環境開發需要分別在window以及ubuntu下進行相應設置 一、window 1.安裝DevEco Device Tool OpenAtom OpenHarmony 二、ubuntu 1.將Ubuntu Shell環境修改為bash ls -l /bin/sh 2.打開終端工具&#xff0c;執行如下命令&#xff0c;輸入密碼&#xff0c;然后選擇No&#xff0…

Go學習系列文章聲明

本次學習是基于B站的視頻&#xff0c;【Udemy高分熱門付費課程】Golang&#xff1a;完整開發者指南&#xff08;基礎知識和高級特性&#xff09;中英文字幕_嗶哩嗶哩_bilibili 本人會嘗試輸出視頻中的內容&#xff0c;如有錯誤歡迎指出 next page: Go installation process

error: RPC failed; HTTP 408 curl 22 The requested URL returned error: 408

在git push時報錯&#xff1a;error: RPC failed; HTTP 408 curl 22 The requested URL returned error: 408 原因&#xff1a;可能是推送的文件太大&#xff0c;要么是緩存不夠&#xff0c;要么是網絡不行。 解決方法&#xff1a; 將本地 http.postBuffer 數值調整到500MB&…

Android.bp中添加條件判斷編譯方式

背景&#xff1a; 馬哥學員朋友以前在vip群里&#xff0c;有問道如何在Android.bp中添加條件判斷&#xff0c;在工作中經常需要一套代碼兼容發貨目標版本&#xff0c;即代碼都是公共的一套&#xff0c;但是需要用這一套代碼集成到各個產品設備上 但是這個產品設備可能面臨比…

swift ui基礎

一個樸實無華的目錄 今日學習內容&#xff1a;1.三種布局&#xff08;可以相互包裹&#xff09;1.1 vstack&#xff08;豎直&#xff09;&#xff1a;先寫的在上面1.1 hstack&#xff08;水平&#xff09;&#xff1a;先寫的在左邊1.1 zstack&#xff08;前后&#xff09;&…

第16屆藍橋杯單片機模擬試題Ⅲ

試題 代碼 sys.h #ifndef __SYS_H__ #define __SYS_H__#include <STC15F2K60S2.H> //sys.c extern unsigned char UI; //界面標志(0濕度界面、1參數界面、2時間界面) extern unsigned char time; //時間間隔(1s~10S) extern bit ssflag; //啟動/停止標志…

Node.js中URL模塊詳解

Node.js 中 URL 模塊全部 API 詳解 1. URL 類 const { URL } require(url);// 1. 創建 URL 對象 const url new URL(https://www.example.com:8080/path?queryvalue#hash);// 2. URL 屬性 console.log(協議:, url.protocol); // https: console.log(主機名:, url.hos…

Java接口性能優化面試問題集錦:高頻考點與深度解析

1. 如何定位接口性能瓶頸&#xff1f;常用哪些工具&#xff1f; 考察點&#xff1a;性能分析工具的使用與問題定位能力。 核心答案&#xff1a; 工具&#xff1a;Arthas&#xff08;在線診斷&#xff09;、JProfiler&#xff08;內存與CPU分析&#xff09;、VisualVM、Prometh…

WheatA小麥芽:農業氣象大數據下載器

今天為大家介紹的軟件是WheatA小麥芽&#xff1a;專業純凈的農業氣象大數據系統。下面&#xff0c;我們將從軟件的主要功能、支持的系統、軟件官網等方面對其進行簡單的介紹。 主要內容來源于軟件官網&#xff1a;WheatA小麥芽的官方網站是http://www.wheata.cn/ &#xff0c;…

Python10天突擊--Day 2: 實現觀察者模式

以下是 Python 實現觀察者模式的完整方案&#xff0c;包含同步/異步支持、類型注解、線程安全等特性&#xff1a; 1. 經典觀察者模式實現 from abc import ABC, abstractmethod from typing import List, Anyclass Observer(ABC):"""觀察者抽象基類""…

CST1019.基于Spring Boot+Vue智能洗車管理系統

計算機/JAVA畢業設計 【CST1019.基于Spring BootVue智能洗車管理系統】 【項目介紹】 智能洗車管理系統&#xff0c;基于 Spring Boot Vue 實現&#xff0c;功能豐富、界面精美 【業務模塊】 系統共有三類用戶&#xff0c;分別是&#xff1a;管理員用戶、普通用戶、工人用戶&…

Windows上使用Qt搭建ARM開發環境

在 Windows 上使用 Qt 和 g++-arm-linux-gnueabihf 進行 ARM Linux 交叉編譯(例如針對樹莓派或嵌入式設備),需要配置 交叉編譯工具鏈 和 Qt for ARM Linux。以下是詳細步驟: 1. 安裝工具鏈 方法 1:使用 MSYS2(推薦) MSYS2 提供 mingw-w64 的 ARM Linux 交叉編譯工具鏈…

Python爬蟲教程011:scrapy爬取當當網數據開啟多條管道下載及下載多頁數據

文章目錄 3.6.4 開啟多條管道下載3.6.5 下載多頁數據3.6.6 完整項目下載3.6.4 開啟多條管道下載 在pipelines.py中新建管道類(用來下載圖書封面圖片): # 多條管道開啟 # 要在settings.py中開啟管道 class DangdangDownloadPipeline:def process_item(self, item, spider):…

Mysql -- 基礎

SQL SQL通用語法&#xff1a; SQL分類&#xff1a; DDL: 數據庫操作 查詢&#xff1a; SHOW DATABASES&#xff1b; 創建&#xff1a; CREATE DATABASE[IF NOT EXISTS] 數據庫名 [DEFAULT CHARSET字符集] [COLLATE 排序規則]&#xff1b; 刪除&#xff1a; DROP DATABA…