pytorch如何計算導數_PyTorch怎么用?來看這里

構建深度學習模型的基本流程就是:搭建計算圖,求得損失函數,然后計算損失函數對模型參數的導數,再利用梯度下降法等方法來更新參數。

搭建計算圖的過程,稱為“正向傳播”,這個是需要我們自己動手的,因為我們需要設計我們模型的結構。由損失函數求導的過程,稱為“反向傳播”,求導是件辛苦事兒,所以自動求導基本上是各種深度學習框架的基本功能和最重要的功能之一,PyTorch也不例外。

我們今天來體驗一下PyTorch的自動求導吧,好為后面的搭建模型做準備。

一、設置Tensor的自動求導屬性

所有的tensor都有.requires_grad屬性,都可以設置成自動求導。具體方法就是在定義tensor的時候,讓這個屬性為True:

x = tensor.ones(2,4,requires_grad=True)

In [1]: import torchIn [2]: x = torch.ones(2,4,requires_grad=True)In [3]: print(x)tensor([[1., 1., 1., 1.], [1., 1., 1., 1.]], requires_grad=True)

只要這樣設置了之后,后面由x經過運算得到的其他tensor,就都有equires_grad=True屬性了。

可以通過x.requires_grad來查看這個屬性。

In [4]: y = x + 2In [5]: print(y)tensor([[3., 3., 3., 3.], [3., 3., 3., 3.]], grad_fn=)In [6]: y.requires_gradOut[6]: True

如果想改變這個屬性,就調用tensor.requires_grad_()方法:

In [22]: x.requires_grad_(False)Out[22]:tensor([[1., 1., 1., 1.], [1., 1., 1., 1.]])In [21]: print(x.requires_grad,y.requires_grad)False True

這里,注意區別tensor.requires_grad和tensor.requires_grad_()兩個東西,前面是調用變量的屬性值,后者是調用內置的函數,來改變屬性。

二、求導

下面我們來試試自動求導到底怎么樣。

我們首先定義一個計算圖(計算的步驟):

In [28]: x = torch.tensor([[1.,2.,3.],[4.,5.,6.]],requires_grad=True)In [29]: y = x+1In [30]: z = 2*y*yIn [31]: J = torch.mean(z)

這里需要注意的是,要想使x支持求導,必須讓x為浮點類型,也就是我們給初始值的時候要加個點:“.”。不然的話,就會報錯。

即,不能定義[1,2,3],而應該定義成[1.,2.,3.],前者是整數,后者才是浮點數。

上面的計算過程可以表示為:

9004db61ec6f72ac8c55f895328f4075.png

好了,重點注意的地方來了!

x、y、z都是tensor,但是size為(2,3)的矩陣。但是J是對z的每一個元素加起來求平均,所以J是標量。

求導,只能是【標量】對標量,或者【標量】對向量/矩陣求導!

所以,上圖中,只能J對x、y、z求導,而z則不能對x求導。

我們不妨試一試:

  • PyTorch里面,求導是調用.backward()方法。直接調用backward()方法,會計算對計算圖葉節點的導數。獲取求得的導數,用.grad方法。

試圖z對x求導:

In [31]: z.backward()# 會報錯:Traceback (most recent call last) in ()----> 1 z.backward()RuntimeError: grad can be implicitly created only for scalar outputs

正確的應該是J對x求導:

In [33]: J.backward()In [34]: x.gradOut[34]:tensor([[1.3333, 2.0000, 2.6667], [3.3333, 4.0000, 4.6667]])

檢驗一下,求的是不是對的。

J對x的導數應該是什么呢?

9047fbfc9f55a09571225a1d9795e5eb.png

檢查發現,導數就是:

[[1.3333, 2.0000, 2.6667],

[3.3333, 4.0000, 4.6667]]

總結一下,構建計算圖(正向傳播,Forward Propagation)和求導(反向傳播,Backward Propagation)的過程就是:

c8b8d0f0ea84b7cb7b97718457654cc9.png

三、關于backward函數的一些其他問題:

1. 不是標量也可以用backward()函數來求導?

在看文檔的時候,有一點我半天沒搞懂:

他們給了這樣的一個例子:

947acbb92c149d22d7e123d093b9a190.png

我在前面不是說“只有標量才能對其他東西求導”么?它這里的y是一個tensor,是一個向量。按道理不能求導呀。這個參數gradients是干嘛的?

但是,如果看看backward函數的說明,會發現,里面確實有一個gradients參數:

18d6ad4513caa5ed88630e81a1bf07ff.png
6459d0721136202797e512232899e414.png

從說明中我們可以了解到:

  • 如果你要求導的是一個標量,那么gradients默認為None,所以前面可以直接調用J.backward()就行了如果你要求導的是一個張量,那么gradients應該傳入一個Tensor。那么這個時候是什么意思呢?

在StackOverflow有一個解釋很好:

24cba549f87ee73450c53f7965264483.png

一般來說,我是對標量求導,比如在神經網絡里面,我們的loss會是一個標量,那么我們讓loss對神經網絡的參數w求導,直接通過loss.backward()即可。

