GBDT算法原理及Python實現

一、概述

??GBDT(Gradient Boosting Decision Tree,梯度提升決策樹)是集成學習中提升(Boosting)方法的典型代表。它以決策樹(通常是 CART 樹,即分類回歸樹)作為弱學習器,通過迭代的方式,不斷擬合殘差(回歸任務)或負梯度(分類任務),逐步構建一系列決策樹,最終將這些樹的預測結果進行累加,得到最終的預測值。

二、算法原理

1. 梯度下降思想?

??梯度下降是一種常用的優化算法,用于尋找函數的最小值。在 GBDT 中,它扮演著至關重要的角色。假設我們有一個損失函數 L ( y , y ^ ) L\left( y,\hat{y} \right) L(y,y^?),其中 y y y是真實值, y ^ \hat y y^?是預測值。梯度下降的目標就是通過不斷調整模型參數,使得損失函數的值最小化。具體來說,每次迭代時,沿著損失函數關于參數的負梯度方向更新參數,以逐步接近最優解。在 GBDT 中,雖然沒有顯式地更新參數(通過構建多顆決策樹來擬合目標),但擬合的目標是損失函數的負梯度,本質上也是利用了梯度下降的思想。

2. 決策樹的構建?

??GBDT 使用決策樹作為弱學習器。決策樹是一種基于樹結構的預測模型,它通過對數據特征的不斷分裂,將數據劃分成不同的子集,每個子集對應樹的一個節點。在每個節點上,通過某種準則(如回歸任務中的平方誤差最小化,分類任務中的基尼指數最小化)選擇最優的特征和分裂點,使得劃分后的子集在目標變量上更加 “純凈” 或具有更好的區分度。通過遞歸地進行特征分裂,直到滿足停止條件(如達到最大樹深度、節點樣本數小于閾值等),從而構建出一棵完整的決策樹。

3. 迭代擬合的過程?

(1) 初始化模型

??首先,初始化一個簡單的模型,通常是一個常數模型,記為 f 0 ( X ) f_0(X) f0?(X) ,其預測值為所有樣本真實值的均值(回歸任務)或多數類(分類任務),記為 y ^ 0 \hat y_0 y^?0?。此時,模型的預測結果與真實值之間存在誤差。

(2) 計算殘差或負梯度

??在回歸任務中,計算每個樣本的殘差,即真實值 y i y_i yi?與當前模型預測值 y ^ i , t ? 1 \hat y_{i,t-1} y^?i,t?1?的差值 r i , t = y i ? y ^ i , t ? 1 r_{i,t}=y_i-\hat y_{i,t-1} ri,t?=yi??y^?i,t?1?,其中表示迭代的輪數。在分類任務中,計算損失函數關于當前模型預測值的負梯度 g i , t = ? ? L ( y i , y ^ i , t ? 1 ) ? y ^ i , t ? 1 g_{i,t}=-\frac{\vartheta L(y_i,\hat y_{i,t-1})}{\vartheta \hat y_{i,t-1}} gi,t?=??y^?i,t?1??L(yi?,y^?i,t?1?)?

(3) 擬合決策樹

??使用計算得到的殘差(回歸任務)或負梯度(分類任務)作為新的目標值,訓練一棵新的決策樹 f t ( X ) f_t(X) ft?(X)。這棵樹旨在擬合當前模型的誤差,從而彌補當前模型的不足。

(4) 更新模型

??根據新訓練的決策樹,更新當前模型。更新公式為 y ^ i , t = y ^ i , t ? 1 + α f t ( x i ) \hat y_{i,t}=\hat y_{i,t-1}+\alpha f_t(x_i) y^?i,t?=y^?i,t?1?+αft?(xi?),其中是學習率(也稱為步長),用于控制每棵樹對模型更新的貢獻程度。學習率較小可以使模型訓練更加穩定,但需要更多的迭代次數;學習率較大則可能導致模型收斂過快,甚至無法收斂。

(5) 重復迭代

??重復步驟 (2)–(4)步,不斷訓練新的決策樹并更新模型,直到達到預設的迭代次數、損失函數收斂到一定程度或滿足其他停止條件為止。最終,GBDT 模型由多棵決策樹組成,其預測結果是所有決策樹預測結果的累加。

算法過程圖示

在這里插入圖片描述
??GBDT 算法將梯度下降思想與決策樹相結合,通過迭代擬合殘差或負梯度,逐步構建一個強大的集成模型。它在處理復雜數據和非線性關系時表現較為出色,在數據挖掘、機器學習等領域得到了廣泛的應用。然而,GBDT 也存在一些缺點,如訓練時間較長、對異常值較為敏感等,在實際應用中需要根據具體情況進行優化和調整 。

