關于神經網絡中回歸的概念

神經網絡中的回歸詳解

引言

神經網絡(NeuralNetworks)是一種強大的機器學習模型,可用于分類和回歸任務。本文聚焦于神經網絡中的回歸(Regression),即預測連續輸出值(如房價、溫度)。

回歸問題:給定輸入特征x?\vec{x}x,預測連續目標yyy。神經網絡通過多層非線性變換學習復雜映射f:x??yf:\vec{x}\mapsto yf:x?y

基本概念回顧

神經元與層

  • 神經元(Neuron):基本單元。輸入x?=(x1,…,xn)\vec{x}=(x_1,\dots,x_n)x=(x1?,,xn?),權重w?=(w1,…,wn)\vec{w}=(w_1,\dots,w_n)w=(w1?,,wn?),偏置bbb
    計算:線性組合z=w??x?+b=∑i=1nwixi+bz=\vec{w}\cdot\vec{x}+b=\sum_{i=1}^nw_ix_i+bz=w?x+b=i=1n?wi?xi?+b
    然后激活:a=σ(z)a=\sigma(z)a=σ(z)σ\sigmaσ為激活函數。

  • (Layer):多個神經元組成。

    • 輸入層:原始特征。
    • 隱藏層:中間變換。
    • 輸出層:最終預測y^\hat{y}y^?(回歸中通常1個神經元,無激活或線性激活)。
  • 前饋神經網絡(FeedforwardNeuralNetwork,FNN):信息從輸入到輸出單向流動。也稱多層感知機(MLP)。

激活函數

激活引入非線性。常見:

  • Sigmoid:σ(z)=1/(1+e?z)\sigma(z)=1/(1+e^{-z})σ(z)=1/(1+e?z),輸出[0,1]。
  • Tanh:σ(z)=(ez?e?z)/(ez+e?z)\sigma(z)=(e^z-e^{-z})/(e^z+e^{-z})σ(z)=(ez?e?z)/(ez+e?z),輸出[-1,1]。
  • ReLU:σ(z)=max?(0,z)\sigma(z)=\max(0,z)σ(z)=max(0,z),簡單高效,避免梯度消失。
  • Linear:σ(z)=z\sigma(z)=zσ(z)=z,用于回歸輸出層。

隱藏層常用ReLU,輸出層線性以輸出任意實數。

神經網絡回歸模型結構

數學表示

假設網絡有LLL層。第lll層有mlm_lml?個神經元。

  • 輸入:a?(0)=x?∈Rm0\vec{a}^{(0)}=\vec{x}\in\mathbb{R}^{m_0}a(0)=xRm0?

  • lll層:
    z?(l)=W(l)a?(l?1)+b?(l) \vec{z}^{(l)}=W^{(l)}\vec{a}^{(l-1)}+\vec{b}^{(l)} z(l)=W(l)a(l?1)+b(l)
    a?(l)=σ(l)(z?(l)) \vec{a}^{(l)}=\sigma^{(l)}(\vec{z}^{(l)}) a(l)=σ(l)(z(l))
    其中W(l)∈Rml×ml?1W^{(l)}\in\mathbb{R}^{m_l\times m_{l-1}}W(l)Rml?×ml?1?為權重矩陣,b?(l)∈Rml\vec{b}^{(l)}\in\mathbb{R}^{m_l}b(l)Rml?為偏置。

  • 輸出:y^=a?(L)\hat{y}=\vec{a}^{(L)}y^?=a(L)(標量)。

整個網絡:y^=f(x?;θ)\hat{y}=f(\vec{x};\theta)y^?=f(x;θ)θ={W(l),b?(l)}l=1L\theta=\{W^{(l)},\vec{b}^{(l)}\}_{l=1}^Lθ={W(l),b(l)}l=1L?為參數。

示例結構

