一般多任務,大家都喜歡疊加很多損失,由此產生很多損失權重系數。此外,有的學者直接對梯度進行操作。咋一看,上面三個系數貌似重復多余,直接用其中一個系數代替不行嗎?為此,回顧了下神經網絡的前向傳播和反向求導公式,感覺有點拉大旗作虎皮的意味。標題本來是“Rethinking”,想著會有一些新發現,但隨后就改成了“Relooking”蒜鳥。
形式化
直觀來說,損失權重 λ λ λ 、梯度權重 α α α、學習率 η η η可以看做是三個標量系數,即trade-off parameter 或 weighting coefficient。
L = λ 1 L 1 + λ 2 L 2 ? θ L = α 1 ? L 1 + α 2 ? L 2 θ : = θ ? η ? ? θ L \begin{aligned} L &=\lambda_1 L_1+\lambda_2 L_2\\ \nabla_\theta L &=\alpha_1 \nabla L_1+\alpha_2 \nabla L_2\\ \theta :&= \theta-\eta \cdot \nabla_\theta L \end{aligned} L?θ?Lθ:?=λ1?L1?+λ2?L2?=α1??L1?+α2??L2?=θ?η??θ?L?
作用:
- 損失權重 λ λ λ:對相應任務的損失值進行縮放。 λ λ λ越大,表明該項貢獻越大(越重要),則要放大其損失值,促使模型對該項的優化。反之,越小,則是該項損失趨近0,貢獻被忽略。
- 梯度權重 α α α:在反向傳播中,直接對梯度值進行縮放。
- 學習率 η η η:對所有梯度統一縮放,以控制模型參數的更新步長。 η η η越大,則模型參數的步長越大。
案例講解
下面以一個神經網絡的為例,從底層原理來看它們的作用。
1. 網絡結構定義
考慮一個雙層網絡:
- 輸入: x x x
- 參數: W 1 , b 1 , W 2 , b 2 W_1, b_1, W_2, b_2 W1?,b1?,W2?,b2?
- 激活函數: g ( ? ) g(\cdot) g(?) (如ReLU)
- 輸出層未激活
2. 前向傳播
流程:Fc1 --> Activation --> Fc2。
z 1 = W 1 x + b 1 a 1 = g ( z 1 ) z 2 = W 2 a 1 + b 2 \begin{align} z_1 &= W_1 x + b_1 \\ a_1 &= g(z_1) \\ z_2 &= W_2 a_1 + b_2 \\ \end{align} z1?a1?z2??=W1?x+b1?=g(z1?)=W2?a1?+b2???
3. 多任務損失計算
為了方便展示損失任務的權重系數,這里假設兩個損失函數。其中,主任務交叉熵損失,輔助任務均方誤差損失。
L = λ 1 ? CE ( z 2 , y ce ) + λ 2 ? MSE ( a 1 , y mse ) = λ 1 ? l o s s 1 + λ 2 ? l o s s 2 \begin{align} L &= \lambda_1 \cdot \text{CE}(z_2, y_{\text{ce}}) + \lambda_2 \cdot \text{MSE}(a_1, y_{\text{mse}}) \\ &= \lambda_1 \cdot loss_1 + \lambda_2 \cdot loss_2 \end{align} L?=λ1??CE(z2?,yce?)+λ2??MSE(a1?,ymse?)=λ1??loss1?+λ2??loss2???
4. 反向傳播梯度計算
? L ? W 1 = ? ( λ 1 l o s s 1 + λ 2 l o s s 2 ) ? W 1 = α 1 ( λ 1 ? l o s s 1 ? W 1 ) + α 2 ( λ 2 ? l o s s 2 ? W 1 ) = α 1 λ 1 ? l o s s 1 ? W 1 + α 2 λ 2 ? l o s s 2 ? W 1 \begin{align} \frac{\partial L}{\partial W_1} &= \frac{\partial (\lambda_1 loss_1 + \lambda_2 loss_2)}{\partial W_1} \\ &= \alpha_1 \left( \lambda_1 \frac{\partial loss_1}{\partial W_1}\right) + \alpha_2 \left( \lambda_2 \frac{\partial loss_2}{\partial W_1}\right) \\ &= \alpha_1 \lambda_1 \frac{\partial loss_1}{\partial W_1} + \alpha_2 \lambda_2 \frac{\partial loss_2}{\partial W_1} \\ \end{align} ?W1??L??=?W1??(λ1?loss1?+λ2?loss2?)?=α1?(λ1??W1??loss1??)+α2?(λ2??W1??loss2??)=α1?λ1??W1??loss1??+α2?λ2??W1??loss2????
5. 參數更新
W 1 ← W 1 ? η ? ? L ? W 1 W_1 \leftarrow W_1 - \eta \cdot \frac{\partial L}{\partial W_1} W1?←W1??η??W1??L?
即:
Δ W 1 = ? η [ α 1 ? 梯度權重 ( λ 1 ? 損失權重 ? l o s s 1 ? W 1 ) + α 2 ? 梯度權重 ( λ 2 ? 損失權重 ? l o s s 2 ? W 1 ) ] \Delta W_1 = -\eta \left[ \overbrace{\alpha_1}^{\text{梯度權重}} \left( \overbrace{\lambda_1}^{\text{損失權重}} \frac{\partial loss_1}{\partial W_1} \right) + \overbrace{\alpha_2}^{\text{梯度權重}} \left( \overbrace{\lambda_2}^{\text{損失權重}} \frac{\partial loss_2}{\partial W_1} \right) \right] ΔW1?=?η ?α1? ?梯度權重? ?λ1? ?損失權重??W1??loss1?? ?+α2? ?梯度權重? ?λ2? ?損失權重??W1??loss2?? ? ?
總結
- 根據step4可知,一般不需要對梯度進行懲罰操作,且過于復雜,直接對損失函數施加權重具有同樣的功能。
- 根據step5可知,學習率全局縮放梯度向量,即調整整體的步長。
- 如梯度裁剪或者梯度歸一化等特殊情況才在內部對梯度操作,非必須,一般不作用于梯度。
注:上述情況與GPT 4O交流的結果。以當前本人的水平,還無法體會到更深層次的含義。