機器學習進階,梯度提升機(GBM)與XGBoost

梯度提升機(Gradient Boosting Machine, GBM),特別是其現代高效實現——XGBoost。這是繼隨機森林后自然進階的方向,也是當前結構化數據競賽和工業界應用中最強大、最受歡迎的算法之一。

為什么推薦XGBoost?

  1. 與隨機森林互補:同屬集成學習,但Random Forest是Bagging思想,而XGBoost是Boosting思想。學習它可以幫助你全面理解集成學習的兩種主流范式。
  2. State-of-the-Art性能:在表格型數據上,XGBoost通常比隨機森林表現更好,是Kaggle等數據科學競賽中的"大殺器"。
  3. 高效且可擴展:專為速度和性能設計,支持并行處理,能處理大規模數據。
  4. 內置正則化:相比傳統GBM,XGBoost自帶正則化項,更不容易過擬合。

核心概念:Boosting vs Bagging

● Bagging(隨機森林):并行構建多個獨立的弱模型,然后通過投票/平均得到最終結果。
● Boosting(XGBoost):串行構建多個相關的弱模型,每個新模型都專注于糾正前一個模型的錯誤。

完整代碼示例

下面我們使用XGBoost來解決同樣的鳶尾花分類問題,并與隨機森林進行對比。

