基于簡單神經網絡的線性回歸

一、概述

本代碼實現了一個簡單的神經網絡進行線性回歸任務。通過生成包含噪聲的線性數據集,定義一個簡單的神經網絡類,使用梯度下降算法訓練網絡以擬合數據,并最終通過可視化展示原始數據、真實線性關系以及模型的預測結果。

二、依賴庫

  1. numpy:用于數值計算,包括生成數組、進行隨機數操作、執行數學運算等。
  2. matplotlib.pyplot:用于數據可視化,繪制散點圖和折線圖以展示數據和模型的預測結果。

三、代碼詳解

1. 生成數據集

python

np.random.seed(42)
x = np.linspace(-10, 10, 100)
y = x + np.random.normal(0, 1, x.shape)  # 添加噪聲

  • np.random.seed(42):設置隨機數種子,確保每次運行代碼時生成的隨機數序列相同,從而使結果可復現。
  • np.linspace(-10, 10, 100):生成一個包含 100 個元素的一維數組x,元素均勻分布在 - 10 到 10 之間。
  • x + np.random.normal(0, 1, x.shape):生成因變量y,它基于真實的線性關系y = x,并添加了均值為 0、標準差為 1 的高斯噪聲。np.random.normal(0, 1, x.shape)生成與x形狀相同的隨機噪聲數組。

2. 定義神經網絡(線性回歸)

python

class SimpleNN:def __init__(self):self.w = np.random.randn()  # 權重self.b = np.random.randn()  # 偏置def forward(self, x):return self.w * x + self.b  # 前向傳播def loss(self, y_true, y_pred):return np.mean((y_true - y_pred) **2)  # 均方誤差def gradient(self, x, y_true, y_pred):dw = -2 * np.mean(x * (y_true - y_pred))  # 權重的梯度db = -2 * np.mean(y_true - y_pred)       # 偏置的梯度return dw, dbdef train(self, x, y, lr=0.01, epochs=1000):for epoch in range(epochs):y_pred = self.forward(x)dw, db = self.gradient(x, y, y_pred)self.w -= lr * dw  # 更新權重self.b -= lr * db  # 更新偏置if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {self.loss(y, y_pred):.4f}')

  • __init__方法:初始化神經網絡的權重self.w和偏置self.b,使用np.random.randn()生成隨機的初始值。
  • forward方法:實現前向傳播,根據輸入x、權重self.w和偏置self.b計算輸出y_pred,即y_pred = self.w * x + self.b
  • loss方法:計算預測值y_pred和真實值y_true之間的均方誤差(MSE),公式為np.mean((y_true - y_pred) ** 2)
  • gradient方法:計算權重self.w和偏置self.b的梯度。dw是權重的梯度,計算公式為-2 * np.mean(x * (y_true - y_pred))db是偏置的梯度,計算公式為-2 * np.mean(y_true - y_pred)
  • train方法:使用梯度下降算法訓練神經網絡。在指定的epochs(訓練輪數)內,每次迭代進行前向傳播計算預測值y_pred,然后計算梯度dwdb,根據學習率lr更新權重self.w和偏置self.b。每 100 輪打印一次當前輪數和損失值。

3. 訓練模型

python

model = SimpleNN()
model.train(x, y, lr=0.01, epochs=1000)

  • SimpleNN():創建一個SimpleNN類的實例model
  • model.train(x, y, lr=0.01, epochs=1000):調用modeltrain方法,使用生成的數據集xy,學習率lr=0.01,訓練輪數epochs=1000進行訓練。

4. 可視化結果

python

y_pred = model.forward(x)
plt.scatter(x, y, label='Data points')
plt.plot(x, x, color='red', label='y = x')
plt.plot(x, y_pred, color='green', label='Predicted')
plt.legend()
plt.show()

  • model.forward(x):使用訓練好的模型model對數據集x進行前向傳播,得到預測值y_pred
  • plt.scatter(x, y, label='Data points'):繪制原始數據集的散點圖,標簽為Data points
  • plt.plot(x, x, color='red', label='y = x'):繪制真實的線性關系y = x的折線圖,顏色為紅色,標簽為y = x
  • plt.plot(x, y_pred, color='green', label='Predicted'):繪制模型預測結果的折線圖,顏色為綠色,標簽為Predicted
  • plt.legend():顯示圖例,方便區分不同的圖形。
  • plt.show():顯示繪制好的圖形。

四、注意事項

  1. 本代碼實現的是一個簡單的線性回歸神經網絡,實際應用中可能需要更復雜的模型結構和優化方法。
  2. 學習率lr和訓練輪數epochs是超參數,可能需要根據具體數據和任務進行調整以獲得更好的訓練效果。
  3. 代碼中使用的均方誤差損失函數和梯度計算公式是針對線性回歸問題的常見選擇,但在其他問題中可能需要使用不同的損失函數和梯度計算方法。

完整代碼

