《零基礎入門AI:線性回歸進階(梯度下降算法詳解)》

在上一篇博客中,我們學習了線性回歸的基本概念、損失函數(如MSE)以及最小二乘法。最小二乘法通過求解解析解(直接計算出最優參數)的方式得到線性回歸模型,但它有一個明顯的局限:當特征數量很多時,計算過程會非常復雜(涉及矩陣求逆等操作)。今天我們來學習另一種更通用、更適合大規模數據的參數優化方法——梯度下降

一、什么是梯度下降?

梯度下降(Gradient Descent)是一種迭代優化算法,核心思想是:通過不斷地沿著損失函數"下降最快"的方向調整參數,最終找到損失函數的最小值(或近似最小值)。

我們可以用一個生活中的例子理解:假設你站在一座山上,周圍被大霧籠罩,你看不見山腳在哪里,但你想以最快的速度走到山腳下。此時,你能做的最合理的選擇就是:先感受一下腳下的地面哪個方向坡度最陡且向下,然后沿著那個方向走一步;走到新的位置后,再重復這個過程——感受坡度最陡的向下方向,再走一步;直到你感覺自己已經走到了最低點(腳下各個方向都不再向下傾斜)。

這個過程就是梯度下降的直觀體現:

  • 這座山就是我們的損失函數
  • 你的位置代表當前參數值
  • 你感受到的"坡度最陡的向下方向"就是負梯度方向
  • 你每次走的"一步"的長度就是學習率
  • 最終到達的"山腳下"就是損失函數的最小值點

在線性回歸中,我們的目標是找到最優參數(如權重w),使得損失函數L(w)達到最小值。梯度下降的作用就是幫助我們一步步調整這些參數,最終找到讓損失函數最小的參數值。

二、梯度下降的基本步驟

梯度下降的過程可以總結為4個核心步驟,我們以單特征且不含偏置項的線性回歸模型y = wx為例(即b=0,損失函數使用MSE),逐步說明:

步驟1:初始化參數

首先需要給參數w設定初始值。初始值可以是任意的(比如隨機值、0或1),因為梯度下降會通過迭代不斷優化它。

為什么初始值可以任意選擇?因為梯度下降是一個迭代優化的過程,無論從哪個點開始,只要迭代次數足夠多且學習率合適,最終都會收斂到損失函數的最小值附近。

例如:我們可以簡單地將初始值設為w = 0,然后開始優化過程。

步驟2:計算損失函數的梯度

“梯度"在單參數情況下就是損失函數對該參數的導數,它表示損失函數在當前參數位置的"變化率"和"變化方向”。

對單特征且b=0的模型y = wx,我們只需要計算一個導數:

  • 損失函數L對w的導數:?L/?w(表示當w變化時,損失函數L的變化率)

這個導數就是"梯度",它指向損失函數增長最快的方向。這很重要:梯度指向的是損失函數值上升最快的方向,所以要讓損失函數減小,我們需要向相反的方向移動。

步驟3:更新參數

為了讓損失函數減小,我們需要沿著梯度的反方向(即負梯度方向)調整參數。更新公式為:

w = w - α · (?L/?w)

其中α是"學習率"(后面會詳細解釋),它控制參數更新的"步長"。

為什么是減去梯度而不是加上?因為梯度指向損失函數增大的方向,所以減去梯度就意味著向損失函數減小的方向移動,這正是我們想要的。

步驟4:重復迭代,直到收斂

重復步驟2和步驟3:每次計算當前參數的梯度,然后沿負梯度方向更新參數。當滿足以下條件之一時,停止迭代(即"收斂"):

  • 梯度的絕對值接近0(此時損失函數變化很小,接近最小值);
  • 損失函數L(w)的變化量小于某個閾值(比如連續兩次迭代的損失差小于10??);
  • 達到預設的最大迭代次數(防止無限循環)。

"收斂"這個詞可以理解為:參數值已經穩定下來,繼續迭代也不會有明顯變化,此時我們可以認為找到了最優參數

三、梯度下降的公式推導(單特征且b=0)

要實現梯度下降,核心是求出損失函數對參數w的導數。我們以MSE損失函數為例(且b=0),詳細推導?L/?w的計算過程,每一步都會給出詳細說明。

