循環神經網絡(RNN):原理、架構與實戰

循環神經網絡(Recurrent Neural Network, RNN)是一類專門處理序列數據的神經網絡,如時間序列、自然語言、音頻等。與前饋神經網絡不同,RNN 引入了循環結構,能夠捕捉序列中的時序信息,使模型在不同時間步之間共享參數。這種結構賦予了 RNN 處理變長輸入、保留歷史信息的能力,成為序列建模的強大工具。

RNN 的基本原理與核心結構

傳統神經網絡在處理序列數據時,無法利用序列中的時序依賴關系。RNN 通過在網絡中引入循環連接,使得信息可以在不同時間步之間傳遞。

1. 簡單 RNN 的數學表達

在時間步t,RNN 的隱藏狀態\(h_t\)的計算如下:

\(h_t = \sigma(W_{hh}h_{t-1} + W_{xh}x_t + b)\)

其中,\(x_t\)是當前時間步的輸入,\(h_{t-1}\)是上一時間步的隱藏狀態,\(W_{hh}\)和\(W_{xh}\)是權重矩陣,b是偏置,\(\sigma\)是非線性激活函數(如 tanh 或 ReLU)。

2. RNN 的展開結構

雖然 RNN 在結構上包含循環,但在計算時通常將其展開為一個時間步序列。這種展開視圖更清晰地展示了 RNN 如何處理序列數據:

plaintext

x1    x2    x3    ...   xT
|     |     |           |
v     v     v           v
h0 -> h1 -> h2 -> ... -> hT
|     |     |           |
v     v     v           v
y1    y2    y3    ...   yT

其中,\(h_0\)通常初始化為零向量,\(y_t\)是時間步t的輸出(如果需要)。

3. RNN 的局限性

簡單 RNN 雖然能夠處理序列數據,但存在嚴重的梯度消失或梯度爆炸問題,導致難以學習長距離依賴關系。這限制了它在處理長序列時的性能。

長短期記憶網絡(LSTM)與門控循環單元(GRU)

為了解決簡單 RNN 的局限性,研究人員提出了更復雜的門控機制,主要包括 LSTM 和 GRU。

1. 長短期記憶網絡(LSTM)

LSTM 通過引入遺忘門、輸入門和輸出門,有效控制信息的流動:

\(\begin{aligned} f_t &= \sigma(W_f[h_{t-1}, x_t] + b_f) \\ i_t &= \sigma(W_i[h_{t-1}, x_t] + b_i) \\ o_t &= \sigma(W_o[h_{t-1}, x_t] + b_o) \\ \tilde{C}_t &= \tanh(W_C[h_{t-1}, x_t] + b_C) \\ C_t &= f_t \odot C_{t-1} + i_t \odot \tilde{C}_t \\ h_t &= o_t \odot \tanh(C_t) \end{aligned}\)

其中,\(f_t\)、\(i_t\)、\(o_t\)分別是遺忘門、輸入門和輸出門,\(C_t\)是細胞狀態,\(\odot\)表示逐元素乘法。

2. 門控循環單元(GRU)

GRU 是 LSTM 的簡化版本,合并了遺忘門和輸入門,并將細胞狀態和隱藏狀態合并:

\(\begin{aligned} z_t &= \sigma(W_z[h_{t-1}, x_t] + b_z) \\ r_t &= \sigma(W_r[h_{t-1}, x_t] + b_r) \\ \tilde{h}_t &= \tanh(W_h[r_t \odot h_{t-1}, x_t] + b_h) \\ h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \end{aligned}\)

其中,\(z_t\)是更新門,\(r_t\)是重置門。

RNN 的典型應用場景

RNN 在各種序列建模任務中取得了廣泛應用:

  1. 自然語言處理:機器翻譯、文本生成、情感分析、命名實體識別等。
  2. 語音識別:將語音信號轉換為文本。
  3. 時間序列預測:股票價格預測、天氣預測等。
  4. 視頻分析:動作識別、視頻描述生成。
  5. 音樂生成:自動作曲。
使用 PyTorch 實現 RNN 進行文本分類

下面我們使用 PyTorch 實現一個基于 LSTM 的文本分類模型,使用 IMDB 電影評論數據集進行情感分析。

python

運行