import numpy as np
import matplotlib.pyplot as plt# 1. 生成數據集
np.random.seed(42)
x = np.linspace(-10, 10, 100)
y = x + np.random.normal(0, 1, x.shape)  # 添加噪聲# 2. 定義神經網絡(線性回歸)
class SimpleNN:def __init__(self):self.w = np.random.randn()  # 權重self.b = np.random.randn()  # 偏置def forward(self, x):return self.w * x + self.b  # 前向傳播def loss(self, y_true, y_pred):return np.mean((y_true - y_pred) **2)  # 均方誤差def gradient(self, x, y_true, y_pred):dw = -2 * np.mean(x * (y_true - y_pred))  # 權重的梯度db = -2 * np.mean(y_true - y_pred)       # 偏置的梯度return dw, dbdef train(self, x, y, lr=0.01, epochs=1000):for epoch in range(epochs):y_pred = self.forward(x)dw, db = self.gradient(x, y, y_pred)self.w -= lr * dw  # 更新權重self.b -= lr * db  # 更新偏置if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {self.loss(y, y_pred):.4f}')# 3. 訓練模型
model = SimpleNN()
model.train(x, y, lr=0.01, epochs=1000)# 4. 可視化結果
y_pred = model.forward(x)
plt.scatter(x, y, label='Data points')
plt.plot(x, x, color='red', label='y = x')
plt.plot(x, y_pred, color='green', label='Predicted')
plt.legend()
plt.show()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/899703.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/899703.shtml
英文地址,請注明出處:http://en.pswp.cn/news/899703.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

?19.思科路由器:OSPF協議引入直連路由的實驗研究

思科路由器:OSPF協議引入直連路由的實驗研究 一、實驗拓撲二、基本配置2.1、sw1的配置2.2、開啟交換機三層功能三、ospf的配置3.1、R1的配置3.2、R2的配置3.3、重啟ospf進程四、引入直連路由五、驗證結果隨著互聯網技術的不斷發展,路由器作為網絡互聯的關鍵設備,其性能與穩定…

USB——刪除注冊表信息

文章目錄 背景工具下載地址工具使用刪除注冊表信息背景 注測表中已記錄這個設備的信息,但現在設備描述符又指定為了 WinUSB 設備,所以當設備再次插入的時候,不會發送 0xEE 命令,造成了枚舉失敗。 兩種處理方式: 修改枚舉時候的 VID/PID刪除 USB 的注冊表信息工具下載地址…

如何快速解決django報錯:cx_Oracle.DatabaseError: ORA-00942: table or view does not exist

我們在使用django連接oracle進行編程時,使用model進行表映射對接oracle數據時,默認表名組成結構為:應用名_類名(如:OracleModel_test),故即使我們庫中存在表test,運行查詢時候&#…

從 0 到跑通的 Qt + OpenGL + VS 項目的完整流程

🧩 全流程目標: 在 Visual Studio 中成功打開、編譯并運行一個 Qt OpenGL 項目(.vcxproj 格式) ? 第 1 步:安裝必要環境 工具說明Visual Studio 2017 / 2019 / 2022必須勾選 “使用 C 的桌面開發” 和 “MSVC 工具…

鴻蒙開發03樣式相關介紹(二)

文章目錄 一、樣式復用1.1 Styles修飾符1.2 Extend修飾符 二、多態樣式 一、樣式復用 在頁面開發過程中,會出出現大量重復的樣式設置代碼,可以使用Styles和Extend修飾符將幫助我們進行樣式復用。 1.1 Styles修飾符 Styles裝飾器可以將多條樣式設置提煉…

裝飾器模式與模板方法模式實現MyBatis-Plus QueryWrapper 擴展