三、Python實現

(環境:Python 3.11,scikit-learn 1.5.1)

分類情形

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
from sklearn import metrics# 生成樣本數據
X, y = make_classification(n_samples=1000, n_features=50, n_informative=10, n_redundant=5, random_state=1)# 劃分訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)# 創建GDBT分類模型
gbc = GradientBoostingClassifier(n_estimators=100, learning_rate=1.0, max_depth=1, random_state=1)# 訓練模型
gbc.fit(X_train, y_train)# 進行預測
y_pred = gbc.predict(X_test)# 計算準確率
accuracy = metrics.accuracy_score(y_test,y_pred)
print('準確率為:',accuracy)

在這里插入圖片描述

回歸情形

from sklearn.datasets import make_regression
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error# 生成樣本數據
X, y = make_regression(n_samples=1000, n_features=10, n_informative=5, noise=0.1, random_state=42)# 劃分訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 創建GDBT回歸模型
model = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, random_state=42)# 訓練模型
model.fit(X_train, y_train)# 在測試集上進行預測
y_pred = model.predict(X_test)# 計算均方誤差
mse = mean_squared_error(y_test, y_pred)
print(f"MSE: {mse}")

在這里插入圖片描述

End.



下載

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/80973.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/80973.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/80973.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

WordPress開心導航站_一站式網址_資源與資訊垂直行業主題模板

一款集網址、資源與資訊于一體的導航類主題,專為追求高效、便捷用戶體驗的垂直行業網站而設計無論您是構建行業資訊門戶、資源聚合平臺還是個人興趣導航站,這款開心版導航主題都能成為您理想的選擇。 核心特色: 一體化解決方案:整合了網址導航、資源下載…

馬井堂-區塊鏈技術:架構創新、產業變革與治理挑戰(馬井堂)

區塊鏈技術:架構創新、產業變革與治理挑戰 摘要 區塊鏈技術作為分布式賬本技術的革命性突破,正在重構數字時代的信任機制。本文系統梳理區塊鏈技術的核心技術架構,分析其在金融、供應鏈、政務等領域的實踐應用,探討共識算法優化、…

從像素到駕駛決策:Python與OpenCV賦能自動駕駛圖像識別

從像素到駕駛決策:Python與OpenCV賦能自動駕駛圖像識別 引言:圖像識別的力量驅動自動駕駛 自動駕駛技術正以令人驚嘆的速度改變交通方式,而其中最核心的技術之一便是圖像識別。作為車輛的“視覺系統”,圖像識別可以實時獲取道路信息,識別交通標志、車輛、行人等關鍵目標…

Spring計時器StopWatch 統計各個方法執行時間和占比

Spring計時器StopWatch 用法代碼 返回結果是毫秒 一毫秒等于千分之一秒(0.001秒)。因此,如果你有一個以毫秒為單位的時間值,你可以通過將這個值除以1000來將其轉換為秒。例如,500毫秒等于0.5秒。 import org.springf…

2.2.2goweb內置的 HTTP 處理程序2

http.StripPrefix http.StripPrefix 是 Go 語言 net/http 包中的一個函數,它的主要作用是創建一個新的 HTTP 處理程序。這個新處理程序會在處理請求之前,從請求的 URL 路徑中移除指定的前綴,然后將處理工作委托給另一個提供的處理程序。 使…

【Fifty Project - D20】

今日完成記錄 TimePlan完成情況7:30 - 11:30收拾行李閃現廣州 & 《挪威的森林》√10:00 - 11:00Leetcode√16:00 - 17:00健身√ Leetcode 每日一題 每日一題來到了滑動窗口系列,今天是越…

【圖片識別改名】批量讀取圖片區域文字識別后批量改名,基于Python和騰訊云的實現方案

項目場景 ??辦公文檔管理??:將掃描的發票、合同等文檔按編號、日期自動重命名。例如,識別“編號:2023001 日期:20230403”生成“2023001_20230403.jpg”。??產品圖片整理??:電商產品圖片按產品編號、名稱自動命名。例如,…

生物化學筆記:神經生物學概論04 視覺通路簡介視網膜視網膜神經細胞大小神經節細胞(視錯覺)

視覺通路簡介 神經節細胞的胞體構成一明確的解剖層次,其外鄰神經纖維層,內接內叢狀層,該層在鼻側厚約10~20μm,最厚在黃斑區約60~80μm。 全部細胞數約為120萬個(1000000左右)。 每個細胞有一軸突&#xff…

「Mac暢玩AIGC與多模態08」開發篇04 - 基于 OpenAPI Schema 開發專用 Agent 插件