import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.legacy import data, datasets
import random
import numpy as np# 設置隨機種子,保證結果可復現
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True# 定義字段
TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm', include_lengths=True)
LABEL = data.LabelField(dtype=torch.float)# 加載IMDB數據集
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)# 創建驗證集
train_data, valid_data = train_data.split(random_state=random.seed(SEED))# 構建詞匯表
MAX_VOCAB_SIZE = 25000
TEXT.build_vocab(train_data, max_size=MAX_VOCAB_SIZE, vectors="glove.6B.100d")
LABEL.build_vocab(train_data)# 創建迭代器
BATCH_SIZE = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits((train_data, valid_data, test_data), batch_size=BATCH_SIZE,sort_within_batch=True,device=device)# 定義LSTM模型
class LSTMClassifier(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout, pad_idx):super().__init__()# 嵌入層self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)# LSTM層self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, dropout=dropout)# 全連接層self.fc = nn.Linear(hidden_dim * 2, output_dim)# Dropout層self.dropout = nn.Dropout(dropout)def forward(self, text, text_lengths):# text = [sent len, batch size]# 應用dropout到嵌入層embedded = self.dropout(self.embedding(text))# embedded = [sent len, batch size, emb dim]# 打包序列packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.to('cpu'))# 通過LSTM層packed_output, (hidden, cell) = self.lstm(packed_embedded)# 展開序列output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)# output = [sent len, batch size, hid dim * num directions]# hidden = [num layers * num directions, batch size, hid dim]# cell = [num layers * num directions, batch size, hid dim]# 我們使用雙向LSTM的最終隱藏狀態hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))# hidden = [batch size, hid dim * num directions]return self.fc(hidden)# 初始化模型
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 256
OUTPUT_DIM = 1
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.5
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]model = LSTMClassifier(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, BIDIRECTIONAL, DROPOUT, PAD_IDX)# 加載預訓練的詞向量
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)# 優化器和損失函數
optimizer = optim.Adam(model.parameters())
criterion = nn.BCEWithLogitsLoss()model = model.to(device)
criterion = criterion.to(device)# 準確率計算函數
def binary_accuracy(preds, y):"""Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8"""# 四舍五入預測值rounded_preds = torch.round(torch.sigmoid(preds))correct = (rounded_preds == y).float()  # 轉換為float計算準確率acc = correct.sum() / len(correct)return acc# 訓練函數
def train(model, iterator, optimizer, criterion):epoch_loss = 0epoch_acc = 0model.train()for batch in iterator:optimizer.zero_grad()text, text_lengths = batch.textpredictions = model(text, text_lengths).squeeze(1)loss = criterion(predictions, batch.label)acc = binary_accuracy(predictions, batch.label)loss.backward()optimizer.step()epoch_loss += loss.item()epoch_acc += acc.item()return epoch_loss / len(iterator), epoch_acc / len(iterator)# 評估函數
def evaluate(model, iterator, criterion):epoch_loss = 0epoch_acc = 0model.eval()with torch.no_grad():for batch in iterator:text, text_lengths = batch.textpredictions = model(text, text_lengths).squeeze(1)loss = criterion(predictions, batch.label)acc = binary_accuracy(predictions, batch.label)epoch_loss += loss.item()epoch_acc += acc.item()return epoch_loss / len(iterator), epoch_acc / len(iterator)# 訓練模型
N_EPOCHS = 5best_valid_loss = float('inf')for epoch in range(N_EPOCHS):train_loss, train_acc = train(model, train_iterator, optimizer, criterion)valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)if valid_loss < best_valid_loss:best_valid_loss = valid_losstorch.save(model.state_dict(), 'lstm-model.pt')print(f'Epoch: {epoch+1:02}')print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')# 測試模型
model.load_state_dict(torch.load('lstm-model.pt'))
test_loss, test_acc = evaluate(model, test_iterator, criterion)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

RNN 的挑戰與發展趨勢

盡管 RNN 在序列建模中取得了成功,但仍面臨一些挑戰:

  1. 長序列處理困難:即使是 LSTM 和 GRU,在處理極長序列時仍有困難。
  2. 并行計算能力有限:RNN 的時序依賴性導致難以高效并行化。
  3. 注意力機制的興起:注意力機制可以更靈活地捕獲序列中的長距離依賴,減少對完整歷史的依賴。

近年來,RNN 的發展趨勢包括:

  1. 注意力機制與 Transformer:注意力機制和 Transformer 架構在許多序列任務中取代了傳統 RNN,如 BERT、GPT 等模型。
  2. 混合架構:結合 RNN 和注意力機制的優點,如 Google 的 T5 模型。
  3. 少樣本學習與遷移學習:利用預訓練模型(如 XLNet、RoBERTa)進行微調,減少對大量標注數據的需求。
  4. 神經圖靈機與記憶網絡:增強 RNN 的記憶能力,使其能夠處理更復雜的推理任務。

