【PyTorch】PyTorch中torch.nn模塊的循環層

PyTorch深度學習總結

第九章 PyTorch中torch.nn模塊的循環層


文章目錄

  • PyTorch深度學習總結
  • 前言
  • 一、循環層
      • 1. 簡單循環層(RNN)
      • 2. 長短期記憶網絡(LSTM)
      • 3. 門控循環單元(GRU)
      • 4. 雙向循環層
  • 二、循環層參數
      • 1. 輸入維度相關參數
      • 2. 隱藏層相關參數
      • 3. 其他參數
  • 三、函數總結


前言

上文介紹了PyTorch中介紹了池化和torch.nn模塊中的池化層函數,本文將進一步介紹torch.nn模塊中的循環層。


一、循環層

在PyTorch中,循環層Recurrent Layers)是處理序列數據的重要組件,常用于自然語言處理、時間序列分析等領域。
下面為你詳細介紹幾種常見的循環層:

1. 簡單循環層(RNN)

  • 原理簡單循環層RNN)是最基礎的循環神經網絡結構,它在每個時間步接收當前輸入和上一個時間步的隱藏狀態,通過特定的激活函數計算當前時間步的隱藏狀態。這種結構使得RNN能夠對序列數據中的時間依賴關系進行建模。
  • PyTorch實現:在PyTorch中,可以使用torch.nn.RNN類來構建簡單循環層。以下是一個簡單的示例代碼:
import torch
import torch.nn as nn# 定義輸入維度、隱藏維度和層數
input_size = 10
hidden_size = 20
num_layers = 1# 創建RNN層
rnn = nn.RNN(input_size, hidden_size, num_layers)# 生成輸入數據
batch_size = 3
seq_len = 5
input_data = torch.randn(seq_len, batch_size, input_size)# 初始化隱藏狀態
h_0 = torch.randn(num_layers, batch_size, hidden_size)# 前向傳播
output, h_n = rnn(input_data, h_0)
  • 應用場景:簡單循環層適用于處理一些簡單的序列數據,例如短文本分類、簡單的時間序列預測等。但由于存在梯度消失梯度爆炸的問題,對于長序列數據的處理效果不佳。

2. 長短期記憶網絡(LSTM)

  • 原理長短期記憶網絡LSTM)是為了解決RNN的梯度消失問題而提出的。它引入了門控機制,包括輸入門、遺忘門和輸出門,通過這些門控單元可以更好地控制信息的流動,從而有效地捕捉序列數據中的長距離依賴關系。
  • PyTorch實現:在PyTorch中,可以使用torch.nn.LSTM類來構建LSTM層。示例代碼如下:
import torch
import torch.nn as nn# 定義輸入維度、隱藏維度和層數
input_size = 10
hidden_size = 20
num_layers = 1# 創建LSTM層
lstm = nn.LSTM(input_size, hidden_size, num_layers)# 生成輸入數據
batch_size = 3
seq_len = 5
input_data = torch.randn(seq_len, batch_size, input_size)# 初始化隱藏狀態和細胞狀態
h_0 = torch.randn(num_layers, batch_size, hidden_size)
c_0 = torch.randn(num_layers, batch_size, hidden_size)# 前向傳播
output, (h_n, c_n) = lstm(input_data, (h_0, c_0))
  • 應用場景:LSTM廣泛應用于自然語言處理中的機器翻譯、文本生成,以及時間序列分析中的股票價格預測、天氣預測等領域。

3. 門控循環單元(GRU)

  • 原理門控循環單元GRU)是LSTM的一種簡化版本,它將LSTM中的輸入門和遺忘門合并為一個更新門,并取消了細胞狀態,只保留隱藏狀態。這種簡化使得GRU的計算效率更高,同時也能夠較好地捕捉序列數據中的長距離依賴關系。
  • PyTorch實現:在PyTorch中,可以使用torch.nn.GRU類來構建GRU層。示例代碼如下:
import torch
import torch.nn as nn# 定義輸入維度、隱藏維度和層數
input_size = 10
hidden_size = 20
num_layers = 1# 創建GRU層
gru = nn.GRU(input_size, hidden_size, num_layers)# 生成輸入數據
batch_size = 3
seq_len = 5
input_data = torch.randn(seq_len, batch_size, input_size)# 初始化隱藏狀態
h_0 = torch.randn(num_layers, batch_size, hidden_size)# 前向傳播
output, h_n = gru(input_data, h_0)
  • 應用場景:GRU在一些對計算資源要求較高的場景中表現出色,例如實時語音識別、在線文本分類等。

