深度學習 精選筆記(7)前向傳播、反向傳播和計算圖

學習參考:

  • 動手學深度學習2.0
  • Deep-Learning-with-TensorFlow-book
  • pytorchlightning

①如有冒犯、請聯系侵刪。
②已寫完的筆記文章會不定時一直修訂修改(刪、改、增),以達到集多方教程的精華于一文的目的。
③非常推薦上面(學習參考)的前兩個教程,在網上是開源免費的,寫的很棒,不管是開始學還是復習鞏固都很不錯的。

深度學習回顧,專欄內容來源多個書籍筆記、在線筆記、以及自己的感想、想法,佛系更新。爭取內容全面而不失重點。完結時間到了也會一直更新下去,已寫完的筆記文章會不定時一直修訂修改(刪、改、增),以達到集多方教程的精華于一文的目的。所有文章涉及的教程都會寫在開頭、一起學習一起進步。

前向傳播用于計算模型的預測輸出,反向傳播用于根據預測輸出和真實標簽之間的誤差來更新模型參數。

前向傳播和反向傳播是神經網絡訓練中的核心步驟,通過這兩個過程,神經網絡能夠學習如何更好地擬合數據,提高預測準確性。

一、計算圖

計算圖(Computational Graph)是一種圖形化表示方法,用于描述數學表達式中各個變量之間的依賴關系和計算流程。在深度學習和機器學習領域,計算圖常用于可視化復雜的數學運算和函數計算過程,尤其是在反向傳播算法中的梯度計算過程中被廣泛應用。

計算圖通常包括兩種節點:

  • 計算節點(Compute Nodes):這些節點表示數學運算,如加法、乘法等。計算節點接受輸入,并產生輸出。
  • 數據節點(Data Nodes):這些節點表示數據或變量,如輸入數據、權重、偏置等。

通過連接計算節點和數據節點的邊,構建了一個有向圖,其中每個節點表示一個操作,邊表示數據流向。計算圖可以幫助理解復雜的計算過程,特別是在深度學習中涉及大量參數和運算的情況下。

二、前向傳播

前向傳播(forward propagation或forward pass) 指的是:按順序(從輸入層到輸出層)計算和存儲神經網絡中每層的結果。

前向傳播(Forward Propagation):

  • 定義:前向傳播是指輸入數據通過神經網絡模型的各層,逐層進行計算并傳遞至輸出層的過程。
  • 作用:在前向傳播過程中,輸入數據經過神經網絡的權重和激活函數的計算,最終得到模型的預測輸出。
  • 目的:前向傳播的目的是計算模型對輸入數據的預測值,為后續的損失函數計算和反向傳播提供基礎。

1.前向傳播的計算圖

假設單隱藏層神經網絡中,輸入樣本是 𝐱∈? d, 并且隱藏層不包括偏置項。 這里的中間變量是:
在這里插入圖片描述
其中 𝐖(1)∈??×𝑑 是隱藏層的權重參數。 將中間變量 𝐳∈?? 通過激活函數 𝜙 后, 得到長度為 ? 的隱藏激活向量是:
在這里插入圖片描述
隱藏變量 𝐡也是一個中間變量。 假設輸出層的參數只有權重 𝐖(2)∈?𝑞×?, 可以得到輸出層變量,它是一個長度為 𝑞 的向量:
在這里插入圖片描述
假設損失函數為 𝑙,樣本標簽為 𝑦,可以計算單個數據樣本的損失項,
在這里插入圖片描述
根據 𝐿2 正則化的定義,給定超參數 𝜆 ,正則化項為
在這里插入圖片描述
其中矩陣的Frobenius范數是將矩陣展平為向量后應用的 𝐿2范數。 最后,模型在給定數據樣本上的正則化損失為:
在這里插入圖片描述
該函數J就是目標函數。

繪制計算圖有助于可視化計算中操作符和變量的依賴關系。

與上述簡單網絡相對應的計算圖, 其中正方形表示變量,圓圈表示操作符。 左下角表示輸入,右上角表示輸出。 注意顯示數據流的箭頭方向主要是向右和向上的。
在這里插入圖片描述

