動手學深度學習(Pytorch版)代碼實踐 -循環神經網絡- 56門控循環單元(`GRU`)

56門控循環單元(GRU

我們討論了如何在循環神經網絡中計算梯度, 以及矩陣連續乘積可以導致梯度消失或梯度爆炸的問題。 下面我們簡單思考一下這種梯度異常在實踐中的意義:

  • 我們可能會遇到這樣的情況:早期觀測值對預測所有未來觀測值具有非常重要的意義。 考慮一個極端情況,其中第一個觀測值包含一個校驗和, 目標是在序列的末尾辨別校驗和是否正確。 在這種情況下,第一個詞元的影響至關重要。 我們希望有某些機制能夠在一個記憶元里存儲重要的早期信息。 如果沒有這樣的機制,我們將不得不給這個觀測值指定一個非常大的梯度, 因為它會影響所有后續的觀測值。
  • 我們可能會遇到這樣的情況:一些詞元沒有相關的觀測值。 例如,在對網頁內容進行情感分析時, 可能有一些輔助HTML代碼與網頁傳達的情緒無關。 我們希望有一些機制來跳過隱狀態表示中的此類詞元。
  • 我們可能會遇到這樣的情況:序列的各個部分之間存在邏輯中斷。 例如,書的章節之間可能會有過渡存在, 或者證券的熊市和牛市之間可能會有過渡存在。 在這種情況下,最好有一種方法來重置我們的內部狀態表示。

門控循環單元與普通的循環神經網絡之間的關鍵區別在于: 前者支持隱狀態的門控。 這意味著模型有專門的機制來確定應該何時更新隱狀態, 以及應該何時重置隱狀態。 這些機制是可學習的,并且能夠解決了上面列出的問題。 例如,如果第一個詞元非常重要, 模型將學會在第一次觀測之后不更新隱狀態。 同樣,模型也可以學會跳過不相關的臨時觀測。 最后,模型還將學會在需要的時候重置隱狀態。

1.重置門和更新門
  • 重置門有助于捕獲序列中的短期依賴關系。
  • 更新門有助于捕獲序列中的長期依賴關系。

在這里插入圖片描述

2.候選隱狀態

在這里插入圖片描述

3.隱狀態

在這里插入圖片描述

4.從零開始實現
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt# 定義批量大小和時間步數
batch_size, num_steps = 32, 35# 使用d2l庫的load_data_time_machine函數加載數據集
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)def get_params(vocab_size, num_hiddens, device):"""初始化GRU模型的參數。參數:vocab_size (int): 詞匯表的大小。num_hiddens (int): 隱藏單元的數量。device (torch.device): 張量所在的設備。返回:list of torch.Tensor: 包含所有參數的列表。"""num_inputs = num_outputs = vocab_size  # 輸入和輸出的數量都等于詞匯表大小def normal(shape):"""使用均值為0,標準差為0.01的正態分布初始化張量。參數: shape (tuple): 張量的形狀。返回:torch.Tensor: 初始化后的張量。"""return torch.randn(size=shape, device=device) * 0.01def three():"""初始化GRU門的參數。返回:tuple of torch.Tensor: 包含門的權重和偏置的元組。"""return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three()   # 更新門參數W_xr, W_hr, b_r = three()   # 重置門參數W_xh, W_hh, b_h = three()   # 候選隱藏狀態參數# 輸出層參數W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 將所有參數收集到一個列表中params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params: # 啟用所有參數的梯度計算param.requires_grad_(True)return paramsdef init_gru_state(batch_size, num_hiddens, device):"""初始化GRU的隱藏狀態。參數:batch_size (int): 批量大小。num_hiddens (int): 隱藏單元的數量。device (torch.device): 張量所在的設備。返回:tuple of torch.Tensor: 初始隱藏狀態。"""return (torch.zeros((batch_size, num_hiddens), device=device), )def gru(inputs, state, params):"""定義GRU的前向傳播。參數:inputs (torch.Tensor): 輸入數據。state (tuple of torch.Tensor): 隱藏狀態。params (list of torch.Tensor): GRU的參數。返回:torch.Tensor: GRU的輸出。tuple of torch.Tensor: 更新后的隱藏狀態。"""W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = state  # 獲取隱藏狀態outputs = []  # 存儲輸出的列表for X in inputs:  # 遍歷每一個輸入時間步# 計算更新門ZZ = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)# 計算重置門RR = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)# 計算候選隱藏狀態H_tildaH_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)# 更新隱藏狀態HH = Z * H + (1 - Z) * H_tilda# 計算輸出YY = H @ W_hq + b_qoutputs.append(Y)  # 將輸出添加到列表中return torch.cat(outputs, dim=0), (H,)  # 返回連接后的輸出和更新后的隱藏狀態# 獲取詞匯表大小、隱藏單元數量和設備
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
# 定義訓練的輪數和學習率
num_epochs, lr = 500, 1
# 初始化GRU模型
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru)
# 使用d2l庫的train_ch8函數訓練模型
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.1, 38557.3 tokens/sec on cuda:0
# time traveller for so it will be convenient to speak of himwas e

