詳解pytorch中循環神經網絡(RNN、LSTM、GRU)的維度

詳解pytorch中循環神經網絡(RNN、LSTM、GRU)的維度

  • RNN
    • torch.nn.rnn詳解
    • RNN輸入輸出維度
  • LSTM
    • torch.nn.LSTM詳解
    • LSTM輸入輸出維度
  • GRU
    • torch.nn.GRU詳解
    • GRU輸入輸出維度
  • 三種RNN的示例

首先如果你對RNN、LSTM、GRU不太熟悉,可點擊查看。

RNN

torch.nn.rnn詳解

torch.nn.RNN(input_size,
hidden_size,
num_layers=1,
nonlinearity=‘tanh’,
bias=True,
batch_first=False,
dropout=0.0,
bidirectional=False,
device=None,
dtype=None)

原理
在這里插入圖片描述

參數詳解

  • input_size – 輸入x中預期特征的數量

  • hidden_size – 隱藏狀態h中的特征數量

  • num_layers – 循環層數。例如,設置num_layers=2 意味著將兩個LSTM堆疊在一起形成堆疊 LSTM,第二個 LSTM 接收第一個 LSTM 的輸出并計算最終結果。默認值:1

  • nonlinearity– 使用的非線性。可以是’tanh’或’relu’。默認:‘tanh’

  • bias– 如果False,則該層不使用偏差權重b_ih和b_hh。默認:True

  • batch_first – 如果,則輸入和輸出張量以(batch, seq, feature)True形式提供,而不是(seq, batch, feature)。請注意,這不適用于隱藏狀態或單元狀態。默認:False

  • dropout – 如果非零,則在除最后一層之外的每個LSTM層的輸出上 引入Dropout層,dropout 概率等于 。默認值:0.0

  • bidirectional – 如果True, 則成為雙向LSTM。默認:False

RNN輸入輸出維度

rnn = nn.RNN(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)

可以看到輸入是xh_0,h_0可以是None。如果batch_size是第0維度,需設置batch_first=True
輸出則是outputh_n。h_n存了每一層的t時刻的隱藏狀態值

# Efficient implementation equivalent to the following with bidirectional=False
def forward(x, h_0=None):if batch_first:x = x.transpose(0, 1)seq_len, batch_size, _ = x.size()if h_0 is None:h_0 = torch.zeros(num_layers, batch_size, hidden_size)...return output, h_n

輸入:
x的輸入維度:(batch_size, sequence_length, input_size) [前提:batch_first=True]
h_0的維度:(D?num_layers, hidden_size) [可以為None]

輸出: output的輸出維度:(batch_size, sequence_length, D*hidden_size)
[D=2 if bidirectional=True otherwise 1]
h_n的維度:(D?num_layers, hidden_size)

LSTM

torch.nn.LSTM詳解

torch.nn.LSTM(input_size,
hidden_size,
num_layers=1,
bias=True,
batch_first=False,
dropout=0.0,
bidirectional=False,
proj_size=0,
device=None,
dtype=None)

原理:

參數詳解:
相比于RNN多了proj_size參數,少了nonlinearity參數

  • input_size – 輸入x中預期特征的數量

  • hidden_size – 隱藏狀態h中的特征數量

  • num_layers – 循環層數。例如,設置num_layers=2 意味著將兩個LSTM堆疊在一起形成堆疊 LSTM,第二個 LSTM 接收第一個 LSTM 的輸出并計算最終結果。默認值:1

  • bias– 如果False,則該層不使用偏差權重b_ih和b_hh。默認:True

  • batch_first – 如果,則輸入和輸出張量以(batch, seq, feature)True形式提供,而不是(seq, batch, feature)。請注意,這不適用于隱藏狀態或單元狀態。默認:False

  • dropout – 如果非零,則在除最后一層之外的每個LSTM層的輸出上 引入Dropout層,dropout 概率等于 。默認值:0dropout

  • bidirectional – 如果True, 則成為雙向LSTM。默認:False

  • proj_size – 如果,將使用具有相應大小投影的LSTM 。默認值:0

LSTM輸入輸出維度

LSTM= nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = LSTM(input, (h0, c0))

輸入是x,此外h_0c_0可以是None。如果batch_size是第0維度,需設置batch_first=True
輸出則是output和一個元組(h_n, c_n)

