【機器學習】反向傳播如何求梯度(公式推導)

寫在前面

前期學習深度學習的時候,很多概念都是一筆帶過,只是覺得它在一定程度上解釋得通就行,但是在強化學習的過程中突然意識到,反向傳播求梯度其實并不是一件簡單的事情,這篇博客的目的就是要講清楚反向傳播是如何對特定的某一層求梯度,進而更新其參數的

為什么反向傳播不容易實現

首先,眾所周知,深度網絡就是一個廣義的多元函數,
但在通常情況下,想要求一個函數的梯度,就必須知道這個函數的具體表達式
但是問題就在于,深度網絡的“傳遞函數”并不容易獲得,或者說并不容易顯式地獲得
進而導致反向傳播的過程難以進行

為什么反向傳播可以實現

損失函數是關于參數的函數

如果要將一個函數F對一個變量x求偏導,那偏導存在的前提條件就是F是關于x的函數,否則求導結果就是0

  • 符號定義(后續公式均據此展開)
    • x=[x1,x2,…xn]x=[x_1,x_2,\dots x_n]x=[x1?,x2?,xn?]
    • ypred=[y1,y2,…ym]y_{pred}=[y_1,y_2, \dots y_m]ypred?=[y1?,y2?,ym?]
    • θ=[w1,b1,w2,b2,…wn,bn]\theta=[w_1,b_1,w_2,b_2,\dots w_n,b_n]θ=[w1?,b1?,w2?,b2?,wn?,bn?]
    • ai=第i層網絡激活函數的輸出,最后一層的輸出就是ypreda^{i}=第 i層網絡激活函數的輸出,最后一層的輸出就是y_{pred}ai=i層網絡激活函數的輸出,最后一層的輸出就是ypred?
    • zi=第i層網絡隱藏層的輸出z^{i}=第i層網絡隱藏層的輸出zi=i層網絡隱藏層的輸出
    • gi?′(zi)第i層激活函數的導數,在輸入=zi處的值g^i\ '(z^i)第i層激活函數的導數,在輸入=z^i處的值gi?(zi)i層激活函數的導數,在輸入=zi處的值
  • 關系式
    • 網絡的抽象函數式 ypred=F(x;θ)y_{pred}=F(x;\theta)ypred?=F(x;θ)
      即網絡就是一個巨大的多元函數,接受兩個向量(模型輸入和參數)作為輸入,經過內部正向傳播后輸出一個向量
    • 損失函數Loss的抽象函數式 Loss=L(ytrue,ypred)=L(ytrue,F(x;θ))Loss=L(y_{true},y_{pred})=L(y_{true},F(x;\theta))Loss=L(ytrue?,ypred?)=L(ytrue?,F(x;θ))
      其中 ytruey_{true}ytrue?xxx 屬于參變量,它雖然會變,但是和模型本身沒什么關系,唯一屬于模型自己的變量就是 θ\thetaθ,所以不難看出,損失函數L是關于模型參數 θ\thetaθ 的函數,損失值Loss完全由模型參數 θ\thetaθ 決定

