第一部分:為什么需要RNN?
在了解RNN是什么之前,我們先要明白它解決了什么問題。
傳統的神經網絡,比如我們常見的前饋神經網絡(Feedforward Neural Network)或者卷積神經網絡(CNN),它們有一個共同的特點:輸入之間是相互獨立的。
你給它一張貓的圖片,它判斷是貓。再給它一張狗的圖片,它判斷是狗。
這兩個判斷過程互不影響。前一次的輸入和輸出,對后一次的判斷沒有任何幫助。這在很多場景下是沒問題的。
但是,請思考以下任務:
閱讀理解: "今天天氣很好,我心情也很___。" 空格里很可能填“好”或“不錯”。這個推斷依賴于前面的“天氣很好”。
語音識別: 當你聽到一句話的開頭,它會幫助你預測后面可能出現的音節。
股票預測: 今天的股價,很大程度上取決于昨天、前天乃至過去一段時間的走勢。
這些任務的共同點是,它們處理的都是序列數據(Sequential Data)。序列中的數據不是獨立的,前一個數據點包含了對理解后一個數據點至關重要的信息。
傳統的神經網絡缺乏記憶能力,無法處理這種時間上的依賴關系。而RNN,就是為了解決這個問題而生的。
結論:RNN是一種專門用于處理序列數據的神經網絡,其設計的核心就是賦予網絡一種“記憶”能力,讓它能夠捕捉序列中的時間依賴關系。
第二部分:RNN的核心結構
1. 折疊形式 (Folded):
,-----,| | <-- (代表信息的循環)'-----'^|x_t ---> [ A ] ---> o_t(輸入) (RNN單元) (輸出)
[ A ]: 代表RNN的處理單元。
x_t: 代表在時間點
t
的輸入。o_t: 代表在時間點
t
的輸出。最重要的部分是那個指向自身的循環箭頭: 它表示
A
單元的輸出結果(具體來說是隱藏狀態h_t
,我們稍后會講)會作為下一次計算的輸入,再次進入A
單元。這就是“循環”或“記憶”的來源。
2. 展開形式 (Unfolded):
(初始記憶)h_(-1)|v... --> [ A ] --(傳遞記憶 h_0)--> [ A ] --(傳遞記憶 h_1)--> [ A ] --(傳遞記憶 h_2)--> ...| | |^ ^ ^| | |x_0 x_1 x_2 (序列輸入)| | |v v vo_0 o_1 o_2 (序列輸出)(t=0 時刻) (t=1 時刻) (t=2 時刻)
讓我們來詳細解讀一下這個結構:
x_t:這是在時間步(time step)t 的輸入。比如,在處理一句話 "I am a student" 時,x_0 就是 "I",x_1 就是 "am",以此類推。
h_t:這是在時間步 t 的隱藏狀態(Hidden State)。可以把它理解為RNN在時間點 t 的記憶。它不僅包含了當前輸入x_t的信息,還包含了上一個時間步的隱藏狀態h_t?1(也就是過去的記憶)的信息。
o_t:這是在時間步 t 的輸出。比如,在做下一個詞預測時,o_t 就是基于到x_t為止的所有信息,預測出的下一個最可能的詞。
A:代表RNN的計算單元。重要的是,在所有時間步中,這個A是完全相同的。它包含的參數(權重矩陣)在整個序列處理過程中是共享的。這大大減少了模型的參數量,也讓模型學會一種通用的處理規則,而不是為每個時間點都學一套新規則。
圖中雖然畫了多個
[ A ]
,但請記住,它們是同一個單元,擁有完全相同的參數(權重)。我們只是為了說明流程,把它在時間維度上復制了多份。
工作流程(前向傳播):
初始狀態:在 t=0 時,我們需要一個初始的隱藏狀態 h_?1(通常初始化為全零向量)。
t=0 時刻:RNN單元接收初始隱藏狀態 h_?1 和第一個輸入 x_0。通過內部計算,它會生成新的隱藏狀態(新的記憶)h_0,并可能產生一個輸出 o_0。
t=1 時刻:RNN單元接收上一時刻的記憶 h_0 和當前輸入 x_1。它將這兩者結合,更新自己的記憶,生成新的隱藏狀態 h_1,并輸出 o_1。
循環往復:這個過程一直持續下去,直到序列的所有輸入都被處理完畢。在每一步,h_t 都像一個“記憶膠囊”,攜帶著從序列開始到當前位置的所有重要信息,向下傳遞。
結論:RNN通過一個循環的隱藏狀態(Hidden State),將過去的信息編碼并傳遞到當前步驟,從而實現了對序列數據的記憶。
第三部分:深入RNN的數學原理
現在我們把那個黑盒子 "A" 打開,看看里面到底發生了什么計算。
在任意一個時間步 t,計算主要分為兩步:
1. 更新隱藏狀態 h_t:
拆解這個公式:
h_t?1:上一時刻的隱藏狀態(向量)。
x_t:當前時刻的輸入(向量)。
W_hh:隱藏狀態到隱藏狀態的權重矩陣。它決定了“應該保留多少上一時刻的記憶”。
W_xh:輸入到隱藏狀態的權重矩陣。它決定了“應該從當前輸入中吸收多少新信息”。
b_h:隱藏狀態的偏置項(bias)。
f:激活函數。在RNN中,通常使用 tanh(雙曲正切函數)。為什么用tanh?因為它能將輸出值壓縮到-1到1之間,這有助于控制信息流,防止梯度在網絡中傳播時變得過大或過小(盡管不能完全解決,后面會講)。
2. 計算輸出 o_t:
h_t:當前時刻剛剛計算出來的隱藏狀態。
W_hy:隱藏狀態到輸出的權重矩陣。它決定了“如何利用當前的記憶來生成輸出”。
b_y:輸出的偏置項。
g:輸出層的激活函數。這個根據具體任務決定。
如果是分類任務(比如情感分析),通常用 Softmax。
如果是回歸任務(比如預測股價),可能就不用激活函數或用線性激活函數。
關鍵點:在整個訓練過程中,模型要學習的就是這幾個共享的權重矩陣(W_hh,W_xh,W_hy)和偏置項。無論序列有多長,它們都是同一套參數。
第四部分:RNN的訓練與挑戰
訓練:BPTT算法
RNN的訓練算法叫做通過時間的反向傳播(Backpropagation Through Time, BPTT)。
還記得那個展開的RNN圖嗎?BPTT的原理其實很簡單:
前向傳播:按照我們上面講的流程,從頭到尾計算出所有時間步的隱藏狀態和輸出。
計算總損失:將每個時間步的輸出 o_t 與真實標簽 y_t 進行比較,計算損失(例如使用交叉熵損失),然后將所有時間步的損失相加,得到總損失。
反向傳播:將總損失從最后一個時間步開始,沿著展開的圖反向傳播,計算每個權重矩陣的梯度。因為權重是共享的,所以每個時間步計算出的梯度會累加到對應的共享權重上。
更新權重:使用梯度下降法(如Adam, SGD等)根據累加后的總梯度來更新權重矩陣 W_hh,W_xh,W_hy。
長期依賴問題(Long-Term Dependencies)
這是簡單RNN最致命的弱點。
想象這個句子:"I grew up in France... (此處省略很多文字)... therefore, I speak fluent French."
為了正確預測出 "French",模型需要記住很久以前的信息 "France"。
在BPTT過程中,梯度需要從序列末端一直傳播回序列的開端。根據鏈式法則,這個梯度會不斷地乘以權重矩陣 W_hh。
梯度消失(Vanishing Gradients):如果 W_hh 中的值(更準確地說是它的雅可比矩陣的范數)小于1,那么在多次連乘后,梯度會變得極其微小,趨近于0。這導致模型無法從遙遠的過去學習到信息,長期記憶丟失。這是最常見也最棘手的問題。
梯度爆炸(Exploding Gradients):反之,如果 W_hh 中的值大于1,多次連乘后梯度會變得非常大,導致模型訓練不穩定,參數更新幅度過大,甚至變成NaN。這個問題相對容易發現和解決(例如通過梯度裁剪 (Gradient Clipping) 來限制梯度的大小)。
由于梯度消失問題的存在,標準的RNN很難學習到超過5-10個時間步的依賴關系,這極大地限制了它的應用。
第五部分:解決方案與演進——LSTM與GRU
為了解決長期依賴問題,研究人員設計了更復雜的RNN變體,其中最成功、最流行的就是長短期記憶網絡(Long Short-Term Memory, LSTM)和門控循環單元(Gated Recurrent Unit, GRU)。
它們的核心思想是引入門(Gate)的結構。
你可以把門想象成一個信息過濾器,它由一個Sigmoid激活函數和一個逐元素相乘操作組成。Sigmoid的輸出在0到1之間,可以看作是一個開關:
輸出為0,表示“完全關閉”,不允許任何信息通過。
輸出為1,表示“完全打開”,讓所有信息通過。
輸出在0和1之間,表示“部分打開”,按比例讓信息通過。
LSTM: 它引入了一個獨立的細胞狀態(Cell State),專門負責長距離傳遞信息。然后,它設計了三個精巧的門來控制細胞狀態:
遺忘門(Forget Gate):決定應該從細胞狀態中丟棄哪些舊信息。
輸入門(Input Gate):決定哪些新信息應該被存入細胞狀態。
輸出門(Output Gate):決定細胞狀態中的哪些信息應該被用作當前的輸出。
通過這三個門的協同工作,LSTM可以明確地學習到何時遺忘、何時記憶、何時輸出,從而有效地解決了梯度消失問題,能夠捕捉非常長的序列依賴。
GRU: 這是LSTM的一個簡化版,它將遺忘門和輸入門合并為了一個更新門(Update Gate),并且沒有獨立的細胞狀態。GRU的結構更簡單,參數更少,計算效率更高,在許多任務上能達到和LSTM相近的效果。
明天我們講解RNN的pytorch逐行實現以及LSTM與GRU的深入原理講解