簡單回歸網絡:輸入2維,1隱藏層(3神經元),輸出1維。

  • 輸入層:x?=(x1,x2)\vec{x}=(x_1,x_2)x=(x1?,x2?)
  • 隱藏層:W(1)∈R3×2W^{(1)}\in\mathbb{R}^{3\times2}W(1)R3×2b?(1)∈R3\vec{b}^{(1)}\in\mathbb{R}^3b(1)R3,激活ReLU。
  • 輸出層:W(2)∈R1×3W^{(2)}\in\mathbb{R}^{1\times3}W(2)R1×3b?(2)∈R\vec{b}^{(2)}\in\mathbb{R}b(2)R,激活線性。

損失函數

回歸常用均方誤差(MeanSquaredError,MSE):
L(y^,y)=12(y^?y)2 \mathcal{L}(\hat{y},y)=\frac{1}{2}(\hat{y}-y)^2 L(y^?,y)=21?(y^??y)2
批次樣本:$ \mathcal{L}=\frac{1}{N}\sum_{i=1}N\frac{1}{2}(\hat{y}_i-y_i)2 $

其他:MAE(L=∣y^?y∣\mathcal{L}=|\hat{y}-y|L=y^??y),HuberLoss(對異常值魯棒)。

訓練過程:反向傳播與梯度下降

前向傳播

從輸入計算到輸出,得到y^\hat{y}y^?L\mathcal{L}L

反向傳播(Backpropagation)

