使用純NumPy實現回歸任務:深入理解機器學習本質

在深度學習框架普及的今天,回歸基礎用NumPy從頭實現機器學習模型具有特殊意義。本文將完整演示如何用純NumPy實現二次函數回歸任務,揭示機器學習底層原理。整個過程不使用任何深度學習框架,每一行代碼都透明可見。


1. 環境配置與數據生成

import numpy as np
from matplotlib import pyplot as plt 設置隨機種子保證可復現性 
np.random.seed(100)  生成訓練數據:100個點在[-1,1]區間均勻分布 
x = np.linspace(-1, 1, 100).reshape(100, 1)基于y=3x2+2生成目標值,并添加高斯噪聲
y = 3 * np.power(x, 2) + 2 + 0.2 * np.random.rand(x.size).reshape(100, 1)

**數據可視化結果: **

散點圖展示了添加噪聲后的數據分布,我們的目標是找到最佳擬合曲線y=wx2+by=wx^2+by=wx2+b


2. 模型初始化與核心參數

隨機初始化待學習參數
w = np.random.rand(1, 1)  # 權重參數 (理論值應接近3)
b = np.random.rand(1, 1)  # 偏置項 (理論值應接近2)lr = 0.001  # 學習率 (梯度下降步長)
epochs = 800  # 訓練輪數

初始參數可視化:

print(f"初始參數: w={w[0][0]:.4f}, b={b[0][0]:.4f}")
典型輸出: w=0.7123, b=0.1582 (每次運行結果不同)

3. 訓練過程與數學原理

3.1 前向傳播計算預測值

y_pred = np.power(x, 2) * w + b

3.2 損失函數定義

采用均方誤差(MSE)的變體:

loss = 0.5 * (y_pred - y)  2 
total_loss = loss.sum()  # 所有樣本損失之和 

3.3 梯度計算解析

關鍵數學推導(鏈式法則):

權重w的梯度: ?Loss/?w = Σ(y_pred - y)*x2 
grad_w = np.sum((y_pred - y) * np.power(x, 2))偏置b的梯度: ?Loss/?b = Σ(y_pred - y)
grad_b = np.sum((y_pred - y))

3.4 參數更新(梯度下降)

w -= lr * grad_w  # w = w - η·(?Loss/?w)
b -= lr * grad_b  # b = b - η·(?Loss/?b)

4. 完整訓練代碼

for epoch in range(epochs):# 前向傳播y_pred = np.power(x, 2) * w + b # 損失計算 loss = 0.5 * (y_pred - y)  2total_loss = loss.sum()# 梯度計算grad_w = np.sum((y_pred - y) * np.power(x, 2))grad_b = np.sum((y_pred - y))# 參數更新w -= lr * grad_w b -= lr * grad_b# 每100輪打印訓練進展 if epoch % 100 == 0:print(f"Epoch {epoch}: w={w[0][0]:.4f}, b={b[0][0]:.4f}, Loss={total_loss:.4f}")

訓練過程輸出:

Epoch 0: w=0.9461, b=0.3827, Loss=160.9256 
Epoch 100: w=2.1433, b=1.8047, Loss=1.8925 
Epoch 200: w=2.6555, b=2.0404, Loss=0.4583
Epoch 300: w=2.8543, b=2.1023, Loss=0.2985 
...
Epoch 700: w=2.9887, b=2.0161, Loss=0.2502 

5. 訓練結果可視化

生成預測曲線 
x_test = np.linspace(-1, 1, 30).reshape(30, 1)
y_test = np.power(x_test, 2) * w + b 繪制結果對比圖 
plt.figure(figsize=(10, 6))
plt.scatter(x, y, color='blue', alpha=0.5, label='真實數據')
plt.plot(x_test, y_test, 'r-', linewidth=3, label='模型預測')
plt.plot(x_test, 3*x_test2+2, 'g--', label='理論曲線')
plt.xlim(-1, 1)
plt.ylim(2, 6)
plt.legend()
plt.title('NumPy實現回歸結果')
plt.show()輸出最終參數
print(f"訓練結果: w={w[0][0]:.4f} (接近理論值3), b={b[0][0]:.4f} (接近理論值2)")

可視化結果:

紅色實線為模型預測曲線,綠色虛線為理論曲線y=3x2+2y=3x^2+2y=3x2+2,藍色點為帶噪聲的訓練數據


6. 關鍵技術解析

1. 梯度下降的本質

通過參數空間中的"下坡運動"尋找最優解,學習率控制步長大小:

  • 學習率過大 → 震蕩發散
  • 學習率過小 → 收斂緩慢
  • 本例0.001是多次試驗后的平衡值

2. 手動求導的意義

# 關鍵導數計算
grad_w = np.sum((y_pred - y) * np.power(x, 2))

理解此式需掌握:

  • 鏈式法則:?Loss/?w = (?Loss/?y_pred)·(?y_pred/?w)
  • 損失函數導數:?Loss/?y_pred = (y_pred - y)
  • 模型輸出導數:?y_pred/?w = x2

