深入理解矩陣乘積的導數:以線性回歸損失函數為例
在機器學習和數據分析領域,矩陣微積分扮演著至關重要的角色。特別是當我們涉及到優化問題,如最小化損失函數時,對矩陣表達式求導變得必不可少。本文將通過一個具體的例子——線性回歸中的均方誤差損失函數,來詳細解釋如何使用分配律(FOIL,First, Outer, Inner, Last)來展開矩陣乘積,并計算其導數。
線性回歸與均方誤差
線性回歸是預測連續數值型響應變量的一種統計方法。在簡單線性回歸中,我們嘗試找到一條直線,最好地擬合輸入變量 (X) 和輸出變量 (y) 之間的關系。模型可以表示為:
y = X w + b y = Xw + b y=Xw+b
其中,(X) 是設計矩陣,(w) 是權重向量,(b) 是偏置項。在多元線性回歸中,模型擴展為:
y = X w + ? y = Xw + \epsilon y=Xw+?
這里,(\epsilon) 表示誤差項。
均方誤差損失函數
為了訓練模型,我們需要定義一個損失函數來衡量模型預測值與實際值之間的差異。均方誤差(MSE)是常用的損失函數之一,定義為:
L ( w ) = ( y ? X w ) T ( y ? X w ) L(w) = (y - Xw)^T(y - Xw) L(w)=(y?Xw)T(y?Xw)
這個函數衡量了預測值 (Xw) 與真實值 (y) 之間的平方差。
展開損失函數
為了找到最小化損失函數的 (w) 值,我們需要對 (L(w)) 求導。首先,我們展開 (L(w)):
L ( w ) = ( y T ? w T X T ) ( y ? X w ) L(w) = (y^T - w^T X^T)(y - Xw) L(w)=(yT?wTXT)(y?Xw)
應用分配律(FOIL)展開這個乘積:
- First: (y^T y)
- Outer: (-y^T Xw)
- Inner: (-w^T X^T y)
- Last: (w^T X^T Xw)
將這些項組合起來,我們得到:
L ( w ) = y T y ? y T X w ? w T X T y + w T X T X w L(w) = y^T y - y^T Xw - w^T X^T y + w^T X^T Xw L(w)=yTy?yTXw?wTXTy+wTXTXw
求導數
接下來,我們對 (L(w)) 關于 (w) 求導。注意到 (y^T y) 是常數項,其導數為0。對于其他項,我們有:
- (-y^T Xw) 的導數是 (-X^T y)。
- (-w^T X^T y) 的導數是 (-X y)。
- (w^T X^T Xw) 的導數需要使用矩陣微積分的鏈式法則,結果為 (2X^T Xw)。
因此,(L(w)) 的導數為:
? L ? w = ? X T y ? X y + 2 X T X w \frac{\partial L}{\partial w} = -X^T y - X y + 2X^T Xw ?w?L?=?XTy?Xy+2XTXw
簡化后得到:
? L ? w = 2 X T X w ? X T y ? X y \frac{\partial L}{\partial w} = 2X^T Xw - X^T y - X y ?w?L?=2XTXw?XTy?Xy
結論
通過展開損失函數并計算其導數,我們得到了一個關鍵的梯度表達式,它將用于梯度下降算法中更新權重 (w)。這個過程展示了矩陣微積分在機器學習中的重要性,特別是在處理線性模型和優化問題時。理解如何正確地展開和求導矩陣表達式是進行有效模型訓練的基礎。