三、反向傳播

反向傳播(Backpropagation):

  • 定義:反向傳播是指通過計算損失函數對模型參數的梯度(梯度是一個由偏導數組成的向量,表示函數在某一點處的變化率或者斜率方向、也就是在每個自變量方向上的偏導數),從輸出層向輸入層傳播梯度的過程。
  • 作用:在反向傳播過程中,根據損失函數計算模型參數的梯度,然后利用梯度下降等優化算法更新模型參數,以減小損失函數的值。
  • 目的:反向傳播的目的是根據模型預測與真實標簽的誤差,調整神經網絡中每個參數的值,使模型能夠更好地擬合訓練數據,并提高在新數據上的泛化能力。

反向傳播(backward propagation或backpropagation)指的是計算神經網絡參數梯度的方法。 簡言之,該方法根據微積分中的鏈式規則,按相反的順序從輸出層到輸入層遍歷網絡。 該算法存儲了計算某些參數梯度時所需的任何中間變量(偏導數)。 假設有函數 𝖸=𝑓(𝖷) 和 𝖹=𝑔(𝖸) , 其中輸入和輸出 𝖷,𝖸,𝖹 是任意形狀的張量。 利用鏈式法則,可以計算 𝖹 關于 𝖷 的導數:

在這里插入圖片描述
使用 prod 運算符在執行必要的操作(如換位和交換輸入位置)后將其參數相乘。 對于向量,這很簡單,它只是矩陣-矩陣乘法。

在前向傳播的計算圖中,單隱藏層簡單網絡的參數是 𝐖(1) 和 𝐖(2) 。 反向傳播的目的是計算梯度 ?𝐽/?𝐖(1)?𝐽/?𝐖(2) 。為此,應用鏈式法則,依次計算每個中間變量和參數的梯度。 計算的順序與前向傳播中執行的順序相反,因為需要從計算圖的結果開始,并朝著參數的方向努力。第一步是計算目標函數 𝐽=𝐿+𝑠 相對于損失項 𝐿 和正則項 𝑠 的梯度。

這里為什么等于1?因為單隱藏層簡單網絡的最后一層上面是
在這里插入圖片描述
根據鏈式法則計算目標函數關于輸出層變量 𝐨 的梯度:
在這里插入圖片描述
計算正則化項相對于兩個參數的梯度:

在這里插入圖片描述
計算最接近輸出層的模型參數的梯度 ?𝐽/?𝐖(2)∈?𝑞×? 。 使用鏈式法則得出:

在這里插入圖片描述
為了獲得關于 𝐖(1)的梯度,需要繼續沿著輸出層到隱藏層反向傳播。 關于隱藏層輸出的梯度 ?𝐽/?𝐡∈?? 由下式給出:
在這里插入圖片描述
由于激活函數 𝜙 是按元素計算的, 計算中間變量 𝐳的梯度 ?𝐽/?𝐳∈?? 需要使用按元素乘法運算符,用 表示:
在這里插入圖片描述
最后,可以得到最接近輸入層的模型參數的梯度 ?𝐽/?𝐖(1)∈??×𝑑 。 根據鏈式法則,我們得到:
在這里插入圖片描述

四、訓練神經網絡

在訓練神經網絡時,前向傳播和反向傳播相互依賴。

對于前向傳播,沿著依賴的方向遍歷計算圖并計算其路徑上的所有變量。 然后將這些用于反向傳播,其中計算順序與計算圖的相反。

以上述簡單網絡為例:
正則項:

在這里插入圖片描述
反向傳播中計算J對W(2)的梯度公式:
在這里插入圖片描述
反向傳播中計算J對W(1)的梯度公式:
在這里插入圖片描述
一方面,在前向傳播期間計算正則項取決于模型參數𝐖(1)和 𝐖(2)的當前值。 它們是由優化算法根據最近迭代的反向傳播給出的。 另一方面,反向傳播期間參數的梯度計算, 取決于由前向傳播給出的隱藏變量𝐡的當前值。