循環神經網絡為序列數據處理提供了強大的工具,盡管面臨一些挑戰,但通過不斷的研究和創新,RNN 及其變體仍在眾多領域發揮著重要作用,并將繼續推動序列建模技術的發展。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/82550.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/82550.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/82550.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

java 項目登錄請求業務解耦模塊全面

登錄是統一的閘機&#xff1b; 密碼存在數據庫中&#xff0c;用的是密文&#xff0c;后端加密&#xff0c;和數據庫中做對比 1、UserController public class UserController{Autowiredprivate IuserService userservicepublic JsonResult login(Validated RequestBody UserLo…

【手寫數據庫核心揭秘系列】第9節 可重入的SQL解析器,不斷解析Structure Query Language,語言翻譯好幫手

可重入的SQL解析器 文章目錄 可重入的SQL解析器一、概述 二、可重入解析器 2.1 可重入設置 2.2 記錄狀態的數據結構 2.3 節點數據類型定義 2.4 頭文件引用 三、調整后的程序結構 四、總結 一、概述 現在就來修改之前sqlscanner.l和sqlgram.y程序,可以不斷輸入SQL語句,循環執…

微軟開源bitnet b1.58大模型,應用效果測評(問答、知識、數學、邏輯、分析)

微軟開源bitnet b1.58大模型,應用效果測評(問答、知識、數學、邏輯、分析) 目 錄 1. 前言... 2 2. 應用部署... 2 3. 應用效果... 3 1.1 問答方面... 3 1.2 知識方面... 4 1.3 數字運算... 6 1.4 邏輯方面... …

用HTML5+JavaScript實現漢字轉拼音工具

用HTML5JavaScript實現漢字轉拼音工具 前一篇博文&#xff08;https://blog.csdn.net/cnds123/article/details/148067680&#xff09;提到&#xff0c;當需要將拼音添加到漢字上面時&#xff0c;用python實現比HTML5JavaScript實現繁瑣。在這篇博文中用HTML5JavaScript實現漢…

鴻蒙OSUniApp 開發的動態背景動畫組件#三方框架 #Uniapp

使用 UniApp 開發的動態背景動畫組件 前言 在移動應用開發中&#xff0c;動態背景動畫不僅能提升界面美感&#xff0c;還能增強用戶的沉浸感和品牌辨識度。無論是登錄頁、首頁還是活動頁&#xff0c;恰到好處的動態背景都能讓產品脫穎而出。隨著鴻蒙&#xff08;HarmonyOS&am…

云原生技術架構技術探索

文章目錄 前言一、什么是云原生技術架構二、云原生技術架構的優勢三、云原生技術架構的應用場景結語 前言 在當今的技術領域&#xff0c;云原生技術架構正以一種勢不可擋的姿態席卷而來&#xff0c;成為了眾多開發者、企業和技術愛好者關注的焦點。那么&#xff0c;究竟什么是…

AWS之AI服務

目錄 一、AWS AI布局 ??1. 底層基礎設施與芯片?? ??2. AI訓練框架與平臺?? ??3. 大模型與應用層?? ??4. 超級計算與網絡?? ??與競品對比?? AI服務 ??1. 機器學習平臺?? ??2. 預訓練AI服務?? ??3. 邊緣與物聯網AI?? ??4. 數據與AI…

lwip_bind、lwip_listen 是阻塞函數嗎

在 lwIP 協議棧中&#xff0c;lwip_bind 和 lwip_listen 函數本質上是非阻塞的。 通常&#xff0c;bind和listen在大多數實現中都是非阻塞的&#xff0c;因為它們只是設置套接字的屬性&#xff0c;不需要等待外部事件。阻塞通常發生在接受連接&#xff08;accept&#xff09;、…

【后端高階面經:消息隊列篇】28、從零設計高可用消息隊列

一、消息隊列架構設計的核心目標與挑戰 設計高性能、高可靠的消息隊列需平衡功能性與非功能性需求,解決分布式系統中的典型問題。 1.1 核心設計目標 吞吐量:支持百萬級消息/秒處理,通過分區并行化實現橫向擴展。延遲:端到端延遲控制在毫秒級,適用于實時業務場景。可靠性…

【運維實戰】Linux 內存調優之進程內存深度監控