# xgboost_module.py
# -*- coding: utf-8 -*-"""
XGBoost分類器示例 - 鳶尾花數據集
模塊化實現,包含數據加載、模型訓練、評估、可視化和高級功能
"""# 1. 導入必要的庫
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import accuracy_score, classification_report
from sklearn.ensemble import RandomForestClassifier
import xgboost as xgb
import warnings
warnings.filterwarnings('ignore')# 設置全局樣式
plt.style.use('seaborn-v0_8')
np.random.seed(42)  # 設置隨機種子以確保結果可重現# 2. 數據加載模塊
def load_data():"""加載鳶尾花數據集"""iris = load_iris()X = iris.datay = iris.targetfeature_names = iris.feature_namestarget_names = iris.target_namesreturn X, y, feature_names, target_names# 3. 數據預處理模塊
def prepare_data(X, y, test_size=0.2, random_state=42):"""準備訓練和測試數據集"""X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state, stratify=y)print(f"訓練集大小: {X_train.shape[0]}")print(f"測試集大小: {X_test.shape[0]}")return X_train, X_test, y_train, y_test# 4. 隨機森林基準模型模塊
def train_random_forest(X_train, y_train, **params):"""訓練隨機森林模型作為基準"""# 設置默認參數default_params = {'n_estimators': 100,'max_depth': 3,'random_state': 42}# 更新默認參數default_params.update(params)# 初始化并訓練模型model = RandomForestClassifier(**default_params)model.fit(X_train, y_train)print("\n=== 隨機森林模型訓練完成 ===")print(f"使用參數: {default_params}")return model# 5. XGBoost模型訓練模塊
def train_xgboost(X_train, y_train, **params):"""訓練XGBoost模型"""# 設置默認參數default_params = {'n_estimators': 100,'max_depth': 3,'learning_rate': 0.1,'random_state': 42,'use_label_encoder': False,'eval_metric': 'logloss'}# 更新默認參數default_params.update(params)# 初始化并訓練模型model = xgb.XGBClassifier(**default_params)model.fit(X_train, y_train)print("\n=== XGBoost模型訓練完成 ===")print(f"使用參數: {default_params}")return model# 6. 模型評估模塊
def evaluate_model(model, X_test, y_test, model_name="模型"):"""評估模型性能"""# 預測y_pred = model.predict(X_test)# 計算準確率accuracy = accuracy_score(y_test, y_pred)print(f"\n=== {model_name}性能 ===")print(f"測試集準確率: {accuracy:.4f}")return accuracy, y_pred# 7. 交叉驗證比較模塊
def compare_cv_models(models, X, y, cv=5):"""使用交叉驗證比較多個模型"""print("\n=== 交叉驗證比較 ===")results = {}for name, model in models.items():scores = cross_val_score(model, X, y, cv=cv, scoring='accuracy')results[name] = scoresprint(f"{name} 交叉驗證平均分: {scores.mean():.4f}{scores.std():.4f})")return results# 8. 特征重要性可視化模塊
def plot_feature_importance(models, feature_names):"""可視化多個模型的特征重要性"""n_models = len(models)plt.figure(figsize=(5 * n_models, 5))for i, (name, model) in enumerate(models.items(), 1):plt.subplot(1, n_models, i)# 獲取特征重要性if hasattr(model, 'feature_importances_'):importances = model.feature_importances_else:# 對于XGBoost模型importances = model.get_booster().get_score(importance_type='weight')# 轉換為數組格式importances_array = np.zeros(len(feature_names))for j, feat in enumerate(feature_names):importances_array[j] = importances.get(f"f{j}", 0)importances = importances_array# 排序并繪制indices = np.argsort(importances)[::-1]plt.bar(range(len(feature_names)), importances[indices])plt.xticks(range(len(feature_names)), [feature_names[i] for i in indices], rotation=45)plt.title(f'{name} - Feature Importance')plt.tight_layout()plt.show()# 9. 高級功能:早停法訓練模塊
def train_xgboost_early_stopping(X_train, y_train, X_test, y_test, **params):"""使用早停法訓練XGBoost模型"""# 設置默認參數default_params = {'max_depth': 3,'learning_rate': 0.1,'objective': 'multi:softmax','num_class': 3,'eval_metric': 'mlogloss'}# 更新默認參數default_params.update(params)# 轉換為XGBoost的DMatrix格式dtrain = xgb.DMatrix(X_train, label=y_train)dtest = xgb.DMatrix(X_test, label=y_test)# 訓練并使用早停法evals = [(dtrain, 'train'), (dtest, 'test')]model = xgb.train(default_params, dtrain, num_boost_round=1000,evals=evals,early_stopping_rounds=10,verbose_eval=False)print("\n=== 早停法訓練完成 ===")print(f"在 {model.best_iteration} 輪停止")print(f"最佳驗證分數: {model.best_score:.4f}")return model# 10. 預測模塊
def make_predictions(model, new_samples, target_names, model_type='sklearn'):"""使用模型進行新樣本預測"""if model_type == 'xgboost_early_stop':# 對于早停法訓練的XGBoost模型dnew = xgb.DMatrix(new_samples)predictions = model.predict(dnew)# 早停法訓練的模型不直接提供概率,需要額外處理print("注意: 早停法訓練的XGBoost模型不直接提供概率輸出")predictions_proba = Noneelse:# 對于標準sklearn接口的模型predictions = model.predict(new_samples)predictions_proba = model.predict_proba(new_samples)print("\n=== 新樣本預測 ===")for i, sample in enumerate(new_samples):predicted_class = target_names[int(predictions[i])]print(f"樣本 {i+1} {sample}:")print(f"  預測類別: {predicted_class}")if predictions_proba is not None:print(f"  類別概率: {dict(zip(target_names, predictions_proba[i].round(4)))}")return predictions, predictions_proba# 11. 主函數 - 整合所有模塊
def main():"""主函數,整合所有模塊"""# 加載數據X, y, feature_names, target_names = load_data()print("=== 鳶尾花數據集 ===")print(f"數據集形狀: {X.shape}")print(f"特征名稱: {feature_names}")print(f"類別名稱: {target_names}")# 準備數據X_train, X_test, y_train, y_test = prepare_data(X, y)# 訓練隨機森林模型rf_model = train_random_forest(X_train, y_train)rf_accuracy, rf_pred = evaluate_model(rf_model, X_test, y_test, "隨機森林")# 訓練XGBoost模型xgb_model = train_xgboost(X_train, y_train)xgb_accuracy, xgb_pred = evaluate_model(xgb_model, X_test, y_test, "XGBoost")# 交叉驗證比較models = {'隨機森林': rf_model,'XGBoost': xgb_model}cv_results = compare_cv_models(models, X, y)# 特征重要性可視化plot_feature_importance(models, feature_names)# 詳細分類報告print("\n=== XGBoost詳細分類報告 ===")print(classification_report(y_test, xgb_pred, target_names=target_names))# 高級功能:早停法訓練xgb_early_model = train_xgboost_early_stopping(X_train, y_train, X_test, y_test)# 進行預測new_samples = [[5.1, 3.5, 1.4, 0.2],  # 很可能為setosa[6.7, 3.0, 5.2, 2.3]   # 很可能為virginica]predictions, predictions_proba = make_predictions(xgb_model, new_samples, target_names)return {'rf_model': rf_model,'xgb_model': xgb_model,'xgb_early_model': xgb_early_model,'rf_accuracy': rf_accuracy,'xgb_accuracy': xgb_accuracy,'cv_results': cv_results,'predictions': predictions}# 12. 執行主程序
if __name__ == "__main__":results = main()

代碼解析與學習要點

  1. 參數對比:
    ○ XGBoost有與隨機森林相似的參數(n_estimators, max_depth)
    ○ 但也有特有參數如learning_rate(學習率),控制每棵樹的貢獻程度
  2. 性能比較:
    ○ 代碼中比較了兩種算法的準確率和交叉驗證結果
    ○ 通常情況下,XGBoost會略優于隨機森林
  3. 特征重要性:
    ○ 可視化對比兩種算法計算的特征重要性
    ○ 注意:兩種算法計算重要性的方法不同,結果可能有差異
  4. 高級功能:
    ○ 演示了早停法(Early Stopping),這是防止過擬合的重要技術
    ○ 展示了DMatrix數據格式,這是XGBoost的高效數據容器
  5. 預測概率:
    ○ XGBoost可以提供每個類別的預測概率,這對于不確定性分析很有用

代碼運行結果

=== 鳶尾花數據集 ===
數據集形狀: (150, 4)
特征名稱: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
類別名稱: ['setosa' 'versicolor' 'virginica']
訓練集大小: 120
測試集大小: 30=== 隨機森林模型訓練完成 ===
使用參數: {'n_estimators': 100, 'max_depth': 3, 'random_state': 42}=== 隨機森林性能 ===
測試集準確率: 0.9667=== XGBoost模型訓練完成 ===
使用參數: {'n_estimators': 100, 'max_depth': 3, 'learning_rate': 0.1, 'random_state': 42, 'use_label_encoder': False, 'eval_metric': 'logloss'}=== XGBoost性能 ===
測試集準確率: 0.9333=== 交叉驗證比較 ===
隨機森林 交叉驗證平均分: 0.9667 (±0.0211)
XGBoost 交叉驗證平均分: 0.9467 (±0.0267)=== XGBoost詳細分類報告 ===precision    recall  f1-score   supportsetosa       1.00      1.00      1.00        10versicolor       0.90      0.90      0.90        10virginica       0.90      0.90      0.90        10accuracy                           0.93        30macro avg       0.93      0.93      0.93        30
weighted avg       0.93      0.93      0.93        30=== 早停法訓練完成 ===
在 33 輪停止
最佳驗證分數: 0.1948=== 新樣本預測 ===
樣本 1 [5.1, 3.5, 1.4, 0.2]:預測類別: setosa類別概率: {'setosa': 0.9911, 'versicolor': 0.0067, 'virginica': 0.0023}
樣本 2 [6.7, 3.0, 5.2, 2.3]:預測類別: virginica類別概率: {'setosa': 0.0019, 'versicolor': 0.0025, 'virginica': 0.9956}

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/95639.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/95639.shtml
英文地址,請注明出處:http://en.pswp.cn/web/95639.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【ARMv7】開篇:掌握ARMv7架構Soc開發技能

本專欄,開始與大家共同總結使用ARMv7系列CPU的Soc開發技能。大概匯總了一下,后面再逐步完善下面的思維導圖。簡單說說:與通用的ARMv7-A/R相比,以STM32F為代表的ARMv7-M架構有以下關鍵區別和重點:無MMU,有MP…

【學術會議論文投稿】JavaScript在數據可視化領域的探索與實踐

【ACM出版 | EI快檢索 | 高錄用】2024年智能醫療與可穿戴智能設備國際學術會議(SHWID 2024)_艾思科藍_學術一站式服務平臺 更多學術會議請看 學術會議-學術交流征稿-學術會議在線-艾思科藍 目錄 引言 JavaScript可視化庫概覽 D3.js基礎入門 1. 引入…

CSS基礎學習步驟

好的,這是一份為零基礎初學者量身定制的 **CSS 學習基礎詳細步驟**。我們將從最根本的概念開始,通過一步一步的實踐,帶你穩穩地入門。 第一步:建立核心認知 - CSS 是做什么的? 1. 理解角色: HTML&…

MTK Linux DRM分析(三十七)- MTK phy-mtk-hdmi.c 和 phy-mtk-hdmi-mt8173.c

一、簡介 HDMI PHY驅動 HDMI 的物理層接口主要就是 HDMI Type-A 接口(19 pin),除此之外還有 Type-B、Type-C(Mini HDMI)、Type-D(Micro HDMI)、Type-E(車載專用)。 1. HDMI Type-A(常見 19-pin 標準接口) HDMI Type-A Connector Pinout ========================…

【人工智能學習之MMdeploy部署踩坑總結】

【人工智能學習之MMdeploy部署踩坑總結】報錯1:TRTNet: device must be a GPU!報錯2:Failed to create Net backend: tensorrt報錯3:Failed to load library libonnxruntime_providers_shared.so1. 確認庫文件是否存在2. 重新安裝 ONNX Runti…

力扣516 代碼隨想錄Day16 第一題

找二叉樹左下角的值class Solution { public:int maxd0;int result;void traversal(TreeNode* root,int depth){if(root->leftNULL&&root->rightNULL){if(depth>maxd){maxddepth;resultroot->val;}}if(root->left){depth;traversal(root->left,depth…

網格圖--Day07--網格圖DFS--LCP 63. 彈珠游戲,305. 島嶼數量 II,2061. 掃地機器人清掃過的空間個數,489. 掃地機器人,2852. 所有單元格的遠離程度之和

網格圖–Day07–網格圖DFS–LCP 63. 彈珠游戲,305. 島嶼數量 II,2061. 掃地機器人清掃過的空間個數,489. 掃地機器人,2852. 所有單元格的遠離程度之和 今天要訓練的題目類型是:【網格圖DFS】,題單來自靈茶山…

多功能修改電腦機器碼序列號工具 綠色版

多功能修改電腦機器碼序列號工具 綠色版電腦機器碼序列號修改軟件是一款非常使用的數據化虛擬修改工具。機器碼修改軟件可以虛擬的定制您電腦上的硬件信息,軟件不會對您的電腦造成傷害。軟件不需要您有專業的知識,就可以模擬一份硬件信息。機器碼修改軟…

React Hooks深度解析:useState、useEffect及自定義Hook最佳實踐

React Hooks自16.8版本引入以來,徹底改變了我們編寫React組件的方式。它們讓函數組件擁有了狀態管理和生命周期方法的能力,使代碼更加簡潔、可復用且易于測試。本文將深入探討三個最重要的Hooks:useState、useEffect,以及如何創建…

期權平倉后權利金去哪了?

本文主要介紹期權平倉后權利金去哪了?期權平倉后權利金的去向需結合交易角色(買方/賣方)、平倉方式及市場價格變動綜合分析,具體可拆解為以下邏輯鏈條。期權平倉后權利金去哪了?1. 買方平倉:權利金的“差價…

2025國賽C題題目及最新思路公布!

C 題 NIPT 的時點選擇與胎兒的異常判 問題 1 試分析胎兒 Y 染色體濃度與孕婦的孕周數和 BMI 等指標的相關特性,給出相應的關系模 型,并檢驗其顯著性。 思路1:針對附件中孕婦的 NIPT 數據,首先對數據進行預處理,并對多…

NLP技術爬取

“NLP技術爬取”這個詞組并不指代一種單獨的爬蟲技術,而是指將自然語言處理(NLP)技術應用于網絡爬蟲的各個環節,以解決傳統爬蟲難以處理的問題,并從中挖掘出更深層次的價值。簡單來說,它不是指“用NLP去爬”…

讓錄音變得清晰的軟件:語音降噪AI模型與工具推薦

在數字內容創作日益普及的今天,無論是播客、線上課程、視頻口播,還是遠程會議,清晰的錄音質量都是提升內容專業度和觀眾體驗的關鍵因素之一。然而,由于環境噪音、設備限制等因素,錄音中常常夾雜各種干擾聲音。本文將介…

大話 IOT 技術(1) -- 架構篇

文章目錄前言拋出問題現有條件初步設想HTTP 與 MQTT中間的服務端完整的鏈路測試的虛擬設備實現后話當你迷茫的時候,請點擊 物聯網目錄大綱 快速查看前面的技術文章,相信你總能找到前行的方向 前言 Internet of Things (IoT) 就是物聯網,萬物…

【wpf】WPF 自定義控件綁定數據對象的最佳實踐

WPF 自定義控件綁定數據對象的最佳實踐:以 ImageView 為例 在 WPF 中開發自定義控件時,如何優雅地綁定數據對象,是一個經常遇到的問題。最近在實現一個自定義的 ImageView 控件時,我遇到了一個典型場景: 控件內部需要使…

[Dify 專欄] 如何通過 Prompt 在 Dify 中模擬 Persona:即便沒有專屬配置,也能讓 AI 扮演角色

在 AI 應用開發中,“Persona(角色扮演)”常被視為塑造 AI 個性與專業邊界的重要手段。然而,許多開發者在使用 Dify 時會疑惑:為什么我在 Chat 應用 / Agent 應用 / Workflow 里都找不到所謂的 Persona 配置項? 答案是:Dify 平臺目前并沒有內建的 Persona 配置入口。角色…

解決雙向循環鏈表中對存儲數據進行奇偶重排輸出問題

1. 概念 對鏈表而言,雙向均可遍歷是最方便的,另外首尾相連循環遍歷也可大大增加鏈表操作的便捷性。因此,雙向循環鏈表,是在實際運用中是最常見的鏈表形態。 2. 基本操作 與普通的鏈表完全一致,雙向循環鏈表雖然指針較多,但邏輯是完全一樣。基本的操作包括: 節點設計 初…

Kubernetes集群升級與etcd備份恢復指南

目錄 Kubernetes etcd備份恢復 集群管理命令 環境變量 查看etcd版本 查看etcd集群節點信息 查看集群健康狀態 查看告警事件 添加成員(單節點部署的etcd無法直接擴容)(不用做) 更新成員 刪除成員 數據庫操作命令 增加(put) 查詢(get) 刪除(…

【LeetCode熱題100道筆記】旋轉圖像

題目描述 給定一個 n n 的二維矩陣 matrix 表示一個圖像。請你將圖像順時針旋轉 90 度。 你必須在 原地 旋轉圖像,這意味著你需要直接修改輸入的二維矩陣。請不要 使用另一個矩陣來旋轉圖像。 示例 1:輸入:matrix [[1,2,3],[4,5,6],[7,8,9]…

SpringBoot【集成p6spy】使用p6spy-spring-boot-starter集成p6spy監控數據庫(配置方法舉例)

使用p6spy-spring-boot-starter集成p6spy監控數據庫1.簡單說明2.核心依賴3.主要配置4.簡單測試5.其他配置1.簡單說明 p6spy 類似于 druid 可以攔截 SQL 可以用于項目調試,直接引入 p6spy 的博文已經很多了,這里主要是介紹一下 springboot 使用 p6spy-sp…