3. 批量梯度下降特點

  • 每次迭代使用全部樣本(不同于隨機梯度下降)
  • 計算穩定但內存消耗大
  • 適合中小規模數據集

7. 拓展思考

1. 學習率動態調整

# 添加學習率衰減 
if epoch % 200 == 0:lr *= 0.8  # 每200輪衰減20%

2. 添加正則化項(L2正則化)

# 修改損失函數
lambda_reg = 0.01  # 正則化系數 
loss = 0.5*(y_pred-y)2 + 0.5*lambda_reg*(w2)

3. 動量優化(Momentum)

# 添加動量項
beta = 0.9  # 動量系數 
v_w = beta*v_w + (1-beta)*grad_w 
w -= lr * v_w

8. 總結與啟示

NumPy實現的價值

  • 透明機制:每個運算步驟完全可見
  • ?? 數學本質:揭示梯度下降和反向傳播核心原理
  • 🔍 調試優勢:便于定位問題和理解優化過程

局限性:

  • 📈 僅適合簡單模型
  • ?? 復雜網絡需大量重復代碼
  • 缺乏自動微分等高級功能

通過這個基礎實現,我們能更深刻地理解PyTorch/TensorFlow等框架封裝的高級功能背后的數學原理,為后續學習打下堅實基礎。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/93072.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/93072.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/93072.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

java理解

springboot 打包 mvn install:install-file -Dfile=<path-to-jar> -DgroupId=<group-id> -DartifactId=<artifact-id> -Dversion=<version> -Dpackaging=jar <path-to-jar> 是你的 JAR 文件的路徑。 <group-id> 是你的項目的組 ID。 <…

圖論核心算法詳解:從存儲結構到最短路徑(附C++實現)

目錄 一、圖的基礎概念與術語 二、圖的存儲結構 1. 鄰接矩陣 實現思路&#xff1a; 2. 鄰接表 實現思路&#xff1a; 應用場景&#xff1a; 時間復雜度分析&#xff1a; 三、圖的遍歷算法 1. 廣度優先搜索&#xff08;BFS&#xff09; 核心思想&#xff1a; 應用場…

力扣top100(day03-02)--圖論

本文為力扣TOP100刷題筆記 筆者根據數據結構理論加上最近刷題整理了一套 數據結構理論加常用方法以下為該文章&#xff1a; 力扣外傳之數據結構&#xff08;一篇文章搞定數據結構&#xff09; 200. 島嶼數量 class Solution {// DFS輔助方法&#xff0c;用于標記和"淹沒&q…

建造者模式:從“參數地獄”到優雅構建

深夜&#xff0c;一條緊急告警刺穿寂靜&#xff1a;核心報表服務因NullPointerException全線崩潰。排查根源&#xff0c;罪魁禍首竟是一個擁有10多個參數的“上帝構造函數”。本文將從這個災難現場出發&#xff0c;引入“鏈式建造者模式”進行重構&#xff0c;并深入Spring AI、…

jenkins在windows配置sshpass

我的服務器里jenkins是通過docker安裝的&#xff0c;jenkins與項目都部署在同一臺服務器上還好&#xff0c;但是當需要通過jenkins構建&#xff0c;再通過scp遠程推送到別的服務器上&#xff0c;就出問題了&#xff0c;畢竟不是手動執行scp命令&#xff0c;可以手動輸入密碼&am…

Linux操作系統從入門到實戰(十八)在Linux里面怎么查看進程

Linux操作系統從入門到實戰&#xff08;十八&#xff09;在Linux里面怎么查看進程前言一、如何識別一個進程&#xff1f;—— PID二、怎么查看進程的信息&#xff1f;方式1&#xff1a;通過/proc目錄方式2&#xff1a;用ps命令三、父進程是什么&#xff1f;—— PPID四、bash是…

[TryHackMe](知識學習)---基于堆棧得到緩沖區溢出

1.了解緩沖區溢出WINDOWS程序動態調試工具immunity debuggerhttps://www.immunityinc.com/products/debugger/2.Mona腳本#!/usr/bin/env python3import socket, time, sysip "10.201.99.37"port 1337 timeout 5 prefix "OVERFLOW1 "string prefix &q…

LRU算法與LFU算法

知識點&#xff1a; LRU是Least Recently Used的縮寫&#xff0c;意思是最近最少使用&#xff0c;它是一種Cache替換算法 Cache的容量有限&#xff0c;因此當Cache的容量用完后&#xff0c;而又有新的內容需要添加進來時&#xff0c; 就需要挑選 并舍棄原有的部分內容&#xf…

目標檢測公開數據集全解析:從經典到前沿

