Softmax函數是一種將一個含任意實數的K維向量轉化為另一個K維向量的函數,這個輸出向量的每個元素都在(0, 1)區間內,并且所有元素之和等于1。
因此,它可以被看作是某種概率分布,常用于多分類問題中作為輸出層的激活函數。這里我們以拓展邏輯回歸解決多分類的角度對Softmax函數進行理解:
假設共有 C C C 個類別,模型對輸入 x \mathbf{x} x 輸出 C C C個類別的得分,
則屬于類別 c c c 的后驗概率為:
P ( y = c ∣ x ) = e β c ? x ∑ j = 1 C e β j ? x P(y = c \mid \mathbf{x}) = \frac{e^{\beta_c^\top \mathbf{x}}}{\sum_{j=1}^{C} e^{\beta_j^\top \mathbf{x}}} P(y=c∣x)=∑j=1C?eβj??xeβc??x?
其中 β c \beta_c βc? 是第 c c c 類對應的參數向量, j j j 是求和的類別索引, x \mathbf{x} x 是輸入特征向量。
為什么使用指數函數 e e e?
Softmax 函數的形式為:
σ ( z ) i = e z i ∑ j = 1 C e z j , \sigma(\mathbf{z})_i = \frac{e^{z_i}}{\sum_{j=1}^{C} e^{z_j}}, σ(z)i?=∑j=1C?ezj?ezi??,
其中每個得分 z i z_i zi? 的形式為:
z i = β i ? x , z_i = \beta_i^\top \mathbf{x}, zi?=βi??x,
表示輸入特征向量 x \mathbf{x} x 與第 i i i 類對應的參數向量 β i \beta_i βi? 的線性組合。
使用指數函數 e z i e^{z_i} ezi? 有以下幾點重要理由:
-
非負性:對于任意實數 z i z_i zi?,都有 e z i > 0 e^{z_i} > 0 ezi?>0。這保證了 Softmax 輸出的概率值始終為正數。
-
保持序關系:指數函數是嚴格單調遞增函數。若 z i > z j z_i > z_j zi?>zj?,則 e z i > e z j e^{z_i} > e^{z_j} ezi?>ezj?,從而保留了原始得分之間的相對大小關系。
-
便于求導:指數函數具有良好的可導性,且其導數形式簡單 ( d d x e x = e x ) \left(\frac{d}{dx}e^x = e^x\right) (dxd?ex=ex),這對基于梯度下降等優化算法非常友好。
-
映射到概率分布:通過除以總和 ∑ j = 1 C e z j \sum_{j=1}^{C} e^{z_j} ∑j=1C?ezj?,使得所有類別的輸出加起來等于 1,形成一個合法的概率分布。
下面的示意圖清晰地表示 Softmax 函數的原理和計算過程。以下是一個完整的推導流程示例,包括線性回歸輸出、Softmax 激活函數的應用,以及最終的分類結果。
( 0.5 0 0.7 0.5 0.5 0.9 0.1 0.1 0.6 0.6 0.1 0 ) X × ( ? 0.15 0.95 2.2 ) β = ( 0.5 ? ( ? 0.15 ) + 0 ? 0.95 + 0.7 ? 2.2 0.5 ? ( ? 0.15 ) + 0.5 ? 0.95 + 0.9 ? 2.2 0.1 ? ( ? 0.15 ) + 0.1 ? 0.95 + 0.6 ? 2.2 0.6 ? ( ? 0.15 ) + 0.1 ? 0.95 + 0 ? 2.2 ) = ( 1.385 2.43 1.37 ? 0.095 ) 線性輸出 z \overset{X}{\begin{pmatrix} 0.5 & 0 & 0.7 \\ 0.5 & 0.5 & 0.9 \\ 0.1 & 0.1 & 0.6 \\ 0.6 & 0.1 & 0 \end{pmatrix}} \times \overset{\bm{\beta}}{ \begin{pmatrix} -0.15 \\ 0.95 \\ 2.2 \end{pmatrix}} =\begin{pmatrix} 0.5 \cdot (-0.15) + 0 \cdot 0.95 + 0.7 \cdot 2.2 \\ 0.5 \cdot (-0.15) + 0.5 \cdot 0.95 + 0.9 \cdot 2.2 \\ 0.1 \cdot (-0.15) + 0.1 \cdot 0.95 + 0.6 \cdot 2.2 \\ 0.6 \cdot (-0.15) + 0.1 \cdot 0.95 + 0 \cdot 2.2 \end{pmatrix}=\overset{\text{線性輸出 } \mathbf{z}}{ \begin{pmatrix} 1.385 \\ 2.43 \\ 1.37 \\ -0.095 \end{pmatrix}}