4. 雙向循環層

  • 原理雙向循環層Bidirectional RNN/LSTM/GRU)是在單向循環層的基礎上擴展而來的。它同時考慮了序列數據的正向和反向信息,通過將正向和反向的隱藏狀態拼接或相加,能夠更全面地捕捉序列數據中的上下文信息。
  • PyTorch實現:在PyTorch中,可以通過設置bidirectional=True來創建雙向循環層。以雙向LSTM為例,示例代碼如下:
import torch
import torch.nn as nn# 定義輸入維度、隱藏維度和層數
input_size = 10
hidden_size = 20
num_layers = 1# 創建雙向LSTM層
lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True)# 生成輸入數據
batch_size = 3
seq_len = 5
input_data = torch.randn(seq_len, batch_size, input_size)# 初始化隱藏狀態和細胞狀態
h_0 = torch.randn(num_layers * 2, batch_size, hidden_size)
c_0 = torch.randn(num_layers * 2, batch_size, hidden_size)# 前向傳播
output, (h_n, c_n) = lstm(input_data, (h_0, c_0))
  • 應用場景雙向循環層自然語言處理中的命名實體識別、情感分析等任務中表現出色,因為這些任務需要充分利用上下文信息來做出準確的判斷。

二、循環層參數

以下為你詳細介紹 PyTorch 中幾種常見循環層(RNN、LSTM、GRU)的常見參數:

1. 輸入維度相關參數

  • input_size
  • 含義:該參數表示輸入序列中每個時間步的特征數量。可以理解為輸入數據的特征維度
    - 例子:在處理文本數據時,如果使用詞向量表示每個單詞,詞向量的維度就是 input_size。假如使用 300 維的詞向量,那么 input_size 就為 300。
  • batch_first
  • 含義:這是一個布爾類型的參數,用于指定輸入和輸出張量的維度順序。當 batch_first=True 時,輸入和輸出張量的形狀為 (batch_size, seq_len, input_size);當 batch_first=False(默認值)時,形狀為 (seq_len, batch_size, input_size)
    - 例子:假設 batch_size 為 32,seq_len 為 10,input_size 為 50。若 batch_first=True,輸入張量形狀就是 (32, 10, 50);若 batch_first=False,輸入張量形狀則為 (10, 32, 50)

2. 隱藏層相關參數

  • hidden_size
  • 含義:代表隱藏狀態的維度,即每個時間步中隱藏層神經元的數量。隱藏狀態在循環層的計算中起著關鍵作用,它會在不同時間步之間傳遞信息。
    - 例子:如果 hidden_size 設置為 128,意味著每個時間步的隱藏層有 128 個神經元,隱藏狀態的維度就是 128。
  • num_layers
  • 含義:表示循環層的堆疊層數。多層循環層可以學習更復雜的序列模式,通過堆疊多個循環層,模型能夠從不同抽象層次上處理序列數據。
    - 例子:當 num_layers 為 2 時,意味著有兩個循環層堆疊在一起,前一層的輸出會作為后一層的輸入。

3. 其他參數

  • bias
  • 含義:布爾類型參數,用于決定是否在循環層中使用偏置項。bias=True 表示使用偏置,bias=False 則不使用。
    - 例子:在大多數情況下,bias 默認為 True,即使用偏置項,這樣可以增加模型的靈活性。
  • dropout
  • 含義:該參數用于在循環層中應用 Dropout 正則化,以防止過擬合。取值范圍為 0 到 1 之間,表示 Dropout 的概率。
    - 例子:當 dropout = 0.2 時,意味著在訓練過程中,每個神經元有 20% 的概率被隨機置為 0。需要注意的是,dropout 只在 num_layers > 1 時有效。
  • bidirectional
  • 含義:布爾類型參數,用于指定是否使用雙向循環層。bidirectional=True 表示使用雙向循環層,bidirectional=False 表示使用單向循環層。
    - 例子:在雙向 LSTM 中,設置 bidirectional=True 后,模型會同時考慮序列的正向和反向信息,最后將正反向的隱藏狀態進行拼接或相加。
  • LSTM 的 proj_size
  • 含義:用于指定 LSTM 中投影層的維度。投影層可以將隱藏狀態的維度進行壓縮,從而減少模型的參數數量。
    - 例子:若 proj_size 為 64,原本 hidden_size 為 128,那么經過投影層后,隱藏狀態的維度會變為 64。

三、函數總結

