循環神經網絡 - 參數學習之隨時間反向傳播算法

本文中,我們以同步的序列到序列模式為例來介紹循環神經網絡的參數學習。

循環神經網絡中存在一個遞歸調用的函數 𝑓(?),因此其計算參數梯度的方式和前饋神經網絡不太相同。在循環神經網絡中主要有兩種計算梯度的方式:隨時間反向傳播(BPTT)算法和實時循環學習(RTRL)算法。

本文我們來學習隨時間反向傳播算法。

BPTT 算法將循環神經網絡看作一個展開的多層前饋網絡,其中“每一層”對 應循環網絡中的“每個時刻”。這樣,循環神經網絡就可以按照前饋網絡中的反向傳播算法計算參數梯度。在“展開”的前饋網絡中,所有層的參數是共享的,因此參數的真實梯度是所有“展開層”的參數梯度之和。

一、數學推導:

以隨機梯度下降為例,給定一個訓練樣本 (𝒙, 𝒚),其中 𝒙1∶𝑇 =? (𝒙1, ? , 𝒙𝑇 )為長度是𝑇的輸入序列,𝑦1∶𝑇 =(𝑦1,?,𝑦𝑇)是長度為𝑇的標簽序列。即在每個時 刻 𝑡,都有一個監督信息 𝑦𝑡 ,我們定義時刻 𝑡 的損失函數為:

其中 𝑔(𝒉𝑡) 為第 𝑡 時刻的輸出,L 為可微分的損失函數,比如交叉熵。那么整個序列的損失函數為:

整個序列的損失函數 L 關于參數 𝑼 的梯度為

即每個時刻損失 L𝑡 對參數 𝑼 的偏導數之和。

基于參數 𝑼 和隱藏層在每個時刻 𝑘(1 ≤ 𝑘 ≤ 𝑡) 的凈輸入有關,通過數學推導(推導過程比較復雜,這里略過,大家著重掌握公式),可以得出:

計算偏導數

得到整個序列的損失函數 L 關于參數 𝑼 的梯度:

同理可得,L 關于權重 𝑾 和偏置 𝒃 的梯度為:

其中,類似前饋神經網絡中的誤差項為:

定義以下誤差項為第 𝑡 時刻的損失對第 𝑘 時刻隱藏神經層的凈輸入 𝒛𝑘 的導數,則當 1 ≤ 𝑘 < 𝑡 時

由上可以看出誤差項,時刻k的𝛿𝑡,𝑘可以由時刻k+1的𝛿𝑡,𝑘+1得出,即所謂的反向傳播。

下圖給出了誤差項隨時間進行反向傳播算法的示例:

二、進一步理解隨時間反向傳播算法BPTT

BPTT 的具體實現核心在于將 RNN 在時間維度上“展開”,從而把整個循環網絡視作一個深層的前饋網絡,然后利用反向傳播算法計算每個時間步的梯度。以下是關鍵步驟:

(一)前向傳播(Forward Pass)

在前向傳播階段,模型從初始隱藏狀態開始,按時間順序依次處理輸入序列的每個時間步。在每個時間步 t?中,RNN 會計算出當前隱藏狀態 ht:

同時,根據隱藏狀態產生輸出:

這些中間狀態和輸出都被存儲下來,供之后的反向傳播使用。

(二)損失計算

對于整個序列的輸出,我們會計算一個總體損失 LL,它通常是所有時間步損失 LtL_t 的求和或平均:

例如,在一個回歸或分類任務中,可能使用均方誤差交叉熵作為每個時間步的損失。

(三)時間展開(Unrolling)

為了使反向傳播適用于循環結構,我們把 RNN 展開成一個由 T?個層組成的前饋網絡,每一層對應一個時間步。雖然這些層共享同一組參數,但在展開的過程中,各個時間步之間的依賴關系(主要是隱藏狀態 ht)得以顯現。

這樣做的目的是為了使我們能夠用標準的反向傳播算法計算梯度,從而更新整個序列中共享的參數。下面通過一個簡單的例子說明這一過程。

假設情景

假設我們有一個 RNN 模型,用來處理一個長度為 3 的輸入序列 [x1,x2,x3](例如數值 1、2、3),初始隱藏狀態 h0? 設為零。模型的前向計算公式為:

將 RNN 展開成前饋網絡

原始的 RNN 是通過循環實現的,即使用同一組參數不斷將隱藏狀態從前一步傳遞到下一步。為了直觀地理解反向傳播的過程,我們將其在時間軸上展開,即把每個時間步看作網絡中的一層,這些層之間按照時間順序相連。

對于我們的序列,有如下展開:

  1. 時間步 1

  2. 時間步 2

  3. 時間步 3

這整個過程就像一個前饋網絡,共有 3 層(不包括初始狀態),每層的輸出 h_t? 都依賴于前一層的輸出 h_{t-1}? 和當前輸入 x_t?。注意,雖然在展開過程中每一層對應一個不同的時間步,但所有層共享同一組權重和偏置。

為什么這樣展開?

這種展開方式將時序依賴“展開”到層級結構中,使得整個序列可以看成一個深層網絡。這樣有兩個好處:

  1. 便于反向傳播計算
    我們可以像對普通前饋神經網絡那樣,基于鏈式法則逐層計算梯度,并且由于參數共享,每層計算的梯度會累積在同一組權重上。

  2. 捕捉長距離依賴
    通過展開,我們能直觀地理解誤差如何從最后一層傳回到第一層,反映長距離依賴問題,以及梯度消失或爆炸的問題。

總結

  • 展開過程:將 RNN 從時間步 1 到 T 展開,每個時間步視為一層前饋網絡,所有層使用同一組參數。

  • 前向傳播:依次計算每層隱藏狀態和輸出。

  • 反向傳播:從最后一層開始反向傳播,逐層累積梯度,更新共享參數。

這種展開不僅使得梯度計算過程清晰,而且方便我們理解如何利用 BPTT 解決時間依賴問題,確保模型能夠捕捉序列中長期和短期的信息。

(四)反向傳播(Backward Pass Through Time,BPTT)

從展開后的最后一個時間步 T?開始,依次向前計算梯度:

  • 局部梯度計算:在每個時間步,根據當前輸出與目標之間的誤差,首先計算當前時間步的輸出層梯度,然后通過當前隱藏狀態對損失的貢獻,計算激活函數(如 tanh?)的導數。

  • 梯度傳遞與累積:由于隱藏狀態 hth_t 不僅直接影響當前輸出,還間接影響后續所有時間步的輸出,因而需要將來自未來時間步傳回的梯度(往往稱為 “dh_next”)與當前時間步的梯度相加,形成一個總的梯度

  • 參數梯度更新:利用鏈式法則,通過隱藏狀態梯度計算出對輸入到隱藏權重 Wxh 和隱藏到隱藏權重 Whh 以及偏置的梯度。由于這些參數在每個時間步都是共享的,每一步計算出的梯度都會被累加起來。

  • 時間傳遞:在完成當前時間步梯度計算后,再將梯度通過 傳遞回前一個時間步,繼續重復這一過程直到第一個時間步。

下面以一個簡單的 RNN 模型展開一個長度為 3 的序列的反向傳播過程,來詳細說明 BPTT 中的四個關鍵步驟:局部梯度計算、梯度傳遞與累積、參數梯度更新、和時間傳遞。假設模型的前向傳播計算如下(激活函數采用 tanh):

  • 隱藏狀態更新

  • 輸出計算

假設我們的損失函數 L?是所有時間步損失的加和:

其中每個時間步的損失 Lt 是模型輸出 yt 和目標 (y_t)^{target} 之間的誤差(例如均方誤差)。

下面分步詳細說明 BPTT 的反向傳播過程。

1. 局部梯度計算

在反向傳播時,我們需要先計算每個時間步在輸出端的局部梯度,然后再傳回隱藏層。具體來說:

  • 對于時間步 t,我們先計算輸出層的梯度:

  • 接著,通過輸出層將梯度傳遞給隱藏狀態:

  • 由于隱藏狀態經過 tanh? 激活,,其局部梯度部分需乘上激活函數的導數:

    因此,得到當前時間步的局部梯度:

    其中“⊙”表示元素級相乘。

