文章目錄
- 一、RNN基礎:序列建模的核心思想
- 1.1 RNN的本質與核心機制
- 1.2 應用場景與結構分類
- 二、傳統RNN:序列模型的起點
- 2.1 內部結構與數學表達
- 2.2 計算示例
- 2.3 RNN在Pytorch中的API
- 2.4 代碼示例
- 2.5 優缺點與梯度問題
- 三、LSTM:門控機制破解長期依賴
- 3.1 四大門控機制詳解
- 3.2 "Bob"案例的LSTM完整計算示例
- 3.3 LSTM在Pytorch中的API
- 3.4 代碼示例
- 3.5 門控機制的數學本質
- 四、GRU:LSTM的輕量級進化
- 4.1 雙門控簡化結構
- 4.2 "Bob"案例的GRU完整計算過程
- 4.3 GRU在Pytorch中的API
- 3.4 代碼示例
- 五、三大模型的對比與實踐選擇
- 5.1 核心指標對比
- 5.2 適用場景建議
- 5.3 要點
循環神經網絡(RNN)作為處理序列數據的核心模型,在自然語言處理、時間序列分析等領域發揮著關鍵作用。本文將系統梳理傳統RNN、LSTM和GRU的內部機制、數學表達與實踐應用,通過統一案例對比三者的計算過程,幫助讀者深入理解序列模型的進化脈絡。
一、RNN基礎:序列建模的核心思想
1.1 RNN的本質與核心機制
RNN(Recurrent Neural Network)的核心創新在于"循環記憶"機制——將上一時間步的隱藏狀態與當前輸入結合,形成對序列依賴關系的建模能力。其本質是通過參數共享實現對時間維度的特征提取,數學上表現為:
h t = f ( W ? [ x t , h t ? 1 ] + b ) h_t = f(W \cdot [x_t, h_{t-1}] + b) ht?=f(W?[xt?,ht?1?]+b)
- x t x_t xt?:當前時間步輸入向量
- h t ? 1 h_{t-1} ht?1?:上一時間步隱藏狀態
- f f f:激活函數(通常為tanh或sigmoid)
這種結構使得RNN能夠捕捉序列中的短期依賴,例如判斷"我愛吃蘋果"中"蘋果"的詞性需依賴前文"吃"的動作。
1.2 應用場景與結構分類
RNN按輸入輸出結構可分為四類:
- N vs N:輸入輸出等長,適用于語音識別中的幀級標注
- N vs 1:序列輸入→單一輸出,典型如文本分類
- 1 vs N:單一輸入→序列輸出,常用于圖片生成描述
- N vs M:seq2seq架構,輸入輸出長度不限,是機器翻譯的基礎
二、傳統RNN:序列模型的起點
結構圖
2.1 內部結構與數學表達
傳統RNN的計算流程可拆解為:
- 輸入拼接: [ x t , h t ? 1 ] [x_t, h_{t-1}] [xt?,ht?1?]
- 線性變換: W ? [ x t , h t ? 1 ] + b W \cdot [x_t, h_{t-1}] + b W?[xt?,ht?1?]+b
- 激活輸出: h t = tanh ? ( ? ) h_t = \tanh(\cdot) ht?=tanh(?)
以一個3維輸入、4維隱藏狀態的RNN為例,其參數矩陣為:
- 輸入權重 W i h ∈ R 4 × 3 W_{ih} \in \mathbb{R}^{4 \times 3} Wih?∈R4×3
- 隱藏權重 W h h ∈ R 4 × 4 W_{hh} \in \mathbb{R}^{4 \times 4} Whh?∈R4×4
- 偏置 b ∈ R 4 b \in \mathbb{R}^4 b∈R4
2.2 計算示例
人名"Bob"的特征提取
假設:
- 字符編碼:‘B’=[0.1,0,0], ‘o’=[0,0.1,0], ‘b’=[0,0,0.1]
- RNN參數: W i h = [ [ 1 , 0 , 0 ] , [ 0 , 1 , 0 ] , [ 0 , 0 , 1 ] , [ 1 , 1 , 1 ] ] W_{ih} = [[1,0,0],[0,1,0],[0,0,1],[1,1,1]] Wih?=[[1,0,0],[0,1,0],[0,0,1],[1,1,1]](簡化示例)
- W h h = [ [ 1 , 0 , 0 , 0 ] , [ 0 , 1 , 0 , 0 ] , [ 0 , 0 , 1 , 0 ] , [ 0 , 0 , 0 , 1 ] ] W_{hh} = [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]] Whh?=[[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]
計算步驟:
-
時間步1:輸入’B’
h 1 = tanh ? ( W i h ? ′ B ′ + W h h ? h 0 ) = tanh ? ( [ 0.1 , 0 , 0 , 0.1 ] ) h_1 = \tanh(W_{ih} \cdot 'B' + W_{hh} \cdot h_0) = \tanh([0.1, 0, 0, 0.1]) h1?=tanh(Wih??′B′+Whh??h0?)=tanh([0.1,0,0,0.1])
h 1 = [ 0.099 , 0 , 0 , 0.099 ] h_1 = [0.099, 0, 0, 0.099] h1?=[0.099,0,0,0.099](假設h0全0) -
時間步2:輸入’o’
h 2 = tanh ? ( W i h ? ′ o ′ + W h h ? h 1 ) h_2 = \tanh(W_{ih} \cdot 'o' + W_{hh} \cdot h_1) h2?=tanh(Wih??′o′+Whh??h1?)
= tanh ? ( [ 0 , 0.1 , 0 , 0.1 + 0.099 ] ) = [ 0 , 0.099 , 0 , 0.197 ] = \tanh([0, 0.1, 0, 0.1+0.099]) = [0, 0.099, 0, 0.197] =tanh([0,0.1,0,0.1+0.099])=[0,0.099,0,0.197] -
時間步3:輸入’b’
h 3 = tanh ? ( [ 0 , 0 , 0.1 , 0.1 + 0.197 ] ) = [ 0 , 0 , 0.099 , 0.292 ] h_3 = \tanh([0, 0, 0.1, 0.1+0.197]) = [0, 0, 0.099, 0.292] h3?=tanh([0,0,0.1,0.1+0.197])=[0,0,0.099,0.292]
最終隱藏狀態 h 3 h_3 h3?已經包含"Bob"的所有信息,即為"Bob"的序列特征表示。
2.3 RNN在Pytorch中的API
- RNN類定義與核心參數
torch.nn.RNN(input_size, # 輸入特征維度hidden_size, # 隱藏狀態維度num_layers=1, # 堆疊的RNN層數nonlinearity='tanh', # 非線性激活函數 'tanh' 或 'relu'bias=True, # 是否使用偏置batch_first=False, # 輸入格式是否為(batch, seq, feature)dropout=0, # Dropout概率bidirectional=False # 是否為雙向RNN
)
-
輸入與輸出格式
輸入參數:
- input:輸入序列,形狀為
(seq_len, batch, input_size)
(默認)或(batch, seq_len, input_size)
(batch_first=True
) - h 0 h_0 h0?:初始隱藏狀態,形狀為
(num_layers * num_directions, batch, hidden_size)
輸出參數:
- output:所有時間步的隱藏狀態,形狀為
(seq_len, batch, hidden_size * num_directions)
- h n h_n hn?:最后一個時間步的隱藏狀態,形狀同
$h_0$
- input:輸入序列,形狀為
-
關鍵屬性與方法
權重矩陣:
- w e i g h t i h l [ k ] weight_ih_l[k] weighti?hl?[k]:第
k
層的輸入到隱藏的權重 - w e i g h t h h l [ k ] weight_hh_l[k] weighth?hl?[k]:第
k
層的隱藏到隱藏的權重 - b i a s i h l [ k ] bias_ih_l[k] biasi?hl?[k] 和 b i a s h h l [ k ] bias_hh_l[k] biash?hl?[k]:對應偏置
前向傳播方法:
- w e i g h t i h l [ k ] weight_ih_l[k] weighti?hl?[k]:第
output, h_n = rnn(input, h_0)
2.4 代碼示例
- 基本用法
import torch
import torch.nn as nn# 創建RNN模型
rnn = nn.RNN(input_size=10, # 輸入特征維度hidden_size=20, # 隱藏狀態維度num_layers=2, # 2層RNN堆疊batch_first=True, # 使用(batch, seq, feature)格式bidirectional=True # 雙向RNN
)# 準備輸入
batch_size = 3
seq_len = 5
x = torch.randn(batch_size, seq_len, 10) # 輸入序列# 初始化隱藏狀態(可選)
h0 = torch.zeros(2*2, batch_size, 20) # 2層 * 雙向# 前向傳播
output, hn = rnn(x, h0)# 輸出形狀分析
print("Output shape:", output.shape) # (3, 5, 40) [batch, seq, hidden*2]
print("Final hidden shape:", hn.shape) # (4, 3, 20) [layers*directions, batch, hidden]
- 獲取最后時間步的隱藏狀態
# 方法1:從output中獲取
last_output = output[:, -1, :] # (batch, hidden*directions)# 方法2:從hn中獲取
last_hidden = hn[-2:] if rnn.bidirectional else hn[-1] # 雙向時需拼接兩個方向
last_hidden = torch.cat([last_hidden[0], last_hidden[1]], dim=1) if rnn.bidirectional else last_hidden
2.5 優缺點與梯度問題
- 優勢:結構簡單,參數量少,短序列計算效率高
- 致命缺陷:長序列中梯度消失嚴重,如:
梯度連乘公式: ? = ∏ i = 1 n σ ′ ( z i ) ? w i \nabla = \prod_{i=1}^n \sigma'(z_i) \cdot w_i ?=∏i=1n?σ′(zi?)?wi?
當 w i < 1 w_i < 1 wi?<1時,連乘導致梯度趨近于0,無法更新遠層參數
三、LSTM:門控機制破解長期依賴
3.1 四大門控機制詳解
LSTM通過引入門控系統,將傳統RNN的單一隱藏狀態拆分為:
- 細胞狀態C:長期記憶載體
- 隱藏狀態h:短期特征表示
核心公式組:
-
遺忘門:決定丟棄歷史信息
- 功能:決定丟棄細胞狀態中的哪些歷史信息。
- 計算過程:
- 輸入當前輸入 x t x_t xt? 和上一時刻隱藏狀態 h t ? 1 h_{t-1} ht?1?,拼接后通過全連接層
- f t f_t ft? 是0到1之間的門值,1表示“完全保留”,0表示“完全遺忘”。
-
輸入門:篩選新信息
- 功能:決定當前輸入的新信息中哪些需要存儲到細胞狀態。
- 計算過程:
- 生成輸入門門值 i t i_t it?(類似遺忘門,通過sigmoid激活):
- 生成候選細胞狀態 C ~ t \tilde{C}_t C~t?(通過tanh激活):
-
細胞狀態更新:
- 功能:存儲長期記憶,通過門控機制更新。
- 更新過程:
- f t ? C t ? 1 f_t * C_{t-1} ft??Ct?1?:遺忘門作用于舊細胞狀態,丟棄部分歷史信息;
- i t ? C ~ t i_t *\tilde{C}_t it??C~t?:輸入門篩選新信息并與候選狀態結合。
-
輸出門:生成當前隱藏狀態
- 功能:決定細胞狀態中的哪些信息作為當前隱藏狀態輸出。
- 計算過程:
- 生成輸出門門值 o t o_t ot?
- 細胞狀態通過tanh激活后,與輸出門值相乘得到隱藏狀態
3.2 "Bob"案例的LSTM完整計算示例
為便于與RNN對比,我們保持輸入維度、隱藏狀態維度一致,并使用相似的參數設置:
假設條件:
- 輸入維度=3,隱藏狀態維度=4
- 字符編碼:‘B’=[0.1,0,0], ‘o’=[0,0.1,0], ‘b’=[0,0,0.1]
- LSTM參數(簡化后):
(注:每個權重矩陣實際為4x7,因拼接 h t ? 1 h_{t-1} ht?1?(4維)和 x t x_t xt?(3維))W_f = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]] W_i = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]] W_c = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]] W_o = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]]
詳細計算過程:
時間步1:輸入 ‘B’ = [0.1, 0, 0]
-
遺忘門計算:
f 1 = σ ( W f ? [ h 0 , ′ B ′ ] + b f ) f_1 = \sigma(W_f \cdot [h_0, 'B'] + b_f) f1?=σ(Wf??[h0?,′B′]+bf?)
假設 b f = [ 0 , 0 , 0 , 0 ] b_f=[0,0,0,0] bf?=[0,0,0,0],則:
W f ? [ h 0 , ′ B ′ ] = [ [ 1 , 0 , 0 , 1 , 0 , 0 , 0 ] , . . . ] ? [ 0 , 0 , 0 , 0 , 0.1 , 0 , 0 ] T = [ 0.1 , 0 , 0 , 0 ] W_f \cdot [h_0, 'B'] = [[1,0,0,1,0,0,0], ...] \cdot [0,0,0,0,0.1,0,0]^T = [0.1, 0, 0, 0] Wf??[h0?,′B′]=[[1,0,0,1,0,0,0],...]?[0,0,0,0,0.1,0,0]T=[0.1,0,0,0]
f 1 = σ ( [ 0.1 , 0 , 0 , 0 ] ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] f_1 = \sigma([0.1, 0, 0, 0]) = [0.525, 0.5, 0.5, 0.5] f1?=σ([0.1,0,0,0])=[0.525,0.5,0.5,0.5] -
輸入門計算:
i 1 = σ ( W i ? [ h 0 , ′ B ′ ] + b i ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] i_1 = \sigma(W_i \cdot [h_0, 'B'] + b_i) = [0.525, 0.5, 0.5, 0.5] i1?=σ(Wi??[h0?,′B′]+bi?)=[0.525,0.5,0.5,0.5]
C ~ 1 = tanh ? ( W c ? [ h 0 , ′ B ′ ] + b c ) = tanh ? ( [ 0.1 , 0 , 0 , 0 ] ) = [ 0.099 , 0 , 0 , 0 ] \tilde{C}_1 = \tanh(W_c \cdot [h_0, 'B'] + b_c) = \tanh([0.1, 0, 0, 0]) = [0.099, 0, 0, 0] C~1?=tanh(Wc??[h0?,′B′]+bc?)=tanh([0.1,0,0,0])=[0.099,0,0,0] -
細胞狀態更新:
C 1 = f 1 ⊙ C 0 + i 1 ⊙ C ~ 1 C_1 = f_1 \odot C_0 + i_1 \odot \tilde{C}_1 C1?=f1?⊙C0?+i1?⊙C~1?
= [ 0.525 , 0.5 , 0.5 , 0.5 ] ⊙ [ 0 , 0 , 0 , 0 ] + [ 0.525 , 0.5 , 0.5 , 0.5 ] ⊙ [ 0.099 , 0 , 0 , 0 ] = [0.525, 0.5, 0.5, 0.5] \odot [0,0,0,0] + [0.525, 0.5, 0.5, 0.5] \odot [0.099, 0, 0, 0] =[0.525,0.5,0.5,0.5]⊙[0,0,0,0]+[0.525,0.5,0.5,0.5]⊙[0.099,0,0,0]
= [ 0.052 , 0 , 0 , 0 ] = [0.052, 0, 0, 0] =[0.052,0,0,0] -
輸出門與隱藏狀態:
o 1 = σ ( W o ? [ h 0 , ′ B ′ ] + b o ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] o_1 = \sigma(W_o \cdot [h_0, 'B'] + b_o) = [0.525, 0.5, 0.5, 0.5] o1?=σ(Wo??[h0?,′B′]+bo?)=[0.525,0.5,0.5,0.5]
h 1 = o 1 ⊙ tanh ? ( C 1 ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] ⊙ [ 0.052 , 0 , 0 , 0 ] = [ 0.027 , 0 , 0 , 0 ] h_1 = o_1 \odot \tanh(C_1) = [0.525, 0.5, 0.5, 0.5] \odot [0.052, 0, 0, 0] = [0.027, 0, 0, 0] h1?=o1?⊙tanh(C1?)=[0.525,0.5,0.5,0.5]⊙[0.052,0,0,0]=[0.027,0,0,0]
時間步2:輸入 ‘o’ = [0, 0.1, 0]
-
遺忘門計算:
W f ? [ h 1 , ′ o ′ ] = [ [ 1 , 0 , 0 , 1 , 0 , 0 , 0 ] , . . . ] ? [ 0.027 , 0 , 0 , 0 , 0 , 0.1 , 0 ] T = [ 0.027 , 0.1 , 0 , 0 ] W_f \cdot [h_1, 'o'] = [[1,0,0,1,0,0,0], ...] \cdot [0.027,0,0,0,0,0.1,0]^T = [0.027, 0.1, 0, 0] Wf??[h1?,′o′]=[[1,0,0,1,0,0,0],...]?[0.027,0,0,0,0,0.1,0]T=[0.027,0.1,0,0]
f 2 = σ ( [ 0.027 , 0.1 , 0 , 0 ] ) = [ 0.507 , 0.525 , 0.5 , 0.5 ] f_2 = \sigma([0.027, 0.1, 0, 0]) = [0.507, 0.525, 0.5, 0.5] f2?=σ([0.027,0.1,0,0])=[0.507,0.525,0.5,0.5] -
輸入門計算:
i 2 = σ ( [ 0.027 , 0.1 , 0 , 0 ] ) = [ 0.507 , 0.525 , 0.5 , 0.5 ] i_2 = \sigma([0.027, 0.1, 0, 0]) = [0.507, 0.525, 0.5, 0.5] i2?=σ([0.027,0.1,0,0])=[0.507,0.525,0.5,0.5]
C ~ 2 = tanh ? ( [ 0.027 , 0.1 , 0 , 0 ] ) = [ 0.027 , 0.099 , 0 , 0 ] \tilde{C}_2 = \tanh([0.027, 0.1, 0, 0]) = [0.027, 0.099, 0, 0] C~2?=tanh([0.027,0.1,0,0])=[0.027,0.099,0,0] -
細胞狀態更新:
C 2 = f 2 ⊙ C 1 + i 2 ⊙ C ~ 2 C_2 = f_2 \odot C_1 + i_2 \odot \tilde{C}_2 C2?=f2?⊙C1?+i2?⊙C~2?
= [ 0.026 , 0 , 0 , 0 ] + [ 0.014 , 0.052 , 0 , 0 ] = [ 0.04 , 0.052 , 0 , 0 ] = [0.026, 0, 0, 0] + [0.014, 0.052, 0, 0] = [0.04, 0.052, 0, 0] =[0.026,0,0,0]+[0.014,0.052,0,0]=[0.04,0.052,0,0] -
輸出門與隱藏狀態:
o 2 = σ ( [ 0.027 , 0.1 , 0 , 0 ] ) = [ 0.507 , 0.525 , 0.5 , 0.5 ] o_2 = \sigma([0.027, 0.1, 0, 0]) = [0.507, 0.525, 0.5, 0.5] o2?=σ([0.027,0.1,0,0])=[0.507,0.525,0.5,0.5]
h 2 = o 2 ⊙ tanh ? ( C 2 ) = [ 0.507 , 0.525 , 0.5 , 0.5 ] ⊙ [ 0.04 , 0.052 , 0 , 0 ] = [ 0.02 , 0.027 , 0 , 0 ] h_2 = o_2 \odot \tanh(C_2) = [0.507, 0.525, 0.5, 0.5] \odot [0.04, 0.052, 0, 0] = [0.02, 0.027, 0, 0] h2?=o2?⊙tanh(C2?)=[0.507,0.525,0.5,0.5]⊙[0.04,0.052,0,0]=[0.02,0.027,0,0]
時間步3:輸入 ‘b’ = [0, 0, 0.1]
-
遺忘門計算:
W f ? [ h 2 , ′ b ′ ] = [ 0.02 , 0.027 , 0.1 , 0 ] W_f \cdot [h_2, 'b'] = [0.02, 0.027, 0.1, 0] Wf??[h2?,′b′]=[0.02,0.027,0.1,0]
f 3 = σ ( [ 0.02 , 0.027 , 0.1 , 0 ] ) = [ 0.505 , 0.507 , 0.525 , 0.5 ] f_3 = \sigma([0.02, 0.027, 0.1, 0]) = [0.505, 0.507, 0.525, 0.5] f3?=σ([0.02,0.027,0.1,0])=[0.505,0.507,0.525,0.5] -
輸入門計算:
i 3 = σ ( [ 0.02 , 0.027 , 0.1 , 0 ] ) = [ 0.505 , 0.507 , 0.525 , 0.5 ] i_3 = \sigma([0.02, 0.027, 0.1, 0]) = [0.505, 0.507, 0.525, 0.5] i3?=σ([0.02,0.027,0.1,0])=[0.505,0.507,0.525,0.5]
C ~ 3 = tanh ? ( [ 0.02 , 0.027 , 0.1 , 0 ] ) = [ 0.02 , 0.027 , 0.099 , 0 ] \tilde{C}_3 = \tanh([0.02, 0.027, 0.1, 0]) = [0.02, 0.027, 0.099, 0] C~3?=tanh([0.02,0.027,0.1,0])=[0.02,0.027,0.099,0] -
細胞狀態更新:
C 3 = f 3 ⊙ C 2 + i 3 ⊙ C ~ 3 C_3 = f_3 \odot C_2 + i_3 \odot \tilde{C}_3 C3?=f3?⊙C2?+i3?⊙C~3?
= [ 0.02 , 0.026 , 0 , 0 ] + [ 0.01 , 0.014 , 0.052 , 0 ] = [ 0.03 , 0.04 , 0.052 , 0 ] = [0.02, 0.026, 0, 0] + [0.01, 0.014, 0.052, 0] = [0.03, 0.04, 0.052, 0] =[0.02,0.026,0,0]+[0.01,0.014,0.052,0]=[0.03,0.04,0.052,0] -
輸出門與隱藏狀態:
o 3 = σ ( [ 0.02 , 0.027 , 0.1 , 0 ] ) = [ 0.505 , 0.507 , 0.525 , 0.5 ] o_3 = \sigma([0.02, 0.027, 0.1, 0]) = [0.505, 0.507, 0.525, 0.5] o3?=σ([0.02,0.027,0.1,0])=[0.505,0.507,0.525,0.5]
h 3 = o 3 ⊙ tanh ? ( C 3 ) = [ 0.015 , 0.02 , 0.027 , 0 ] h_3 = o_3 \odot \tanh(C_3) = [0.015, 0.02, 0.027, 0] h3?=o3?⊙tanh(C3?)=[0.015,0.02,0.027,0]
最終結果對比
模型 | "Bob"的特征表示(最終隱藏狀態) |
---|---|
RNN | [0, 0, 0.099, 0.292] |
LSTM | [0.015, 0.02, 0.027, 0] |
關鍵差異分析:
-
信息保留方式:
- RNN直接累加歷史信息,導致后期輸入權重過大(如’b’的影響占主導)
- LSTM通過門控機制平衡了各字符的影響,保留了更均衡的特征表示
-
梯度傳遞能力:
- RNN的梯度依賴 tanh ? \tanh tanh導數(最大值為1),易衰減
- LSTM的細胞狀態通過 f t f_t ft?(接近1)傳遞梯度,避免消失
3.3 LSTM在Pytorch中的API
- LSTM類定義與核心參數
torch.nn.LSTM(input_size, # 輸入特征維度hidden_size, # 隱藏狀態維度num_layers=1, # 堆疊的LSTM層數bias=True, # 是否使用偏置batch_first=False, # 輸入格式是否為(batch, seq, feature)dropout=0, # Dropout概率bidirectional=False # 是否為雙向LSTM
)
-
輸入與輸出格式
輸入參數:
- input:輸入序列,形狀為
(seq_len, batch, input_size)
(默認)或(batch, seq_len, input_size)
(batch_first=True
) - h 0 h_0 h0?:初始隱藏狀態,形狀為
(num_layers * num_directions, batch, hidden_size)
- c 0 c_0 c0?:初始細胞狀態,形狀為
(num_layers * num_directions, batch, hidden_size)
輸出參數:
- output:所有時間步的隱藏狀態,形狀為
(seq_len, batch, hidden_size * num_directions)
- h n h_n hn?:最后一個時間步的隱藏狀態,形狀同
h_0
- c n c_n cn?:最后一個時間步的細胞狀態,形狀同
c_0
- input:輸入序列,形狀為
-
關鍵屬性與方法
權重矩陣:
- w e i g h t i h l [ k ] weight_ih_l[k] weighti?hl?[k]:第
k
層的輸入到隱藏的權重(4個門合并) - w e i g h t h h l [ k ] weight_hh_l[k] weighth?hl?[k]:第
k
層的隱藏到隱藏的權重(4個門合并) - b i a s i h l [ k ] bias_ih_l[k] biasi?hl?[k] 和 b i a s h h l [ k ] bias_hh_l[k] biash?hl?[k]:對應偏置
- w e i g h t i h l [ k ] weight_ih_l[k] weighti?hl?[k]:第
-
前向傳播方法
output, (h_n, c_n) = lstm(input, (h_0, c_0))
3.4 代碼示例
- 基本用法
import torch
import torch.nn as nn# 創建LSTM模型
lstm = nn.LSTM(input_size=10, # 輸入特征維度hidden_size=20, # 隱藏狀態維度num_layers=2, # 2層LSTM堆疊batch_first=True, # 使用(batch, seq, feature)格式bidirectional=True # 雙向LSTM
)# 準備輸入
batch_size = 3
seq_len = 5
x = torch.randn(batch_size, seq_len, 10) # 輸入序列# 初始化隱藏狀態和細胞狀態(可選)
h0 = torch.zeros(2*2, batch_size, 20) # 2層 * 雙向
c0 = torch.zeros(2*2, batch_size, 20)# 前向傳播
output, (hn, cn) = lstm(x, (h0, c0))# 輸出形狀分析
print("Output shape:", output.shape) # (3, 5, 40) [batch, seq, hidden*2]
print("Final hidden shape:", hn.shape) # (4, 3, 20) [layers*directions, batch, hidden]
print("Final cell shape:", cn.shape) # (4, 3, 20)
- 獲取最后時間步的隱藏狀態
# 方法1:從output中獲取
last_output = output[:, -1, :] # (batch, hidden*directions)# 方法2:從hn中獲取
last_hidden = hn[-2:] if lstm.bidirectional else hn[-1] # 雙向時需拼接兩個方向
last_hidden = torch.cat([last_hidden[0], last_hidden[1]], dim=1) if lstm.bidirectional else last_hidden
3.5 門控機制的數學本質
LSTM通過線性組合( C t = f t C t ? 1 + i t C ~ t C_t = f_tC_{t-1} + i_t\tilde{C}_t Ct?=ft?Ct?1?+it?C~t?)實現梯度的"直連"傳播,避免了傳統RNN的連乘衰減,數學上表現為:
? C t ? C t ? 1 = f t \frac{\partial C_t}{\partial C_{t-1}} = f_t ?Ct?1??Ct??=ft?
當 f t f_t ft?接近1時,梯度可近乎無損地傳遞至遠層,這是LSTM解決長期依賴的核心。
四、GRU:LSTM的輕量級進化
4.1 雙門控簡化結構
GRU將LSTM的四門結構簡化為:
- 更新門:控制歷史信息保留比例
z t = σ ( W z ? [ h t ? 1 , x t ] + b z ) z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) zt?=σ(Wz??[ht?1?,xt?]+bz?) - 重置門:控制歷史信息遺忘程度
r t = σ ( W r ? [ h t ? 1 , x t ] + b r ) r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) rt?=σ(Wr??[ht?1?,xt?]+br?)
核心公式:
- 候選狀態: h ~ t = tanh ? ( W ? [ r t ⊙ h t ? 1 , x t ] + b ) \tilde{h}_t = \tanh(W \cdot [r_t \odot h_{t-1}, x_t] + b) h~t?=tanh(W?[rt?⊙ht?1?,xt?]+b)
- 狀態更新: h t = ( 1 ? z t ) ⊙ h t ? 1 + z t ⊙ h ~ t h_t = (1-z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t ht?=(1?zt?)⊙ht?1?+zt?⊙h~t?
4.2 "Bob"案例的GRU完整計算過程
為便于與RNN和LSTM對比,我們保持相同的輸入維度、隱藏狀態維度,并使用相似的參數設置:
假設條件:
- 輸入維度=3,隱藏狀態維度=4
- 字符編碼:‘B’=[0.1,0,0], ‘o’=[0,0.1,0], ‘b’=[0,0,0.1]
- GRU參數(簡化后):
(注:每個權重矩陣實際為4x7,因拼接 h t ? 1 h_{t-1} ht?1?(4維)和 x t x_t xt?(3維))W_z = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]] W_r = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]] W_h = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]]
詳細計算過程:
時間步1:輸入 ‘B’ = [0.1, 0, 0]
-
更新門計算:
z 1 = σ ( W z ? [ h 0 , ′ B ′ ] + b z ) z_1 = \sigma(W_z \cdot [h_0, 'B'] + b_z) z1?=σ(Wz??[h0?,′B′]+bz?)
假設 b z = [ 0 , 0 , 0 , 0 ] b_z=[0,0,0,0] bz?=[0,0,0,0],則:
W z ? [ h 0 , ′ B ′ ] = [ [ 1 , 0 , 0 , 1 , 0 , 0 , 0 ] , . . . ] ? [ 0 , 0 , 0 , 0 , 0.1 , 0 , 0 ] T = [ 0.1 , 0 , 0 , 0 ] W_z \cdot [h_0, 'B'] = [[1,0,0,1,0,0,0], ...] \cdot [0,0,0,0,0.1,0,0]^T = [0.1, 0, 0, 0] Wz??[h0?,′B′]=[[1,0,0,1,0,0,0],...]?[0,0,0,0,0.1,0,0]T=[0.1,0,0,0]
z 1 = σ ( [ 0.1 , 0 , 0 , 0 ] ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] z_1 = \sigma([0.1, 0, 0, 0]) = [0.525, 0.5, 0.5, 0.5] z1?=σ([0.1,0,0,0])=[0.525,0.5,0.5,0.5] -
重置門計算:
r 1 = σ ( W r ? [ h 0 , ′ B ′ ] + b r ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] r_1 = \sigma(W_r \cdot [h_0, 'B'] + b_r) = [0.525, 0.5, 0.5, 0.5] r1?=σ(Wr??[h0?,′B′]+br?)=[0.525,0.5,0.5,0.5] -
候選隱藏狀態:
h ~ 1 = tanh ? ( W h ? [ r 1 ⊙ h 0 , ′ B ′ ] + b h ) \tilde{h}_1 = \tanh(W_h \cdot [r_1 \odot h_0, 'B'] + b_h) h~1?=tanh(Wh??[r1?⊙h0?,′B′]+bh?)
= tanh ? ( [ 0.1 , 0 , 0 , 0 ] ) = [ 0.099 , 0 , 0 , 0 ] = \tanh([0.1, 0, 0, 0]) = [0.099, 0, 0, 0] =tanh([0.1,0,0,0])=[0.099,0,0,0] -
最終隱藏狀態:
h 1 = ( 1 ? z 1 ) ⊙ h 0 + z 1 ⊙ h ~ 1 h_1 = (1-z_1) \odot h_0 + z_1 \odot \tilde{h}_1 h1?=(1?z1?)⊙h0?+z1?⊙h~1?
= [ 0.475 , 0.5 , 0.5 , 0.5 ] ⊙ [ 0 , 0 , 0 , 0 ] + [ 0.525 , 0.5 , 0.5 , 0.5 ] ⊙ [ 0.099 , 0 , 0 , 0 ] = [0.475, 0.5, 0.5, 0.5] \odot [0,0,0,0] + [0.525, 0.5, 0.5, 0.5] \odot [0.099, 0, 0, 0] =[0.475,0.5,0.5,0.5]⊙[0,0,0,0]+[0.525,0.5,0.5,0.5]⊙[0.099,0,0,0]
= [ 0.052 , 0 , 0 , 0 ] = [0.052, 0, 0, 0] =[0.052,0,0,0]
時間步2:輸入 ‘o’ = [0, 0.1, 0]
-
更新門計算:
W z ? [ h 1 , ′ o ′ ] = [ [ 1 , 0 , 0 , 1 , 0 , 0 , 0 ] , . . . ] ? [ 0.052 , 0 , 0 , 0 , 0 , 0.1 , 0 ] T = [ 0.052 , 0.1 , 0 , 0 ] W_z \cdot [h_1, 'o'] = [[1,0,0,1,0,0,0], ...] \cdot [0.052,0,0,0,0,0.1,0]^T = [0.052, 0.1, 0, 0] Wz??[h1?,′o′]=[[1,0,0,1,0,0,0],...]?[0.052,0,0,0,0,0.1,0]T=[0.052,0.1,0,0]
z 2 = σ ( [ 0.052 , 0.1 , 0 , 0 ] ) = [ 0.513 , 0.525 , 0.5 , 0.5 ] z_2 = \sigma([0.052, 0.1, 0, 0]) = [0.513, 0.525, 0.5, 0.5] z2?=σ([0.052,0.1,0,0])=[0.513,0.525,0.5,0.5] -
重置門計算:
r 2 = σ ( [ 0.052 , 0.1 , 0 , 0 ] ) = [ 0.513 , 0.525 , 0.5 , 0.5 ] r_2 = \sigma([0.052, 0.1, 0, 0]) = [0.513, 0.525, 0.5, 0.5] r2?=σ([0.052,0.1,0,0])=[0.513,0.525,0.5,0.5] -
候選隱藏狀態:
h ~ 2 = tanh ? ( W h ? [ r 2 ⊙ h 1 , ′ o ′ ] + b h ) \tilde{h}_2 = \tanh(W_h \cdot [r_2 \odot h_1, 'o'] + b_h) h~2?=tanh(Wh??[r2?⊙h1?,′o′]+bh?)
= tanh ? ( [ 0.027 , 0.1 , 0 , 0 ] ) = [ 0.027 , 0.099 , 0 , 0 ] = \tanh([0.027, 0.1, 0, 0]) = [0.027, 0.099, 0, 0] =tanh([0.027,0.1,0,0])=[0.027,0.099,0,0] -
最終隱藏狀態:
h 2 = ( 1 ? z 2 ) ⊙ h 1 + z 2 ⊙ h ~ 2 h_2 = (1-z_2) \odot h_1 + z_2 \odot \tilde{h}_2 h2?=(1?z2?)⊙h1?+z2?⊙h~2?
= [ 0.026 , 0 , 0 , 0 ] + [ 0.014 , 0.05 , 0 , 0 ] = [ 0.04 , 0.05 , 0 , 0 ] = [0.026, 0, 0, 0] + [0.014, 0.05, 0, 0] = [0.04, 0.05, 0, 0] =[0.026,0,0,0]+[0.014,0.05,0,0]=[0.04,0.05,0,0]
時間步3:輸入 ‘b’ = [0, 0, 0.1]
-
更新門計算:
W z ? [ h 2 , ′ b ′ ] = [ 0.04 , 0.05 , 0.1 , 0 ] W_z \cdot [h_2, 'b'] = [0.04, 0.05, 0.1, 0] Wz??[h2?,′b′]=[0.04,0.05,0.1,0]
z 3 = σ ( [ 0.04 , 0.05 , 0.1 , 0 ] ) = [ 0.51 , 0.512 , 0.525 , 0.5 ] z_3 = \sigma([0.04, 0.05, 0.1, 0]) = [0.51, 0.512, 0.525, 0.5] z3?=σ([0.04,0.05,0.1,0])=[0.51,0.512,0.525,0.5] -
重置門計算:
r 3 = σ ( [ 0.04 , 0.05 , 0.1 , 0 ] ) = [ 0.51 , 0.512 , 0.525 , 0.5 ] r_3 = \sigma([0.04, 0.05, 0.1, 0]) = [0.51, 0.512, 0.525, 0.5] r3?=σ([0.04,0.05,0.1,0])=[0.51,0.512,0.525,0.5] -
候選隱藏狀態:
h ~ 3 = tanh ? ( W h ? [ r 3 ⊙ h 2 , ′ b ′ ] + b h ) \tilde{h}_3 = \tanh(W_h \cdot [r_3 \odot h_2, 'b'] + b_h) h~3?=tanh(Wh??[r3?⊙h2?,′b′]+bh?)
= tanh ? ( [ 0.02 , 0.026 , 0.1 , 0 ] ) = [ 0.02 , 0.026 , 0.099 , 0 ] = \tanh([0.02, 0.026, 0.1, 0]) = [0.02, 0.026, 0.099, 0] =tanh([0.02,0.026,0.1,0])=[0.02,0.026,0.099,0] -
最終隱藏狀態:
h 3 = ( 1 ? z 3 ) ⊙ h 2 + z 3 ⊙ h ~ 3 h_3 = (1-z_3) \odot h_2 + z_3 \odot \tilde{h}_3 h3?=(1?z3?)⊙h2?+z3?⊙h~3?
= [ 0.02 , 0.025 , 0 , 0 ] + [ 0.01 , 0.013 , 0.052 , 0 ] = [ 0.03 , 0.038 , 0.052 , 0 ] = [0.02, 0.025, 0, 0] + [0.01, 0.013, 0.052, 0] = [0.03, 0.038, 0.052, 0] =[0.02,0.025,0,0]+[0.01,0.013,0.052,0]=[0.03,0.038,0.052,0]
三種模型的最終特征表示對比
模型 | "Bob"的特征表示(最終隱藏狀態) |
---|---|
RNN | [0, 0, 0.099, 0.292] |
LSTM | [0.015, 0.02, 0.027, 0] |
GRU | [0.03, 0.038, 0.052, 0] |
關鍵差異分析:
-
信息融合方式:
- RNN直接累加輸入,導致后期信息主導
- LSTM通過細胞狀態長期記憶,但計算復雜
- GRU通過更新門動態平衡歷史與當前信息,計算效率更高
-
參數效率:
- GRU參數量約為LSTM的2/3,訓練速度更快
- 在短序列任務中,GRU通常能達到與LSTM接近的性能
4.3 GRU在Pytorch中的API
- GRU類定義與核心參數
torch.nn.GRU(input_size, # 輸入特征維度hidden_size, # 隱藏狀態維度num_layers=1, # 堆疊的GRU層數bias=True, # 是否使用偏置batch_first=False, # 輸入格式是否為(batch, seq, feature)dropout=0, # Dropout概率bidirectional=False # 是否為雙向GRU
)
-
輸入與輸出格式
輸入參數:- input:輸入序列,形狀為
(seq_len, batch, input_size)
(默認)或(batch, seq_len, input_size)
(batch_first=True
) - h_0:初始隱藏狀態,形狀為
(num_layers * num_directions, batch, hidden_size)
輸出參數:
- output:所有時間步的隱藏狀態,形狀為
(seq_len, batch, hidden_size * num_directions)
- h_n:最后一個時間步的隱藏狀態,形狀同
h_0
- input:輸入序列,形狀為
-
關鍵屬性與方法
權重矩陣:- weight_ih_l[k]:第
k
層的輸入到隱藏的權重(重置門和更新門合并) - weight_hh_l[k]:第
k
層的隱藏到隱藏的權重 - bias_ih_l[k] 和 bias_hh_l[k]:對應偏置
前向傳播方法:
- weight_ih_l[k]:第
output, h_n = gru(input, h_0) # 與LSTM相比,少了細胞狀態c_n
3.4 代碼示例
- 基本用法
import torch
import torch.nn as nn# 創建GRU模型
gru = nn.GRU(input_size=10, # 輸入特征維度hidden_size=20, # 隱藏狀態維度num_layers=2, # 2層GRU堆疊batch_first=True, # 使用(batch, seq, feature)格式bidirectional=True # 雙向GRU
)# 準備輸入
batch_size = 3
seq_len = 5
x = torch.randn(batch_size, seq_len, 10) # 輸入序列# 初始化隱藏狀態(可選)
h0 = torch.zeros(2*2, batch_size, 20) # 2層 * 雙向# 前向傳播
output, hn = gru(x, h0)# 輸出形狀分析
print("Output shape:", output.shape) # (3, 5, 40) [batch, seq, hidden*2]
print("Final hidden shape:", hn.shape) # (4, 3, 20) [layers*directions, batch, hidden]
- 獲取最后時間步的隱藏狀態
# 方法1:從output中獲取
last_output = output[:, -1, :] # (batch, hidden*directions)# 方法2:從hn中獲取
last_hidden = hn[-2:] if gru.bidirectional else hn[-1] # 雙向時需拼接兩個方向
last_hidden = torch.cat([last_hidden[0], last_hidden[1]], dim=1) if gru.bidirectional else last_hidden
五、三大模型的對比與實踐選擇
5.1 核心指標對比
模型 | 門控數量 | 參數量(輸入n→隱藏m) | 長期依賴能力 | 計算效率 |
---|---|---|---|---|
傳統RNN | 0 | nm + mm | 差 | 高 |
LSTM | 4 | 4*(nm + mm) | 優 | 低 |
GRU | 2 | 3*(nm + mm) | 良 | 中 |
5.2 適用場景建議
-
傳統RNN:
短序列任務(如長度<20的文本分類)、計算資源嚴格受限場景 -
LSTM:
長序列建模(機器翻譯、語音識別)、對精度要求高的任務 -
GRU:
中等長度序列(如對話系統、時間序列預測)、希望平衡精度與效率的場景
5.3 要點
- 梯度處理:
- LSTM/GRU天然緩解梯度消失,但仍需配合梯度裁剪(gradient clipping)防止爆炸
- 參數初始化:
- 傳統RNN需謹慎初始化權重以避免梯度問題
- 雙向與多層:
- 雙向結構可捕捉雙向依賴,多層網絡提升特征提取能力,但會顯著增加計算量
循環神經網絡的進化史是模型表達能力與計算效率的平衡藝術。從RNN到LSTM再到GRU,每一次改進都圍繞"如何更高效地建模序列依賴"展開。在實際應用中,應根據數據長度、計算資源和任務精度要求,選擇最適合的模型架構。