背景
本文內容還是對之前關于面試題transformer的一個延伸,詳細講解一下softmax
面試常問系列(二)-神經網絡參數初始化之自注意力機制-CSDN博客
Softmax函數的梯度特性與輸入值的幅度密切相關,這是Transformer中自注意力機制需要縮放點積結果的關鍵原因。以下從數學角度展開分析:
1. Softmax 函數回顧
給定輸入向量?z?= [z?, z?, ..., z?],Softmax 輸出概率為:
????????????????????????
其中?S?是歸一化因子。
2. 梯度計算目標
計算 Softmax 對輸入?z?的梯度,即?對所有?i,j∈{1,…,k}。
3. 梯度推導
根據鏈式法則,對?σi??關于?zj??求導:
具體推到過程就不展示了,感興趣的有需要的可以評論下。因為本次重點不是通用的softmax分析,而是偏實戰分析。
4. 與交叉熵損失結合的梯度
在實際應用中,Softmax 通常與交叉熵損失 結合使用。此時梯度計算更簡單:
其中?是真實標簽的 one-hot 編碼。
5. 推導
- 交叉熵損失對??
?的梯度:
? ? 2. 通過鏈式法則:
????3. 代入在上面求解出的:
- 當
時,
- 當
時,
? ? 4.合并上述結果
6.?梯度消失問題
- 極端輸入值:若
遠大于其他
,則
,其他
。此時:
- 對
的梯度:
(若yk?=1,梯度接近0)。
- 對其他zi?的梯度:
,梯度趨近于0。
- 對
- 后果:梯度消失導致參數更新困難,模型難以訓練。
7.?縮放的作用
在Transformer中,點積結果除以dk??后:
- 輸入值范圍受限:縮放后
的方差為1,避免極端值。
- 梯度穩定性提升:
分布更均勻,
和
不會趨近于0,梯度保持有效。
5.?直觀示例
- 未縮放:若dk?=512,點積標準差結果可能達±22,Softmax輸出接近0或1,梯度消失。
- 縮放后:點積結果范圍約±5,σ(zi?)分布平緩,梯度穩定。
- 這個示例在最開始的跳轉鏈接有詳細解釋,可以參考。
總結
Softmax的梯度對輸入值敏感,過大輸入會導致梯度消失。Transformer通過除以dk??控制點積方差,確保Softmax輸入值合理,從而保持梯度穩定,提升訓練效率。這一設計是深度學習中處理高維數據時的重要技巧。