這部分稱為“局部梯度計算”,即對當前時刻輸出誤差先求到隱藏層(通過 Why?),再結合激活函數求出對 zt 的梯度。

2. 梯度傳遞與累積

由于 RNN 中隱藏狀態間存在依賴,當前時刻 ht 不僅受當前時間步損失 Lt 影響,還間接受到后續時間步的反饋。因此,反向傳播時需要將未來時間步傳回來的梯度累積在當前時刻。設我們計算總的梯度 dht? 對隱藏狀態的偏導,其計算方式為:

即當前時刻的總梯度等于當前局部梯度加上由下一時間步傳回來的梯度經過隱藏層的權重傳遞后的結果。

3. 參數梯度更新

有了每個時間步的梯度 δt? 和從后續傳來的梯度累積 ,我們可以對各層參數求導。具體來說:

  • 對于 輸入到隱藏層權重 Wxh?

    其中 ? 表示外積,此處對每個時間步將 δt 與相應的輸入 xt 外積,然后累加。

  • 對于 隱藏到隱藏層權重 Whh?

    同樣每個時間步累加當前梯度與前一隱藏狀態的外積。

  • 對于 隱藏層偏置 bh?

  • 對于 隱藏到輸出層權重 Why?,輸出層的梯度已經在局部步驟中計算:

    .
  • 對于 輸出偏置 by?

這些梯度在反向傳播過程中在每個時間步內計算完畢后,通過累加得到整個序列上的梯度,接著就可以用常規優化方法更新參數。

4. 時間傳遞(從未來到過去)

在反向傳播過程中,必須將未來時間步的梯度傳遞到當前時間步,這就是“時間傳遞”。具體步驟為:

這種梯度傳遞過程在整個序列反向迭代中重復執行,從時間步 T?逐層傳回到時間步 1。

綜合一個詳細例子

假設我們有一個時間序列長度為 3 的 RNN,且以時間步 t=3 開始反向傳播。簡化起見,以下給出各步描述:

  1. 時間步 3

  2. 時間步 2

  3. 時間步 1

總結

  • 局部梯度計算:在每個時間步,根據輸出誤差乘以輸出層權重和激活函數導數,得到對當前隱藏單元輸入的梯度(δt?)。

  • 梯度傳遞與累積:從后向前逐步將未來時刻的梯度通過隱藏層(乘以 和激活導數)傳遞給前一時間步,累加成當前時刻的總梯度

  • 參數梯度更新:利用每個時間步局部梯度與輸入(或前一時刻隱藏狀態)的外積,累積得到 Wxh?、Whh? 和 bh? 的梯度;輸出層參數的梯度也由對應輸出誤差累積。

  • 時間傳遞:通過計算隱藏狀態之間的依賴(即 ??),將梯度從后續傳遞給當前,直至序列首端。

這種詳細步驟體現了 BPTT 如何讓 RNN 捕捉序列中長距離依賴,以及如何利用鏈式求導從序列的末端逐步將梯度傳回并更新共享參數。

(五)參數更新

在累積了整個序列上各時間步的梯度后,使用如梯度下降、Adam 等優化算法對共享參數進行更新,從而使整體損失下降,模型逐步學會捕捉時序依賴關系。

總體來說,BPTT 的實現流程可總結為:

  • 先在時間上前向傳播:依次計算每個時間步的隱藏狀態和輸出,并存儲中間結果。

  • 計算整個序列的總損失:對每一時間步的輸出和目標計算損失。

  • 從后向前反向傳播:將誤差信息沿時間展開的網絡逐層反向傳遞,每一步既考慮當前的局部誤差,也考慮來自未來時間步的反饋,累積梯度。

  • 更新共享參數:利用累積的梯度,通過優化算法更新各個權重和偏置。

這一過程確保了即使序列較長,模型也能捕捉到早期輸入對后續輸出的影響,從而在學習長距離依賴關系方面發揮關鍵作用。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/77505.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/77505.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/77505.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

體驗OceanBase的 并行導入功能

在數據庫的日常使用中&#xff0c;會經常遇到以下場景&#xff1a; ?數據復制?&#xff1a;將一個或多個表中的數據復制到目標表中&#xff0c;可能是復制全部數據&#xff0c;也可能僅復制部分數據。數據合并&#xff1a;將數據從一個表轉移到另一個表&#xff0c;或者將多…