循環層類型原理PyTorch實現應用場景優缺點
簡單循環層(RNN)每個時間步接收當前輸入和上一個時間步的隱藏狀態,通過激活函數計算當前時間步隱藏狀態,對序列時間依賴關系建模rnn = nn.RNN(input_size, hidden_size, num_layers)短文本分類、簡單時間序列預測等簡單序列數據處理優點:結構簡單;缺點:存在梯度消失或爆炸問題,處理長序列效果不佳
長短期記憶網絡(LSTM)引入門控機制(輸入門、遺忘門和輸出門),控制信息流動,捕捉長距離依賴關系lstm = nn.LSTM(input_size, hidden_size, num_layers)機器翻譯、文本生成、股票價格預測、天氣預測等優點:能有效處理長序列;缺點:計算復雜度相對較高
門控循環單元(GRU)將LSTM的輸入門和遺忘門合并為更新門,取消細胞狀態,保留隱藏狀態gru = nn.GRU(input_size, hidden_size, num_layers)實時語音識別、在線文本分類等對計算資源要求高的場景優點:計算效率高;缺點:在某些復雜長序列任務效果可能不如LSTM
雙向循環層(Bidirectional RNN/LSTM/GRU)同時考慮序列正向和反向信息,通過拼接或相加正反向隱藏狀態捕捉上下文信息lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True)命名實體識別、情感分析等需充分利用上下文信息的任務優點:能更全面捕捉上下文;缺點:計算量更大

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/88040.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/88040.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/88040.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Ubuntu 24.04 LTS 服務器配置:安裝 JDK、Nginx、Redis。

Ubuntu 24.04 LTS 服務器配置:安裝 JDK、Nginx、Redis。新建用來放置軟件安裝包的目錄 mkdir /home/software 配置目錄所有者為 ubuntu 用戶: chown ubuntu /home/software將軟件安裝包上傳到 /home/software配置 JDK-8 新建 jdk 安裝目錄 mkdir /usr/ja…

工作中用到過哪些設計模式?是怎么實現的?

1. 單例模式(結合 Spring Component)場景:配置中心、全局狀態管理 Spring 實現:java// 自動注冊為Spring Bean(默認單例) Component public class AppConfig {Value("${server.port}")private in…

Leetcode 3609. Minimum Moves to Reach Target in Grid

Leetcode 3609. Minimum Moves to Reach Target in Grid 1. 解題思路2. 代碼實現 題目鏈接:3609. Minimum Moves to Reach Target in Grid 1. 解題思路 這一題我一開始走岔了,走了一個正向遍歷走法的思路,無論怎么剪枝都一直超時。后來看了…

工作流引擎:IDEA沒有actiBPMN插件怎么辦?

文章目錄一、問題描述二、替代方案一、問題描述 我們在學習activiti7工作流引擎的時候,需要設計流程圖。 一般推薦的就是使用IDEA插件actiBPMN進行開發。 但是,這個插件在IDEA2019后的版本都不在支持。 也就是搜不到 那么,怎么辦了&#x…

Android音視頻探索之旅 | CMake基礎語法 創建支持Ffmpeg的Android項目

一.CMake語法 CMake語法非常多,我們知道如何導入靜態庫和動態庫以及最基礎的使用,目前是夠用的。其它方面則根據實際項目同步學習。 1.1.基礎語法-常用 cmake_minimum_required:指定cmake最小版本include_directories:引入&#x…

React Native 初始化項目和模擬器運行

中文官方文檔:https://reactnative.cn/docs/environment-setup 英文官方文檔:https://reactnative.dev/docs/getting-started-without-a-framework#step-1-creating-a-new-application 創建新項目 1、初始化 # 如果你之前全局安裝過舊的react-native-cli…

20250706-5-Docker 快速入門(上)-創建容器常用選項_筆記

一、創建容器常用選項1. 創建容器常用選項1)常用選項創建容器常用選項交互式選項:-i:保持標準輸入打開,允許交互式操作-t:分配偽終端,使容器像傳統終端一…

插值與擬合(3):B樣條曲線

在路徑規劃問題中,通常會用到B樣條來平滑路徑,本文實現并封裝了三次準均勻開放B樣條曲線,供大學學習使用。作者提供了三套代碼方案。可以用于不同平臺:方案1:MATLAB;方案2:標準C;方案…

[免費]基于Python豆瓣電影數據分析及可視化系統(Flask+echarts+pandas)【論文+源碼+SQL腳本】

