深度學習與圖像處理 | 基于PaddlePaddle的梯度下降算法實現(線性回歸投資預測)

?演示基于PaddlePaddle自動求導技術實現梯度下降,簡化求解過程。

01、梯度下降法

梯度下降法是機器學習領域非常重要和具有代表性的算法,它通過迭代計算來逐步尋找目標函數極小值。既然是一種迭代計算方法,那么最重要的就是往哪個方向迭代,梯度下降法選擇從目標函數的梯度切入。首先需要明確一個數學概念,即函數的梯度方向是函數值變化最快的方向。梯度下降法就是基于此來進行迭代。

圖2.28對應一個雙自變量函數

圖片

。想要求得該函數極小值,只需要隨機選擇一個初始點,然后計算當前點對應的梯度,按照梯度反方向下降一定高度,然后重新計算當前位置對應的梯度,繼續按照梯度反方向下降。按照上述方式迭代,最終就可以用最快的速度到達極小值附近。

■??梯度下降法示意圖

?

如果函數很復雜并有多個極小值點,那么選擇不同的初始值,按照梯度下降算法的計算方式很有可能會到達不同的極小值點,并且耗時也不一樣。因此,在工程實現上選擇一個好的初始值是非常重要的。

對于前面的直線擬合任務來說,其目標函數就是L,模型參數就是a和b。按照梯度下降算法的原理,對應實現步驟如下:

(1)初始化模型參數a和b;

(2)輸入每個樣本x,根據公式y=ax+b計算每個樣本數據的預測輸出值

圖片

(3)計算所有樣本的預測值

圖片

和真值y之間的平方差L;

(4)計算當前L對模型參數a和b的梯度值,即

圖片

按照下式更新參數a和b:

圖片

其中t表示當前迭代的輪次,

圖片

是一個提前設置好的參數,這個參數的作用是代表每一步迭代下降的跨度,專業術語也叫學習率;

(6)重復步驟(2)~(5),直至迭代次數超過某個預設值。

注意到,上述算法第(5)步中,需要計算目標函數L對a和b的偏導。盡管對于這個直線擬合任務來說其偏導求取非常簡單,但是依然需要手工進行求導。在2.3.3節中,介紹過可以通過PaddlePaddle來自動計算梯度,因此,可以使用PaddlePaddle來更便捷的實現這個梯度下降算法。

完整代碼如下(machine_learning/auto_diff.py):

import matplotlib.pyplot as plt
import numpy as np
import paddle
# 輸入數據
x_train = np.array(    [3.3, 4.4, 5.5, 6.7, 6.9, 4.2, 9.8, 6.2, 7.6, 2.2, 7, 10.8, 5.3, 8, 3.1],    dtype=np.float32,)
y_train = np.array(    [17, 28, 21, 32, 17, 16, 34, 26, 25, 12, 28, 35, 17, 29, 13], dtype=np.float32)
# numpy轉tensor
x_train = paddle.to_tensor(x_train)
y_train = paddle.to_tensor(y_train)
# 隨機初始化模型參數
a = np.random.randn(1)
a = paddle.to_tensor(a, dtype="float32", stop_gradient=False)
b = np.random.randn(1)
b = paddle.to_tensor(b, dtype="float32", stop_gradient=False)
# 循環迭代
for t in range(10):# 計算平方差損失   y_ = a * x_train + b loss = paddle.sum((y_ - y_train) ** 2)# 自動計算梯度loss.backward()# 更新參數(梯度下降),學習率默認使用1e-3a = a.detach() - 1e-3 * float(a.grad)b = b.detach() - 1e-3 * float(b.grad)a.stop_gradient = Falseb.stop_gradient = False# 輸出當前輪的目標函數值Lprint("epoch: {}, loss: {}".format(t, (float(loss))))
# 訓練結束,終止a和b的梯度計算
a.stop_gradient = True
b.stop_gradient = True
# 可視化輸出
x_pred = paddle.arange(0, 15)
y_pred = a * x_pred + b
plt.plot(x_train.numpy(), y_train.numpy(), "go", label="Original Data")
plt.plot(x_pred.numpy(), y_pred.numpy(), "r-", label="Fitted Line")plt.xlabel("investment")
plt.ylabel("income")plt.legend()plt.savefig("result.png")
# 預測第16年的收益值
x = 12.5
y = a * x + b
print(y.numpy())

?上述代碼對每輪迭代的目標函數進行了輸出,同時預測了第16年的收益值,結果如下:

