機器學習——一元線性回歸(算法實現與評估)

一元線性回歸是統計學中最基礎的回歸分析方法,用于建立兩個變量之間的線性關系模型。

1. 模型表達式

一元線性回歸的數學模型為:


\hat{y}=kx+b

  • \hat{y}:因變量(預測值)
  • x:自變量(輸入變量)
  • k:回歸系數(斜率),表示xx每增加1單位時\hat{y}的變化量
  • b:截距項,表示當x=0時\hat{y}?的取值

2. 參數估計:最小二乘法

通過最小化預測值與實際值的**殘差平方和(RSS)**求解kk和bb:
目標函數


min\sum_{i=1}^{n}\left ( y_{i} -\hat{y}_{i}\right )_{}^{2}=min\sum_{i=1}^{n}\left ( y_{i} -kx_{i}-b\right )_{}^{2}

參數計算公式

  • 斜率k

k=\frac{\sum_{i=1}^{n}\left ( x_{i} -\hat{x}\right )\left ( y_{i} -\hat{y} \right )}{\left ( x_{i} -\hat{x}\right )_{}^{2}}

  • 截距b
  • b=\hat{y}-k\hat{x}

  • 其中,\hat{x}\hat{y}分別表示x和y的樣本均值。

3、線性回歸模型的評價

(1) MAE(平均絕對誤差)

  • 定義:所有樣本預測值與真實值之差的絕對值的平均值。
  • 公式

  • MAE=\frac{1}{n}\sum_{i=1}^{n}\left | y_{i} -\hat{y}\right |
  • 特點
    • 單位與因變量一致,便于直觀理解(如房價預測中的“萬元”)。
    • 對異常值不敏感,適用于需避免大誤差懲罰的場景(如穩健預測)。
    • 無法反映誤差方向,僅衡量平均偏差大小

(2) MSE(均方誤差)

  • 定義:預測值與真實值之差的平方和的均值。
  • 公式
  • MAE=\frac{1}{n}\sum_{i=1}^{n}\left ( y_{i}-\hat{y} \right )_{}^{2}
  • 特點
    • 單位是原變量的平方(如“萬元2”),解釋性較差,但數學性質優良(連續可導)。
    • 對異常值敏感,大誤差會被放大,適用于需強調大誤差的場景(如金融風險預測)。

(3) RMSE(均方根誤差)

  • 定義:MSE的平方根,將誤差單位還原為原變量單位。
  • 公式

RMSE=\sqrt{MSE}

  • 特點
    • 結合了MSE和MAE的優點:單位直觀且對大誤差敏感56。
    • 常用于實際業務中(如房價預測誤差表示為“5萬美元”)

(4) R2(決定系數)

  • 定義:模型解釋因變量方差的比例,衡量擬合優度。
  • 公式
  • ^{R_{}^{2}}=1-\frac{\left (y_{i} -\hat{y_{i}} \right )_{}^{2}}{\left (y_{i} -\bar{y} \right )_{}^{2}}
    ?
  • 特點
    • 取值范圍[0,1]:越接近1,模型解釋力越強;0表示模型不優于均值預測。
    • 無量綱性:不受數據量綱影響,適合跨數據集比較模型性能。
    • 局限性:樣本量小時可能高估擬合效果,且不直接反映誤差大小。

4、代碼實現

(1)、手動實現線性回歸

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt###準備數據集#加載boston(波士頓房價)數據集
data_url = "http://lib.stat.cmu.edu/datasets/boston"
raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)
data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
target = raw_df.values[1::2, 2]x = data[:,5]
y = targetx = x[y<50]
y = y[y<50]#顯示
plt.scatter(x,y)
plt.show()#劃分數據集
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.3, random_state = 0)plt.scatter(x_train, y_train)
plt.show()###一元線性回歸
def fit(x, y):a_up = np.sum((x-np.mean(x))*(y - np.mean(y)))a_bottom = np.sum((x-np.mean(x))**2)a = a_up / a_bottomb = np.mean(y) - a * np.mean(x)return a, b
a, b = fit(x_train, y_train)
print(a,b) #結果:(8.056822140369603, -28.49306872447786)#訓練回歸線
plt.scatter(x_train, y_train)
plt.plot(x_train, a*x_train+ b, c='r')
plt.show()#測試回歸線
plt.scatter(x_test, y_test)
plt.plot(x_test, a*x_test+ b, c='r')
plt.show()

