🌟 大背景:訓練神經網絡 = 下山尋寶
訓練神經網絡就像你蒙著眼在一座大山里,想找最低點(最小損失)。你只能靠腳下的坡度(梯度)來決定往哪兒走。
- 你的位置 = 模型參數(權重 www)
- 坡度 = 梯度 ?J?J?J
- 走一步的大小 = 學習率 ηηη
- 山的結構 = 神經網絡的結構
現在,我們看看路上會遇到哪些坑。
1. 參數初始化:你一開始站哪兒?
💬 大白話:
你不能隨便往地上一站就開始走。
如果一開始站的位置太差(比如平地或懸崖),可能一輩子都找不到寶藏。
在神經網絡里,參數初始化就是決定模型一開始的“姿勢”。
? 常見錯誤:全0初始化
w = 0 # 所有權重都設為0
問題:所有神經元“長得一模一樣”,更新也一樣 → 對稱性問題,學不到東西。
? 正確做法:用隨機小數初始化
但隨機也不能亂來!太小 → 信號傳著傳著沒了;太大 → 信號爆炸。
于是,大神們提出了科學初始化方法:
1.1 Lecun Initialization(1998)
適用:tanh
激活函數 + 線性激活(如MLP)
核心思想:讓前向傳播時,信號的方差保持穩定。
數學公式:
權重 www 從均值為 0、方差為 1nin\frac{1}{n_{in}}nin?1? 的正態分布中采樣:
w~N(0,1nin)
w ~ N(0, \frac{1}{n_{in}})
w~N(0,nin?1?)
其中 ninn_{in}nin? 是前一層的神經元數量。
例子:
前一層有 512 個神經元 → Var(w)=1512≈0.00195→w~N(0,0.00195)Var(w) = \frac{1}{512} ≈ 0.00195 → w ~ N(0, 0.00195)Var(w)=5121?≈0.00195→w~N(0,0.00195)
1.2 Xavier (Glorot) Initialization(2010)
適用:tanh
或 sigmoid
激活函數
核心思想:同時考慮前向和反向傳播,讓梯度也不爆炸不消失。
數學公式:
權重 www 從均值為 0、方差為 2nin+nout\frac{2}{n_{in} + n_{out}}nin?+nout?2? 的正態分布中采樣:
w~N(0,2nin+nout) w ~ N(0, \frac{2}{n_{in} + n_{out}}) w~N(0,nin?+nout?2?)
其中:
- ninn_{in}nin?:前一層神經元數
- noutn_{out}nout?:后一層神經元數
例子:
nin=512,nout=256→Var(w)=2512+256=2768≈0.0026→w~N(0,0.0026)n_{in}=512, n_{out}=256 → Var(w) = \frac{2}{512+256} = \frac{2}{768} ≈ 0.0026 → w ~ N(0, 0.0026)nin?=512,nout?=256→Var(w)=512+2562?=7682?≈0.0026→w~N(0,0.0026)
? 比 Lecun 更“平衡”,兼顧前后向。
1.3 Kaiming (He) Initialization(2015)
適用:ReLU
及其變體(如 Leaky ReLU
)激活函數
為什么需要它?
因為 ReLU(x)=max?(0,x)ReLU(x) = \max(0,x)ReLU(x)=max(0,x) 會“殺死”一半的神經元(負數變0),所以信號會衰減。
核心思想:補償 ReLU
的“死亡率”,讓信號方差穩定。
數學公式:
權重 www 從均值為 0、方差為 2nin\frac{2}{n_{in}}nin?2? 的正態分布中采樣:
w~N(0,2nin) w ~ N(0, \frac{2}{n_{in}}) w~N(0,nin?2?)
例子:
nin=512→Var(w)=2/512=0.0039→w~N(0,0.0039)n_{in}=512 → Var(w) = 2/512 = 0.0039 → w ~ N(0, 0.0039)nin?=512→Var(w)=2/512=0.0039→w~N(0,0.0039)
? 現代深度學習(如ResNet)幾乎都用 Kaiming 初始化。
📊 總結對比
方法 | 適用激活函數 | 方差公式 |
---|---|---|
Lecun | tanh , 線性 | 1/nin1 / n_{in}1/nin? |
Xavier | tanh , sigmoid | 2/(nin+nout)2 / (n_{in} + n_{out})2/(nin?+nout?) |
Kaiming | ReLU , Leaky ReLU | 2/nin2 / n_{in}2/nin? |
2. 梯度消失 & 梯度爆炸
💬 大白話:
你走著走著,發現:
- 梯度消失:坡度太小,幾乎感覺不到方向,走不動了。
- 梯度爆炸:坡度太大,一步邁過頭,直接飛出山外。
📌 根源:鏈式法則的“連乘效應”
假設一個3層網絡,損失 J
對第一層權重 w1
的梯度:
?J/?w1=(?J/?a3)×(?a3/?z3)×(?z3/?a2)×(?a2/?z2)×(?z2/?a1)×(?a1/?z1)×(?z1/?w1) ?J/?w_1 = (?J/?a_3) × (?a_3/?z_3) × (?z_3/?a_2) × (?a_2/?z_2) × (?z_2/?a_1) × (?a_1/?z_1) × (?z_1/?w_1) ?J/?w1?=(?J/?a3?)×(?a3?/?z3?)×(?z3?/?a2?)×(?a2?/?z2?)×(?z2?/?a1?)×(?a1?/?z1?)×(?z1?/?w1?)
其中 a=σ(z)a = σ(z)a=σ(z),σσσ 是激活函數。
📐 例子:Sigmoid 激活函數
Sigmoid 導數:σ′(z)=σ(z)(1?σ(z))σ'(z) = σ(z)(1-σ(z))σ′(z)=σ(z)(1?σ(z)),最大值 0.25
假設每層梯度都被乘以 0.25:
- 3層:0.253=0.01560.25^3 = 0.01560.253=0.0156
- 10層:0.2510≈9.5e?70.25^{10} ≈ 9.5e-70.2510≈9.5e?7 → 幾乎為0 → 梯度消失!
📐 例子:權重太大
如果 www 初始化太大,比如 w=5w=5w=5,那 z=w×az = w×az=w×a 會很大 → σ(z)≈1σ(z) ≈ 1σ(z)≈1 → σ′(z)≈0σ'(z) ≈ 0σ′(z)≈0 → 同樣梯度消失。
或者 w=10w=10w=10,zzz 超大 → 梯度超大 → 一步更新 w=w?η×?Jw = w - η×?Jw=w?η×?J,直接飛出合理范圍 → 梯度爆炸。
? 解決方案
- 換激活函數:用
ReLU
,它的導數在正區間是 1,不會讓梯度變小。 - 用 Kaiming 初始化:專門針對
ReLU
設計。 - 加 Batch Normalization:讓每層輸入保持穩定分布。
- 殘差連接:見下文。
3. 殘差連接(Residual Connection)
💬 大白話:
梯度消失是因為路太長,梯度傳不回去。
那怎么辦?修條“捷徑”!
讓信息和梯度可以“抄近道”直接傳回去。
📐 數學公式(ResNet)
普通層:
a[l]=σ(w[l]×a[l?1]+b[l])
a[l] = σ( w[l] × a[l-1] + b[l] )
a[l]=σ(w[l]×a[l?1]+b[l])
殘差塊:
a[l]=σ(F(a[l?1])+a[l?1])
a[l] = σ( F(a[l-1]) + a[l-1] )
a[l]=σ(F(a[l?1])+a[l?1])
其中 F(a[l?1])F(a[l-1])F(a[l?1]) 是“主路”(比如兩層卷積),a[l?1]a[l-1]a[l?1] 是“捷徑”。
📐 為什么能解決梯度消失?
看梯度 ?J/?a[l?1]?J/?a[l-1]?J/?a[l?1]:
?J/?a[l?1]=?J/?a[l]×?a[l]/?a[l?1] ?J/?a[l-1] = ?J/?a[l] × ?a[l]/?a[l-1] ?J/?a[l?1]=?J/?a[l]×?a[l]/?a[l?1]
而:
?a[l]/?a[l?1]=?/?a[l?1][F(a[l?1])+a[l?1]]=F′+I
?a[l]/?a[l-1] = ?/?a[l-1] [ F(a[l-1]) + a[l-1] ] = F' + I
?a[l]/?a[l?1]=?/?a[l?1][F(a[l?1])+a[l?1]]=F′+I
其中 III 是單位矩陣(來自 a[l?1]a[l-1]a[l?1] 的導數)。
關鍵:即使 F′≈0F' ≈ 0F′≈0(主路梯度消失),F′+IF' + IF′+I 的最小特征值也有 1!
所以梯度至少能以 111 的比例傳回去,不會消失。
🖼? 就像你下山,主路被雪埋了,但旁邊有條水泥小路(+1),你還能走回去。
? 效果:
ResNet 可以輕松訓練 100層、1000層 的網絡!
4. 學習率與訓練不穩定性
💬 大白話:
你走一步的大小(學習率 η
)得合適。
η
太小:走得慢,天黑了還沒到底。η
太大:容易邁過頭,來回震蕩,甚至飛出山外(發散)。
📐 數學公式(梯度下降)
wt+1=wt?η×?J(wt) w_{t+1} = w_t - η × ?J(w_t) wt+1?=wt??η×?J(wt?)
📌 例子:J(w)=w2J(w) = w2J(w)=w2(拋物線,最小值在 w=0w=0w=0)
- 梯度:?J(w)=2w?J(w) = 2w?J(w)=2w
- 從 w=3w=3w=3 開始
學習率 ηηη | 更新過程 | 結果 |
---|---|---|
η=0.1η=0.1η=0.1 | 3→2.4→1.92→...3 → 2.4 → 1.92 → ...3→2.4→1.92→... | 慢慢收斂 |
η=0.6η=0.6η=0.6 | 3→?0.6→0.12→?0.024→...3 → -0.6 → 0.12 → -0.024 → ...3→?0.6→0.12→?0.024→... | 震蕩收斂 |
η=1.1η=1.1η=1.1 | 3→?3.6→4.32→...3 → -3.6 → 4.32 → ...3→?3.6→4.32→... | 發散! |
🎯 訓練不穩定性表現:
- 損失函數上躥下跳
- 損失變成 NaNNaNNaN(數值溢出)
? 解決方案
- 調小學習率:最直接。
- 學習率預熱(Warm-up):開始用小學習率,等穩定了再加大。
- 學習率衰減:訓練后期逐漸減小 ηηη。
- 用自適應優化器:如 Adam,它會自動調每個參數的“步子”。
📊 總結對比表
問題 | 原因 | 數學關鍵 | 解決方案 |
---|---|---|---|
初始化不當 | 信號爆炸/消失 | Var(w)Var(w)Var(w) 不合理 | Lecun/Xavier/Kaiming |
梯度消失 | 鏈式法則連乘 | ∏σ′(z)≈0∏ σ'(z) ≈ 0∏σ′(z)≈0 | ReLU + Kaiming + BN + Residual |
殘差連接 | 深網絡梯度傳不回 | ?a[l]/?a[l?1]=F′+I?a[l]/?a[l-1] = F' + I?a[l]/?a[l?1]=F′+I | 加“捷徑” |
學習率太大 | 更新步子太大 | w=w?η×?Jw = w - η×?Jw=w?η×?J 爆炸 | 調 ηηη,用 Adam |
💡 一句話記住
- 初始化:別亂站,按激活函數選“科學站位”。
- 梯度消失:路太長,坡沒了 → 修“捷徑”(殘差)。
- 學習率:步子太大扯著蛋,太小半天到不了。