計算梯度?L/?θ\partial\mathcal{L}/\partial\theta?L/?θ

  • 輸出層誤差:δ(L)=?L/?z?(L)=(y^?y)?σ(L)′(z?(L))\delta^{(L)}=\partial\mathcal{L}/\partial\vec{z}^{(L)}=(\hat{y}-y)\cdot\sigma^{(L)'}(\vec{z}^{(L)})δ(L)=?L/?z(L)=(y^??y)?σ(L)(z(L))(線性激活時σ′=1\sigma'=1σ=1,故δ(L)=y^?y\delta^{(L)}=\hat{y}-yδ(L)=y^??y)。
  • 向后傳播:δ(l)=(W(l+1))Tδ(l+1)⊙σ(l)′(z?(l))\delta^{(l)}=(W^{(l+1)})^T\delta^{(l+1)}\odot\sigma^{(l)'}(\vec{z}^{(l)})δ(l)=(W(l+1))Tδ(l+1)σ(l)(z(l))⊙\odot為逐元素乘。
  • 梯度:
    ?L?W(l)=δ(l)(a?(l?1))T \frac{\partial\mathcal{L}}{\partial W^{(l)}}=\delta^{(l)}(\vec{a}^{(l-1)})^T ?W(l)?L?=δ(l)(a(l?1))T
    ?L?b?(l)=δ(l) \frac{\partial\mathcal{L}}{\partial\vec{b}^{(l)}}=\delta^{(l)} ?b(l)?L?=δ(l)

優化:梯度下降

更新參數:θ:=θ?η?θL\theta:=\theta-\eta\nabla_\theta\mathcal{L}θ:=θ?η?θ?Lη\etaη為學習率。

變體:

  • SGD:隨機梯度下降,每批次更新。
  • Momentum:添加動量v:=βv?η?v:=\beta v-\eta\nablav:=βv?η?θ:=θ+v\theta:=\theta+vθ:=θ+v
  • Adam:自適應學習率,結合動量和RMSProp。

完整訓練算法

  1. 初始化θ\thetaθ(e.g.,Xavier初始化)。
  2. 對于每個epoch:
    a. 前向:計算y^\hat{y}y^?L\mathcal{L}L
    b. 反向:計算梯度。
    c. 更新θ\thetaθ
  3. 監控驗證損失,早停防止過擬合。

數學推導示例:簡單網絡

假設單隱藏層,輸入1維xxx,隱藏1神經元,輸出y^\hat{y}y^?

  • 前向:
    z(1)=w1x+b1z^{(1)}=w_1x+b_1z(1)=w1?x+b1?a(1)=σ(z(1))a^{(1)}=\sigma(z^{(1)})a(1)=σ(z(1))(ReLU)。
    z(2)=w2a(1)+b2z^{(2)}=w_2a^{(1)}+b_2z(2)=w2?a(1)+b2?y^=z(2)\hat{y}=z^{(2)}y^?=z(2)(線性)。
  • 損失:L=12(y^?y)2\mathcal{L}=\frac{1}{2}(\hat{y}-y)^2L=21?(y^??y)2
  • 梯度:
    ?L/?y^=y^?y\partial\mathcal{L}/\partial\hat{y}=\hat{y}-y?L/?y^?=y^??y
    ?L/?w2=(y^?y)a(1)\partial\mathcal{L}/\partial w_2=(\hat{y}-y)a^{(1)}?L/?w2?=(y^??y)a(1)
    ?L/?b2=y^?y\partial\mathcal{L}/\partial b_2=\hat{y}-y?L/?b2?=y^??y
    ?L/?a(1)=(y^?y)w2\partial\mathcal{L}/\partial a^{(1)}=(\hat{y}-y)w_2?L/?a(1)=(y^??y)w2?
    ?L/?z(1)=?L/?a(1)?σ′(z(1))\partial\mathcal{L}/\partial z^{(1)}=\partial\mathcal{L}/\partial a^{(1)}\cdot\sigma'(z^{(1)})?L/?z(1)=?L/?a(1)?σ(z(1))(ReLU’:1 ifz(1)>0z^{(1)}>0z(1)>0,else0)。
    ?L/?w1=?L/?z(1)?x\partial\mathcal{L}/\partial w_1=\partial\mathcal{L}/\partial z^{(1)}\cdot x?L/?w1?=?L/?z(1)?x
    ?L/?b1=?L/?z(1)\partial\mathcal{L}/\partial b_1=\partial\mathcal{L}/\partial z^{(1)}?L/?b1?=?L/?z(1)

正則化與優化技巧

  • 過擬合防治

    • L1/L2正則:添加λ∑∣w∣\lambda\sum|w|λwλ∑w2\lambda\sum w^2λw2到損失。
    • Dropout:訓練時隨機丟棄神經元(概率p)。
    • 數據增強:增加訓練數據。
    • 早停:驗證損失上升時停止。
  • 初始化:He初始化forReLU:w~N(0,2/ml?1)w\sim\mathcal{N}(0,\sqrt{2/m_{l-1}})wN(0,2/ml?1??)

  • 批標準化(BatchNormalization):在每層后標準化z?(l)\vec{z}^{(l)}z(l),加速訓練。

  • 學習率調度:余弦退火或指數衰減。

  • 超參數調優:層數、神經元數、學習率、批大小。用GridSearch或BayesianOptimization。

優點與缺點

  • 優點

    • 處理非線性關系:通用函數逼近器。
    • 自動特征提取:隱藏層學習高級表示。
    • 可擴展:深層網絡捕捉復雜模式。
  • 缺點

    • 計算密集:訓練需GPU。
    • 黑箱:解釋性差(用SHAP或LIME改善)。
    • 需大量數據:小數據集易過擬合。
    • 梯度消失/爆炸:深層網絡問題(用ReLU、殘差連接緩解)。

應用場景

  • 房價預測:輸入面積、位置等,輸出價格。
  • 時間序列預測:RNN/LSTM變體,但基本FNN可用于簡單回歸。
  • 圖像回歸:CNN提取特征,后接全連接回歸(如年齡估計)。
  • 金融:股票價格預測。

實際例子

例子1:線性回歸模擬

用單層無激活網絡模擬線性回歸y=2x+1y=2x+1y=2x+1

  • 輸入xxx,輸出y^=wx+b\hat{y}=wx+by^?=wx+b
  • 損失MSE。
  • 訓練后w≈2w\approx2w2b≈1b\approx1b1

例子2:非線性回歸

預測y=sin?(x)+噪聲y=\sin(x)+噪聲y=sin(x)+噪聲

  • 網絡:輸入1,隱藏[64,64]ReLU,輸出1線性。
  • 數據:1000點x∈[?π,π]x\in[-π,π]x[?π,π]
  • 訓練:Adam,MSE,epochs=1000。
    網絡學習正弦曲線。

代碼實現(Python with PyTorch)

import torch
import torch.nn as nn
import torch.optim as optimclass RegressionNet(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(1, 64)self.fc2 = nn.Linear(64, 64)self.fc3 = nn.Linear(64, 1)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))return self.fc3(x)# 數據
x = torch.randn(1000, 1) * 3.14
y = torch.sin(x) + 0.1 * torch.randn(1000, 1)# 訓練
model = RegressionNet()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)for epoch in range(1000):optimizer.zero_grad()output = model(x)loss = criterion(output, y)loss.backward()optimizer.step()