鏈式法則

  • 這個法則是深層網絡得以實現梯度計算的關鍵
    核心公式如下:
    ?L?θi=?L?zi??zi?θi \frac{\partial L}{\partial \theta^i}=\frac{\partial L}{\partial z^i}·\frac{\partial z^i}{\partial \theta^i} ?θi?L?=?zi?L???θi?zi?
    其中,?L?zi\frac{\partial L}{\partial z^i}?zi?L?是損失L對第i層加權輸入ziz^izi的梯度,?zi?θi\frac{\partial z^i}{\partial \theta^i}?θi?zi?是第i層加權輸入ziz^izi對本層參數θi\theta^iθi的梯度

  • 進一步深究可以發現?zi?θi\frac{\partial z^i}{\partial \theta^i}?θi?zi?相對容易求,因為它只涉及到當前層的當前神經元的求解,在面向對象語言中,很容易為每個屬于同一個類的實例增加一個方法,比如像這里的輸入對參數求導,舉例來說;
    if?θi=Wi?and?Zi=Wi?ai?1+bi,then??zi?θi=(ai?1)T if\ \theta^i=W^i\ and\ Z^i=W^i*a^{i-1}+b^i,\\ then\ \frac{\partial z^i}{\partial \theta^i}=(a^{i-1})^T if?θi=Wi?and?Zi=Wi?ai?1+bi,then??θi?zi?=(ai?1)T
    其中,
    (說實話,我非常想把隱藏層稱為“傳遞函數”,控制和機器學習實際上有非常多可以相互借鑒的地方,而且在事實上,二者也確實是不可分割的關系)

  • 然后我們要來處理相對麻煩的 ?L?zi\frac{\partial L}{\partial z^i}?zi?L?

    • 多層感知機為例,共k層,已知網絡輸出,求網絡第i層的梯度
    • 用數學歸納法在這種遞歸系統中比較合適
      • 歸納奠基
        L=L(ytrue,ypred)=L(ytrue,F(x;θ))L=L(y_{true},y_{pred})=L(y_{true},F(x;\theta))L=L(ytrue?,ypred?)=L(ytrue?,F(x;θ))

        ?L?zk=?L?ak??ak?zk=?L?ak?(gk)′(zk)\frac{\partial L}{\partial z^k}=\frac{\partial L}{\partial a^k}·\frac{\partial a^k}{\partial z^k}=\frac{\partial L}{\partial a^k}\otimes (g^k)'(z^k)?zk?L?=?ak?L???zk?ak?=?ak?L??(gk)(zk)

        上面的公式說明:損失對隱藏層輸出的偏導,等價于損失函數對最終輸出的偏導,再逐元素乘上最后層激活函數 在隱藏層輸出處 的導數

        其中,激活函數在創建網絡時就明確已知,因此求導取值并沒有難度
        由于Loss=L(ytrue,ypred)Loss=L(y_{true},y_{pred})Loss=L(ytrue?,ypred?)直接與網絡最終輸出ypredy_{pred}ypred?相關,因此損失對最終輸出的偏導并不難求;
        比如將損失函數定義為均方差MSE:(其他網絡基本同理)
        L=12∑j=1m(yj?ajk)2?L?ak=?(yj?ajk)L=\frac{1}{2}\sum^m_{j=1}(y_j-a_j^k)^2\\\frac{\partial L}{\partial a^k}=-(y_j-a_j^k)L=21?j=1m?(yj??ajk?)2?ak?L?=?(yj??ajk?)

      • 歸納遞推(從第 i 層到第 i-1 層)

        假設已知 ?L?zi\frac{\partial L}{\partial z^i}?zi?L?(反向傳播,因此我們假設的是后一層已知)

        由鏈式法則可得:
        ?L?zi?1=(?L?zi)?(?zi?ai?1)?(?ai?1?zi?1)\frac{\partial L}{\partial z^{i-1}}=(\frac{\partial L}{\partial z^i})·(\frac{\partial z^i}{\partial a^{i-1}})·(\frac{\partial a^{i-1}}{\partial z^{i-1}})?zi?1?L?=(?zi?L?)?(?ai?1?zi?)?(?zi?1?ai?1?)
        其中,第一個因子已知

        第二個因子?zi?ai?1\frac{\partial z^i}{\partial a^{i-1}}?ai?1?zi?,分子為第 i 層隱藏層的輸出,分母為第 i 層隱藏層的輸入(即第 i-1 層激活層的輸出),因此其值就是第 i 層隱藏層的權重矩陣WiW^iWi本身

        第三個因子?ai?1?zi?1\frac{\partial a^{i-1}}{\partial z^{i-1}}?zi?1?ai?1?,分子為第 i-1 層激活層的輸出,分母為第 i-1 層激活層的輸入,因此其值就是 第 i-1 層激活函數 在隱藏層輸出處 的導數

        綜上:在已知第 i 層損失對輸出的梯度的情況下,可以推出第 i-1 層損失對輸出的梯度,遞推成立

      • 歸納總結
        綜上所述,反向傳播求梯度完全可行,按照上面的過程撰寫程序,就可以很方便地反向逐層 根據損失梯度 更新參數

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/914300.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/914300.shtml
英文地址,請注明出處:http://en.pswp.cn/news/914300.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

