數學原理:
(1) 前向傳播的方差一致性
假設輸入 x 的均值為 0,方差為 σx2σ_x^2σx2?,權重 W的均值為 0,方差為 σW2σ_W^2σW2?,則輸出 z=Wxz=Wxz=Wx的方差為:
Var(z)=nin?Var(W)?Var(x)
Var(z)=n_{in}?Var(W)?Var(x)
Var(z)=nin??Var(W)?Var(x)
為了使 Var(z)=Var(x),需要:
nin?Var(W)=1?????Var(W)=1nin
n_{in}?Var(W)=1?????Var(W)=\frac{1}{n_{in}}
nin??Var(W)=1?????Var(W)=nin?1?
其中 ninn_{in}nin?是輸入維度(fan_in)。這里乘以 nin 的原因是,輸出 z 是由 nin 個輸入 x 的線性組合得到的,每個輸入 x 都與一個權重 W 相乘。因此,輸出 z 的方差是 nin 個獨立的 Wx 項的方差之和。
(2) 反向傳播的梯度方差一致性
在反向傳播過程中,梯度 ?L?x\frac{?L}{?x}?x?L? 是通過鏈式法則計算得到的,其中 L 是損失函數,x 是輸入,z 是輸出。梯度?L?x\frac{?L}{?x}?x?L?可以表示為:
?L?x=?L?z.?z?x
\frac{?L}{?x}=\frac{?L}{?z}.\frac{?z}{?x}
?x?L?=?z?L?.?x?z?
假設 z=Wx,其中 W 是權重矩陣,那么 ?z?x=W\frac{?z}{?x}=W?x?z?=W。因此,梯度 ?L?x\frac{?L}{?x}?x?L?可以寫為: ?L?x=?L?zW\frac{?L}{?x}=\frac{?L}{?z}W?x?L?=?z?L?W
反向傳播時梯度 ?L?x\frac{?L}{?x}?x?L? 的方差應與 ?L?z\frac{?L}{?z}?z?L? 相同,因此:
nout?Var(W)=1?????Var(W)=1nout
n_{out}?Var(W)=1?????Var(W)=\frac{1}{n_{out}}
nout??Var(W)=1?????Var(W)=nout?1?
其中 noutn_{out}nout?是輸出維度(fan_out)。為了保持梯度的方差一致性,我們需要確保每個輸入維度 nin 的梯度方差與輸出維度 nout 的梯度方差相同。因此,我們需要將 W 的方差乘以 nout,以確保梯度的方差在反向傳播過程中保持一致。
(3) 綜合考慮
為了同時平衡前向傳播和反向傳播,Xavier 采用:
Var(W)=2nin+nout
Var(W)=\frac{2}{n_{in}+n_{out}}
Var(W)=nin?+nout?2?
權重從以下分布中采樣:
均勻分布:
W~U(?6nin+nout,6nin+nout)
W\sim\mathrm{U}\left(-\frac{\sqrt{6}}{\sqrt{n_\mathrm{in}+n_\mathrm{out}}},\frac{\sqrt{6}}{\sqrt{n_\mathrm{in}+n_\mathrm{out}}}\right)
W~U(?nin?+nout??6??,nin?+nout??6??)
在Xavier初始化中,我們選擇 a=?6nin+nouta=?\sqrt{\frac{6}{n_{in}+n_{out}}}a=?nin?+nout?6?? 和 b=6nin+noutb=\sqrt{\frac{6}{n_{in}+n_{out}}}b=nin?+nout?6??,這樣方差為:
Var(W)=(b?a)212=(26nin+nout)212=4?6nin+nout12=2nin+nout
Var(W)=\frac{(b?a)^2}{12}=\frac{(2\sqrt{\frac{6}{n_{in}+n_{out}}})^2}{12}=\frac{4?\frac{6}{nin+nout}}{12}=\frac{2}{n_{in}+n_{out}}
Var(W)=12(b?a)2?=12(2nin?+nout?6??)2?=124?nin+nout6??=nin?+nout?2?
正態分布:
W~N(0,2nin+nout)
W\sim\mathrm{N}\left(0,\frac{2}{n_\mathrm{in}+n_\mathrm{out}}\right)
W~N(0,nin?+nout?2?)
N(0,std2) \mathcal{N}(0, \text{std}^2) N(0,std2)
其中 ninn_{\text{in}}nin? 是當前層的輸入神經元數量,noutn_{\text{out}}nout?是輸出神經元數量。
在前向傳播中,輸出的方差受 ninn_{in}nin? 影響。在反向傳播中,梯度的方差受 noutn_{out}nout? 影響。