(2)、sklearn 實現一元線性回歸

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt #準備數據集
data_url = "http://lib.stat.cmu.edu/datasets/boston"
raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)
data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
target = raw_df.values[1::2, 2]x=data[:,5]
y=target x=x[y<50]
y=y[y<50]#劃分數據集
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.3, random_state = 0)#實現回歸from sklearn.linear_model import LinearRegression
lin_reg = LinearRegression()lin_reg.fit(x_train.reshape(-1,1), y_train)y_predict = lin_reg.predict(x_test.reshape(-1,1))plt.scatter(x_test, y_test)
plt.plot(x_test, y_predict, c='r')
plt.show()

(3)、sklearn 實現多元線性回歸(注意多元不用歸一化)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt #準備數據集
data_url = "http://lib.stat.cmu.edu/datasets/boston"
raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)
data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
target = raw_df.values[1::2, 2]x=data
y=target x=x[y<50]
y=y[y<50]#劃分數據集
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.3, random_state = 0)#實現回歸from sklearn.linear_model import LinearRegression
lin_reg = LinearRegression()lin_reg.fit(x_train, y_train)lin_reg.score(x_test, y_test) #結果:0.7455942658788952

5、模型評價

import numpy as np
import pandas as pd
from sklearn import datasets
import matplotlib.pyplot as plt###數據準備data_url = "http://lib.stat.cmu.edu/datasets/boston"
raw_df = pd.read_csv(data_url, sep=r"\s+", skiprows=22, header=None)
data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
target = raw_df.values[1::2, 2]x=data[:, -1].reshape(-1, 1) 
y=target.reshape(-1, 1) from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.3, random_state = 0)###實現一元線性回歸from sklearn.linear_model import LinearRegression
linearReg = LinearRegression()
#適配數據
model = linearReg.fit(x_train, y_train)
#得到預測函數
y_predict = model.predict(x_test)
#顯示
plt.scatter(x_test, y_test, s = 10)
plt.plot(x_test, y_predict, c = 'r')
plt.show()### MSEy_real = y_test
mse = np.sum((y_real - y_predict) ** 2) / len(y_test)
print(mse )#公式計算 39.81715050474416
from sklearn.metrics import mean_squared_error
mean_squared_error(y_real, y_predict)# sklearn計算 39.81715050474416### RMSE
rmse = np.sqrt(mse) 
print(rmse )#公式計算 6.310083240714354
mean_squared_error(y_real, y_predict)# sklearn計算 6.310083240714354### MAE
mae = np.sum(np.abs(y_real - y_predict)) / len(y_test) 
print(mae ) #公式計算 4.4883446998468415
from sklearn.metrics import mean_absolute_error
mean_absolute_error(y_real, y_predict) # sklearn計算  4.4883446998468415###R方
r2 = 1 - (np.sum((y_real - y_predict) ** 2)) / (np.sum((y_real - np.mean(y_real)) ** 2))
print(r2)  #公式計算 0.5218049526125568 等效:1 - mse / np.var(y_real)
from sklearn.metrics import r2_score
r2_score(y_real, y_predict)# sklearn計算  0.5218049526125568

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/73144.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/73144.shtml
英文地址,請注明出處:http://en.pswp.cn/web/73144.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Ubuntu下用QEMU模擬運行OpenBMC

1、前言 在調試過程中&#xff0c;安裝了很多依賴庫&#xff0c;具體沒有記錄。關于kvm&#xff0c;也沒理清具體有什么作用。本文僅記錄&#xff0c;用QEMU成功的將OpenBMC跑起來的過程&#xff0c;做備忘&#xff0c;也供大家參考。 2、環境信息 VMware Workstation 15 Pro…