ALB、NLB、CLB 負載均衡深度剖析

ALB、NLB、CLB 負載均衡深度剖析 前言 筆者在上周的實際工作中遇到了一個典型的負載均衡選擇問題:在使用代理調用相關模型時,最初配置 Nginx 的代理地址為 ALB 的 7 層虛擬 IP(VIP),但由于集團網絡默認的超時時間為 3 …

歷史數據分析——云南白藥

醫藥板塊走勢分析: 從月線級別來看 2008年11月到2021年2月,月線上走出了兩個震蕩中樞的月線級別2085-20349的上漲段; 2021年2月到2024年9月,月線上走出了20349-6702的下跌段; 目前月線級別放巨量,總體還在震蕩區間內,后續還有震蕩和上漲的概率。 從周線級別來看 從…

【讀書筆記】《Effective Modern C++》第3章 Moving to Modern C++

《Effective Modern C》第3章 Moving to Modern C 一、區分圓括號 () 與大括號 {} (Item?7) C11 引入統一初始化(brace?initialization),即使用 {} 來初始化對象,與傳統的 () 存在細微差別:避…

Rust基礎-part1

Rust基礎[part1]—安裝和編譯 安裝 ? rust curl --proto https --tlsv1.2 https://sh.rustup.rs -sSf | sh安裝成功 [外鏈圖片轉存中…(img-ClSHJ4Op-1752058241580)] 驗證 ? rust rustc --version zsh: command not found: rustc因為我是用的是zsh,所以zsh配置…

PyQt5布局管理(QGridLayout(網格布局))