因此,在訓練神經網絡時,在初始化模型參數后, 交替使用前向傳播和反向傳播,利用反向傳播給出的梯度來更新模型參數。

注意,反向傳播重復利用前向傳播中存儲的中間值,以避免重復計算。 帶來的影響之一是需要保留中間值,直到反向傳播完成。 這也是訓練比單純的預測需要更多的內存(顯存)的原因之一。 此外,這些中間值的大小與網絡層的數量和批量的大小大致成正比。 因此,使用更大的批量來訓練更深層次的網絡更容易導致內存不足(out of memory)錯誤

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/718267.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/718267.shtml
英文地址,請注明出處:http://en.pswp.cn/news/718267.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

藍橋杯:單詞分析

題目 題目描述 小藍給學生們組織了一場考試,卷面總分為 100 分,每個學生的得分都是一個 0 到 100 的整數。 如果得分至少是 60 分,則稱為及格。如果得分至少為 85 分,則稱為優秀。 請計算及格率和優秀率,用百分數表…

Rstudio-深度學習執行代碼

RStudio是一個開源的集成開發環境(IDE),專門用于R編程語言的開發和數據分析。R語言是一種流行的統計計算和數據可視化語言,廣泛用于數據科學、統計學和機器學習領域。 RStudio提供了許多功能強大的工具,包括代碼編輯器…

SQL 基本條件查詢DQL 練習

DQL DQL(Data Query Language)是SQL語言中的一種類型,用于執行數據查詢操作。它是SQL的一部分,用于從數據庫中檢索數據。DQL語句用于從一個或多個表中選擇、過濾和排序數據。常見的DQL查詢語句包括SELECT、FROM、WHERE、GROUP BY…

U盤無法讀取?輕松掌握正確解決方法!

“為什么我的u盤插入電腦后會顯示無法讀取呢?想查看一些比較重要的文件,但就是無法讀取U盤,想問問大家,我應該怎么操作呢?” U盤作為一種便捷的數據存儲設備,廣泛應用于我們的日常生活和工作中。然而&#…

獨立游戲《星塵異變》UE5 C++程序開發日志2——創建并編寫一個C++類

在本篇日志中,我們將要用一個C類來實現一個游戲內的物品,同時介紹UCLASS、USTRUCT、UPROPERTY的使用 一、創建一個C類 我們在UE5的"內容側滑菜單"中,在右側空白中右鍵選擇"新建C類",然后可以選擇一個想要的…

python70-Python的函數入門,了解下函數

函數是執行特定任務的一段代碼,程序通過將一段代碼定義成函數,并為該函數指定一個函數名,這樣即可在需要的時候多次調用這段代碼。因此,函數是代碼復用的重要手段。學習函數需要重點掌握定義函數、調用函數的方法。 與函數緊密相關的另一個知識點是lambda表達式。lamda表達…

Spring AOP(Aspect-Oriented Programming,面向切面編程)介紹

Spring AOP(Aspect-Oriented Programming,面向切面編程)是Spring框架的一個重要模塊,它提供了一種強大的方式來幫助開發者實現橫切關注點(cross-cutting concerns)的模塊化。橫切關注點是指那些影響多個模塊…

Linux設備模型(十一) - platform設備

一,platform device概述 在Linux2.6以后的設備驅動模型中,需關心總線、設備和驅動這3個實體,總線將設備和驅動綁定。在系統每注冊一個設備的時候, 會尋找與之匹配的驅動;相反的,在系統每注冊一個設備的時…

【Redis】實際應用 - 緩存

文章目錄 1. 緩存的基本概念2. Redis作為緩存的優勢2.1 內存存儲2.2 持久性選項2.3 數據結構豐富 3. Redis緩存的使用3.1 安裝和配置Redis3.2 連接到Redis3.3 存儲和獲取數據3.4 設置過期時間 4. 緩存策略4.1 LRU(最近最少使用)4.2 數據失效4.3 主動刷新…