Kafka和RocketMQ相比有什么區別?那個更好用?

Kafka和RocketMQ相比有什么區別?那個更好用? Kafka 和 RocketMQ 都是廣泛使用的消息隊列系統&#xff0c;它們有很多相似之處&#xff0c;但也有一些關鍵的區別。具體選擇哪個更好用&#xff0c;要根據你的應用場景和需求來決定。以下是它們之間的主要區別&#xff1a; 1. …

UniApp 實現兼容 H5 和小程序的拖拽排序組件

如何使用 UniApp 實現一個兼容 H5 和小程序的 九宮格拖拽排序組件&#xff0c;實現思路和關鍵步驟。 一、實現目標 支持拖動菜單項改變順序拖拽過程實時預覽移動位置拖拽松開后自動吸附回網格兼容 H5 和小程序平臺 二、功能結構拆解以及完整代碼 完整代碼&#xff1a; <…

[raspberrypi 0w and respeaker 2mic]實時音頻波形

0. 環境 ubuntu22主機&#xff0c; 192.168.8.162&#xff0c; raspberry 0w&#xff0c; 192.168.8.220 路由器 1. 樹莓派 # rpi - send.py # 或者命令行&#xff1a;arecord -D plughw:1,0 -t wav -f cd -r 16000 -c 2 | nc 192.168.8.162 12345import socket imp…

公司內部建立apt源

有一篇建立pypi源的在這里需要的可以查看&#xff1a;公司內部建立pypi源-CSDN博客 背景&#xff0c;公司內部有很多工具僅供內部使用&#xff0c;如果用apt的方式就比較方便&#xff0c;只需要修改sources.list將源添加進去就可以了。我們接下來的操作就是為了實現這個需求。…

UE5中如何修復后處理動畫藍圖帶來的自然狀態下的metablriger身體綁定形變(如聳肩)問題

【[metablriger] UE5中如何修復后處理動畫藍圖帶來的自然狀態下的metablriger身體綁定形變(如聳肩)問題】 UE5中如何修復后處理動畫藍圖帶來的自然狀態下的metablriger身體綁定形變(如聳肩)問題

AWS Bedrock生成視頻詳解:AI視頻創作新時代已來臨

?? TL;DR: AWS Bedrock現已支持AI視頻生成功能,讓企業無需深厚AI專業知識即可創建高質量視頻內容。本文詳解Bedrock視頻生成能力的工作原理、應用場景和實操指南,助你快速掌握這一革命性技術。 ?? AWS Bedrock視頻生成:改變內容創作的游戲規則 還記得幾年前,制作一個專…

1.2 測試設計階段:打造高質量的測試用例

測試設計階段&#xff1a;打造高質量的測試用例 摘要 本文詳細介紹了軟件測試流程中的測試設計階段&#xff0c;包括測試用例設計、測試數據準備、測試環境搭建和測試方案設計等內容。通過本文&#xff0c;讀者可以系統性地了解測試設計的方法和技巧&#xff0c;掌握如何高效…

jQueryHTML與插件

1.jQuery 事件機制 1.1 注冊事件 bind()、on()方法向被選元素添加一個或多個事件處理程序&#xff0c;以及當事件發生時運行的函數 $("p").on({"click": function () {alert("點擊了")},"mouseenter": function () {…

MySQL 觸發器與存儲過程:數據庫的自動化工廠

在數據世界的工業區&#xff0c;有一座運轉高效的自動化工廠&#xff0c;那里的機器人日夜不停地處理數據…這就是 MySQL 的觸發器與存儲過程系統&#xff0c;它讓數據庫從"手工作坊"變成了"現代化工廠"… 什么是 MySQL 觸發器與存儲過程&#xff1f;&…

PostgreSQL-中文字段排序-修改字段的排序規則

最新版本更新 https://code.jiangjiesheng.cn/article/365?fromcsdn 推薦 《高并發 & 微服務 & 性能調優實戰案例100講 源碼下載》 -- 修改字段的排序規則 ALTER TABLE "public"."your_table_name" ALTER COLUMN "name" TYPE varcha…