已知條件

  • 模型:y_pred = wx(預測值,因b=0,無偏置項)

  • 真實值:y

  • 損失函數(MSE):

    L(w) = (1/2n)Σ(y? - y_pred,?)2 = (1/2n)Σ(y? - wx?)2
    

(注:公式中加入1/2是為了后續求導時抵消平方項的系數2,使計算更簡潔,不影響最終結果)

推導?L/?w(損失函數對w的導數)

  1. 先對單個樣本的損失求導:
    單個樣本的損失為l? = (1/2)(y? - wx?)2,對w求導:

    ?l?/?w = 2 · (1/2)(y? - wx?) · (-x?) = -(y? - wx?)x?
    

    這里用到了復合函數求導法則(鏈式法則):首先對平方項求導得到2·(1/2)(…),然后對括號內的內容求導,由于我們是對w求導,所以(wx?)對w的導數是x?,前面有個負號,所以整體是-(y? - wx?)x?。

  2. 對所有樣本的損失求和后求導:
    總損失L是所有單個樣本損失的平均值:L = (1/n)Σl?,因此:

    ?L/?w = (1/n)Σ(?l?/?w) = (1/n)Σ[-(y? - wx?)x?] = -(1/n)Σ(y? - y_pred,?)x?
    

    這一步的含義是:總損失對w的導數等于所有單個樣本損失對w的導數的平均值。

最終更新公式

將上面得到的導數代入參數更新公式(參數 = 參數 - 學習率 × 導數),得到:

w = w + α · (1/n)Σ(y? - y_pred,?)x?

(注:負負得正,公式中的減號變為加號)

這個公式的含義是:

  • 如果預測值y_pred,?小于真實值y?(即y? - y_pred,?為正),則w會增大;反之則減小。
  • 增大或減小的幅度取決于三個因素:誤差大小(y? - y_pred,?)、特征值x?的大小和學習率α。
  • 特征值x?越大,相同誤差下w的更新幅度也越大,這體現了特征對參數調整的影響。

四、學習率(α)的作用

學習率(Learning Rate)是梯度下降中最重要的超參數(需要人工設定的參數),它控制參數更新的"步長"。我們繼續用"下山"的例子來理解:

  • 如果學習率α太小:就像每次只邁一小步下山,雖然安全,但需要走很多步才能到達山腳(迭代次數多,效率低)。
  • 如果學習率α太大:就像每次邁一大步下山,可能會直接跨過山腳,甚至走到對面的山坡上(跳過最小值,甚至導致損失函數越來越大,無法收斂)。
  • 合適的學習率:步長適中,能快速逼近最小值,既不會太慢也不會跳過。

實際應用中,學習率通常需要通過嘗試確定,常見的初始值有0.1、0.01、0.001等。一種常用的策略是"學習率衰減":隨著迭代次數增加,逐漸減小學習率,這樣在開始時可以快速接近最小值,后期可以精細調整。

舉個形象的例子:假設你在下山,開始時你離山腳很遠,可以大踏步前進(較大的學習率);當快到山腳時,你會放慢腳步,小步移動(較小的學習率),以免走過頭。

完整示例(手動實現梯度下降,單特征,b=0)

import numpy as np
import matplotlib.pyplot as plt  # 可視化# 創建數據  植物的溫度、和生長高度 [[20,10],[22,10],[27,12],[25,16]]
data =np.array([[20,10],[22,10],[27,12],[25,16]])
# 劃分
x=data[:,0]
y=data[:,1]
print(x)
print(y)# 創建一個模型
def model(x,w):return x*w# 定義損失函數
# def loss(y_pred,y):
#      return np.sum((y_pred-y)**2)/len(y)# 手動將損失函數展開  便于下面寫梯度函數
def loss(w):return 2238*(w**2) - 1144*w + 600# 梯度函數  即,將損失函數求導
def gradient(w):return 2*2238*w - 1144# 梯度下降  給定初始系數w  迭代100次 優化w
w=0
learning_rate = 1e-5  # 降低學習率避免溢出
for i in range(100):w=w-learning_rate*gradient(w)print('e:',loss(w),'w:',w)# 繪制損失函數
plt.plot(np.linspace(0,1,100),loss(np.linspace(0,1,100)))# 繪制模型
def draw_line(w):point_x = np.linspace(0, 30, 100)point_y = model(point_x, w)plt.plot(point_x, point_y, label=f'Fitted line (w={w:.4f})')plt.scatter(x, y, color='red', label='Data points')plt.legend()plt.xlabel("Temperature")plt.ylabel("Height")plt.title("Linear Regression via Gradient Descent")plt.grid(True)plt.show()# draw_line(w)

