深度學習 -- 梯度計算及上下文控制

深度學習 – 梯度計算及上下文控制

文章目錄

  • 深度學習 -- 梯度計算及上下文控制
  • 一,自動微分
    • 1.1 基礎概念
    • 1.2 計算梯度
      • 1.2.1 計算標量梯度
      • 1.2.2 計算向量梯度
      • 1.2.3 多標量梯度計算
      • 1.2.4 多向量梯度計算
  • 二,梯度上下文控制
    • 2.1 控制梯度計算
    • 2.2 累計梯度
    • 2.3 梯度清零
  • 三,案例一----求函數最小值
  • 四,案例二----函數參數求解


一,自動微分

自動微分模塊torch.autograd負責自動計算張量操作的梯度,具有自動求導功能。自動微分模塊是構成神經網絡訓練的必要模塊,可以實現網絡權重參數的更新,使得反向傳播算法的實現變得簡單而高效。

1.1 基礎概念

  1. 張量

    Torch中一切皆為張量,屬性requires_grad決定是否對其進行梯度計算。默認是 False,如需計算梯度則設置為True。

  2. 計算圖

    torch.autograd通過創建一個動態計算圖來跟蹤張量的操作,每個張量是計算圖中的一個節點,節點之間的操作構成圖的邊。

    在 PyTorch 中,當張量的 requires_grad=True 時,PyTorch 會自動跟蹤與該張量相關的所有操作,并構建計算圖。每個操作都會生成一個新的張量,并記錄其依賴關系。當設置為 True 時,表示該張量在計算圖中需要參與梯度計算,即在反向傳播(Backpropagation)過程中會自動計算其梯度;當設置為 False 時,不會計算梯度。

  • 計算依賴圖
    在這里插入圖片描述
    葉子結點判斷方式
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)  # 葉子節點
y = x ** 2  # 非葉子節點(通過計算生成)
z = y.sum()print(x.is_leaf)  # True
print(y.is_leaf)  # False
print(z.is_leaf)  # False

1.2 計算梯度

使用tensor.backward()方法執行反向傳播,從而計算張量的梯度

1.2.1 計算標量梯度

import torchdef test01():x= torch.tensor(1.0,requires_grad=True)# requires_grad=True表示需要求導y=x**2# 操作張量y.backward() #計算梯度,也就是反向傳播print(x.grad) # 打印梯度if __name__ == '__main__':test01()

1.2.2 計算向量梯度

def test002():# 1. 創建張量:必須為浮點類型x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)y = x ** 2loss = y.mean()loss.backward()print(x.grad)if __name__ == "__main__":test002()

調用 loss.backward() 從輸出節點 loss 開始,沿著計算圖反向傳播,計算每個節點的梯度。

損失函數loss=mean(y)=1n∑i=1nyiloss=mean(y)=\frac{1}{n}∑_{i=1}^ny_iloss=mean(y)=n1?i=1n?yi?,其中 n=3。

對于每個 yiy_iyi?,其梯度為 ?loss?yi=1n=13\frac{?loss}{?y_i}=\frac{1}{n}=\frac13?yi??loss?=n1?=31?

對于每個 xix_ixi?,其梯度為:
?loss?xi=?loss?yi×?yi?xi=13×2xi=2xi3\frac{?loss}{?x_i}=\frac{?loss}{?y_i}×\frac{?y_i}{?x_i}=\frac1{3}×2x_i=\frac{2x_i}3 ?xi??loss?=?yi??loss?×?xi??yi??=31?×2xi?=32xi??
所以,x.grad 的值為:[2×1.03,2×2.03,2×3.03]=[23,43,2]≈[0.6667,1.3333,2.0000][\frac{2×1.0}3, \frac{2×2.0}3, \frac{2×3.0}3]=[\frac23,\frac43,2]≈[0.6667,1.3333,2.0000][32×1.0?,32×2.0?,32×3.0?]=[32?,34?,2][0.6667,1.3333,2.0000]