輸入: x的輸入維度:(batch_size, sequence_length, input_size)`
[前提:batch_first=True]

輸出: output的輸出維度:(batch_size, sequence_length, D*hidden_size)
[D=2 if bidirectional=True otherwise 1]

具體可參考官方文檔:nn.LSTM
在這里插入圖片描述

GRU

torch.nn.GRU詳解

torch.nn.GRU(input_size,
hidden_size,
num_layers=1,
bias=True,
batch_first=False,
dropout=0.0,
bidirectional=False,
device=None,
dtype=None)

原理:
在這里插入圖片描述

參數詳解:
與上文LSTM相比,缺少了proj_size參數,與RNN相比也缺少了nonlinearity參數

GRU輸入輸出維度

gru= nn.GRU(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = gru(input, h0)

與RNN一致見上文,相比LSTM少了c_n

三種RNN的示例

import torch
import torch.nn as nnrnn = nn.RNN(10, 20, 2, batch_first=True) # (input_size, hidden_size, num_layer)
lstm = nn.LSTM(10, 20, 2, batch_first=True)
gru = nn.GRU(10, 20, 2, batch_first=True)input = torch.randn(5, 3, 10)  # (batchsize, seq, input_size)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)output_rnn, h_n = rnn(input)
output_lstm, (hn, cn) = lstm(input)
output_gru, h_n2 = gru(input)
print("輸入維度:", input.shape)
print(f"RNN 輸出維度:{output_rnn.shape}, h_n維度:{h_n.shape}" )
print("LSTM 輸出維度:", output_lstm.shape)
print("GRU 輸出維度:", output_gru.shape)"""
輸入維度: torch.Size([5, 3, 10])
RNN 輸出維度:torch.Size([5, 3, 20]), h_n維度:torch.Size([2, 5, 20])
LSTM 輸出維度: torch.Size([5, 3, 20])
GRU 輸出維度: torch.Size([5, 3, 20])
"""

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/12546.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/12546.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/12546.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

python數據可視化:層次聚類熱圖clustermap()

【小白從小學Python、C、Java】 【考研初試復試畢業設計】 【Python基礎AI數據分析】 python數據可視化: 層次聚類熱圖 clustermap() [太陽]選擇題 請問關于以下代碼表述錯誤的選項是? import seaborn as sns import matplotlib.pyplot as plt import n…

代碼隨想錄—— 填充每個節點的下一個右側節點指針(Leetcode116)

題目鏈接 層序遍歷 /* // Definition for a Node. class Node {public int val;public Node left;public Node right;public Node next;public Node() {}public Node(int _val) {val _val;}public Node(int _val, Node _left, Node _right, Node _next) {val _val;left _…

開源的全自動生成視頻文案、視頻素材、視頻字幕、視頻背景音樂的AI項目

網址 https://github.com/harry0703/MoneyPrinterTurbo 只需提供一個視頻 主題 或 關鍵詞 ,就可以全自動生成視頻文案、視頻素材、視頻字幕、視頻背景音樂,然后合成一個高清的短視頻。 如果用來做視頻,可以先收藏一下,值得本地…

51 單片機[2-1]:點亮一個LED

一、在 Keil5 中新建項目 打開 Keil5 ,點擊 Project —— new μVision Project 新建文件夾 KeilProject ,以后的項目都在這個文件夾下,再建一個文件夾 2-1 點亮一個LED。在該文件夾下創建名為 Project 的文件,并保存。推薦起這…

Python快速入門3:面向對象OOP(需要有編程基礎)

面向對象是什么: 面向對象編程(Object-Oriented Programming,OOP)是一種編程范式,它以對象為基礎,將數據和操作封裝在一起以創建可重用的代碼模塊。在面向對象編程中,對象是程序的基本單元&…

mysql實戰題目練習

1、創建和管理數據庫 創建一個名為school的數據庫。 列出所有的數據庫,并確認school數據庫已經創建。 如果school數據庫已經存在,刪除它并重新創建。 mysql> create database school; Query OK, 1 row affected (0.01 sec)mysql> mysql> sh…

Spring Boot:異常處理

Spring Boot 前言使用自定義錯誤頁面處理異常使用 ExceptionHandler 注解處理異常使用 ControllerAdvice 注解處理異常使用配置類處理異常使用自定義類處理異常 前言 在 Spring Boot 中,異常處理是一個重要的部分,可以允許開發者優雅地處理應用程序中可…

復利效應(應用于成長)

應用 每個人在智力、知識、經驗上,復利效應都一樣,只要能積累的東西,基本上最終都會產生復利效應。 再來看一下復利公式:FP*(1i)^n P本金;i利率;n持有期限。在使用時,一定要注意4個限定條件&a…

AI圖書推薦:ChatGPT等生成式AI在高等教育中的應用

自2022年11月以來,ChatGPT及其在高等教育各個層面的影響已成為所有教育對話的核心內容。Chan和Colloton所著的書籍是首批全面探討ChatGPT與生成式人工智能(GenAI)在高等教育中應用及影響的作品之一。 該書深入研究了針對專業環境定制的AI素養…

js中Array的2個容易被遺忘的函數some和array

Array.prototype.some() 和 Array.prototype.every() 是 JavaScript 中的兩個容易被遺忘的數組方法。它們都用于檢查數組中的元素是否滿足某個條件。 1. Array.prototype.some() some() 方法用于檢查數組中至少有一個元素滿足給定的條件。當找到滿足條件的元素時,…

基礎學習-Git(分布式版本控制系統)

學習視頻推薦 http://【黑馬程序員Git全套教程,完整的git項目管理工具教程,一套精通git】 https://www.bilibili.com/video/BV1MU4y1Y7h5/?p5&share_sourcecopy_web&vd_source2b85bd9be9213709642d908906c3d863 1、Git環境配置 安裝Git Git下…

wireshark_概念

ARP (Address Resolution Protocol)協議,即地址解析協議。該協議的功能就是將IP地址解析成MAC地址。 混雜模式 抓取經過網卡的所有數據包,包括發往本網卡和非發往本網卡的。 非混雜模式 只抓取目標地址是本網卡的數據包,對于發往…

《控制系統實驗與綜合設計》綜合四至六(含程序和題目)

1.電機模型辨識實驗 1.1 實驗目的 (1)掌握一階系統階躍響應的特點,通過實驗加深對直流電解模型的理解; (2)掌握系統建模過程中參數的整定,體會參數變化對系統的影響; &#xff0…

單片機開發板上外設資源講解

單片機開發電路板上簡單外設 開發板上各基礎外設LED燈按鍵:數碼管介紹液晶屏矩陣鍵盤掃描的概念LED點陣屏實時時鐘蜂鳴器存儲器 溫度傳感器&單總線 開發板上各基礎外設 LED燈 中文名:發光二極管 外文名:Light Emitting Diode 簡稱&…

楊校老師項目之基于單片機STC89C52的智能環境監測系統【嵌入式】

獲取全套資料: 有償獲取:mryang511688 技術:C語言、單片機等 摘要: 此設計可分為三個主要部分。此中的溫度和濕度的檢測功能,通過操縱單總線型溫濕度傳感器DHT11以數字形式顯示,實現了切確測得溫濕度的功能…

如何管理多個版本的Node.js

我們如何在本地管理多個版本的Node.js,有沒有那種不需要重新安裝軟件再修改配置文件和環境變量的方法?經過我的查找,還真有這種方式,那就是nvm(Node Version Manager)。 下面我就給大家介紹下NVM的使用 1…

vs2019 c++中模板 enable_if_t 的使用

&#xff08;1&#xff09; 該模板的定義如下&#xff1a; template <bool _Test, class _Ty void> struct enable_if {}; // no member "type" when !_Testtemplate <class _Ty> struct enable_if<true, _Ty> { // type is _Ty for _Testusing …

Golang | Leetcode Golang題解之第89題格雷編碼

題目&#xff1a; 題解&#xff1a; func grayCode(n int) []int {ans : make([]int, 1<<n)for i : range ans {ans[i] i>>1 ^ i}return ans }

MSR810-LM快速配置通過LTE模塊上網

正文共&#xff1a;1111 字 13 圖&#xff0c;預估閱讀時間&#xff1a;1 分鐘 之前買了一個無線版本的MSR810-W&#xff08;淘了一臺二手的H3C企業路由器&#xff0c;就用它來打開網絡世界的大門&#xff09;&#xff0c;并整理了一份快速配置&#xff08;腳本案例來了&#x…

三菱FX3U-4AD模擬量電壓輸入采集實例

硬件&#xff1a;&#xff30;&#xff2c;&#xff23;模塊 &#xff26;&#xff38;&#xff13;&#xff27;&#xff21;-&#xff12;&#xff14;&#xff2d;&#xff34; &#xff1b;&#xff21;&#xff0f;&#xff24;模塊&#xff26;&#xff38;&#xff13…