五、多特征的梯度下降(以2個特征為例)

現實中,我們遇到的問題往往有多個特征(比如用"面積"和"房間數"預測房價)。下面我們推導2個特征的線性回歸模型的梯度下降公式,方法與單特征類似,但需要考慮更多參數。

模型與損失函數

  • 2個特征的模型:y_pred = w?x? + w?x?(x?、x?是兩個特征,w?、w?是對應的權重,因b=0,無偏置項)

  • 損失函數(MSE):

    L(w?,w?) = (1/2n)Σ(y? - (w?x?,? + w?x?,?))2
    

推導各參數的偏導數

與單特征思路一致,我們分別對w?、w?求偏導:

  1. 對w?的偏導:

    ?L/?w? = -(1/n)Σ(y? - y_pred,?)x?,?
    

    推導過程與單特征中w的導數完全相同,只是這里特征是x?,所以最后乘以x?,?。

  2. 對w?的偏導:

    ?L/?w? = -(1/n)Σ(y? - y_pred,?)x?,?
    

    同理,這里特征是x?,所以最后乘以x?,?。

參數更新公式

將上述偏導數代入更新公式,得到:

w? = w? + α · (1/n)Σ(y? - y_pred,?)x?,?
w? = w? + α · (1/n)Σ(y? - y_pred,?)x?,?

多特征的擴展規律

從2個特征的推導可以看出,梯度下降的公式可以很容易擴展到k個特征的情況:

  • 模型:y_pred = w?x? + w?x? + … + w?x?

  • 對第j個權重w?的更新公式:

    w? = w? + α · (1/n)Σ(y? - y_pred,?)x?,?
    

(x?,?表示第i個樣本的第j個特征值)

這個規律非常重要,它告訴我們:無論有多少個特征,梯度下降的更新規則都是相似的——每個權重w?的更新量都與對應特征x?和誤差(y? - y_pred,?)的乘積有關。

完整示例(手動實現梯度下降,兩個特征,b=0)

import numpy as np
import matplotlib.pyplot as plt
# 如果使用中文顯示,建議添加以下配置
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用來正常顯示中文標簽
plt.rcParams['axes.unicode_minus'] = False    # 用來正常顯示負號# 創建數據  [[1,1,3],[2,1,4],[1,2,5],[2,2,6]]
data = np.array([[1,1,3],[2,1,4],[1,2,5],[2,2,6]])
# 劃分
x=data[:,:-1]
y=data[:,-1]
print(x)
print(y)# 創建模型
def model(x,w):return np.sum(x*w)# 創建損失函數
# def loss(w,x):# return np.sum((np.sum(x*w,axis=1)-y)**2)# return np.sum((model(x,w)-y)**2)
def loss(w1,w2):return 5*w1**2 + 5*w2**2 +9*w1*w2 -28*w1-29*w2 +43# 創建梯度函數
def gradient_w1(w1,w2):return 10*w1+9*w2-28def gradient_w2(w1,w2):return 9*w1+10*w2-29# 初始化w1,w2
w1=0
w2=0# 迭代100次 優化w1,w2
for i in range(100):w1,w2=w1-0.01*gradient_w1(w1,w2),w2-0.01*gradient_w2(w1,w2)print('e:',loss(w1,w2),'w1:',w1,'w2:',w2)# # 繪制模型  沒寫出來(所以注釋了)
# def draw_line(w1,w2):
#     point_x=np.linspace(0,5,100)
#     point_y=model(point_x,w1,w2)
#     plt.plot(point_x,point_y)# draw_line(w1,w2)

六、梯度下降與最小二乘法的對比

特點梯度下降最小二乘法
本質迭代優化(數值解)直接求解方程(解析解)
計算復雜度低(適合大規模數據/多特征)高(涉及矩陣求逆)
適用性幾乎所有損失函數僅適用于凸函數且有解析解
超參數依賴需要調整學習率等無需超參數
內存需求低(可分批處理數據)高(需要一次性加載所有數據)