1.2.3 多標量梯度計算

#多標量梯度計算
def test03():x1= torch.tensor(1.0,requires_grad=True)x2=torch.tensor(3.0,requires_grad=True)y=x1**2+x2*7z=y.sum()z.backward()print(x1.grad,x2.grad)if __name__ == '__main__':test03()

1.2.4 多向量梯度計算

# 多向量梯度計算
def test04():x1 = torch.tensor([1.0, 2.0], requires_grad=True)x2 = torch.tensor([3.0, 4.0], requires_grad=True)y = x1 ** 2 + x2 * 7z = y.sum()z.backward()print(x1.grad, x2.grad)if __name__ == '__main__':test04()

二,梯度上下文控制

梯度計算的上下文控制和設置對于管理計算圖、內存消耗、以及計算效率至關重要。下面我們學習下Torch中與梯度計算相關的一些主要設置方式。

2.1 控制梯度計算

簡單的運算不需要梯度

import torchdef test01():x=torch.tensor(11.5,requires_grad=True)print(x.requires_grad)# Truey=x**2print(y.requires_grad)# Truewith torch.no_grad():z=x**2print(z.requires_grad)#使用裝飾器@torch.no_grad()def test():return x**2y=test()print(y.requires_grad)if __name__ == '__main__':test01()

2.2 累計梯度

默認情況下,當我們重復對一個自變量進行梯度計算時,梯度是累加的

import torchdef test002():x = torch.tensor([1.0, 2.0, 5.3], requires_grad=True)# 2. 累計梯度:每次計算都會累計梯度for i in range(3):y = x**2 + 2 * x + 7z = y.mean()z.backward()print(x.grad)if __name__ == "__main__":test002()

2.3 梯度清零

大多數情況下是不需要梯度累加的

#梯度清零
def test002():x = torch.tensor([1.0, 2.0, 5.3], requires_grad=True)# 2. 累計梯度:每次計算都會累計梯度for i in range(3):y = x**2 + 2 * x + 7z = y.mean()# 2.1 反向傳播之前先對梯度進行清零if x.grad is not None:x.grad.zero_()z.backward()print(x.grad)if __name__ == "__main__":test002()

三,案例一----求函數最小值

import torch
from matplotlib import pyplot as plt
import numpy as npdef test01():x = np.linspace(-10, 10, 100)y = x ** 2plt.plot(x, y)plt.show()def test02():# 初始化自變量Xx = torch.tensor([3.0], requires_grad=True, dtype=torch.float)# 迭代輪次epochs = 50# 學習率lr = 0.1list = []for i in range(epochs):# 計算函數表達式y = x ** 2# 梯度清零if x.grad is not None:x.grad.zero_()# 反向傳播y.backward()with torch.no_grad():x -= lr * x.gradprint('epoch:', i, 'x:', x.item(), 'y:', y.item())list.append((x.item(), y.item()))# 散點圖,觀察收斂效果x_list = [l[0] for l in list]y_list = [l[1] for l in list]plt.scatter(x=x_list, y=y_list)plt.show()if __name__ == "__main__":test01()test02()

四,案例二----函數參數求解

