循環神經網絡(RNN):從理論到翻譯

循環神經網絡(RNN)是一種專為處理序列數據設計的神經網絡,如時間序列、自然語言或語音。與傳統的全連接神經網絡不同,RNN具有"記憶"功能,通過循環傳遞信息,使其特別適合需要考慮上下文或順序的任務。它出現在Transformer之前,廣泛應用于文本生成、語音識別和時間序列預測(如股價預測)等領域。

RNN的數學基礎

rnn-https://zlu.me

核心方程

在每個時間步 t t t,RNN執行以下操作:

  1. 隱藏狀態更新
    h t = tanh ( W h h h t ? 1 + W x h x t + b h ) h_t = \text{tanh}(W_{hh}h_{t-1} + W_{xh}x_t + b_h) ht?=tanh(Whh?ht?1?+Wxh?xt?+bh?)

    • h t h_t ht?: 時間 t t t的新隱藏狀態(形狀:[hidden_size]
    • h t ? 1 h_{t-1} ht?1?: 前一個隱藏狀態(形狀:[hidden_size]
    • x t x_t xt?: 時間 t t t的輸入(形狀:[input_size]
    • W h h W_{hh} Whh?: 隱藏到隱藏的權重矩陣(形狀:[hidden_size, hidden_size]
    • W x h W_{xh} Wxh?: 輸入到隱藏的權重矩陣(形狀:[hidden_size, input_size]
    • b h b_h bh?: 隱藏層偏置項(形狀:[hidden_size]
    • tanh \text{tanh} tanh: 雙曲正切激活函數
  2. 輸出計算
    o t = W h y h t + b y o_t = W_{hy}h_t + b_y ot?=Why?ht?+by?

    • o t o_t ot?: 時間 t t t的輸出(形狀:[output_size]
    • W h y W_{hy} Why?: 隱藏到輸出的權重矩陣(形狀:[output_size, hidden_size]
    • b y b_y by?: 輸出偏置項(形狀:[output_size]

隨時間反向傳播(BPTT)

RNN使用BPTT進行訓練,它通過時間展開網絡并應用鏈式法則:

? L ? W = ∑ t = 1 T ? L t ? o t ? o t ? h t ∑ k = 1 t ( ∏ i = k + 1 t ? h i ? h i ? 1 ) ? h k ? W \frac{\partial L}{\partial W} = \sum_{t=1}^T \frac{\partial L_t}{\partial o_t} \frac{\partial o_t}{\partial h_t} \sum_{k=1}^t \left( \prod_{i=k+1}^t \frac{\partial h_i}{\partial h_{i-1}} \right) \frac{\partial h_k}{\partial W} ?W?L?=t=1T??ot??Lt???ht??ot??k=1t?(i=k+1t??hi?1??hi??)?W?hk??

這可能導致梯度消失/爆炸問題,LSTM和GRU架構可以解決這個問題。

GRU:門控循環單元

在深入翻譯示例之前,讓我們先了解GRU的數學基礎。GRU通過門控機制解決了標準RNN中的梯度消失問題。

GRU方程

在每個時間步 t t t,GRU計算以下內容:

  1. 更新門 ( z t z_t zt?):
    z t = σ ( W z ? [ h t ? 1 , x t ] + b z ) z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) zt?=σ(Wz??[ht?1?,xt?]+bz?)

    • z t z_t zt?: 更新門(形狀:[hidden_size]
    • W z W_z Wz?: 更新門的權重矩陣(形狀:[hidden_size, hidden_size + input_size]
    • b z b_z bz?: 更新門的偏置項(形狀:[hidden_size]
    • h t ? 1 h_{t-1} ht?1?: 前一個隱藏狀態
    • x t x_t xt?: 當前輸入
    • σ \sigma σ: Sigmoid激活函數(將值壓縮到0和1之間)

    更新門決定保留多少之前的隱藏狀態。

  2. 重置門 ( r t r_t rt?):
    r t = σ ( W r ? [ h t ? 1 , x t ] + b r ) r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) rt?=σ(Wr??[ht?1?,xt?]+br?)

    • r t r_t rt?: 重置門(形狀:[hidden_size]
    • W r W_r Wr?: 重置門的權重矩陣(形狀:[hidden_size, hidden_size + input_size]
    • b r b_r br?: 重置門的偏置項(形狀:[hidden_size]

    重置門決定忘記多少之前的隱藏狀態。

  3. 候選隱藏狀態 ( h ~ t \tilde{h}_t h~t?):
    h ~ t = tanh ( W ? [ r t ⊙ h t ? 1 , x t ] + b ) \tilde{h}_t = \text{tanh}(W \cdot [r_t \odot h_{t-1}, x_t] + b) h~t?=tanh(W?[rt?ht?1?,xt?]+b)

    • h ~ t \tilde{h}_t h~t?: 候選隱藏狀態(形狀:[hidden_size]
    • W W W: 候選狀態的權重矩陣(形狀:[hidden_size, hidden_size + input_size]
    • b b b: 偏置項(形狀:[hidden_size]
    • ⊙ \odot : 逐元素乘法(哈達瑪積)

    這表示可能使用的新隱藏狀態內容。

  4. 最終隱藏狀態 ( h t h_t ht?):
    h t = ( 1 ? z t ) ⊙ h t ? 1 + z t ⊙ h ~ t h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t ht?=(1?zt?)ht?1?+zt?h~t?

    • 最終隱藏狀態是前一個隱藏狀態和候選狀態的組合
    • z t z_t zt?作為新舊信息之間的插值因子

GRU在翻譯中的優勢

  1. 更新門

    • 在英中翻譯中,這有助于決定:
      • 保留多少上下文(例如,保持句子的主語)
      • 更新多少新信息(例如,遇到新詞時)
  2. 重置門

    • 幫助忘記不相關的信息
    • 例如,在翻譯新句子時,可以重置前一個句子的上下文
  3. 梯度流動

    • 最終隱藏狀態計算中的加法更新( + + +)有助于保持梯度流動
    • 這對于學習翻譯任務中的長程依賴關系至關重要

簡單的RNN示例

這個簡化示例訓練一個RNN來預測單詞"hello"中的下一個字符。

  1. 模型定義

    • nn.RNN處理循環計算
    • 全連接層(fc)將隱藏狀態映射到輸出(字符預測)
  2. 數據

    • 使用"hell"作為輸入,期望輸出為"ello"(序列移位)
    • 字符轉換為one-hot向量(例如,‘h’ → [1, 0, 0, 0])
  3. 訓練

    • 通過最小化預測字符和目標字符之間的交叉熵損失來學習
  4. 預測

    • 訓練后,模型可以預測下一個字符
import torch
import torch.nn as nnclass SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.rnn = nn.RNN(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x, hidden):out, hidden = self.rnn(x, hidden)out = self.fc(out)return out, hiddendef init_hidden(self, batch_size):return torch.zeros(1, batch_size, self.hidden_size)# 超參數
input_size = 4   # 唯一字符數 (h, e, l, o)
hidden_size = 8  # 隱藏狀態大小
output_size = 4  # 與input_size相同
learning_rate = 0.01# 字符詞匯表
chars = ['h', 'e', 'l', 'o']
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}# 輸入數據:"hell" 預測 "ello"
input_seq = "hell"
target_seq = "ello"# 轉換為one-hot編碼
def to_one_hot(seq):tensor = torch.zeros(1, len(seq), input_size)  # [batch_size, seq_len, input_size]for t, char in enumerate(seq):tensor[0][t][char_to_idx[char]] = 1  # 批大小為1return tensor# 準備輸入和目標張量
input_tensor = to_one_hot(input_seq)  # 形狀: [1, 4, 4]
print("輸入張量形狀:", input_tensor.shape)
target_tensor = torch.tensor([char_to_idx[ch] for ch in target_seq], dtype=torch.long)  # 形狀: [4]# 初始化模型、損失函數和優化器
model = SimpleRNN(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 訓練循環
for epoch in range(100):hidden = model.init_hidden(1)  # 批大小為1print("隱藏狀態形狀:", hidden.shape)  # 應該是 [1, 1, 8]optimizer.zero_grad()output, hidden = model(input_tensor, hidden)  # 輸出: [1, 4, 4], 隱藏: [1, 1, 8]loss = criterion(output.squeeze(0), target_tensor)  # output.squeeze(0): [4, 4], target: [4]loss.backward()optimizer.step()if epoch % 20 == 0:print(f'輪次 {epoch}, 損失: {loss.item():.4f}')# 測試模型
with torch.no_grad():hidden = model.init_hidden(1)

英中翻譯示例

我們將使用PyTorch的GRU(門控循環單元)構建一個簡單的英中翻譯模型,GRU是RNN的一種變體,能更好地處理長程依賴關系。

1. 數據準備

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np# 樣本平行語料(英文 -> 中文)
english_sentences = ["hello", "how are you", "i love machine learning","good morning", "artificial intelligence"
]chinese_sentences = ["你好", "你好嗎", "我愛機器學習","早上好", "人工智能"
]# 創建詞匯表
eng_chars = sorted(list(set(' '.join(english_sentences))))
zh_chars = sorted(list(set(''.join(chinese_sentences))))# 添加特殊標記
SOS_token = 0  # 句子開始
EOS_token = 1  # 句子結束
eng_chars = ['<SOS>', '<EOS>', '<PAD>'] + eng_chars
zh_chars = ['<SOS>', '<EOS>', '<PAD>'] + zh_chars# 創建詞到索引的映射
eng_to_idx = {ch: i for i, ch in enumerate(eng_chars)}
zh_to_idx = {ch: i for i, ch in enumerate(zh_chars)}# 將句子轉換為張量
def sentence_to_tensor(sentence, vocab, is_target=False):indices = [vocab[ch] for ch in (sentence if not is_target else sentence)]if is_target:indices.append(EOS_token)  # 為目標添加EOS標記return torch.tensor(indices, dtype=torch.long).view(-1, 1)

2. 模型架構

class Seq2Seq(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Seq2Seq, self).__init__()self.hidden_size = hidden_size# 編碼器(英文到隱藏狀態)self.embedding = nn.Embedding(input_size, hidden_size)self.gru = nn.GRU(hidden_size, hidden_size)# 解碼器(隱藏狀態到中文)self.out = nn.Linear(hidden_size, output_size)self.softmax = nn.LogSoftmax(dim=1)def forward(self, input_seq, hidden=None, max_length=10):# 編碼器embedded = self.embedding(input_seq).view(1, 1, -1)output, hidden = self.gru(embedded, hidden)# 解碼器decoder_input = torch.tensor([[SOS_token]], device=input_seq.device)decoder_hidden = hiddendecoded_words = []for _ in range(max_length):output, decoder_hidden = self.gru(self.embedding(decoder_input).view(1, 1, -1),decoder_hidden)output = self.softmax(self.out(output[0]))topv, topi = output.topk(1)if topi.item() == EOS_token:breakdecoded_words.append(zh_chars[topi.item()])decoder_input = topi.detach()return ''.join(decoded_words), decoder_hiddendef init_hidden(self):return torch.zeros(1, 1, self.hidden_size)

3. 訓練模型

# 超參數
hidden_size = 256
learning_rate = 0.01
n_epochs = 1000# 初始化模型
model = Seq2Seq(len(eng_chars), hidden_size, len(zh_chars))
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 訓練循環
for epoch in range(n_epochs):total_loss = 0for eng_sent, zh_sent in zip(english_sentences, chinese_sentences):# 準備數據input_tensor = sentence_to_tensor(eng_sent, eng_to_idx)target_tensor = sentence_to_tensor(zh_sent, zh_to_idx, is_target=True)# 前向傳播model.zero_grad()hidden = model.init_hidden()# 編碼器前向傳播embedded = model.embedding(input_tensor).view(len(input_tensor), 1, -1)_, hidden = model.gru(embedded, hidden)# 準備解碼器decoder_input = torch.tensor([[SOS_token]])decoder_hidden = hiddenloss = 0# 教師強制:使用目標作為下一個輸入for di in range(len(target_tensor)):output, decoder_hidden = model.gru(model.embedding(decoder_input).view(1, 1, -1),decoder_hidden)output = model.out(output[0])loss += criterion(output, target_tensor[di])decoder_input = target_tensor[di]# 反向傳播和優化loss.backward()optimizer.step()total_loss += loss.item() / len(target_tensor)# 打印進度if (epoch + 1) % 100 == 0:print(f'輪次 {epoch + 1}, 平均損失: {total_loss / len(english_sentences):.4f}')# 測試翻譯
def translate(sentence):with torch.no_grad():input_tensor = sentence_to_tensor(sentence.lower(), eng_to_idx)output_words, _ = model(input_tensor)return output_words# 示例翻譯
print("\n翻譯結果:")
print(f"'hello' -> '{translate('hello')}'")
print(f"'how are you' -> '{translate('how are you')}'")
print(f"'i love machine learning' -> '{translate('i love machine learning')}'")

4. 理解輸出

訓練后,模型應該能夠將簡單的英文短語翻譯成中文。例如:

  • 輸入: “hello”

    • 輸出: “你好”
  • 輸入: “how are you”

    • 輸出: “你好嗎”
  • 輸入: “i love machine learning”

    • 輸出: “我愛機器學習”

5. 關鍵組件解釋

  1. 嵌入層

    • 將離散的詞索引轉換為連續向量
    • 捕捉詞與詞之間的語義關系
  2. GRU(門控循環單元)

    • 使用更新門和重置門控制信息流
    • 解決標準RNN中的梯度消失問題
  3. 教師強制

    • 在訓練過程中使用目標輸出作為下一個輸入
    • 幫助模型更快地學習正確的翻譯
  4. 束搜索

    • 可以用于提高翻譯質量
    • 在解碼過程中跟蹤多個可能的翻譯

6. 挑戰與改進

  1. 處理變長序列

    • 使用填充和掩碼
    • 實現注意力機制以獲得更好的對齊
  2. 詞匯表大小

    • 使用子詞單元(如Byte Pair Encoding, WordPiece)
    • 實現指針生成網絡處理稀有詞
  3. 性能

    • 使用雙向RNN增強上下文理解
    • 實現Transformer架構以實現并行處理

這個示例為使用RNN進行序列到序列學習提供了基礎。對于生產系統,建議使用基于Transformer的模型(如BART或T5),這些模型在機器翻譯任務中表現出色。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/84591.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/84591.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/84591.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

window批處理文件(.bat),用來清理git的master分支

echo off chcp 65001 > nul setlocal enabledelayedexpansionecho 正在檢查Git倉庫... git rev-parse --is-inside-work-tree >nul 2>&1 if %errorlevel% neq 0 (echo 錯誤&#xff1a;當前目錄不是Git倉庫&#xff01;pauseexit /b 1 )echo 警告&#xff1a;這將…

C#中的CLR屬性、依賴屬性與附加屬性

CLR屬性的主要特征 封裝性&#xff1a; 隱藏字段的實現細節 提供對字段的受控訪問 訪問控制&#xff1a; 可單獨設置get/set訪問器的可見性 可創建只讀或只寫屬性 計算屬性&#xff1a; 可以在getter中執行計算邏輯 不需要直接對應一個字段 驗證邏輯&#xff1a; 可以…

【mysql】聯合索引和單列索引的區別

區別核心&#xff1a;聯合索引可加速多個字段組合查詢&#xff0c;單列索引只能加速一個字段。 &#x1f539;聯合索引&#xff08;復合索引&#xff09; INDEX(col1, col2, col3)適用范圍&#xff1a; WHERE col1 ... ? WHERE col1 ... AND col2 ... ? WHERE col1 ..…

如何用 HTML 展示計算機代碼

原文&#xff1a;如何用 HTML 展示計算機代碼 | w3cschool筆記 &#xff08;請勿將文章標記為付費&#xff01;&#xff01;&#xff01;&#xff01;&#xff09; 在編程學習和文檔編寫過程中&#xff0c;清晰地展示代碼是一項關鍵技能。HTML 作為網頁開發的基礎語言&#x…

大模型筆記_模型微調

1. 大模型微調的概念 大模型微調&#xff08;Fine-tuning&#xff09;是指在預訓練大語言模型&#xff08;如GPT、BERT、LLaMA等&#xff09;的基礎上&#xff0c;針對特定任務或領域&#xff0c;使用小量的目標領域數據對模型進行進一步訓練&#xff0c;使其更好地適配具體應…

React Native UI 框架與動畫系統:打造專業移動應用界面

React Native UI 框架與動畫系統&#xff1a;打造專業移動應用界面 關鍵要點 UI 框架加速開發&#xff1a;NativeBase、React Native Paper、UI Kitten 和 Tailwind-RN 提供預構建組件&#xff0c;幫助開發者快速創建美觀、一致的界面。動畫提升體驗&#xff1a;React Native…

在QT中使用OpenGL

參考資料&#xff1a; 主頁 - LearnOpenGL CN https://blog.csdn.net/qq_40120946/category_12566573.html 由于OpenGL的大多數實現都是由顯卡廠商編寫的&#xff0c;當產生一個bug時通常可以通過升級顯卡驅動來解決。 OpenGL中的名詞解釋 OpenGL 上下文&#xff08;Conte…

Qt::QueuedConnection詳解

在多線程編程中&#xff0c;線程間的通信是一個關鍵問題。Qt框架提供了強大的信號和槽機制來處理線程通信&#xff0c;其中Qt::QueuedConnection是一種非常有用的連接類型。本文將深入探討Qt::QueuedConnection的原理、使用場景及注意事項。 一、基本概念 Qt::QueuedConnecti…

X86 OpenHarmony5.1.0系統移植與安裝

近期在研究X86鴻蒙,通過一段時間的研究終于成功了,在X86機器上成功啟動了openharmony系統了.下面做個總結和分享 1. 下載源碼 獲取OpenHarmony標準系統源碼 repo init -u https://gitee.com/openharmony/manifest.git -b refs/tags/OpenHarmony-v5.1.0-Release --no-repo-ve…

如何診斷服務器硬盤故障?出現硬盤故障如何處理比較好?

當服務器硬盤出現故障時&#xff0c;及時診斷問題并采取正確的處理方法至關重要。硬盤故障可能導致數據丟失和系統不穩定&#xff0c;影響服務器的正常運行。以下是診斷服務器硬盤故障并處理的最佳實踐&#xff1a; 診斷服務器硬盤故障的步驟 1. 監控警報 硬盤監控工具&#…

vue3提供的hook和通常的函數有什么區別

Vue 3 提供的 hook&#xff08;組合式函數&#xff09; 和普通函數在使用場景、功能和設計目的上有明顯區別&#xff0c;它們是 Vue 3 組合式 API 的核心概念。下面從幾個關鍵維度分析它們的差異&#xff1a; 1. 設計目的不同 Hook&#xff08;組合式函數&#xff09; 專為 Vu…

Spark提交流程

bin/spark-submit --class org.apache.spark.examples.SparkPi --master yarn ./examples/jars/spark-examples_2.12-3.3.1.jar 10 這一句命令實際上是 啟動一個Java程序 java org.apache.spark.deploy.SparkSubmit 并將命令行參數解析到這個類的對應屬性上 因為master給…

Microsoft Copilot Studio - 嘗試一下Agent

1.簡單介紹 Microsoft Copilot Studio以前的名字是Power Virtual Agent(簡稱PVA)。Power Virutal Agent是2019年出現的&#xff0c;是低代碼平臺Power Platform的一部分。當時Generative AI還沒有出現&#xff0c;但是基于已有的Conversation AI技術&#xff0c;即Microsoft L…

【源碼剖析】2-搭建kafka源碼環境

在上篇文章kafka核心概念中&#xff0c;解釋了kafka的核心概念&#xff0c;下面開始進行kafka源碼編譯。為什么學習源碼需要進行源碼編譯呢&#xff0c;我認為主要有兩點&#xff1a; 可以進行debug&#xff0c;跟蹤代碼執行邏輯可以對源碼改動&#xff0c;強化學習學習效果 …

小紅書視頻圖文提取:采集+CV的實戰手記

項目說明&#xff1a;這波視頻&#xff0c;值不值得采&#xff1f; 你有沒有遇到過這樣的場景&#xff1f;老板說&#xff1a;“我們得看看最近小紅書上關于‘旅行’的視頻都說了些什么。”團隊做數據分析的&#xff0c;立馬傻眼&#xff1a;官網打不開、接口抓不著、視頻不能…

Cloudflare 從 Nginx 到 Pingora:性能、效率與安全的全面升級

在互聯網的快速發展中&#xff0c;高性能、高效率和高安全性的網絡服務成為了各大互聯網基礎設施提供商的核心追求。Cloudflare 作為全球領先的互聯網安全和基礎設施公司&#xff0c;近期做出了一個重大技術決策&#xff1a;棄用長期使用的 Nginx&#xff0c;轉而采用其內部開發…

從編輯到安全設置: 如何滿足專業文檔PDF處理需求

隨著數字化辦公的發展&#xff0c;PDF 已成為跨平臺文檔交互的標準格式。無論是在日常辦公、學術研究&#xff0c;還是項目協作中&#xff0c;對 PDF 文件進行高效編輯與管理的需求日益增長。功能全面、操作流暢且無額外負擔的 PDF 編輯工具&#xff0c;它是一款在功能上可與 A…

Kafka消費者組位移重設指南

#作者&#xff1a;張桐瑞 文章目錄 一、Kafka 與傳統消息引擎的核心差異二、重設消費者組位移的核心原因三、重設位移的兩大維度與七種策略四、重設位移的實現方式&#xff08;一&#xff09;Java API 方式&#xff08;二&#xff09;命令行腳本方式&#xff08;Kafka 0.11&am…

分類模型:邏輯回歸

1、針對設計&#xff1a;二分類 Logistic 回歸最初是為二分類問題設計的&#xff0c; Logistic 回歸基于概率&#xff0c;通過 Sigmoid 函數轉換輸入特征的線性組合&#xff0c;將任意實數映射到 [0, 1] 區間內。 通過引入一個決策規則&#xff08;通常是概率的閾值&#xff…

CppCon 2015 學習:C++ WAT

這段代碼展示了 C 中的一些有趣和令人困惑的特性&#xff0c;尤其是涉及數組訪問和某些語法的巧妙之處。讓我們逐個分析&#xff1a; 1. assert(map[“Hello world!”] e;) 這一行看起來很不尋常&#xff0c;因為 map 在這里被用作數組下標訪問器&#xff0c;但是在前面沒有…