pom <dependency><groupId>com.github.yulichang</groupId><artifactId>mybatis-plus-join-boot-starter</artifactId> <!-- MyBatis 聯表查詢 --> </dependency>MPJLambdaWrapperX /*** 拓展 MyBatis Plus Join QueryWrapper 類&…

05-031-自考數據結構(20331)- 哈希表 - 例題分析

哈希表考題主要涵蓋四大類型:1)函數設計類(如除留余數法計算地址,需掌握質數p的選擇技巧);2)沖突處理類(線性探測法要解決堆積現象,鏈地址法需繪制鏈表結構);3)性能分析類(重點計算ASL,理解裝填因子α的影響規律);4)綜合應用類(如設計ISBN查詢系統,需結合實際問…

rustdesk 自建服務器 key不匹配

請確保id_ed25519文件的權限為&#xff1a; -rw------- 1 root root 88 Apr 31 10:02 id_ed25519在rustdesk安裝目錄執行命令&#xff1a; chmod 700 id_ed25519

Dify 深度集成 MCP實現災害應急響應

一、架構設計 1.1 分層架構 #mermaid-svg-5dVNjmixTX17cCfg {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-5dVNjmixTX17cCfg .error-icon{fill:#552222;}#mermaid-svg-5dVNjmixTX17cCfg .error-text{fill:#552222…

AI與.NET技術實操系列(三):在 .NET 中使用大語言模型(LLMs)

1. 引言 在技術迅猛發展的今天&#xff0c;大語言模型&#xff08;Large Language Models, LLMs&#xff09;已成為人工智能領域的核心驅動力之一。從智能對話系統到自動化內容生成&#xff0c;LLMs的應用正在深刻改變我們的工作與生活方式。對于.NET開發者而言&#xff0c;掌…

一個極簡的詞法分析器實現

文章目錄 推薦&#xff1a;Tiny Lexer - 一個極簡的C語言詞法分析器特點核心代碼實現學習價值擴展建議 用Java實現一個簡單的詞法分析器完整實現代碼代碼解析示例輸出擴展建議 用Go實現極簡詞法分析器完整實現代碼代碼解析示例輸出擴展建議 最近兩天搞一個DSL&#xff0c;不得不…

強制用戶裸奔,微軟封鎖唯一后門操作

周末剛結束&#xff0c;那個常年將「用戶為中心」掛嘴邊的微軟又雙叒叕開始作妖&#xff01; 不錯&#xff0c;大伙兒今后可能再沒法通過「OOBE\BYPASSNRO」命令繞過微軟強制聯網要求了。 熟悉 Windows 11 操作系統的都知道&#xff0c;除硬件上諸多限制外&#xff1b; 軟件層…

大模型備案:攔截關鍵詞列表與敏感詞庫深度解析

隨著《生成式人工智能服務管理暫行辦法》正式實施&#xff0c;大模型上線備案成為企業合規運營的核心環節。其中&#xff0c;敏感詞庫建設與攔截關鍵詞列表管理直接關系內容安全紅線&#xff0c;今天我們就來詳細解析一下大模型備案的這一部分&#xff0c;希望對想要做備案的朋…

快速上手Linux系統輸入輸出

一、管理系統中的輸入輸出 1.什么是重定向&#xff1f; 將原本要輸出到屏幕上的內容&#xff0c;重新輸入到其他設備中或文件中 重定向類型包括 輸入重定向輸出重定向 2.輸入重定向 指定設備&#xff08;通常是文件或命令的執行結果&#xff09;來代替鍵盤作為新的輸入設…

文小言全新升級!多模型協作與智能語音功能帶來更流暢的AI體驗

文小言全新升級&#xff01;多模型協作與智能語音功能帶來更流暢的AI體驗 在3月31日的百度AI DAY上&#xff0c;文小言正式宣布了一系列令人興奮的品牌煥新與功能升級。此次更新不僅帶來了全新的品牌視覺形象&#xff0c;更讓文小言在智能助手的技術和用戶體驗方面邁上了一個新…

C++基礎算法(插入排序)

1.插入排序 插入排序&#xff08;Insertion Sort&#xff09;介紹&#xff1a; 插入排序是一種簡單直觀的排序算法&#xff0c;它的工作原理類似于我們整理撲克牌的方式。 1.基本思想 插入排序的基本思想是&#xff1a; 1.將數組分為已排序和未排序兩部分 2.每次從未排序部分…

k近鄰算法K-Nearest Neighbors(KNN)

算法核心 KNN算法的核心思想是“近朱者赤&#xff0c;近墨者黑”。對于一個待分類或預測的樣本點&#xff0c;它會查找訓練集中與其距離最近的K個樣本點&#xff08;即“最近鄰”&#xff09;。然后根據這K個最近鄰的標簽信息來對當前樣本進行分類或回歸。 在分類任務中&#…

【Feign】??使用 openFeign 時傳遞 MultipartFile 類型的參數參考

&#x1f4a5;&#x1f4a5;????歡迎閱讀本文章????&#x1f4a5;&#x1f4a5; &#x1f3c6;本篇文章閱讀大約耗時三分鐘。 ??motto&#xff1a;不積跬步、無以千里 &#x1f4cb;&#x1f4cb;&#x1f4cb;本文目錄如下&#xff1a;&#x1f381;&#x1f381;&a…

zk基礎—1.一致性原理和算法二

大綱 1.分布式系統特點 2.分布式系統的理論 3.兩階段提交Two-Phase Commit(2PC) 4.三階段提交Three-Phase Commit(3PC) 5.Paxos島的故事來對應ZooKeeper 6.Paxos算法推導過程 7.Paxos協議的核心思想 8.ZAB算法簡述 6.Paxos算法推導過程 (1)Paxos的概念 (2)問題描述 …

216. 組合總和 III 回溯

目錄 問題描述 解決思路 關鍵點 代碼實現 代碼解析 1. 初始化結果和路徑 2. 深度優先搜索&#xff08;DFS&#xff09; 3. 遍歷候選數字 4. 遞歸與回溯 示例分析 復雜度與優化 回溯算法三部曲 1. 路徑選擇&#xff1a;記錄當前路徑 2. 遞歸探索&#xff1a;進入下…