import torchdef test02():x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float)y = torch.tensor([3, 5, 7, 9, 11], dtype=torch.float)a = torch.tensor([1.0], dtype=torch.float, requires_grad=True)b = torch.tensor([1.0], dtype=torch.float, requires_grad=True)lr = 0.05epochs = 1000for epoch in range(epochs):y_pred = a * x + bloss = ((y_pred - y) ** 2).mean()if a.grad is not None:a.grad.zero_()if b.grad is not None:b.grad.zero_()loss.backward()with torch.no_grad():a -= lr * a.gradb -= lr * b.gradif (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')print(f'a: {a.item():.4f}, b: {b.item():.4f}')if __name__ == '__main__':test02()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/915676.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/915676.shtml
英文地址,請注明出處:http://en.pswp.cn/news/915676.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Redisson RLocalCachedMap 核心參數詳解

🧑 博主簡介:CSDN博客專家,歷代文學網(PC端可以訪問:https://literature.sinhy.com/#/?__c1000,移動端可微信小程序搜索“歷代文學”)總架構師,15年工作經驗,精通Java編…

【Unity3D實例-功能-移動】角色移動-通過WSAD(Rigidbody方式)

你是否曾夢想在虛擬世界中自由翱翔,像海豚一樣在海洋自由穿梭,或者像宇航員一樣在宇宙中盡情探索?今天,我們就來聊聊如何在Unity中使用Rigidbody來實現角色移動。 廢話不多說,走,讓我們馬上來一探究竟&…

Vue接口平臺學習十一——業務流測試

效果圖及簡單說明 與之前的用例列表相似布局,也分左右,左邊用于顯示測試流程的名稱,右邊用于顯示流程相關信息。 左側點擊添加,直接增加一個新的業務流。 右側是點擊的業務流詳情,展示名稱,名稱的編輯保存&…

碳化硅缺陷分類與原因

01一、碳化硅晶體材料中的缺陷到底是什么?碳化硅晶體材料中的缺陷是指在晶體生長、加工或使用過程中出現的不完美結構。這些缺陷可能表現為晶體內部的裂紋、表面的凹坑、原子排列的錯誤等。雖然缺陷看起來微不足道,但它們卻可能對晶體的電學、熱學和機械…

Jenkins 實現項目的構建和發布

作者:小凱 沉淀、分享、成長,讓自己和他人都能有所收獲! 本文的宗旨在于通過簡單干凈實踐的方式教會讀者,如何在 Docker 中部署 Jenkins,并通過 Jenkins 完成對項目的打包構建并在 Docker 容器中部署。 Jenkins 的主要…

Django接口自動化平臺實現(三)

3.2 后臺 admin 添加數據 1)注冊模型類到 admin: 1 from django.contrib import admin2 from . import models3 4 5 class ProjectAdmin(admin.ModelAdmin):6 list_display ("id", "name", "proj_owner", "tes…

CentOS 7 配置環境變量常見的4種方式

?博客主頁: https://blog.csdn.net/m0_63815035?typeblog 💗《博客內容》:.NET、Java.測試開發、Python、Android、Go、Node、Android前端小程序等相關領域知識 📢博客專欄: https://blog.csdn.net/m0_63815035/cat…

k8s:手動創建PV,解決postgis數據庫本地永久存儲

1.離線環境CPU:Hygon C86 7285 32-core Processor 操作系統:麒麟操作系統 containerd:1.7.27 Kubernetes:1.26.12 KubeSphere:4.1.2 kubekey:3.1.10 Harbor:2.13.1 Postgis:17-3.52創建StorageClass2.1創建 apiVersion: storage.k8s.io/v1kin…

谷歌瀏覽器Chrome的多用戶配置文件功能

谷歌瀏覽器Chrome的多用戶配置文件功能允許在同一設備上創建多個獨立賬戶,每個賬戶擁有完全隔離的瀏覽數據(如書簽、歷史記錄、擴展、Cookies等),非常適合工作/生活賬戶分離、家庭共享或臨時多賬號登錄場景。 如何使用Chrome的多用戶配置文件功能? 一、創建與切換用戶 1.…

傲軟錄屏 專業高清錄屏軟件 ApowerREC Pro 下載與保姆級安裝教程!!

小編今天分享一款強大的電腦屏幕錄像軟件 傲軟錄屏 ApowerREC,能夠幫助用戶錄制中電腦桌面屏幕上的所有內容,包括畫面和聲音,支持全屏錄制、區域錄制、畫中畫以及攝像頭錄制等多種視頻錄制模式,此外,還支持計劃任務錄制…

【計算機網絡】MAC地址與IP地址:網絡通信的雙重身份標識

在計算機網絡領域,MAC地址與IP地址是兩個核心概念,它們共同構成了數據傳輸的基礎。理解二者的區別與聯系,對于網絡配置、故障排查及安全管理至關重要。 一、基本概念 1. MAC地址(物理地址) 定義:固化在網絡…

如何用keepAlive實現標簽頁緩存

什么是KeepAlive首先,要明確所說的是TCP的 KeepAlive 還是HTTP的 Keep-Alive。TCP的KeepAlive和HTTP的Keep-Alive是完全不同的概念,不能混為一談。實際上HTTP的KeepAlive寫法是Keep-Alive,跟TCP的KeepAlive寫法上也有不同。TCP的KeepAliveTCP…

數據庫隔離級別

隔離級別決定了事務之間的可見性規則,直接影響數據庫的并發性能和數據一致性。SQL 標準定義了 4 種隔離級別,從低到高依次為:讀未提交→讀已提交→可重復讀→串行化。隔離級別越高,對并發問題的解決能力越強,但對性能的…

基于Python flask的電影數據分析及可視化系統的設計與實現,可視化內容很豐富

摘要:基于Python的電影數據分析及可視化系統是一個應用于電影市場的數據分析平臺,旨在為廣大電影愛好者提供更準確、更詳細、更實用的電影數據。數據分析部分主要是對來自貓眼電影網站上的數據進行清洗、分類處理、存儲等步驟,數據可視化則是…

TCP通訊開發注意事項及常見問題解析

文章目錄一、TCP協議特性與開發挑戰二、粘包與拆包問題深度解析1. 成因原理2. 典型場景與實例驗證3. 系統化解決方案接收方每次讀取10字節2. 丟包檢測與驗證工具3. 工程化解決方案四、連接管理關鍵實踐1. 超時機制設計2. TIME_WAIT狀態優化3. 異常處理最佳實踐五、高性能TCP開發…

2021 RoboCom 世界機器人開發者大賽-本科組(復賽)解題報告 | 珂學家

前言 題解 睿抗機器人開發者大賽CAIP-編程技能賽-歷年真題 匯總 2021 RoboCom 世界機器人開發者大賽-本科組(復賽)解題報告 感覺這個T1特別有意思,非典型題,著重推演下結論。 T2是一道玄學題,但是涉及一些優化技巧…

《計算機“十萬個為什么”》之 MQ

《計算機“十萬個為什么”》之 MQ 📨 歡迎來到消息隊列的奇妙世界! 在這篇文章中,我們將探索 MQ 的奧秘,從基礎概念到實際應用,讓你徹底搞懂這個分布式系統中不可或缺的重要組件!🚀 作者&#x…

Django母嬰商城項目實踐(七)- 首頁數據業務視圖

7、首頁數據業務視圖 1、介紹 視圖(View)是Django的MTV架構模式的V部分,主要負責處理用戶請求和生成相應的響應內容,然后在頁面或其他類型文檔中顯示。 也可以理解為視圖是MVC架構里面的C部分(控制器),主要處理功能和業務上的邏輯。我們習慣使用視圖函數處理HTTP請求,…

android 12 的 aidl for HAL 開發示例

說明:aidl for HAL 這種機制,可以自動生成java代碼,app調用可以獲取中間過程的jar包,結合反射調用 ServiceManager.getService 方法,直接獲取 HAL 服務,不再需要費力在framework層添加代碼,方便…

網絡安全滲透攻擊案例實戰:某公司內網為目標的滲透測試全過程

目錄一、案例背景二、目標分析(信息收集階段)🌐 外部信息搜集🧠 指紋識別和端口掃描三、攻擊流程(滲透測試全過程)🎯 步驟1:Web漏洞利用 —— 泛微OA遠程命令執行漏洞(CV…