寫在前面 內容涉及 Linux 進程內存監控 監控方式包括傳統工具 ps/top/pmap ,以及 cgroup 內存子系統&#xff0c;proc 內存偽文件系統 監控內容包括進程內存使用情況&#xff0c; 內存全局數據統計&#xff0c;內存事件指標&#xff0c;以及進程內存段數據監控 監控進程的內…

決策樹 GBDT XGBoost LightGBM

一、決策樹 1. 決策樹有一個很強的假設&#xff1a; 信息是可分的&#xff0c;否則無法進行特征分支 2. 決策樹的種類&#xff1a; 2. ID3決策樹&#xff1a; ID3決策樹的數劃分標準是信息增益&#xff1a; 信息增益衡量的是通過某個特征進行數據劃分前后熵的變化量。但是&…

java基礎學習(十四)

文章目錄 4-1 面向過程與面向對象4-2 Java語言的基本元素&#xff1a;類和對象面向對象的思想概述 4-3 對象的創建和使用內存解析匿名對象 4-1 面向過程與面向對象 面向過程(POP) 與 面向對象(OOP) 二者都是一種思想&#xff0c;面向對象是相對于面向過程而言的。面向過程&…

TCP 三次握手,第三次握手報文丟失會發生什么?

文章目錄 RTO(Retransmission Timeout)注意 客戶端收到服務端的 SYNACK 報文后&#xff0c;會回給服務端一個 ACK 報文&#xff0c;之后處于 ESTABLISHED 狀態 因為第三次握手的 ACK 是對第二次握手中 SYN 的確認報文&#xff0c;如果第三次握手報文丟失了&#xff0c;服務端就…

deepseek告訴您http與https有何區別?

有用戶經常問什么是Http , 什么是Https &#xff1f; 兩者有什么區別&#xff0c;下面為大家介紹一下兩者的區別 一、什么是HTTP HTTP是一種無狀態的應用層協議&#xff0c;用于在客戶端瀏覽器和服務器之間傳輸網頁信息&#xff0c;默認使用80端口 二、HTTP協議的特點 HTTP協議…

openresty如何禁止海外ip訪問

前幾天&#xff0c;我有一個徒弟問我&#xff0c;如何禁止海外ip訪問他的網站系統&#xff1f;操作系統采用的是centos7.9&#xff0c;發布服務采用的是openresty。通過日志他發現&#xff0c;有很多類似以下數據 {"host":"172.30.7.95","clientip&q…

理解 Redis 事務-20 (MULTI、EXEC、DISCARD)

理解 Redis 事務&#xff1a;MULTI、EXEC、DISCARD Redis 事務允許你將一組命令作為一個單一的原子操作來執行。這意味著事務中的所有命令要么全部執行&#xff0c;要么全部不執行。這對于在需要一起執行多個操作時保持數據完整性至關重要。本課程將涵蓋 Redis 事務的基礎知識…

Milvus分區-分片-段結構詳解與最佳實踐

導讀&#xff1a;在構建大規模向量數據庫應用時&#xff0c;數據組織架構的設計往往決定了系統的性能上限。Milvus作為主流向量數據庫&#xff0c;其獨特的三層架構設計——分區、分片、段&#xff0c;為海量向量數據的高效存儲和檢索提供了堅實基礎。 本文通過圖書館管理系統的…

Kettle 遠程mysql 表導入到 hadoop hive

kettle 遠程mysql 表導入到 hadoop hive &#xff08;教學用 &#xff09; 文章目錄 kettle 遠程mysql 表導入到 hadoop hive創建 對象 執行 SQL 語句 -mysql 導出 CSV格式CSV 文件遠程上傳到 HDFS運行 SSH 命令遠程登錄 run SSH 并執行 hadoop fs -put 建表和加載數據總結 創…

Linux輸出命令——echo解析

摘要 全面解析Linux echo命令核心功能&#xff0c;涵蓋文本輸出、變量解析、格式控制及高級技巧&#xff0c;助力提升Shell腳本開發與終端操作效率。 一、核心功能與定位 作為Shell腳本開發的基礎工具&#xff0c;echo命令承擔著信息輸出與數據傳遞的重要角色。其主要功能包…

Windows系統下 NVM 安裝 Node.js 及版本切換實戰指南

以下是 Windows 11 系統下使用 NVM 安裝 Node.js 并實現版本自由切換的詳細步驟&#xff1a; 一、安裝 NVM&#xff08;Node Version Manager&#xff09; 1. 卸載已有 Node.js 如果已安裝 Node.js&#xff0c;請先卸載&#xff1a; 控制面板 ? 程序與功能 ? 找到 Node.js…