epoch: 0, loss: 9141.90625
epoch: 1, loss: 1103.665283203125
epoch: 2, loss: 397.5347900390625
epoch: 3, loss: 335.0281982421875
epoch: 4, loss: 329.02337646484375
epoch: 5, loss: 327.982421875
epoch: 6, loss: 327.381103515625
epoch: 7, loss: 326.8223571777344
epoch: 8, loss: 326.2711486816406
epoch: 9, loss: 325.72442626953125
[44.96492]

可以看到,隨著迭代的不斷進行,目標函數逐漸減少,說明模型的預測輸出越來越接近真值。最終訓練好的模型所預測的第16年的收益值與上一節使用導數法求解的標準解非常接近,驗證了梯度下降算法的有效性。

梯度下降法擬合結果如圖2.29所示。

■??梯度下降法擬合結果

?

從擬合結果看到,利用PaddlePaddle自動幫助求導,通過梯度下降迭代更新模型參數,最后得到了令人滿意的結果,擬合出來的直線基本吻合數據的分布。整個過程不需要手工計算梯度,實現非常簡單。

注意,上述代碼使用了隨機值來初始化模型參數,因此每次運算的結果可能略有不同。另外,使用了固定的學習率1e-3,并得到了一個比較好的訓練結果。如果訓練過程中目標函數沒有逐步下降,那么就需要適當調整學習率重新訓練。

本節案例是一個非常簡單的使用PaddlePaddle進行機器學習的示例,旨在幫助讀者熟悉和鞏固PaddlePaddle的基本使用方法。雖然任務簡單,但是該示例“五臟俱全”,整個建模學習過程分為4個部分,如圖2.30所示。

■??基于PaddlePaddle的梯度下降法步驟

?對于后面的深度學習任務,也會按照上述方式進行模型訓練。下面正式開始介紹如何基于PaddlePaddle實現更復雜的深度學習圖像應用。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/917060.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/917060.shtml
英文地址,請注明出處:http://en.pswp.cn/news/917060.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

負載均衡集群HAproxy