在這里插入圖片描述

5.簡潔實現
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt# 定義批量大小和時間步數
batch_size, num_steps = 32, 35
# 使用d2l庫的load_data_time_machine函數加載數據集
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)num_epochs, lr = 500, 1
# # 獲取詞匯表大小、隱藏單元數量和設備
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens) # 定義一個GRU層,輸入大小為num_inputs,隱藏單元數量為num_hiddens
model = d2l.RNNModel(gru_layer, len(vocab)) # 使用GRU層和詞匯表大小創建一個RNN模型
model = model.to(device)
# 該函數需要模型、訓練數據迭代器、詞匯表、學習率、訓練輪數和設備作為參數
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.0, 248342.8 tokens/sec on cuda:0
# time travelleryou can show black is white by argument said filby

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/42595.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/42595.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/42595.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

機器人動力學模型及其線性化阻抗控制模型

機器人動力學模型 機器人動力學模型描述了機器人的運動與所受力和力矩之間的關系。這個模型考慮了機器人的質量、慣性、關節摩擦、重力等多種因素,用于預測和解釋機器人在給定輸入下的動態行為。動力學模型是設計機器人控制器的基礎,它可以幫助我們理解…

html的浮動作用詳解

HTML中的“浮動”(Float)是一個CSS布局技術,它原本設計用于文本環繞圖像或實現簡單的布局效果,比如并排排列元素。然而,隨著Web開發的演進,浮動也被廣泛用于更復雜的頁面布局設計中,盡管現代CSS…

2024/7/7周報

文章目錄 摘要Abstract文獻閱讀題目問題本文貢獻問題描述圖神經網絡Framework實驗數據集實驗結果 深度學習MAGNN模型相關代碼GNN為什么要用GNN?GNN面臨挑戰 總結 摘要 本周閱讀了一篇用于多變量時間序列預測的多尺度自適應圖神經網絡的文章,多變量時間序…

SAP已下發EWM的交貨單修改下發狀態

此種情況針對EWM未接收到ERP交貨單時,可以使用此程序將ERP交貨單調整為未分配狀態,在進行調整數據后,然后使用VL06I(啟用自動下發EWM配置,則在交貨單修改保存后會立即下發EWM)重新下發EWM系統。 操作步驟如…

3ds Max渲染曝光過度怎么辦?

3dmax效果圖云渲染平臺——渲染100 以3ds Max 2025、VR 6.2、CR 11.2等最新版本為基礎,兼容fp、acescg等常用插件,同時LUT濾鏡等參數也得到了同步支持。 注冊填邀請碼【7788】可領30元禮包和免費渲染券哦~ 遇到3ds Max渲染過程中曝光過度的問題&#xf…

SLF4J的介紹與使用(有logback和log4j2的具體實現案例)

