一、LSTM的背景與動機
1.1 為什么需要LSTM?
在深度學習中,普通的神經網絡(如全連接網絡或卷積神經網絡)在處理序列數據時表現不佳,因為它們無法捕捉數據中的時間依賴關系。循環神經網絡(RNN)被設計來處理序列數據,通過隱藏狀態在時間步之間傳遞信息。然而,傳統RNN存在兩個主要問題:
- 梯度消失/爆炸:在反向傳播時,梯度可能隨著時間步的增加變得極小(消失)或極大(爆炸),導致模型難以學習長期依賴關系。
- 長期依賴問題:RNN在理論上可以記住長時間步的信息,但實際上由于梯度問題,很難捕捉長序列中的遠距離依賴。
LSTM由Hochreiter和Schmidhuber在1997年提出,旨在解決這些問題。它通過引入門控機制(Gates)和記憶單元(Cell State),能夠選擇性地記住或遺忘信息,從而有效建模長期和短期依賴。
1.2 LSTM的核心思想
LSTM的核心是通過一個記憶單元(Cell State)來保存長期信息,并通過門控機制(輸入門、遺忘門、輸出門)控制信息的流動。這些門決定:
- 哪些信息需要被保留(長期記憶)。
- 哪些信息需要被遺忘。
- 當前時間步應該輸出什么。
這使得LSTM在處理長序列時表現優異,適合任務如機器翻譯、文本生成和時間序列預測。
二、LSTM的架構與工作原理
LSTM的基本單元由以下幾個部分組成:
- 記憶單元(Cell State):負責存儲長期信息,貫穿整個序列。
- 隱藏狀態(Hidden State):負責輸出當前時間步的信息,包含短期記憶。
- 門控機制:包括遺忘門(Forget Gate)、輸入門(Input Gate)和輸出門(Output Gate),控制信息的流動。
下面我們詳細解析每個部分。
2.1 記憶單元(Cell State)
記憶單元是LSTM的核心,它像一條“傳送帶”,貫穿所有時間步,負責存儲和傳遞長期信息。Cell State通過門控機制進行更新,確保模型能夠記住關鍵信息(如句子的主語)并遺忘無關信息。
數學上,Cell State在時間步 t t t 的更新公式為:
C t = f t ⊙ C t ? 1 + i t ⊙ C ~ t C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t Ct?=ft?⊙Ct?1?+it?⊙C~t?
其中:
- C t C_t Ct?:當前時間步的Cell State。
- C t ? 1 C_{t-1} Ct?1?:上一時間步的Cell State。
- f t f_t ft?:遺忘門輸出,決定保留多少上一時間步的信息。
- i t i_t it?:輸入門輸出,決定當前輸入有多少信息被加入。
- C ~ t \tilde{C}_t C~t?:候選Cell State,表示當前時間步的候選記憶。
- ⊙ \odot ⊙:逐元素相乘(Hadamard乘積)。
2.2 隱藏狀態(Hidden State)
隱藏狀態 h t h_t ht? 是LSTM的輸出,包含當前時間步的短期信息。它由Cell State通過輸出門進行調節:
h t = o t ⊙ tanh ? ( C t ) h_t = o_t \odot \tanh(C_t) ht?=ot?⊙tanh(Ct?)
其中:
- o t o_t ot?:輸出門輸出,控制Cell State的信息流向隱藏狀態。
- tanh ? \tanh tanh:激活函數,將Cell State的值壓縮到 [ ? 1 , 1 ] [-1, 1] [?1,1].
隱藏狀態 h t h_t ht? 通常被用作模型的輸出,或傳遞到下一層網絡。
2.3 門控機制
LSTM通過三個門控機制控制信息的流動,每個門都使用sigmoid激活函數(輸出范圍為 [ 0 , 1 ] [0, 1] [0,1]),決定信息保留的比例。
2.3.1 遺忘門(Forget Gate)
遺忘門決定上一時間步的Cell State中有多少信息需要被遺忘。它的計算公式為:
f t = σ ( W f ? [ h t ? 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft?=σ(Wf??[ht?1?,xt?]+bf?)
其中:
- h t ? 1 h_