可讓照片人物“開口說話”阿里圖生視頻模型EMO,高啟強普法

3 月 1 日消息,阿里巴巴研究團隊近日發布了一款名為“EMO(Emote Portrait Alive)”的 AI 框架,該框架號稱可以用于“對口型”,只需要輸入人物照片及音頻,模型就能夠讓照片中的人物開口說出相關音頻&#xf…

PDN分析及應用系列二-簡單5V電源分配-Altium Designer仿真分析-AD

PDN分析及應用系列二 —— 案例1:簡單5V電源分配 預模擬DC網絡識別 當最初為PCB設計打開PDN分析儀時,它將嘗試根據公共電源網絡命名法從設計中識別所有直流電源網絡。 正確的DC網絡識別對于獲得最準確的模擬結果非常重要。 在示例項目中已經識別出主DC網絡以簡化該過程。 …

Vulnhub靶機:Bellatrix

一、介紹 運行環境:Virtualbox 攻擊機:kali(10.0.2.4) 靶機:Bellatrix(10.0.2.9) 目標:獲取靶機root權限和flag 靶機下載地址:https://www.vulnhub.com/entry/hogwa…

Leetcode 3070. Count Submatrices with Top-Left Element and Sum Less Than k

Leetcode 3070. Count Submatrices with Top-Left Element and Sum Less Than k 1. 解題思路2. 代碼實現 題目鏈接:3070. Count Submatrices with Top-Left Element and Sum Less Than k 1. 解題思路 這一題就是一個二維的累積數組的問題,我們直接求一…

網絡學習:MPLS技術基礎知識

目錄 一、MPLS技術產生背景 二、MPLS網絡組成(基本概念) 1、MPLS技術簡介:Multiprotocol Lable Switching,多協議標簽交換技術 2、MPLS網絡組成 三、MPLS的優勢 四、MPLS的實際應用 一、MPLS技術產生背景 1、IP采用最長掩碼…

Power BI vs Superset BI 調研報告

調研結論 SupersetPower BI價格開源①. Power BI Pro 每人 $10/月($120/年/人) ②. Power BI Premium 每人 $20/月($240/年/人) ③. Power BI Embedded:4C10G $11W/年 權限基于角色的訪問控制,支持細粒度的訪問: 表級別、庫級別、圖表級別,看板級別,用戶級別 基于角色…

每天一個數據分析題(一百八十五)

給定下述Python代碼段,試問哪個選項正確描述了該代碼段的功能? data_raw[‘gender’] data_raw[‘gender’].map({‘Male’: 1, ‘Female’: 0}) A. 代碼中對gender變量進行了獨熱編碼(One-Hot Encoding),并將gender中的缺失值填充為類別平…

深度學習API——keras初學

keras定義: Keras是一個深度學習API(人工神經網絡庫),使用Python語言編寫的github開源項目,主要開發者為谷歌工程師。Keras底層可調用不同的機器學習平臺,如TensorFlow、Theano或micsoft-CNTK。 作用&…

Tomcat的配置文件

Tomcat的配置文件詳解 一.Tomcat的配置文件 Tomcat的配置文件默認存放在$CATALINA_HOME/conf目錄中,主要有以下幾個: 1.server.xml: Tomcat的主配置文件,包含Service, Connector, Engine, Realm, Valve, Hosts主組件的相關配置信息&#x…

【推薦】免費AI論文寫作神器-「智元兔 AI」

還在為寫論文焦慮?免費AI寫作大師來幫你三步搞定! 智元兔AI是ChatGPT的人工智能助手,并且具有出色的論文寫作能力。它能夠根據用戶提供的題目或要求,自動生成高質量的論文。 不論是論文、畢業論文、散文、科普文章、新聞稿件&…

#WEB前端(浮動與定位)

1.實驗&#xff1a; 2.IDE&#xff1a;VSCODE 3.記錄&#xff1a; float、position 沒有應用浮動前 應用左浮動和右浮動后 應用定位 4.代碼&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><me…