但是,有時候我們可能會有多個輸出值,比如loss=[loss1,loss2,loss3],那么我們可以讓loss的各個分量分別對x求導,這個時候就采用:

loss.backward(torch.tensor([[1.0,1.0,1.0,1.0]]))

如果你想讓不同的分量有不同的權重,那么就賦予gradients不一樣的值即可,比如:

loss.backward(torch.tensor([[0.1,1.0,10.0,0.001]]))

這樣,我們使用起來就更加靈活了,雖然也許多數時候,我們都是直接使用.backward()就完事兒了。

2. 一個計算圖只能backward一次

一個計算圖在進行反向求導之后,為了節省內存,這個計算圖就銷毀了。

如果你想再次求導,就會報錯。

比如你定義了計算圖:

ef6ed079cb9991e32c7b4bdcdecf2c7e.png

你先求p求導,那么這個過程就是反向的p對y求導,y對x求導。

求導完畢之后,這三個節點構成的計算子圖就會被釋放:

765b57889916d654f257036d5afd8172.png

那么計算圖就只剩下z、q了,已經不完整,無法求導了。

所以這個時候,無論你是想再次運行p.backward()還是q.backward(),都無法進行,報錯如下:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

好,怎么辦呢?

遇到這種問題,一般兩種情況:

1. 你的實際計算,確實需要保留計算圖,不讓子圖釋放。

那么,就更改你的backward函數,添加參數retain_graph=True,重新進行backward,這個時候你的計算圖就被保留了,不會報錯。

但是這樣會吃內存!,尤其是,你在大量迭代進行參數更新的時候,很快就會內存不足,memory out了。

2. 你實際根本沒必要對一個計算圖backward多次,而你不小心多跑了一次backward函數。

通常,你要是在IPython里面聯系PyTorch的時候,因為你會反復運行一個單元格的代碼,所以很容易一不小心把backward運行了多次,就會報錯。這個時候,你就檢查一下代碼,防止backward運行多次即可。

文章轉自:https://zhuanlan.zhihu.com/p/51385110

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/532947.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/532947.shtml
英文地址,請注明出處:http://en.pswp.cn/news/532947.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

全國計算機等級考試東營,東營計算機等級考試報名時間

2017年計算機等級考試已經結束,出國留學網為考生們整理了2018年東營市上半年計算機等級考試報名時間,希望能幫到大家,想了解更多資訊,請關注我們,小編會第一時間更新哦。2018年東營市上半年計算機等級考試報名時間根據…

crt中 新建的連接存儲在哪_數字存儲示波器的VPO技術

當使用數字存儲示波器測量串行傳輸信號、數字電路上的地址/數據/控制總線、信號元器件上的噪聲、復合視頻信號或調制信號時,面臨的最大困難在于這些信號隨機、變化迅速、雜亂或不具備周期性。因此,為了提高捕獲這些信號的幾率,減少數字存儲示…

計算機在平面設計中的作用,比例設計在平面設計中的作用與意義

隨著互聯網的不斷發展,用戶體驗在設計師的產品設計中占有的比重越大了,而今天我們就一起來了解一下,比例設計在平面設計中的作用與意義。一、平面設計中的比例是什么?比例尺是指設計元素相對于其他元素的相對大小。一個物體只有在與其他物體…

元組可以直接添加進數據庫嗎_數據庫篇-第一章:數據庫基本概念

面試必備基礎數據庫知識,掃碼關注公眾號提升01 第一,什么是數據庫?維基百科上是這樣定義的:所謂“數據庫”是以一定方式儲存在一起、能予多個用戶共享、具有盡可能小的冗余度、與應用程序彼此獨立的數據集合。一個數據庫由多個表空…

win7計算機找不到腳本文件夾,win7系統TXT文件打開提示找不到腳本文件的解決方法...

很多小伙伴都遇到過win7系統TXT文件打開提示找不到腳本文件的困惑吧,一些朋友看過網上零散的win7系統TXT文件打開提示找不到腳本文件的處理方法,并沒有完完全全明白win7系統TXT文件打開提示找不到腳本文件是如何解決的,今天小編準備了簡單的解…

大學計算機基礎 小報,word制作電子小報教案.doc

一、學習任務【能力目標】1、能利用word文字處理軟件進行板報類文本信息的處理。2、能設計出不同主題、形式的電子板報。【知識目標】1、初步掌握在word中運用圖片、藝術字、文本框、自選圖形進行綜合處理問題的方法。2、學會設計、評價電子板報。【德育目標】1、激發學生的創造…

剪切文件_轉錄組測序技術和結果解讀(十六)——可變剪切

可變剪切的概念可變剪切是指從一個mRNA前體中通過不同剪接方式,選擇不同的剪接位點組合,所產生不同的mRNA剪接異構體的過程。可變剪切的分類:外顯子缺失 (Exon skipping);可變的5’端剪切 (Alternative 5’ splicing);…

archlinux詳細安裝步驟_最新Centos的liunx安裝寶塔的詳細步驟