GitHub優秀項目:數據湖的管理系統LakeFS

lakeFS 是一個開源工具&#xff0c;它將用戶的對象存儲轉換為類似Git的存儲庫。使用戶可以像管理代碼一樣管理數據湖。借助 lakeFS&#xff0c;可以構建可重復、原子化和版本化的數據湖操作--從復雜的ETL作業到數據科學和分析。 Stars 數11090Forks 數3157 主要特點 強大的數據…

頁面編輯器CodeMirror初始化不顯示行號或文本內容

延遲刷新 本來想延遲100毫秒的&#xff0c;但是會出現樣式向左偏移的情況&#xff0c;于是試了試500毫秒&#xff0c;發現就沒有問題了&#xff0c;可能是樣式什么是需要一個加載過程吧。 useEffect(() > {editorRef.current?.setValue(value || );setTimeout(() > {edi…

使用 Spring Boot 和 Uniapp 搭建 NFC 讀取系統

目錄 一、NFC 技術原理大揭秘1.1 NFC 簡介1.2 NFC 工作原理1.3 NFC 應用場景 二、Spring Boot 開發環境搭建2.1 創建 Spring Boot 項目2.2 項目基本配置 三、Spring Boot 讀取 NFC 數據3.1 NFC 設備連接與初始化3.2 數據讀取邏輯實現3.3 數據處理與存儲 四、Uniapp 前端界面開發…

臺式電腦插入耳機沒有聲音或麥克風不管用

目錄 一、如何確定插孔對應功能1.常見音頻插孔顏色及功能2.如何確認電腦插孔?3.常見問題二、 解決方案1. 檢查耳機連接和設備選擇2. 檢查音量設置和靜音狀態3. 更新或重新安裝聲卡驅動4. 檢查默認音頻格式5. 禁用音頻增強功能6. 排查硬件問題7. 檢查系統服務8. BIOS設置(可選…

Gerrit的安裝與使用說明(Ubuntu)

#本頁面按192.168.60.148服務器舉例進行安裝配置 1.權限配置 ## 使用root或者有sudo權限用戶執行 # 創建gerrit用戶 sudo useradd gerrit # 設置gerrit用戶的密碼 sudo passwd gerrit # 增加sudo權限 sudo visudo 在root ALL(ALL:ALL) ALL行下添加如下內容 gerrit ALL(ALL:…

Visual Studio 2019 配置VTK9.3.1

文章目錄 參考博客1、 VTK下載和編譯2、vs2019配置vtk9.3.1參考博客 Visual Studio 2022 配置VTK9.3.0 1、 VTK下載和編譯 見博客 CMake編譯VTK 2、vs2019配置vtk9.3.1 新建一個項目 寫入以下代碼 #include <vtkActor.h> #include <vtkAssembly.h> #include…

C++進階——C++11_右值引用和移動語義_可變參數模板_類的新功能

目錄 1、右值引用和移動語義 1.1 左值和右值 1.2 左值引用和右值引用 1.3 引用延長生命周期 1.4 左值和右值的參數匹配 1.5 右值引用和移動語義的使用場景 1.5.1 左值引用主要使用場景 1.5.2 移動構造和移動賦值 1.5.3 右值引用和移動語義解決傳值返回問題 1.5.4 右值…

HTTP協議原理深度解析:從基礎到實踐

引言 在互聯網技術體系中,HTTP(HyperText Transfer Protocol)協議如同數字世界的"通用語言",支撐著全球超50億網民的日常網絡交互。作為爬蟲開發、Web應用構建的核心技術基礎,理解HTTP原理是每個開發者必須掌握的技能。本文將從協議本質、技術演進、安全機制三…

Web品質 - 重要的HTML元素

Web品質 - 重要的HTML元素 在構建一個優秀的Web頁面時,HTML元素的選擇和運用至關重要。這些元素不僅影響頁面的結構,還直接關系到頁面的可用性、可訪問性和SEO表現。本文將深入探討一些關鍵的HTML元素,并解釋它們在提升Web品質方面的重要性。 1. <html> 根元素 HTM…