簡單來說,當特征數量較少時,最小二乘法可能更簡單直接;但當特征數量很多(比如超過1000個)時,梯度下降通常是更好的選擇。

總結

梯度下降是機器學習中最基礎也最常用的優化算法,它通過"沿損失函數負梯度方向迭代更新參數"的方式,找到使損失最小的參數值。與最小二乘法相比,梯度下降更適合處理大規模數據和復雜模型。

本文我們從概念、步驟、公式推導(單特征且b=0和雙特征)、學習率作用等方面詳細講解了梯度下降,希望能幫助你理解其核心邏輯。掌握梯度下降不僅對理解線性回歸至關重要,也是學習更復雜機器學習算法(如神經網絡)的基礎。

下一篇博客中,我們將通過實際案例演示如何用梯度下降實現線性回歸,進一步加深理解。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/917176.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/917176.shtml
英文地址,請注明出處:http://en.pswp.cn/news/917176.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

基于C語言實現的KV存儲引擎(一)

基于C語言實現的KV存儲引擎項目簡介整體架構網絡模塊的實現recatorproactorNtyco項目簡介 本文主要是基于 C 語言來實現一個簡單的 KV 存儲架構,目的就是將網絡模塊跟實際開發結合起來。 首先我們知道對于數據的存儲可以分為兩種方式,一種是在內存中進…

c++和python聯合編程示例

安裝 C與 Python 綁定工具 pip install pybind11這其實相當于使用 python 安裝了一個 c的庫 pybind11,這個庫只由頭文件構成, 支持基礎數據類型傳遞以及 python 的 numpy 和 c的 eigen 庫之間的自動轉換。 編寫 CMakeList.txt cmake_minimum_required(VERSION 3.14)…

【OD機試題解法筆記】貪心歌手

題目描述 一個歌手準備從A城去B城參加演出。 按照合同,他必須在 T 天內趕到歌手途經 N 座城市歌手不能往回走每兩座城市之間需要的天數都可以提前獲知。歌手在每座城市都可以在路邊賣唱賺錢。 經過調研,歌手提前獲知了每座城市賣唱的收入預期&#xff1a…

AI: 告別過時信息, 用RAG和一份PDF 為LLM打造一個隨需更新的“外腦”

嘿,各位技術同學!今天,我們來聊一個大家在使用大語言模型(LLM)時都會遇到的痛點:知識過時。 無論是像我一樣,用 Gemini Pro 學習日新月異的以太坊,還是希望它能精確掌握某個特定工具…

深度學習(魚書)day08--誤差反向傳播(后三節)