大家好,我是java1234_小鋒老師,看到一個不錯的于Python豆瓣電影數據分析及可視化系統(Flaskechartpandas)【論文源碼SQL腳本】,分享下哈。項目介紹隨著如今電影越來越多,各種各樣的爛片和撈錢的商業片也層出不窮,而有意…

SQL127 月總刷題數和日均刷題數

SQL127 月總刷題數和日均刷題數 withtemp as (selectDATE_FORMAT(submit_time, "%Y%m") as submit_month,count(question_id) as month_q_cnt,round(count(question_id) / day(last_day(max(submit_time))),3) as avg_day_q_cntfrompractice_recordwhereyear(submit…

unity luban接入

1.找到luban官網并下載他的例子和.net8.0的sdk安裝 官網地址如下 快速上手 | Luban 參考大佬教程如下 Luban新版本接入教程_嗶哩嗶哩_bilibili 2.找到他的luban_examples-main示例下的兩個文件MiniTemplate和tool 3.MiniTemplate這個文件復制一份到項目工程下,自…

Django服務開發鏡像構建

最后完整的項目目錄結構1、安裝依賴pip install django django-tables2 django-filter2、創建項目和主應用django-admin startproject configcd configpython manage.py startapp dynamic_models3、配置settings.py將項目模塊dynamic_models加入進來,django_tables2…

20250706-3-Docker 快速入門(上)-常用鏡像管理命令_筆記

一、配置加速器1. Docker Hub簡介與地址公共鏡像倉庫: 由Docker公司維護的公共鏡像倉庫,包含大量容器鏡像默認下載源: Docker工具默認從這個公共鏡像庫下載鏡像訪問地址: https://hub.docker.com鏡像搜索功能: 可通過瀏覽器訪問圖形化管理系…

【unity游戲開發——優化篇】使用Occlusion Culling遮擋剔除,只渲染相機視野內的游戲物體提升游戲性能

注意:考慮到優化的內容比較多,我將該內容分開,并全部整合放在【unity游戲開發——優化篇】專欄里,感興趣的小伙伴可以前往逐一查看學習。 文章目錄 前言實戰1、確保所有靜止的3D物體都標記為Occluder Static靜態遮擋體和Occludee …

通用業務編號生成工具類(MyBatis-Plus + Spring Boot)詳解 + 3種調用方式

在企業應用開發中,我們經常需要生成類似 BZ -240704-0001 這種“業務編號”,它通常具有以下特點:前綴:代表業務類型,如 BZ 表示包裝日期:年月日格式,通常為 yyMMdd序列號:當天內遞增…

前端相關性能優化筆記

1.打開速度怎么變快 - 首屏加載優化2.再次打開速度怎么變快 - 緩存優化了3.操作怎么才順滑 - 渲染優化4.動畫怎么保證流暢 - 長任務拆分2.1 首屏加載指標細化:1.FP(First Paint 首次繪制) 2.FCP(First contentful Paint 首次內容繪制),FP 到 FCP 中間其實主要是 SPA…

7.7晚自習作業

實操作業02:Spark核心開發 作業說明 請嚴格按照步驟操作,并將最終結果文件(命名為:sparkcore_result.txt)于20點前上傳。結果文件需包含每一步的關鍵命令執行結果文本輸出。 一、數據讀取與轉換操作 上傳賬戶數據$…

手機FunASR識別SIM卡通話占用內存和運行性能分析

手機FunASR識別SIM卡通話占用內存和運行性能分析 --本地AI電話機器人 上一篇:手機無網離線使用FunASR識別SIM卡語音通話內容 下一篇:手機通話語音離線ASR識別商用和優化方向 一、前言 書接上一文《阿里FunASR本地斷網離線識別模型簡析》,…

虛幻引擎Unreal Engine5恐怖游戲設計制作教程,從入門到精通從零開始完整項目開發實戰詳細講解中英字幕

和大家分享一個以前收集的UE5虛幻引擎恐怖游戲開發教程,這是國外一個大神制作的視頻教程,教程從零開始到制作出一款完整的游戲。內容講解全面,如藍圖基礎知識講解、角色控制、高級交互系統、高級庫存系統、物品檢查、恐怖環境氛圍設計、過場動…

多人協同開發時Git使用命令

拉取倉庫代碼 # 拉取遠程倉庫至本地tar_dir路徑 git clone gitgithub.com:your-repo.git target_dir # 默認是拉取遠程master分支,下面拉取并切換到自己需要開發的分支上 # 假設自己需要開發的分支是/feature/my_branch分支 git checkout -b feature/my_branch orig…