梯度下降
一 、為什么要用到梯度下降?
正規方程的缺陷:
非凸函數問題:損失函數非凸時,導數為0會得到多個極值點(非唯一解)
計算效率低:逆矩陣運算時間復雜度 O(n3),特征量翻倍時計算時間增為8倍(16特征需512秒)。
?
結論:梯度下降是高效求解大規模、非凸問題的通用優化算法。
二、梯度下降核心思想
目標:以最快的速度找到損失函數 loss的最小值點(最優參數 W)。
原理類比:
人在山地走向谷底,每一步沿當前最陡峭下坡方向行走。
步驟 1:判斷 “下坡最陡的方向”
你低頭觀察腳下的地面:
左邊地面微微向下傾斜,坡度較緩;
正前方地面明顯向下傾斜,坡度最陡;
右邊地面甚至有點向上傾斜(上坡)。
這里的 “坡度” 就是梯度—— 它不僅告訴你 “哪個方向是下坡”,還告訴你 “哪個方向下坡最陡”(梯度的方向),以及 “陡到什么程度”(梯度的大小)。
步驟 2:沿最陡方向走一小步
既然正前方下坡最陡,你就朝著正前方走一步(步長不能太大,否則可能踩空或錯過轉彎)。這一步對應參數更新:
方向:沿 “最陡下坡方向”(負梯度方向,因為梯度本身是 “上坡最陡” 的方向);
步長:對應 “學習率”(不能太大,否則可能直接沖到山的另一側;也不能太小,否則走得太慢)。
步驟 3:重復調整方向,逐步逼近山腳
走完一步后,你站在新的位置,再次觀察腳下的坡度(重新計算梯度),發現此時 “左前方” 變成了最陡的下坡方向。于是你調整方向,沿左前方再走一步…… 這個過程不斷重復:每次都根據當前位置的坡度調整方向,走一小步,直到走到坡度幾乎為 0 的平地(山腳)。
梯度 g是損失函數 loss對參數 W 的偏導數。
如果 g < 0, w 就變大 ; 如果g > 0 , w 就變小(目標左邊是斜率為負右邊為正 )
沿梯度反方向更新參數:W=W?α?g(α 為學習率)。
然后判斷是否收斂(loss變化很小就收斂),如果收斂就跳出迭代,如果沒收斂就再次更新參數 W...
三、單參數(w)梯度下降實現
1. 更新公式
2. 參數α更新邏輯
位置 | 梯度?gg | 更新方向 | 操作 |
---|---|---|---|
最小值左側 | g<0 | w 增大 | w=w?(負值)→右移? |
最小值右側 | g>0 | w?減小 | w=w?(正值)→左移? |
示例流程(初始?w=0.2,α=0.01):
計算 : 假設w=0.2時g=0.24?→?w_new=0.2?0.01×0.24=0.1976
迭代更新直至收斂(g最小)。
# 定義總損失
def loss(w):return 10*(w**2)-15.9*w+6.5# 定義梯度
def g(w):return 20*w-15.9# 定義模型
def model(x,w):return x*w# 繪制模型
def draw_line(w):pt_x = np.linspace(0,5,100)pt_y = model(pt_x,w)plt.plot(pt_x,pt_y)#隨機初始化w
w =10# 迭代
for i in range(100):print('w:',w,'loss:',loss(w))# 學習率lr = 1/(i+100)# 更新ww = w-lr*g(w)x=np.array([4.2,4.2, 2.7, 0.8, 3.7, 1.7, 3.2])
y=np.array([3.8,2.7, 2.4, 1., 2.8, 0.9, 2.9])
plt.plot(x,y,'o')
draw_line(w)
plt.show()
四、學習率(α)
學習率α是控制參數更新的 “步長”,是影響收斂的核心超參數:
- 過小:收斂緩慢,需大量迭代;
- 過大:可能跳過最優解,導致損失震蕩甚至發散;
一般我們把它設置為0.1,0.01,0.001甚至更小。一般情況下學習率在迭代過程中是不變的,但是也可以設置為動態調整,即隨著迭代次數逐漸變小,越接近目標W '步子'邁的更小,以更精準地找到W。
五、多參數(如 w0,w1)梯度下降實現
假設損失函數是有兩個w1,w2特征的椎體
初始化:隨機生成正態分布參數 W(如 w0,w1)。
計算梯度g:求當前 loss 的梯度 g。
更新參數:W=W?α?g
收斂判斷:
loss變化量 < 閾值
或達到預設迭代次數(如1000次)。
終止:滿足條件則輸出 W;否則返回步驟2。
假設loss = (100w1 + 200w2 +1000)**2
import numpy as np
# 假設總損失
def loss(w1,w2):return (100*w1 + 200*w2 +1000)**2
?
# 梯度
# 以w1為參數的梯度
def g1(w1,w2):return 2*(100*w1 + 200*w2 +1000)*100
?
# 以w2為參數的梯度
def g2(w1,w2):return 2*(100*w1 + 200*w2 +1000)*200
?
# 初始化w1,w2
w1 = 10
w2 = 10
for i in range(50):print('w1:',w1,'w2:',w2,'loss:',loss(w1,w2))w1,w2 = w1-0.001*g1(w1,w2), w2-0.01*g2(w1,w2)
?