深度學習(魚書)day08–誤差反向傳播(后三節)一、激活函數層的實現 這里,我們把構成神經網絡的層實現為一個類。先來實現激活函數的ReLU層和Sigmoid層。ReLU層 激活函數ReLU(Rectified Linear Unit&#xff…

C# 中生成隨機數的常用方法

1. 使用 Random 類(簡單場景) 2. 使用 RandomNumberGenerator 類(安全場景) 3. 生成指定精度的隨機小數 C# 中生成隨機數的常用方法: 隨機數類型實現方式示例代碼特點與適用場景隨機整數(無范圍&#xf…

Flink 算子鏈設計和源代碼實現

1、JobGraph (JobManager) JobGraph 生成時,通過 ChainingStrategy 連接算子,最終在 Task 中生成 ChainedDriver 鏈表。StreamingJobGraphGeneratorcreateJobGraph() 構建jobGrapch 包含 JobVertex setChaining() 構建算子鏈isCha…

對接八大應用渠道

背景最近公司想把游戲包上到各個渠道上,因此需要對接各種渠道,渠道如下,oppo、vivo、華為、小米、應用寶、taptap、榮耀、三星等應用渠道 主要就是對接登錄、支付接口(后續不知道會不會有其他的)&#x…

學習:入門uniapp Vue3組合式API版本(17)

42.打包發行微信小程序的上線全流程 域名 配置 發行 綁定手機號 上傳 提交后等待,上傳 43.打包H5并發布上線到unicloud的前端頁面托管 完善配置 unicloud 手機號實名信息不一致:請確保手機號的實名信息與開發者姓名、身份證號一致,請前往開…

SOLIDWORKS材料明細表設置,屬于自己的BOM表模板

上一期我們了解了如何在SOLIDWORKS工程圖中添加材料明細表?接下來,我們將進行對SOLIDWORKS材料明細表的設置、查看縮略圖、模板保存的深度講解。01 材料明細表設置菜單欄生成表格后左側菜單欄會顯示關于材料明細表的相關設置信息。我們先了解一下菜單欄設置詳情&am…

全棧:Maven的作用是什么?本地倉庫,私服還有中央倉庫的區別?Maven和pom.xml配置文件的關系是什么?

Maven和pom.xml配置文件的關系是什么: Maven是一個構建工具和依賴管理工具,而pom.xml(Project Object Model)是Maven的核心配置文件。 SSM 框架的項目不一定是 Maven 項目,但推薦使用 Maven進行管理。 SSM 框架的項目可…

超越 ChatGPT:智能體崛起,開啟全自主 AI 時代

引言 短短三年,生成式 AI 已從對話助手跨越到能自主規劃并完成任務的“智能體(Agentic AI)”時代。這場演進不僅體現在模型規模的提升,更在于系統架構、交互范式與安全治理的全面革新。本文按時間線梳理關鍵階段與核心技術,為您呈現 AI 智能體革命的脈絡與未來趨勢。 1. …

一杯就夠:讓大腦瞬間在線、讓肌肉滿電的 “Kick-out Drink” 全解析

一杯就夠:讓大腦瞬間在線、讓肌肉滿電的 “Kick-out Drink” 全解析“每天清晨,當鬧鐘還在哀嚎,你舉杯一飲,睡意像被扔出擂臺——這,就是 Kick-out Drink 的全部浪漫。”清晨 30 分鐘后,250 mL 常溫水里溶解…

系統開機時自動執行指令

使用 systemd 創建一個服務單元可以讓系統開機時自動執行指令,假設需要執行的指令如下,運行可執行文件(/home/demo/可執行文件),并輸入參數(–input/home/config/demo.yaml): /home/…

Docker 初學者需要了解的幾個知識點 (七):php.ini

這段配置是 php.ini 文件中針對 PHP 擴展和 Xdebug 調試工具的設置,主要用于讓 PHP 支持數據庫連接和代碼調試(尤其在 Docker 環境中),具體解釋如下:[PHP] extensionpdo_mysql extensionmysqli xdebug.modedebug xdebu…

【高階版】R語言空間分析、模擬預測與可視化高級應用

隨著地理信息系統(GIS)和大尺度研究的發展,空間數據的管理、統計與制圖變得越來越重要。R語言在數據分析、挖掘和可視化中發揮著重要的作用,其中在空間分析方面扮演著重要角色,與空間相關的包的數量也達到130多個。在本…

dolphinscheduler中一個腳本用于從列定義中提取列名列表

dolphinscheduler中,我們從一個mysql表導出數據,上傳到hdfs, 再創建一個臨時表,所以需要用到列名定義和列名列表。 原來定義兩個變量,不僅繁鎖,還容易出現差錯,比如兩者列序不對。 所以考慮只定義列定義變量…

JavaWeb(蒼穹外賣)--學習筆記16(定時任務工具Spring Task,Cron表達式)

前言 本篇文章是學習B站黑馬程序員蒼穹外賣的學習筆記📑。我的學習路線是Java基礎語法-JavaWeb-做項目,管理端的功能學習完之后,就進入到了用戶端微信小程序的開發,用戶端開發的流程大致為用戶登錄—商品瀏覽(其中涉及…

靈敏度,精度,精確度,精密度,精準度,準確度,分辨率,分辨力——概念

文章目錄前提總結前提 我最近在整理一份數據指標要求的時候,總是混淆這幾個概念:靈敏度,精度,精確度,精密度,精準度,準確度,分辨率,分辨力,搜了一些文章&…

python-異常(筆記)

#后續代碼可以正常運行 try:f open("xxx.txt","r",encodingutf-8)except:print("except error")#捕獲指定異常,其他異常報錯程序中止,管不到 try:print(name) except NameError as you_call:print("name error"…