QGridLayout(網格布局) QGridLayout(網格布局)是將窗口分隔成行和列的網格來進行排列。通常可以使用函數addWidget()將被管理的控件(Widget)添加到窗口中,或者使用addLayout() 函數將布局(Layou…

Java設計模式之行為型模式(責任鏈模式)介紹與說明

一、核心概念與定義 責任鏈模式是一種行為型設計模式,其核心思想是將請求沿著處理對象鏈傳遞,直到某個對象能夠處理該請求為止。通過這種方式,解耦了請求的發送者與接收者,使多個對象有機會處理同一請求。 關鍵特點: 動…

SQL server之版本的初認知

SQL server之版本的初認知 為什么要編寫此篇文檔呢,主要是因為在最近測試OGG實時同步SQL server數據庫表數據的時候,經過多次測試,發現在安裝了一套SQL server2017初始版本,未安裝任何補丁的時候,在添加TRANDATA的時候…

【前端】jQuery動態加載CSS方法總結

在jQuery 中動態加載 CSS 文件有多種方法&#xff0c;以下是幾種常用實現方式&#xff1a; 方法 1&#xff1a;創建 <link> 標簽&#xff08;推薦&#xff09; // 動態加載外部 CSS 文件 function loadCSS(url) {$(<link>, {rel: stylesheet,type: text/css,href:…

Python爬蟲實戰:研究xlwings庫相關技術

1. 引言 在金融科技快速發展的背景下,數據驅動決策已成為投資領域的核心競爭力。金融市場數據具有海量、多源、實時性強等特點,傳統人工收集與分析方式難以滿足高效決策需求。Python 憑借其豐富的開源庫生態,成為金融數據分析的首選語言。結合 Requests、BeautifulSoup 等爬…

Linux 內核日志中常見錯誤

目錄 **1. `Oops`****含義****典型日志****可能原因****處理建議****2. `panic`****含義****典型日志****可能原因****處理建議****3. `BUG`****含義****典型日志****可能原因****處理建議****4. `kernel NULL pointer`****含義****典型日志****可能原因****處理建議****5. `WA…

Linux驅動開發2:字符設備驅動

Linux驅動開發2&#xff1a;字符設備驅動 字符設備驅動開發流程 字符設備是 Linux 驅動中最基本的一類設備驅動&#xff0c;字符設備就是一個一個字節&#xff0c;按照字節流進行讀寫操作的設備&#xff0c;讀寫數據是分先后順序的。比如最常見的點燈、按鍵、 IIC、 SPI&#x…

RuoYi-Cloud 驗證碼處理流程

以該處理流程去拓展其他功能模塊處理流程&#xff0c;進而熟悉項目開發代碼一、思路JavaWeb流程主干線&#xff1a;發起請求、處理請求、響應請求二、登錄頁面在登錄頁面按鍵F12打開開發者工具&#xff0c;點擊network&#xff0c;刷新頁面&#xff0c;點擊code&#xff0c;查看…

云計算三大服務模式深度解析:IaaS、PaaS、SaaS

架構本質&#xff1a;云計算服務模式定義了資源抽象層級和責任分擔邊界&#xff0c;形成從基礎設施到應用的全棧服務金字塔。三種模式共同構成云計算的服務交付模型核心框架。一、服務模式全景圖 #mermaid-svg-f0Klw2fbuhBQqJTh {font-family:"trebuchet ms",verdana…

【sql學習之拉鏈表】

1.拉鏈表理解 記錄歷史。記錄一個事物從開始&#xff0c;一直到當前狀態的所有變化的信息。字段說明&#xff1a; start_dt&#xff1a;該條記錄的生命周期開始時間 end_dt&#xff1a;該條記錄的生命周期結束時間 end_dt’9999/12/31’表示該條記錄目前處于有效狀態 如果查詢當…

STM32中實現shell控制臺(shell窗口輸入實現)

文章目錄 一、總體結構二、串口接收機制三、命令輸入與處理邏輯四、命令編輯與顯示五、歷史命令管理六、命令執行七、初始化與使用八、小結在嵌入式系統開發中,使用串口Shell控制臺是一種非常常見且高效的調試方式。本文將基于STM32平臺,分析一個簡潔但功能完整的Shell控制臺…

區分三種IO模型和select/poll/epoll

部分內容來源&#xff1a;JavaGuide select/poll/epoll 和 三種IO模型之間的關系是什么&#xff1f;區分普通IO和IO多路復用普通IO&#xff0c;即一個線程對應一個連接&#xff0c;因為每個線程只處理一個客戶端 socket&#xff0c;目標明確&#xff1a;線程中直接操作該 socke…

Actor-Critic重要性采樣原理

目錄 AC的數據低效性&#xff1a; 根本原因&#xff1a;策略更新導致數據失效 應用場景&#xff1a; 1. 離策略值函數估計 2. 離策略策略優化 3. 經驗回放&#xff08;Experience Replay&#xff09; 4. 策略梯度方法 具體場景分析 場景1&#xff1a;連續策略更新 場…

【贈書福利,回饋公號讀者】《智慧城市與智能網聯汽車,融合創新發展之路》

「5G行業應用」公號作家團隊推出《智慧城市與智能網聯汽車&#xff0c;融合創新發展之路》。本書由機械工業出版社出版&#xff0c;探討如何通過車城融合和創新應用&#xff0c;促進汽車產業轉型升級與生態集群發展&#xff0c;提升智慧城市精準治理與出行服務效能。&#xff0…

5G NR PDCCH之處理流程

本節主要介紹PDCCH處理流程概述。PDCCH&#xff08;Physical Downlink Control Channel&#xff0c;物理下行控制信道&#xff09;主要用于傳輸DCI&#xff08;Downlink Control Information&#xff0c;下行控制信息&#xff09;&#xff0c;用于通知UE資源分配&#xff0c;調…

力扣網編程135題:分發糖果(貪心算法)

一. 簡介本文記錄力扣網上涉及數組方面的編程題&#xff1a;分發糖果。這里使用貪心算法的思路來解決&#xff08;求局部最優&#xff0c;最終求全局最優解&#xff09;&#xff1a;每個孩子只需要考慮與相鄰孩子的相對關系。二. 力扣網編程135題&#xff1a;分發糖果&#xff…