RNN中遠距離時間步梯度消失問題及解決辦法
- RNN 遠距離時間步梯度消失問題
- LSTM如何解決遠距離時間步梯度消失問題
RNN 遠距離時間步梯度消失問題
經典的RNN結構如下圖所示:
假設我們的時間序列只有三段, S 0 S_{0} S0? 為給定值,神經元沒有激活函數,則RNN最簡單的前向傳播過程如下:
S 1 = W x X 1 + W s S 0 + b 1 , O 1 = W 0 S 1 + b 2 S_{1} = W_{x} X_{1} + W_{s}S_{0} + b_{1},O_{1} = W_{0} S_{1} + b_{2} S1?=Wx?X1?+Ws?S0?+b1?,O1?=W0?S1?+b2?
S 2 = W x X 2 + W s S 1 + b 1 , O 2 = W 0 S 2 + b 2 S_{2} = W_{x} X_{2} + W_{s}S_{1} + b_{1},O_{2} = W_{0} S_{2} + b_{2} S2?=Wx?X2?+Ws?S1?+b1?,O2?=W0?S2?+b2?
S 3 = W x X 3 + W s S 2 + b 1 , O 3 = W 0 S 3 + b 2 S_{3} = W_{x} X_{3} + W_{s}S_{2} + b_{1},O_{3} = W_{0} S_{3} + b_{2} S3?=Wx?X3?+Ws?S2?+b1?,O3?=W0?S3?+b2?
假設在 t = 3 t=3 t=3時刻,損失函數為 L 3 = 1 2 ( Y 3 ? O 3 ) 2 L_3 = \frac{1}{2}(Y_3 - O_3)^2 L3?=21?(Y3??O3?)2 。則對于一次訓練任務的損失函數為 L = ∑ t = 0 T L t L = \sum_{t=0}^{T} L_t L=∑t=0T?Lt? ,即每一時刻損失值的累加。
使用隨機梯度下降法訓練RNN其實就是對 W x W_x Wx? 、 W s W_s Ws? 、 W o W_o Wo? 以及 b 1 、 b 2 b_1 、 b_2 b1?、b2? 求偏導,并不斷調整它們以使 L L L盡可能達到最小的過程。
現在假設我們我們的時間序列只有三段:t1,t2,t3。我們只對 t 3 t3 t3時刻的 W x W_x Wx?、 W s W_s Ws?、 W o W_o Wo? 求偏導(其他時刻類似):
? L 3 ? W 0 = ? L 3 ? O 3 ? O 3 ? W o = ? L 3 ? O 3 S 3 \frac{\partial L_3}{\partial W_0} = \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial W_o} = \frac{\partial L_3}{\partial O_3} S_3 ?W0??L3??=?O3??L3???Wo??O3??=?O3??L3??S3?
? L 3 ? W x = ? L 3 ? O 3 ? O 3 ? S 3 ? S 3 ? W x + ? L 3 ? O 3 ? O 3 ? S 3 ? S 3 ? S 2 ? S 2 ? W x + ? L 3 ? O 3 ? O 3 ? S 3 ? S 3 ? S 2 ? S 2 ? S 1 ? S 1 ? W x = ? L 3 ? O 3 W 0 ( X 3 + S 2 W s + S 1 W s 2 ) \frac{\partial L_3}{\partial W_x} = \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial W_x} + \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial S_2} \frac{\partial S_2}{\partial W_x} + \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial S_2} \frac{\partial S_2}{\partial S_1} \frac{\partial S_1}{\partial W_x} = \frac{\partial L_3}{\partial O_3} W_0 (X_3 + S_2 W_s + S_1 W_s^2) ?Wx??L3??=?O3??L3???S3??O3???Wx??S3??+?O3??L3???S3??O3???S2??S3???Wx??S2??+?O3??L3???S3??O3???S2??S3???S1??S2???Wx??S1??=?O3??L3??W0?(X3?+S2?Ws?+S1?Ws2?)
? L 3 ? W s = ? L 3 ? O 3 ? O 3 ? S 3 ? S 3 ? W s + ? L 3 ? O 3 ? O 3 ? S 3 ? S 3 ? S 2 ? S 2 ? W s + ? L 3 ? O 3 ? O 3 ? S 3 ? S 3 ? S 2 ? S 2 ? S 1 ? S 1 ? W s = ? L 3 ? O 3 W 0 ( S 2 + S 1 W s + S 0 W s 2 ) \frac{\partial L_3}{\partial W_s} = \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial W_s} + \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial S_2} \frac{\partial S_2}{\partial W_s} + \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial S_2} \frac{\partial S_2}{\partial S_1} \frac{\partial S_1}{\partial W_s} = \frac{\partial L_3}{\partial O_3} W_0 (S_2 + S_1 W_s + S_0 W_s^2) ?Ws??L3??=?O3??L3???S3??O3???Ws??S3??+?O3??L3???S3??O3???S2??S3???Ws??S2??+?O3??L3???S3??O3???S2??S3???S1??S2???Ws??S1??=?O3??L3??W0?(S2?+S1?Ws?+S0?Ws2?)
關于上面這個多元復合函數鏈式求導過程,通過如下對變量層級樹的遍歷可以更加直觀理解這一點:
可以看出對于 W o W_o Wo? 求偏導并沒有長期依賴,但是對于 W x W_x Wx?、 W s W_s Ws? 求偏導,會隨著時間序列產生長期依賴。因為 S t S_t St? 隨著時間序列向前傳播,而 S t S_t St? 又是 W x W_x Wx?、 W s W_s Ws? 的函數。
根據上述求偏導的過程,我們可以得出任意時刻對 W x W_x Wx?、 W s W_s Ws? 求偏導的公式:
? L t ? W x = ∑ k = 0 t ? L t ? O t ? O t ? S t ( ∏ j = k + 1 t ? S j ? S j ? 1 ) ? S k ? W x \frac{\partial L_t}{\partial W_x} = \sum_{k=0}^{t} \frac{\partial L_t}{\partial O_t} \frac{\partial O_t}{\partial S_t} \left(\prod_{j=k+1}^{t} \frac{\partial S_j}{\partial S_{j-1}}\right) \frac{\partial S_k}{\partial W_x} ?Wx??Lt??=k=0∑t??Ot??Lt???St??Ot?? ?j=k+1∏t??Sj?1??Sj?? ??Wx??Sk??
任意時刻對 W s W_s Ws? 求偏導的公式同上。
如果加上激活函數: S j = tanh ? ( W x X j + W s S j ? 1 + b 1 ) S_j = \tanh(W_x X_j + W_s S_{j-1} + b_1) Sj?=tanh(Wx?Xj?+Ws?Sj?1?+b1?)
則 ∏ j = k + 1 t ? S j ? S j ? 1 = ∏ j = k + 1 t tanh ? ′ W s \prod_{j=k+1}^{t} \frac{\partial S_j}{\partial S_{j-1}} = \prod_{j=k+1}^{t} \tanh' W_s j=k+1∏t??Sj?1??Sj??=j=k+1∏t?tanh′Ws?
加上激活函數tanh復合后的多元鏈式求導過程如下圖所示:
激活函數tanh和它的導數圖像如下。
由上圖可以看出 tanh ? ′ ≤ 1 \tanh' \leq 1 tanh′≤1,對于訓練過程大部分情況下tanh的導數是小于1的,因為很少情況下會出現 W x X j + W s S j ? 1 + b 1 = 0 W_x X_j + W_s S_{j-1} + b_1 = 0 Wx?Xj?+Ws?Sj?1?+b1?=0,如果 W s W_s Ws? 也是一個大于0小于1的值,則當t很大時 ∏ j = k + 1 t tanh ? ′ W s \prod_{j=k+1}^{t} \tanh' W_s ∏j=k+1t?tanh′Ws?,就會趨近于0,和 0.0 1 50 0.01^{50} 0.0150 趨近于0是一個道理。同理當 W s W_s Ws? 很大時 ∏ j = k + 1 t tanh ? ′ W s \prod_{j=k+1}^{t} \tanh' W_s ∏j=k+1t?tanh′Ws? 就會趨近于無窮,這就是RNN中梯度消失和爆炸的原因。
至于怎么避免這種現象,再看看 ? L t ? W x = ∑ k = 0 t ? L t ? O t ? O t ? S t ( ∏ j = k + 1 t ? S j ? S j ? 1 ) ? S k ? W x \frac{\partial L_t}{\partial W_x} = \sum_{k=0}^{t} \frac{\partial L_t}{\partial O_t} \frac{\partial O_t}{\partial S_t} \left(\prod_{j=k+1}^{t} \frac{\partial S_j}{\partial S_{j-1}}\right) \frac{\partial S_k}{\partial W_x} ?Wx??Lt??=k=0∑t??Ot??Lt???St??Ot?? ?j=k+1∏t??Sj?1??Sj?? ??Wx??Sk?? 梯度消失和爆炸的根本原因就是 ∏ j = k + 1 t ? S j ? S j ? 1 \prod_{j=k+1}^{t} \frac{\partial S_j}{\partial S_{j-1}} ∏j=k+1t??Sj?1??Sj?? 這一坨,要消除這種情況就需要把這一坨在求偏導的過程中去掉,至于怎么去掉,一種辦法就是使 ? S j ? S j ? 1 ≈ 1 \frac{\partial S_j}{\partial S_{j-1}} \approx 1 ?Sj?1??Sj??≈1 另一種辦法就是使 ? S j ? S j ? 1 ≈ 0 \frac{\partial S_j}{\partial S_{j-1}} \approx 0 ?Sj?1??Sj??≈0。其實這就是LSTM做的事情。
總結:
-
RNN 的梯度計算涉及到對激活函數的導數以及權重矩陣的連乘
- 以 sigmoid 函數為例,其導數的值域在 0 到 0.25 之間,當進行多次連乘時,這些較小的值相乘會導致梯度迅速變小。
- 如果權重矩陣的特征值也小于 1,那么在多個時間步的傳遞過程中,梯度就會呈指數級下降,導致越靠前的時間步,梯度回傳的值越少。
-
由于梯度消失,靠前時間步的參數更新幅度會非常小,甚至幾乎不更新。這使得模型難以學習到序列數據中長距離的依賴關系,對于較早時間步的信息利用不足,從而影響模型的整體性能和對序列數據的建模能力。
注意 : 注意: 注意:
RNN梯度爆炸好理解,就是 ? L t ? W x \frac{\partial L_t}{\partial W_x} ?Wx??Lt??梯度數值發散,甚至慢慢就NaN了;
那梯度消失就是 ? L t ? W x \frac{\partial L_t}{\partial W_x} ?Wx??Lt??梯度變成零嗎?
并不是,我們剛剛說梯度消失是 ∣ ? S j ? S j ? 1 ∣ \left|\frac{\partial S_j}{\partial S_{j-1}}\right| ??Sj?1??Sj?? ? 一直小于1,歷史梯度不斷衰減,但不意味著總的梯度就為0了。RNN中梯度消失的含義是:距離當前時間步越長,那么其反饋的梯度信號越不顯著,最后可能完全沒有起作用,這就意味著RNN對長距離語義的捕捉能力失效了。
說白了,你優化過程都跟長距離的反饋沒關系,怎么能保證學習出來的模型能有效捕捉長距離呢?
再次通俗解釋一下RNN梯度消失,其指的不是 ? L t ? W x \frac{\partial L_t}{\partial W_x} ?Wx??Lt??梯度值接近于0,而是靠前時間步的梯度 ? L 3 ? O 3 ? O 3 ? S 3 ? S 3 ? S 2 ? S 2 ? S 1 ? S 1 ? W x \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial S_2} \frac{\partial S_2}{\partial S_1} \frac{\partial S_1}{\partial W_x} ?O3??L3???S3??O3???S2??S3???S1??S2???Wx??S1??值算出來很小,也就是靠前時間步計算出來的結果對序列最后一個預測詞的生成影響很小,也就是常說的RNN難以去建模長距離的依賴關系的原因;這并不是因為序列靠前的詞對最后一個詞的預測輸出不重要,而是由于損失函數在把有用的梯度更新信息反向回傳的過程中,被若干小于0的偏導連乘給一點點削減掉了。
LSTM如何解決遠距離時間步梯度消失問題
LSTM的更新公式比較復雜,它是:
f t = σ ( W f x t + U f h t ? 1 + b f ) f_t = \sigma (W_f x_t + U_f h_{t-1} + b_f) ft?=σ(Wf?xt?+Uf?ht?1?+bf?)
i t = σ ( W i x t + U i h t ? 1 + b i ) i_t = \sigma (W_i x_t + U_i h_{t-1} + b_i) it?=σ(Wi?xt?+Ui?ht?1?+bi?)
o t = σ ( W o x t + U o h t ? 1 + b o ) o_t = \sigma (W_o x_t + U_o h_{t-1} + b_o) ot?=σ(Wo?xt?+Uo?ht?1?+bo?)
c ^ t = tanh ? ( W c x t + U c h t ? 1 + b c ) \hat{c}_t = \tanh (W_c x_t + U_c h_{t-1} + b_c) c^t?=tanh(Wc?xt?+Uc?ht?1?+bc?)
c t = f t ° c t ? 1 + i t ° c ^ t c_t = f_t \circ c_{t-1} + i_t \circ \hat{c}_t ct?=ft?°ct?1?+it?°c^t?
h t = o t ° tanh ? ( c t ) h_t = o_t \circ \tanh(c_t) \qquad ht?=ot?°tanh(ct?)
我們可以像上面一樣計算 ? h t ? h t ? 1 \frac{\partial h_t}{\partial h_{t-1}} ?ht?1??ht??,但從 h t = o t ° tanh ? ( c t ) h_t = o_t \circ \tanh(c_t) ht?=ot?°tanh(ct?) 可以看出分析 c t c_t ct? 就等價于分析 h t h_t ht?,而計算 ? c t ? c t ? 1 \frac{\partial c_t}{\partial c_{t-1}} ?ct?1??ct?? 顯得更加簡單一些,因此我們往這個方向走。
同樣地,我們先只關心1維的情形,這時候根據求導公式,我們有
? c t ? c t ? 1 = f t + c t ? 1 ? f t ? c t ? 1 + c ^ t ? i t ? c t ? 1 + i t ? c ^ t ? c t ? 1 \frac{\partial c_t}{\partial c_{t-1}} = f_t + c_{t-1} \frac{\partial f_t}{\partial c_{t-1}} + \hat{c}_t \frac{\partial i_t}{\partial c_{t-1}} + i_t \frac{\partial \hat{c}_t}{\partial c_{t-1}} \qquad ?ct?1??ct??=ft?+ct?1??ct?1??ft??+c^t??ct?1??it??+it??ct?1??c^t??
右端第一項 f t f_t ft?,也就是我們所說的“遺忘門”,從下面的論述我們可以知道一般情況下其余三項都是次要項,因此 f t f_t ft? 是“主項”,由于 f t f_t ft? 在0~1之間,因此就意味著梯度爆炸的風險將會很小,至于會不會梯度消失,取決于 f t f_t ft? 是否接近于1。但非常碰巧的是,這里有個相當自洽的結論:如果我們的任務比較依賴于歷史信息,那么 f t f_t ft? 就會接近于1,這時候歷史的梯度信息也正好不容易消失;如果 f t f_t ft? 很接近于0,那么就說明我們的任務不依賴于歷史信息,這時候就算梯度消失也無妨了。
所以,現在的關鍵就是看“其余三項都是次要項”這個結論能否成立。后面的三項都是“一項乘以另一項的偏導”的形式,而且求偏導的項都是 σ \sigma σ或 tanh ? \tanh tanh激活, σ \sigma σ和 tanh ? \tanh tanh的偏導公式基本上是等價的,它們的導數均可以用它們自身來表示:
tanh ? x = 2 σ ( 2 x ) ? 1 \tanh x = 2\sigma(2x) - 1 tanhx=2σ(2x)?1
σ ( x ) = 1 2 ( tanh ? x 2 + 1 ) \sigma(x) = \frac{1}{2} \left( \tanh \frac{x}{2} + 1 \right) \qquad σ(x)=21?(tanh2x?+1)
( tanh ? x ) ′ = 1 ? tanh ? 2 x (\tanh x)' = 1 - \tanh^2 x (tanhx)′=1?tanh2x
σ ′ ( x ) = σ ( x ) ( 1 ? σ ( x ) ) \sigma'(x) = \sigma(x) (1 - \sigma(x)) σ′(x)=σ(x)(1?σ(x))
其中 σ ( x ) = 1 / ( 1 + e ? x ) \sigma(x) = 1/(1 + e^{-x}) σ(x)=1/(1+e?x) 是sigmoid函數。
因此后面三項是類似的,分析了其中一項就相當于分析了其余兩項。以第二項為例,代入 h t ? 1 = o t ? 1 tanh ? ( c t ? 1 ) h_{t-1} = o_{t-1} \tanh(c_{t-1}) ht?1?=ot?1?tanh(ct?1?),可以算得
c t ? 1 ? f t ? c t ? 1 = f t ( 1 ? f t ) o t ? 1 ( 1 ? tanh ? 2 c t ? 1 ) c t ? 1 U f c_{t-1} \frac{\partial f_t}{\partial c_{t-1}} = f_t (1 - f_t) o_{t-1} (1 - \tanh^2 c_{t-1}) c_{t-1} U_f \qquad ct?1??ct?1??ft??=ft?(1?ft?)ot?1?(1?tanh2ct?1?)ct?1?Uf?
注意到 f t , 1 ? f t , o t ? 1 f_t, 1 - f_t, o_{t-1} ft?,1?ft?,ot?1?都是在0~1之間,也可以證明 ∣ ( 1 ? tanh ? 2 c t ? 1 ) c t ? 1 ∣ < 0.45 |(1 - \tanh^2 c_{t-1}) c_{t-1}| < 0.45 ∣(1?tanh2ct?1?)ct?1?∣<0.45,因此它也在-1~1之間。所以 c t ? 1 ? f t ? c t ? 1 c_{t-1} \frac{\partial f_t}{\partial c_{t-1}} ct?1??ct?1??ft??就相當于1個 U f U_f Uf?乘上4個門,結果會變得更加小,所以只要初始化不是很糟糕,那么它都會被壓縮得相當小,因此占不到主導作用。
剩下兩項的結論也是類似的:
c ^ t ? i t ? c t ? 1 = i t ( 1 ? i t ) o t ? 1 ( 1 ? tanh ? 2 c t ? 1 ) c ^ t U i \hat{c}_t \frac{\partial i_t}{\partial c_{t-1}} = i_t (1 - i_t) o_{t-1} (1 - \tanh^2 c_{t-1}) \hat{c}_t U_i \qquad c^t??ct?1??it??=it?(1?it?)ot?1?(1?tanh2ct?1?)c^t?Ui?
i t ? c ^ t ? c t ? 1 = ( 1 ? c ^ t 2 ) o t ? 1 ( 1 ? tanh ? 2 c t ? 1 ) i t U c i_t \frac{\partial \hat{c}_t}{\partial c_{t-1}} = (1 - \hat{c}_t^2) o_{t-1} (1 - \tanh^2 c_{t-1}) i_t U_c it??ct?1??c^t??=(1?c^t2?)ot?1?(1?tanh2ct?1?)it?Uc?
所以,后面三項的梯度帶有更多的“門”,一般而言乘起來后會被壓縮的更厲害,因此占主導的項還是 f t f_t ft?, f t f_t ft? 在0~1之間這個特性決定了它梯度爆炸的風險很小,同時 f t f_t ft? 表明了模型對歷史信息的依賴性,也正好是歷史梯度的保留程度,兩者相互自洽,所以LSTM也能較好地緩解梯度消失問題。因此,LSTM同時較好地緩解了梯度消失/爆炸問題,現在我們訓練LSTM時,多數情況下只需要直接調用Adam等自適應學習率優化器,不需要人為對梯度做什么調整了。