GRU模型
雙向GRU筆記:https://blog.csdn.net/weixin_44579176/article/details/146459952
概念
-
GRU(Gated Recurrent Unit)也稱為門控循環單元,是一種改進版的RNN。與LSTM一樣能夠有效捕捉長序列之間的語義關聯,通過引入兩個"門"機制(重置門和更新門)來控制信息的流動,從而避免了傳統RNN中的梯度消失問題,并減少了LSTM模型中的復雜性。
[^ 要點]:1.GRU同樣是通過門機制來解決傳統RNN中的梯度消失問題的 2.GRU相比于LSTM更為簡潔,它只引入了兩個門 :更新門(Update Gate), 重置門(Reset Gate)
核心組件
-
重置門(Reset Gate)
-
作用: 決定如何將新的輸入與之前的隱藏狀態結合。
- 當重置門值接近0時,表示當前時刻的輸入幾乎不依賴上一時刻的隱藏狀態。
- 當重置門值接近1時,表示當前時刻的輸入幾乎完全依賴上一時刻的隱藏狀態。
-
公式(變體版本): r t = σ ( W r ? [ h t ? 1 , x t ] + b r ) r_t = σ(W_r·[h_{t-1},x_t] + b_r) rt?=σ(Wr??[ht?1?,xt?]+br?)
- r t r_t rt?| 重置門值, r t ∈ ( 0 , 1 ) r_t ∈ (0,1) rt?∈(0,1)
- W r W_r Wr? 和$ b_r$ | 重置門權值和偏置項
- σ | sigmoid函數 保證 r t r_t rt?的輸出值在 0 到 1之間
-
-
更新門(Update Gate)
-
作用: 決定多少之前的信息需要保留,多少新的信息需要更新。
- 當更新門值接近0時,意味著網絡只記住舊的隱藏狀態,幾乎沒有新的信息。
- 當更新門值接近1時,意味著網絡更傾向于使用新的隱藏狀態,記住當前輸入的信息。
-
公式(變體版本): z t = σ ( W r ? [ h t ? 1 , x t ] + b z ) z_t = σ(W_r·[h_{t-1},x_t] + b_z) zt?=σ(Wr??[ht?1?,xt?]+bz?)
- z t z_t zt?| 更新門值, z t ∈ ( 0 , 1 ) z_t ∈ (0,1) zt?∈(0,1)
- W r W_r Wr? 和$ b_r$ | 重置門權值和偏置項
- σ | sigmoid函數 保證 z t z_t zt?的輸出值在 0 到 1之間
-
-
候選隱藏狀態(Candidate Hidden State)
-
作用: 捕捉當前時間步的信息,多少前一隱藏狀態的信息被保留。
-
公式(變體版本): h ^ t = t a n h ( W h ? [ r t ⊙ h t ? 1 , x t ] + b h ) ?_t = tanh(W_h · [r_t \odot h_{t-1} , x_t] + b_h) h^t?=tanh(Wh??[rt?⊙ht?1?,xt?]+bh?)
- h ^ t ?_t h^t?| 候選隱藏狀態值, h ^ t ∈ ( ? 1 , 1 ) ?_t ∈ (-1,1) h^t?∈(?1,1)
- W h W_h Wh? 和$ b_h$ | 候選隱藏狀態的權重和偏置項
- tanh| 雙曲正切函數 保證 h t h_t ht?的輸出值在 -1 到 1之間
- ⊙ \odot ⊙ | Hadamard Product
-
-
最終隱藏狀態(Final Hidden State)
-
作用: 控制信息更新,傳遞長期依賴。
-
公式(變體版本): h t = ( 1 ? z t ) ⊙ h t ? 1 + z t ⊙ h ^ t h_t = (1-z_t) \odot h_{t-1} + z_t \odot ?_t ht?=(1?zt?)⊙ht?1?+zt?⊙h^t?
- h t h_t ht?| 當前時間步的隱藏狀態
- z t z_t zt? | 更新門的輸出,控制新舊信息的比例
- ⊙ \odot ⊙ | Hadamard Product
重置門與更新的對比
門控機制 核心功能 直觀理解 重置門(Reset Gate) 控制歷史信息對當前候選狀態的影響:決定是否忽略部分或全部歷史信息,從而生成新的候選隱藏狀態。 “是否忘記過去,重新開始?”(例如:處理句子中的突變或新段落) 更新門(Update Gate) 控制新舊信息的融合比例:決定保留多少舊狀態的信息,同時引入多少候選狀態的新信息。 “保留多少舊記憶,吸收多少新知識?”(例如:維持長期依賴關系) 重置門作用舉例:
? input: [‘風’,‘可以’,‘吹起’,‘一大張’,‘白紙’,‘卻’,‘無法’,‘吹走’,‘一只’,‘蝴蝶’,‘因為’,‘生命’,‘的’,‘力量’,‘在于’,‘不’,‘順從’]
-
當處理到 ‘卻’ 時,上文信息 : 風可以吹起一大張白紙
- 重置門值 : r t = 0.3 r_t = 0.3 rt?=0.3
- 作用:忽略部分歷史信息,弱化上文影響,為后續信息(無法吹走一只蝴蝶)騰出空間
- 更新門值 : z t = 0.8 z_t = 0.8 zt?=0.8
- 作用: 表示保留更多候選隱藏狀態(由于 r t r_t rt?是一個較小的值,所以候選隱藏狀態中新信息占比更大) 的信息
[^ 注]: 此時$ h_t $接近 $ ?_t$,隱藏狀態被重置為“準備處理轉折后的新邏輯”。
- 重置門值 : r t = 0.3 r_t = 0.3 rt?=0.3
-
當處理到 ‘因為’ 時,上文信息 : 少部分的 "風可以吹起一大張白紙 " + 大部分的 “無法吹走蝴蝶”
- 重置門值 : r t = 0.8 r_t = 0.8 rt?=0.8
- 作用:保留更多上文信息,以便與后續原因關聯
- 更新門值 : z t = 0.5 z_t = 0.5 zt?=0.5
- 作用: 平衡舊狀態(上文結論) 和 新狀態(下文原因) ,逐步構建完整的邏輯鏈
- 重置門值 : r t = 0.8 r_t = 0.8 rt?=0.8
-
內部結構
- GRU的更新門和重置門結構圖
Pytorch實現
nn.GRU(input_size, hidden_size, num_layers, bidirectional, batch_first, dropout)[^ input_size ]:輸入特征的維度
[^ hidden_size ]:隱藏狀態的維度
[^ num_layers ]:GRU的層數(默認值為1)
[^ batch_first ]:如果為True,輸入和輸出的形狀為 (batch_size, seq_len, input_size);否則為 (seq_len, batch_size, input_size)
[^ bidirectional ]:如果為True,使用雙向GRU;否則為單向GRU(默認False)
[^ dropout ]:在多層GRU中,是否在層之間應用dropout(默認值為0)
使用示例
# 定義GRU的參數含義: (input_size, hidden_size, num_layers)
# 定義輸入張量的參數含義: (sequence_length, batch_size, input_size)
# 定義隱藏層初始張量的參數含義: (num_layers * num_directions, batch_size, hidden_size)
import torch.nn as nn
import torchdef dm_gru():# 創建GRU層gru = nn.GRU(input_size=5, hidden_size=6, num_layers=2)# 創建輸入張量input = torch.randn(size=(1, 3, 5))# 初始化隱藏狀態h0 = torch.randn(size=(2, 3, 6))# hn輸出兩層隱藏狀態, 最后1個隱藏狀態值等于output輸出值output, hn = gru(input, h0)print('output--->', output.shape, output)print('hn--->', hn.shape, hn)