這里寫目錄標題
- 1. RNN存在哪些問題呢?
- 1.1 梯度彌散和梯度爆炸
- 1.2 RNN為什么會出現梯度彌散和梯度爆炸呢?
- 2. 解決梯度爆炸方法
- 3. Gradient Clipping的實現
- 4. 解決梯度彌散的方法
1. RNN存在哪些問題呢?
1.1 梯度彌散和梯度爆炸
梯度彌散是梯度趨近于0
梯度爆炸是梯度趨近無窮大
1.2 RNN為什么會出現梯度彌散和梯度爆炸呢?
先看RNN的梯度推導公式,如下圖:
從hk的梯度求導公式和hk的計算過程可以看出,hk的計算和Whh相關,也就是梯度也與Whh有關,因此從h1 時刻到hk時刻,Whh被乘了k-1次,即Whhk-1,那么當W>1時,就使得Wrk隨著k(句子長度)的增大,梯度趨近無窮大,會出現梯度爆炸,而W<1時,Wrk隨著k(句子長度)的增大,梯度會趨近于0,會出現梯度彌散。
綜上:RNN并不是可以處理無限長的句子,其隨著句子的增長可能出現梯度彌散和梯度爆炸的問題
2. 解決梯度爆炸方法
上圖為一篇解決梯度爆炸的paper,其中左邊的圖描述的是梯度爆炸產生的原因,當W出現巨變的時候會導致loss的方向發生變化,從而偏移原來正確的方向,出現梯度爆炸。
解決梯度爆炸的方法是給w.grad設置一個閾值,比如是15,當大于閾值時,將w.grad’=w.grad/||w.grad||15=115=15,從而保證了loss的方向不變,loss雖然可能有一些跳變,比如:從0.23~0.32,,但慢慢的還會下降。
這種方法叫gradient clipping
3. Gradient Clipping的實現
只需獲取到模型參數后調用torch.nn.utils.clip_grad_norm_(p,10)即可,10為閾值。
見下圖,注意torch.nn.utils.clip_grad_norm_(p,10)和print是平齊的。
4. 解決梯度彌散的方法
下文LSTM會講。