目標檢測公開數據集全解析&#xff1a;從經典到前沿 一、引言 目標檢測&#xff08;Object Detection&#xff09;是計算機視覺領域的核心任務之一&#xff0c;旨在在圖像或視頻中識別并定位感興趣的物體。與圖像分類不同&#xff0c;目標檢測不僅需要判斷物體的類別&#xf…

數據備份與進程管理

一、數據備份1.Linux服務器中需要備份的數據&#xff08;1&#xff09;Linux系統重要數據&#xff1a;/root/目錄&#xff0c;/home/目錄&#xff0c;/etc/目錄&#xff08;2&#xff09;安裝服務的數據&#xff1a;Apache&#xff08;配置文件&#xff0c;網頁主目錄&#xff…

docker volume卷入門教程

1. 基礎概念 Docker卷是專門用于持久化容器數據的存儲方案&#xff0c;獨立于容器生命周期。其核心優勢包括&#xff1a; 數據持久化&#xff1a;容器刪除后數據仍保留跨容器共享&#xff1a;多個容器可訪問同一卷備份與遷移&#xff1a;支持直接復制卷數據驅動支持&#xff1a…

計算機網絡——協議

1. 計算機網絡分層1.1 OSI 7層模型應用層表示層會話層傳輸層網絡層數據鏈路層物理層1.2 TCP/IP 4 層模型應用層運輸層網際層網絡接口層1.3 5層體系機構應用層傳輸層網絡層數據鏈路層物理層2. 應用層協議2.1 HTTP協議2.1.1 基本介紹HTTP&#xff08;HyperText Transfer Protocol…

【React】hooks 中的閉包陷阱

在 React Hooks 中的 閉包陷阱&#xff08;Closure Trap&#xff09;在 useEffect、事件回調、定時器等場景里很常見。1. 閉包陷阱是什么 當你在函數組件里定義一個回調&#xff08;比如事件處理函數&#xff09;&#xff0c;這個回調會捕獲當時渲染時的變量值。如果后面狀態更…

校園快遞小程序(騰訊地圖API、二維碼識別、Echarts圖形化分析)

&#x1f388;系統亮點&#xff1a;騰訊地圖API、二維碼識別、Echarts圖形化分析&#xff1b;一.系統開發工具與環境搭建1.系統設計開發工具后端使用Java編程語言的Spring boot框架 項目架構&#xff1a;B/S架構 運行環境&#xff1a;win10/win11、jdk17小程序&#xff1a; 技術…

Python網絡爬蟲(二) - 解析靜態網頁

文章目錄一、網頁解析技術介紹二、Beautiful Soup庫1. Beautiful Soup庫介紹2. Beautiful Soup庫幾種解析器比較3. 安裝Beautiful Soup庫3.1 安裝 Beautiful Soup 43.2 安裝解析器4. Beautiful Soup使用步驟4.1 創建Beautiful Soup對象4.2 獲取標簽4.2.1 通過標簽名獲取4.2.2 通…

【Linux基礎知識系列】第九十四篇 - 如何使用traceroute命令追蹤路由

在網絡環境中&#xff0c;了解數據包從源主機到目標主機的路徑是非常重要的。這不僅可以幫助我們分析網絡連接問題&#xff0c;還可以用于診斷網絡延遲、丟包等問題。traceroute命令是一個強大的工具&#xff0c;它能夠追蹤數據包在網絡中的路徑&#xff0c;顯示每一跳的延遲和…

達夢數據閃回查詢-快速恢復表

Time:2025/08/12Author:skatexg一、環境說明DM數據庫&#xff1a;DM8.0及以上版本二、適用場景研發在誤操作或變更數據后&#xff0c;想馬上恢復表到某個時間點&#xff0c;可以通過閃回查詢功能快速實現&#xff08;通過全量備份恢復時間長&#xff0c;成本高&#xff09;三、…

力扣(LeetCode) ——225 用隊列實現棧(C語言)

題目&#xff1a;用隊列實現棧示例1&#xff1a; 輸入&#xff1a; [“MyStack”, “push”, “push”, “top”, “pop”, “empty”] [[], [1], [2], [], [], []] 輸出&#xff1a; [null, null, null, 2, 2, false] 解釋&#xff1a; MyStack myStack new MyStack(); mySta…

微軟推出AI惡意軟件檢測智能體 Project Ire

開篇 在8月5號&#xff0c;微軟研究院發布了一篇博客文章&#xff0c;在該篇博客中推出了一款名為Project Ire的AI Agent。該Agent可以在無需人類協助的情況下&#xff0c;自主分析和分類二進制文件。它可以在無需了解二進制文件來源或用途的情況下&#xff0c;對文件進行完全的…

哪些對會交由SpringBoot容器管理?

在 Spring Boot 中,交由容器管理的對象通常稱為“Spring Bean”,這些對象的創建、依賴注入、生命周期等由 Spring 容器統一管控。以下是常見的會被 Spring Boot 容器管理的對象類型及識別方式: 一、通過注解聲明的組件(最常見) Spring Boot 通過類級別的注解自動掃描并注…