一、概述 本篇介紹如何在 macOS 環境下,通過編寫 OpenAPI Schema,開發自定義的專用插件,讓智能體可以調用外部 API,擴展功能至任意在線服務。實踐內容基于 Dify 平臺,適配 macOS 開發環境。 二、環境準備 1. 確認本地開發環境 macOS 系統Dify 平臺已完成部署并可訪問本…

【計算機視覺】深度解析MediaPipe:谷歌跨平臺多媒體機器學習框架實戰指南

深度解析MediaPipe:谷歌跨平臺多媒體機器學習框架實戰指南 技術架構與設計哲學核心設計理念系統架構概覽 核心功能與預構建解決方案1. 人臉檢測2. 手勢識別3. 姿勢估計4. 物體檢測與跟蹤 實戰部署指南環境配置基礎環境準備獲取源碼 構建第一個示例(手部追…

NVIDIA高級輔助駕駛領域的創新實踐與云計算教育啟示

AI與高級輔助駕駛的時代浪潮 人工智能正在重塑現代交通的面貌,而高級輔助駕駛技術無疑是這場變革中最具顛覆性的力量之一。作為全球AI計算的領軍企業,NVIDIA憑借其全棧式技術生態和創新實踐,為高級輔助駕駛的產業化落地樹立了標桿。從芯片到…

頭歌實訓之存儲過程、函數與觸發器

🌟 各位看官好,我是maomi_9526! 🌍 種一棵樹最好是十年前,其次是現在! 🚀 今天來學習C語言的相關知識。 👍 如果覺得這篇文章有幫助,歡迎您一鍵三連,分享給更…

醫學圖像處理軟件中幾種MPR

1:設備廠商的MPR 2:后處理的MPR 3:閱片PACS的MPR 4:手術導航 手術規劃的MPR 設備廠商的MPR需求更多是掃描線、需要3DMPR ,三條定位線的任意角度旋轉。 后處理的MPR,需求更多的是算法以及UI工具的研發&a…

java 類的實例化過程,其中的相關順序 包括有繼承的子類等復雜情況,靜態成員變量的初始化順序,這其中jvm在干什么

Java類的實例化過程及初始化順序 Java類的實例化過程涉及多個步驟,特別是在存在繼承關系和靜態成員的情況下。下面我將詳細解釋整個過程,包括JVM在其中的角色。 1. 類加載階段(JVM的工作) 在實例化一個類之前,JVM首…

Sce2DriveX: 用于場景-到-駕駛學習的通用 MLLM 框架——論文閱讀

《Sce2DriveX: A Generalized MLLM Framework for Scene-to-Drive Learning》2025年2月發表,來自中科院軟件所和中科院大學的論文。 端到端自動駕駛直接將原始傳感器輸入映射到低級車輛控制,是Embodied AI的重要組成部分。盡管在將多模態大語言模型&…

【題解-Acwing】870. 約數個數

題目:870. 約數個數 題目描述 給定 n 個正整數 ai,請你輸出這些數的乘積的約數個數,答案對 109+7 取模。 輸入 第一行包含整數 n。 接下來 n 行,每行包含一個整數 ai。 輸出 輸出一個整數,表示所給正整數的乘積的約數個數,答案需對 109+7 取模。 數據范圍 1 ≤ …

創龍全志T536全國產(4核A55 ARM+RISC-V+NPU 17路UART)工業開發板硬件說明書

前 言 本文檔主要介紹TLT536-EVM評估板硬件接口資源以及設計注意事項等內容。 T536MX-CXX/T536MX-CEN2處理器的IO電平標準一般為1.8V、3.3V,上拉電源一般不超過3.3V或1.8V,當外接信號電平與IO電平不匹配時,中間需增加電平轉換芯片或信號隔離芯片。按鍵或接口需考慮ESD設計…

Redis 持久化雙雄:RDB 與 AOF 深度解析

Redis 是一種內存數據庫,為了保證數據在服務器重啟或故障時不丟失,提供了兩種持久化方式:RDB(Redis Database)和 AOF(Append Only File)。以下是它們的詳細介紹: 一、RDB 持久化 工…

數據結構|并查集

Hello !朋友們,這是我在學習過程中梳理的筆記,以作以后復習回顧,有時略有潦草,一些話是我用自己的話描述的,可能不夠準確,還是感謝大家的閱讀! 目錄 一、并查集Quickfind 二、兩種算…

【GPU 微架構技術】Pending Request Table(PRT)技術詳解

PRT(Pending Request Table)是 GPU 中用于管理 未完成內存請求(outstanding memory requests)的一種硬件結構,旨在高效處理大規模并行線程的內存訪問需求。與傳統的 MSHR(Miss Status Handling Registers&a…