HAProxy 簡介HAProxy 是一款高性能的負載均衡器和代理服務器,支持 TCP 和 HTTP 應用。廣泛用于高可用性集群,能夠有效分發流量到多個后端服務器,確保服務的穩定性和可擴展性。HAProxy 核心功能負載均衡:支持輪詢(round…

重生之我在10天內卷贏C++ - DAY 1

坐穩了,我們的C重生之旅現在正式發車!請系好安全帶,前方高能,但絕對有趣!🚀 重生之我在10天內卷贏C - DAY 1導師寄語:嘿,未來的編程大神!歡迎來到C的世界。我知道&#x…

[mind-elixir]Mind-Elixir 的交互增強:單擊、雙擊與鼠標 Hover 功能實現

[mind-elixir]Mind-Elixir 的交互增強:單擊、雙擊與鼠標 Hover 功能實現 功能簡述 通過防抖,實現單擊雙擊區分通過mousemove事件,實現hover效果 實現思路 (一)單擊與雙擊事件 功能描述 單擊節點時,可以觸發…

c++-迭代器類別仿函數常用算法函數

C常用算法函數 1. 前置知識 1.1 迭代器的類別 C中,迭代器是 STL 容器庫的核心組件之一,具有舉足輕重的作用,它提供了一種 統一的方式來訪問和遍歷容器,而無需關心底層數據結構的具體實現。迭代器類似指針,但比指針更通…

Python深度學習框架TensorFlow與Keras的實踐探索

基礎概念與安裝配置 TensorFlow核心架構解析 TensorFlow是由Google Brain團隊開發的開源深度學習框架,其核心架構包含數據流圖(Data Flow Graph)和張量計算系統。數據流圖通過節點表示運算操作(如卷積、激活函數)&…

c# net6.0+ 安裝中文智能提示

https://github.com/stratosblue/IntelliSenseLocalizer 1、安裝tool dotnet tool install -g islocalizer 2、 安裝IntelliSense 文件,安裝其他net版本修改下版本號 安裝中文net6.0采集包 islocalizer install auto -m net6.0 -l zh-cn 安裝中英文雙語net6.0采集包…

【建模與仿真】二階鄰居節點信息驅動的節點重要性排序算法

導讀: 在復雜網絡中,挖掘重要節點對精準推薦、交通管控、謠言控制和疾病遏制等應用至關重要。為此,本文提出一種局部信息驅動的節點重要性排序算法Leaky Noisy Integrate-and-Fire (LNIF)。該算法通過獲取節點的二階鄰居信息計算節點重要性&…

指令微調Qwen3實現文本分類任務

參考文檔: SwanLab入門深度學習:Qwen3大模型指令微調 - 肖祥 - 博客園 vLLM:讓大語言模型推理更高效的新一代引擎 —— 原理詳解一_vllm 原理-CSDN博客 概述 為了實現對100個標簽的多標簽文本分類任務,前期調用gpt-4o進行prom…

【機器學習-3】 | 決策樹與鳶尾花分類實踐篇

0 序言 本文將深入探討決策樹算法,先回顧下前邊的知識,從其基本概念、構建過程講起,帶你理解信息熵、信息增益等核心要點。 接著在引入新知識點,介紹Scikit - learn 庫中決策樹的實現與應用,再通過一個具體項目的方式來…

【數字投影】折幕影院都是沉浸式嗎?

折幕影院作為一種現代化的展示形式,其核心特點在于通過多塊屏幕拼接和投影融合技術,打造更具包圍感的視覺體驗。折幕影院設計通常采用多折幕結構,如三折幕、五折幕等,利用多臺投影機的協同工作,呈現無縫銜接的超大畫面…

數據結構——圖(三、圖的 廣度/深度 優先搜索)

一、廣度優先搜索(BFS)①找到與一個頂點相鄰的所有頂點 ②標記哪些頂點被訪問過 ③需要一個輔助隊列#define MaxVertexNum 100 bool visited[MaxVertexNum]; //訪問標記數組 void BFSTraverse(Graph G){ //對圖進行廣度優先遍歷,處理非連通圖的函數 for(int i0;i…

直擊WAIC | 百度袁佛玉:加速具身智能技術及產品研發,助力場景應用多樣化落地

7月26日,2025世界人工智能大會暨人工智能全球治理高級別會議(WAIC)在上海開幕。同期,由國家地方共建人形機器人創新中心(以下簡稱“國地中心”)與中國電子學會聯合承辦,百度智能云、中國聯通上海…

2025年人形機器人動捕技術研討會將在本周四召開

2025年7月31日愛迪斯通所主辦的【2025人形機器動作捕捉技術研討會】是攜手北京天樹探界公司線下活動結合線上直播的形式,會議將聚焦在“動作捕捉軟硬件協同,加速人形機器人訓練”,將深度講解多項核心技術,包含全球知名的慣性動捕大…

Apple基礎(Xcode①-項目結構解析)

要運行設備之前先選擇好設備Product---->Destination---->選擇設備首次運行手機提示如出現 “未受信任的企業級開發者” → 手機打開 設置 ? 通用 ? VPN與設備管理 → 信任你的 Apple ID 即可ContentView 是 SwiftUI 項目里 最頂層、最主界面 的那個“頁面”&#xff0…

微服務 02

一、網關路由網關就是網絡的關口。數據在網絡間傳輸,從一個網絡傳輸到另一網絡時就需要經過網關來做數據的路由和轉發以及數據安全的校驗。路由是網關的核心功能之一,決定如何將客戶端請求映射到后端服務。1、快速入門創建新模塊,引入網關依賴…

04動手學深度學習筆記(上)

04數據操作 import torch(1)張量表示一個數據組成的數組,這個數組可能有多個維度。 xtorch.arange(12) xtensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])(2)通過shape來訪問張量的形狀和張量中元素的總數 x.shapetorch.Size([12])(3)number of elements表…

MCU中的RTC(Real-Time Clock,實時時鐘)是什么?

MCU中的RTC(Real-Time Clock,實時時鐘)是什么? 在MCU(微控制器單元)中,RTC(Real-Time Clock,實時時鐘) 是一個獨立計時模塊,用于在系統斷電或低功耗狀態下持續記錄時間和日期。以下是關于RTC的詳細說明: 1. RTC的核心功能 精準計時:提供年、月、日、時、分、秒、…

Linux 進程調度管理

進程調度器可粗略分為兩類:實時調度器(kernel),系統中重要的進程由實時調度器調度,獲得CPU能力強。非實時調度器(user),系統中大部分進程由非實時調度器調度,獲得CPU能力弱。實時調度器實時調度器支持的調度策略&#…

基于 C 語言視角:流程圖中分支與循環結構的深度解析

前言(約 1500 字)在 C 語言程序設計中,控制結構是構建邏輯的核心骨架,而流程圖作為可視化工具,是將抽象代碼邏輯轉化為直觀圖形的橋梁。對于入門 C 語言的工程師而言,掌握流程圖與分支、循環結構的對應關系…

threejs創建自定義多段柱

最近在研究自定義建模,有一個多斷柱模型比較有意思,分享下,就是利用幾組點串,比如上中下,然后每組點又不一樣多,點續還不一樣,(比如第一個環的第一個點在左邊,第二個環在右邊)&#…