總結

神經網絡回歸通過多層變換、反向傳播和優化學習連續映射。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/96853.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/96853.shtml
英文地址,請注明出處:http://en.pswp.cn/web/96853.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

JAVASCRIPT 前端數據庫-V9--仙盟數據庫架構-—仙盟創夢IDE

老版本 在v1 版本中我們講述了 基礎版的應用JAVASCRIPT 前端數據庫-V1--仙盟數據庫架構-—-—仙盟創夢IDE-CSDN博客接下載我們做一個更復雜的的其他場景由于,V1查詢字段必須 id接下來我們修改了了代碼JAVASCRIPT 前端數據庫-V2--仙盟數據庫架構-—-—仙盟創夢IDE-CS…

k8s核心資料基本操作

NamespaceNamespace是kubernetes系統中的一種非常重要資源,它的主要作用是用來實現多套環境的資源隔離或者多租戶的資源隔離。默認情況下,kubernetes集群中的所有的Pod都是可以相互訪問的。但是在實際中,可能不想讓兩個Pod之間進行互相的訪問…

PostgreSQL——分區表

分區表一、分區表的意義二、傳統分區表2.1、繼承表2.2、創建分區表2.3、使用分區表2.4、查詢父表還是子表2.5、constraint_exclusion參數2.6、添加分區2.7、刪除分區2.8、分區表相關查詢2.9、傳統分區表注意事項三、內置分區表3.1、創建分區表3.2、使用分區表3.3、內置分區表原…

Linux任務調度全攻略

Linux下的任務調度分為兩類,系統任務調度和用戶任務調度。系統任務調度:系統周期性所要執行的工作,比如寫緩存數據到硬盤、日志清理等。在/etc目錄下有一個crontab文件,這個就是系統任務調度的配置文件。/etc/crontab文件包括下面…

回溯算法通關秘籍:像打怪一樣刷題

🚀 回溯算法通關秘籍:像打怪一樣刷題! 各位同學,今天咱們聊聊 回溯算法(Backtracking)。它聽起來玄乎,但其實就是 “暴力搜索 剪枝” 的優雅版。 打個比方:回溯就是在迷宮里探險&am…

嵌入式Linux常用命令