目錄 1.日志門面的介紹 常見的日志門面 : 常見的日志實現: 日志門面和日志實現的關系: 2.SLF4J 的介紹 業務場景(問題): SLF4J的作用 SLF4J 的基本介紹 日志框架的綁定(重點&#xff09…

Influxdb中,Flux常用的函數

目錄 一、Flux常用的函數及其簡要描述 1. 數據源和篩選函數 2. 聚合函數 3. 時間序列操作函數 4. 轉換和映射函數 5. 窗口函數 6. 其他常用函數 注意事項 二、使用方法舉例 1. 數據源和篩選 2. 聚合 3. 時間序列操作 4. 窗口函數 5. 轉換和映射 注意事項 三、…

跨越界限的溫柔堅守

跨越界限的溫柔堅守 —— 鄭乃馨與男友的甜蜜抉擇在這個光怪陸離、瞬息萬變的娛樂圈里,每一段戀情像是夜空中劃過的流星,璀璨短暫。然而,當“鄭乃馨與男友甜蜜約會”的消息再次躍入公眾視野,它不僅僅是一段簡單的愛情故事&#xf…

iOS中多個tableView 嵌套滾動特性探索

嵌套滾動的機制 目前的結構是這樣的,整個頁面是一個大的tableView, Cell 是整個頁面的大小,cell 中嵌套了一個tableView 通過測試我們發現滾動的時候,系統的機制是這樣的, 我們滑動內部小的tableView, 開始滑動的時候&#xff0c…

C/C++ 代碼注釋規范及 doxygen 工具

參考 谷歌項目風格指南——注釋 C doxygen 風格注釋示例 ubuntu20 中 doxygen 文檔生成 doxygen 官方文檔 在 /Doxygen/Special Command/ 章節介紹 doxygen 的關鍵字 注釋說明 注釋的目的是提高代碼的可讀性與可維護性。 C 風格注釋 // 單行注釋/* 多行注釋 */ C 風格注…

設置某些路由為公開訪問,不需要登錄狀態即可訪問

在單頁面應用(SPA)框架中,如Vue.js,路由守衛是一種非常有用的功能,它允許你控制訪問路由的權限。Vue.js 使用 Vue Router 作為其官方路由管理器。路由守衛主要分為全局守衛和組件內守衛。 以下是如何設置路由守衛以允…

k8s 部署RuoYi-Vue-Plus之mysql搭建

1.直接部署一個pod 需要掛載存儲款, 可參考 之前文章設置 https://blog.csdn.net/weimeibuqieryu/article/details/140183843 2.部署yaml 先創建命名空間ruoyi kubectl create namespace ruoyi創建部署文件 mysql-deploy.yaml --- apiVersion: v1 kind: PersistentVolume …

【論文閱讀筆記】Meta 3D AssetGen

【論文閱讀筆記】Meta 3D AssetGen: Text-to-Mesh Generation with High-Quality Geometry, Texture, and PBR Materials Info摘要引言創新點 相關工作T23D基于圖片的3d 重建使用 PBR 材料的 3D 建模。 方法文本到圖像:從文本中生成陰影和反照率圖像Image-to-3D:基于pbr的大型重…

搭建NEMU與QEMU的DiffTest環境(動態庫方式)

搭建NEMU與QEMU的DiffTest環境(動態庫方式) 1 DiffTest原理簡述2 編譯NEMU3 編譯qemu-dl-difftest3.1 修改NEMU/scripts/isa.mk3.2 修改NEMU/tools/qemu-dl-diff/src/diff-test.c3.3 修改NEMU/scripts/build.mk3.4 讓qemu-dl-difftest帶調試信息3.5 編譯…

C語言實現字符串排序

如果只有英文字符且不區分大小寫的話按照字典序排序可以用strcmp函數&#xff0c;兩個字符串自左向右逐個字符相比&#xff08;按ASCII值大小相比較&#xff09; strcmp(s1,s2) 當s1<s2時&#xff0c;返回為負數&#xff1b; 當s1s2時&#xff0c;返回值 0&#xff1b; …

安卓的組件

人不走空 &#x1f308;個人主頁&#xff1a;人不走空 &#x1f496;系列專欄&#xff1a;算法專題 ?詩詞歌賦&#xff1a;斯是陋室&#xff0c;惟吾德馨 目錄 &#x1f308;個人主頁&#xff1a;人不走空 &#x1f496;系列專欄&#xff1a;算法專題 ?詩詞歌…

【Linux】打包命令——tar

打包和壓縮 雖然打包和壓縮都涉及將多個文件組合成單個實體&#xff0c;但它們之間存在重要差異。 打包和壓縮的區別&#xff1a; 打包是將多個文件或目錄組合在一起&#xff0c;但不對其進行壓縮。這意味著打包后的文件大小可能與原始文件相同或更大。此外&#xff0c;打包…

Win10精英控制器2代青春版 設備刪除失敗,藍牙連接斷斷續續

前提 更新了主板rog z790帶WiFi、藍牙&#xff0c;但是精英控制器連上老師斷斷續續。 過程 在設備管理中嘗試了卸載、重裝主板對應的藍牙驅動&#xff0c;怎么都不行&#xff0c;都已經想放棄了。 但是想起來之前主板沒有藍牙&#xff0c;用的是綠聯的USB藍牙接收器&#xf…

Ubuntu24.04修改系統的環境變量

apache/tomcat配置要用到JDK&#xff0c;使用torch有時也會用到系統庫&#xff0c;涉及到環境變量 1. 查看環境變量 cat /etc/environment2. 新建環境變量 sudo nano /etc/environment在文件底部添加新的環境變量 MY_VARIABLE"your_value"3. 修改環境變量 臨時—…

數字化精益生產系統--APS 排程管理系統

APS&#xff08;Advanced Planning and Scheduling&#xff09;排程管理系統&#xff0c;即高級生產計劃與排程系統&#xff0c;是一種高度智能化的計劃和排程系統。它通過整合各種生產和供應鏈數據&#xff0c;運用先進的算法和數據模型&#xff0c;根據各種約束條件&#xff…