很多人買的服務器是win系統或者是liunx系統,要是說win那就基本上不用學習就和自己的電腦一樣操作就可以,但是有些新人剛接觸liunx系統不知道怎么安裝寶塔環境那今天126云就給大家詳細介紹一下 步驟和操作請看下圖準備的東西是 掛載磁盤 這個簡單介紹就是…

excel表格從某個標志計算機,讓Excel也玩多標簽 多個圖表一個窗口 -電腦資料

很多用戶習慣了傲游、火狐瀏覽器的多標簽瀏覽功能,希望能夠在文檔、窗口中也均能實現,將多個程序以標簽的形式顯示在同一窗口之中,軟件安裝(如圖1)和使用方法非常簡單,安裝后,用戶點擊Excel表格,并同時打開…

卡諾模型案例分析_3個維度看競品分析!

誰都想站在巨人的肩膀上,問題是怎么上去?ABC分享會線下24期回顧時間:10月24日 下午13:00-17:30地點:上海嘉定U-CUBE創意空間 參與人數:18人主題:怎樣做競品分析這次活動是第二次有上…

intellij服務器證書不受信任,ssl證書不受信任怎么辦?ssl證書不受信任解決方案有什么?...

隨之愈來愈多的ssl證書錯誤的狀況出現,大伙兒都是有ssl證書不受信任怎么辦這類的難題,而且對這種難題很頭痛,下邊將帶大伙兒解析一下ssl證書不受信任的緣故及解決方案。一、ssl證書不會受到信任是什么緣故1、SSL證書并不是來源于認可的SSL證書…

小馬源碼_Java互聯網架構-重新認識Java8-HashMap-不一樣的源碼解讀

歡迎關注頭條號:java小馬哥周一至周日早九點半!下午三點半!精品技術文章準時送上!!!精品學習資料獲取通道,參見文末看源碼前我們必須先知道一下ConcurrentHashMap的基本結構。ConcurrentHashMap…

安裝默認報表服務器虛擬目錄,報表服務器虛擬目錄(Reporting Services 配置)

報表服務器虛擬目錄(Reporting Services 配置)12/15/2008本文內容使用“報表服務器虛擬目錄”頁可以配置報表服務器的虛擬目錄。用于訪問報表服務器 Web 服務的 URL 將包含該虛擬目錄名稱。完整的 URL 包括前綴(http:// 或 https://)、服務器名稱和虛擬目錄。服務器名稱可能是內…

小程序向webview傳參_獨家 | 支付寶小程序向個人開發者開放公測

基于興趣和周圍小群體開發的個人小程序,才是為支付寶提供更加多樣化的生活服務場景的來源。文 | Tech星球 (微信ID:tech618) 尹非凡、劉寧寧2月26日,Tech星球(微信ID:tech618) 獨家獲悉,支付寶小程序今日正式面向個人…

原神服務器維護后抽獎池會更新嗎,原神:更新維護一小時,補償60原石,玩家祈求多維護幾天!...

10月21號,原神社區發布公告,游戲將會在10月22號7點至11點進行停服維護,所有玩家在這個時間段將無法進入游戲。而作為補償,官方會贈送5級以上的玩家240原石(停服一小時送60原石)。這是偷偷的更新嗎?官方并沒有說更新內容…

涉及子模塊_COMSOL Multiphysics 5.6 RF模塊更新詳解

業界領先的多物理場仿真、App 設計與部署的軟件解決方案提供商COMSOL 公司發布了全新的COMSOL Multiphysics 軟件5.6 版本。新版本為多核和集群計算提供了計算速度更快且內存需求更低的求解器、更加高效的CAD 裝配處理功能、仿真App 布局模板,以及一系列包括剪裁平面…

系統參數shell服務器,shell 調用遠程服務器shell

shell 調用遠程服務器shell 內容精選換一換流程定義文件描述業務邏輯的XML文件,包括workflow.xml、coordinator.xml、bundle.xml三類,最終由Oozie引擎解析并執行。描述業務邏輯的XML文件,包括workflow.xml、coordinator.xml、bundle.xml三類&…

endnote國標_Citavi 與 Endnote 在 Word 插入引用,哪個更適合你?

前言:不黑、不吹,客觀討論,如有補充請留言,我們一定完善內容。我們先看下兩者在 Word 界面的顯示截圖:Endnote :(看起來很簡潔)Citavi :(看起來功能多一些&am…

思科服務器如何修改啟動項,思科配置tftp服務器

思科配置tftp服務器 內容精選換一換使用mount命令掛載文件系統到云服務器,云服務器系統提示timed out。原因1:網絡狀態不穩定。原因2:網絡連接異常。原因3:云服務器DNS配置錯誤,導致解析不到文件系統的域名&#xff0c…

社保費客戶端顯示服務器連接異常,社保費客戶端登錄服務器異常

社保費客戶端登錄服務器異常 內容精選換一換本章節指導您使用MongoDB客戶端,通過彈性云服務器內網方式連接GaussDB(for Mongo)集群實例。操作系統使用場景:彈性云服務器的操作系統以Linux為例,客戶端本地使用的計算機系統以Windows為例。目標…