Gradle/Maven 本地倉庫默認路徑遷移 (減少系統磁盤占用)

Gradle 配置環境變量 GRADLE_USER_HOME&#xff0c;如D:/.gradle同時將 %userprofile%/.gradle 移動到配置路徑 Maven 修改settings.xml文件&#xff0c;localRepository同時將 %userprofile%/.m2/repository 移動到配置路徑 IDEA默認用的bundle maven, 路徑為安裝目錄下 p…

MinGW與使用VScode寫C語言適配

壓縮包 通過網盤分享的文件&#xff1a;MinGW.zip 鏈接: https://pan.baidu.com/s/1QB-Zkuk2lCIZuVSHc-5T6A 提取碼: 2c2q 需要下載的插件 1.翻譯 找到VScode頁面&#xff0c;從上數第4個&#xff0c;點擊擴展&#xff08;以下通此&#xff09; 搜索---Chinese--點擊---安裝--o…

【C++初階】從零開始模擬實現vector(含迭代器失效詳細講解)

目錄 1、基本結構 1.1成員變量 1.2無參構造函數 1.3有參構造函數 preserve()的實現 代碼部分&#xff1a; push_back()的實現 代碼部分&#xff1a; 代碼部分&#xff1a; 1.4拷貝構造函數 代碼部分&#xff1a; 1.5支持{}初始化的構造函數 代碼部分&#xff1a; …

Java實習生面試題(2025.3.23 be)

一、v-if與v-show的區別 v-show 和 v-if 都是 Vue 中的條件渲染指令&#xff0c;它們的主要區別在于渲染策略&#xff1a;v-if 會根據條件決定是否編譯元素&#xff0c;而 v-show 則始終編譯元素&#xff0c;只是通過改變 CSS 的 display 屬性來控制顯示與隱藏。 二、mybatis-…

stm32標準庫開發需要的基本文件結構

使用STM32標準庫&#xff08;STM32 Standard Peripheral Library&#xff0c;SPL&#xff09;開發時&#xff0c;項目中必須包含一些必要的文件&#xff0c;這些文件確保項目能夠正常運行并與MCU硬件交互。以下詳細說明&#xff1a; 一、標準庫核心文件夾說明 使用標準庫開發S…

學生管理系統(需求文檔)

需求&#xff1a; 采取控制臺的方式去書寫學生管理系統 分析&#xff1a; 初始菜單&#xff1a; “----------歡迎來到java學生管理系統----------” “1:添加學生” “2&#xff1a;刪除學生” “3&#xff1a;修改學生” “4&#xff1a;查詢學生” “5&#xff1a;…

Java算法OJ(13)雙指針

目錄 1.前言 2.正文 2.1快樂數 2.2盛最多水的容器 2.3有效的三角形的個數 2.4和為s的兩個數 2.5三數之和 2.6四數之和 3.小結 1.前言 哈嘍大家好吖&#xff0c;今天繼續加練算法題目&#xff0c;一共六道雙指針&#xff0c;希望能對大家有所幫助&#xff0c;廢話不多…

SpringBoot分布式定時任務實戰:告別重復執行的煩惱

場景再現&#xff1a;你剛部署完基于SpringBoot的集群服務&#xff0c;凌晨3點突然收到監控告警——優惠券發放量超出預算兩倍&#xff01;檢查日志發現&#xff0c;兩個節點同時執行了定時任務。這種分布式環境下的定時任務難題&#xff0c;該如何徹底解決&#xff1f; 本文將…

MySQL 設置允許遠程連接完整指南:安全與效率并重

一、為什么需要遠程連接MySQL&#xff1f; 在分布式系統架構中&#xff0c;應用程序與數據庫往往部署在不同服務器。例如&#xff1a; Web服務器&#xff08;如NginxPHP&#xff09;需要連接獨立的MySQL數據庫數據分析師通過BI工具直連生產庫多服務器集群間的數據同步 但直接…

系統架構書單推薦(一)領域驅動設計與面向對象