📟 核心文件與目錄操作pwd-> 功能: 打印當前工作目錄的絕對路徑。-> 示例: pwd -> 輸出 /home/user/projectls [選項] [目錄]-> 功能: 列出目錄內容。-> 常用選項:-l: 長格式顯示(詳細信息)-a: 顯示所有文件(包括隱…

深入理解 Linux 內核進程管理

在 Linux 系統中,進程是資源分配和調度的基本單位,內核對進程的高效管理直接決定了系統的性能與穩定性。本文將從進程描述符的結構入手,逐步剖析進程的創建、線程實現與進程終結的完整生命周期,帶您深入理解 Linux 內核的進程管理…

ACP(三):讓大模型能夠回答私域知識問題

讓大模型能夠回答私域知識問題 未經過特定訓練答疑機器人,是無法準確回答“我們公司項目管理用什么工具”這類內部問題。根本原因在于,大模型的知識來源于其訓練數據,這些數據通常是公開的互聯網信息,不包含任何特定公司的內部文檔…

使用Xterminal連接Linux服務器

使用Xterminal連接Linux服務器(VMware虛擬機)的步驟如下,前提是虛擬機已獲取IP(如 192.168.31.105)且網絡互通: 一、準備工作(服務器端確認)確保SSH服務已安裝并啟動 Linux服務器需要…

ChatBot、Copilot、Agent啥區別

以下內容為AI生成ChatBot(聊天機器人)、Copilot(副駕駛)和Agent(智能體/代理)是AI應用中常見的三種形態,它們在人機交互、自動化程度和任務處理能力上有著顯著的區別。特征維度ChatBot (聊天機器…

2025 年大語言模型架構演進:DeepSeek V3、OLMo 2、Gemma 3 與 Mistral 3.1 核心技術剖析

編者按: 在 Transformer 架構誕生八年之際,我們是否真的見證了根本性的突破,還是只是在原有設計上不斷打磨?今天我們為大家帶來的這篇文章,作者的核心觀點是:盡管大語言模型在技術細節上持續優化&#xff0…

基于Matlab GUI的心電信號QRS波群檢測與心率分析系統

心電信號(Electrocardiogram, ECG)是臨床診斷心臟疾病的重要依據,其中 QRS 波群的準確檢測對于心率分析、心律失常診斷及自動化心電分析系統具有核心意義。本文設計并實現了一套基于 MATLAB GUI 的心電信號處理與分析系統,集成了數…

1臺SolidWorks服務器能帶8-10人并發使用

在工業設計和機械工程領域,SolidWorks作為主流的三維CAD軟件,其服務器部署方案直接影響企業協同效率。通過云飛云共享云桌面技術實現多人并發使用SolidWorks時,實際承載量取決于硬件配置、網絡環境、軟件優化等多維度因素的綜合作用。根據專業…

String、StringBuilder和StringBuffer的區別

目錄一. String:不可變的字符串二.StringBuilder:可變字符串三.StringBuffer:線程安全的可變字符串四.總結在 Java 開發中,字符串處理是日常編碼中最頻繁的操作之一。String、StringBuilder 和 StringBuffer 這三個類雖然都用于操…

Power Automate List Rows使用Fetchxml查詢的一個bug

看一段FetchXML, 這段查詢在XRMtoolbox中的fech test工具里執行完全ok<fetch version"1.0" mapping"logical" distinct"true" no-lock"false"> <entity name"new_projectchange"> <link-entity name"sy…

Letta(MemGPT)有狀態AI代理的開源框架

1. 項目概述Letta&#xff08;前身為 MemGPT&#xff09;是一個用于構建有狀態AI代理的開源框架&#xff0c;專注于提供長期記憶和高級推理能力。該項目是MemGPT研究論文的實現&#xff0c;引入了"LLM操作系統"的概念用于內存管理。核心特點有狀態代理&#xff1a;具…

除了ollama還有哪些模型部署方式?多樣化模型部署方式

在人工智能的浪潮中&#xff0c;模型部署是釋放其強大能力的關鍵一環。大家都知道ollama&#xff0c;它在模型部署領域有一定知名度&#xff0c;操作相對簡單&#xff0c;受到不少人的青睞。但其實&#xff0c;模型部署的世界豐富多樣&#xff0c;今天要給大家介紹一款工具&…

Linux系統學習之進階命令匯總

文章目錄一、系統信息1.1 查看系統信息&#xff1a;uname1.2 查看主機名&#xff1a;hostname1.3 查看cpu信息&#xff1a;1.4 當前已加載的內核模塊: lsmod1.5 查看磁盤空間使用情況: df1.6 管理磁盤分區: fdisk1.7 查看目錄或文件磁盤使用情況: du1.8 查看I/O使用情況: iosta…

算法面試(2)------休眠函數sleep_for和sleep_until

操作系統&#xff1a;ubuntu22.04 IDE:Visual Studio Code 編程語言&#xff1a;C11 算法描述 這兩個函數都定義在 頭文件中&#xff0c;屬于 std::this_thread 命名空間&#xff0c;用于讓當前線程暫停執行一段時間。函數功能sleep_for(rel_time)讓當前線程休眠一段相對時間&…

貪心算法應用:5G網絡切片問題詳解

Java中的貪心算法應用&#xff1a;5G網絡切片問題詳解 1. 5G網絡切片問題概述 5G網絡切片是將物理網絡劃分為多個虛擬網絡的技術&#xff0c;每個切片可以滿足不同業務需求&#xff08;如低延遲、高帶寬等&#xff09;。網絡切片資源分配問題可以抽象為一個典型的優化問題&…