本文主要是個人在學習過程中所涉獵的一些經典書籍&#xff0c;有些已經閱讀完&#xff0c;有些還在閱讀中。于我而言&#xff0c;希望追求軟件系統設計相關的原則、方法、思想、本質的東西&#xff0c;并希望通過不斷的學習、實踐和積累&#xff0c;提升自身的知識和認知。希望…

動態規劃-01背包

兜兜轉轉了半天&#xff0c;發現還是Carl寫的好。 看過動態規劃-基礎的讀者&#xff0c;大概都清楚。 動態規劃是將大問題&#xff0c;分解成子問題。并將子問題的解儲存下來&#xff0c;避免重復計算。 而背包問題&#xff0c;就是動態規劃延申出來的一個大類。 而01背包&…

使用VS2022編譯CEF

前提 選擇編譯的版本 CEF自動編譯&#xff0c;在這里可以看到最新的穩定版和Beta版。 從這里得出&#xff0c;最新的穩定版是134.0.6998.118&#xff0c;對應的cef branch是6998。通過這個信息可以在Build requirements查到相關的軟件配置信息。 這里主要看Windows下的編譯要…

C++20:玩轉 string 的 starts_with 和 ends_with

文章目錄 一、背景與動機二、string::starts_with 和 string::ends_with&#xff08;一&#xff09;語法與功能&#xff08;二&#xff09;使用示例1\. 判斷字符串開頭2\. 判斷字符串結尾 &#xff08;三&#xff09;優勢 三、string_view::starts_with 和 string_view::ends_w…

智能飛鳥監測 守護高壓線安全

飛鳥檢測新紀元&#xff1a;視覺分析技術的革新應用 在現代化社會中&#xff0c;飛鳥檢測成為了多個領域不可忽視的重要環節。無論是高壓線下的安全監測、工廠內的生產秩序維護&#xff0c;還是農業區的作物保護&#xff0c;飛鳥檢測都扮演著至關重要的角色。傳統的人工檢測方…

ADC噪聲全面分析 -04- 有效噪聲帶寬簡介

為什么要了解ENBW&#xff1f; 了解模數轉換器 (ADC) 噪聲可能具有挑戰性&#xff0c;即使對于最有經驗的模擬設計人員也是如此。 Delta-sigma ADC 具有量化和熱噪聲的組合&#xff0c;這取決于 ADC 的分辨率、參考電壓和輸出數據速率 (ODR)。 在系統級別&#xff0c;額外的信…

STM32單片機uCOS-Ⅲ系統10 內存管理

目錄 一、內存管理的基本概念 二、內存管理的運作機制 三、內存管理的應用場景 四、內存管理函數接口講解 1、內存池創建函數 OSMemCreate() 2、內存申請函數 OSMemGet() 3、內存釋放函數 OSMemPut() 五、實現 一、內存管理的基本概念 在計算系統中&#xff0c;變量、中…

考研課程安排(自用)

文章目錄 408數據結構&#xff08;王道&#xff09;計算機組成原理&#xff08;王道&#xff09;操作系統&#xff08;王道&#xff09;計算機網絡&#xff08;湖科大版&#xff09; 數學一高等數學&#xff08;微積分&#xff09;線性代數和概率論 408 數據結構&#xff08;王…

ultraiso制作u盤啟動

UltraISO制作U盤啟動盤的方法 UltraISO是一款功能強大的工具&#xff0c;可以幫助用戶將ISO鏡像文件寫入U盤&#xff0c;從而制作成可啟動的系統安裝盤。以下是詳細的步驟和注意事項&#xff1a; 1. ?準備工作? ?硬件準備?&#xff1a;一個容量至少為8GB的U盤&#xff0…

C語言-發布訂閱模式詳解與實踐

文章目錄 C語言發布訂閱模式詳解與實踐1. 什么是發布訂閱模式&#xff1f;2. 為什么需要發布訂閱模式&#xff1f;3. 實際應用場景4. 代碼實現4.1 UML 關系圖4.2 頭文件 (pubsub.h)4.3 實現文件 (pubsub.c)4.4 使用示例 (main.c) 5. 代碼分析5.1 關鍵設計點5.2 實現特點 6. 編譯…