LightGBM、XGBoost和CatBoost自定義損失函數和評估指標

LightGBM、XGBoost和CatBoost自定義損失函數和評估指標

    • 函數(縮放誤差)
    • 數學原理
      • 損失函數定義
      • 梯度計算
      • 評估指標
    • LightGBM實現
      • 自定義損失函數
      • 自定義評估指標
      • 使用方式
    • XGBoost實現
      • 自定義損失函數
      • 自定義評估指標
      • 使用方式
    • CatBoost實現
      • 自定義損失函數
      • 自定義評估指標
      • 使用方式
    • 框架對比
    • 實際應用
      • 適用場景
    • 常見問題
      • 1. 為什么要設置最小閾值?
      • 2. 梯度和Hessian計算錯誤怎么辦?
      • 3. 不同框架的性能差異
      • 4. 超參數調優建議

函數(縮放誤差)

傳統的均方誤差(MSE)和平均絕對誤差(MAE)對所有預測值給予相同的權重,但在某些場景下,更關心相對誤差而非絕對誤差。縮放誤差通過將誤差除以真實值來實現這一目標:

縮放誤差 = (真實值 - 預測值) / max(真實值, 閾值)

這樣設計的優勢:

  • 對于大數值和小數值的預測給予相對平等的權重
  • 避免大數值主導損失函數
  • 更適合預測范圍變化很大的場景

數學原理

損失函數定義

設損失函數為:

L(y, ?) = ((y - ?) / max(y, threshold))2

其中:

  • y 是真實值
  • ? 是預測值
  • threshold 是防止除零的最小閾值

梯度計算

對于梯度提升算法,我們需要計算損失函數對預測值的一階導數(梯度)和二階導數(Hessian):

d = max(y, threshold)e = (y - ?) / d

  • 一階導數(梯度)?L/?? = -2e/d
  • 二階導數(Hessian)?2L/??2 = 2/d2

評估指標

配套的評估指標使用縮放平均絕對誤差(Scaled MAE):

Scaled MAE = mean(|y - ?| / max(y, threshold))

LightGBM實現

自定義損失函數

def custom_loss_squared_lgb(y_pred, train_data):"""LightGBM自定義縮放均方誤差損失函數參數:y_pred: 預測值數組train_data: LightGBM的Dataset對象返回:tuple: (梯度數組, Hessian數組)"""y_true = train_data.get_label()  # 獲取真實標簽# 計算分母,防止除零denominator = np.maximum(y_true, threshold)# 計算縮放誤差error = (y_true - y_pred) / denominator# 計算梯度和Hessiangrad = -2 * error / denominatorhess = 2 / (denominator ** 2)return grad, hess

自定義評估指標

def mae_metric_lgb(preds, train_data):"""LightGBM自定義縮放MAE評估指標參數:preds: 預測值數組train_data: LightGBM的Dataset對象返回:tuple: (指標名稱, 指標值, 是否越大越好)"""y_true = train_data.get_label()denominator = np.maximum(y_true, threshold)error = np.abs(preds - y_true) / denominatorreturn 'scaled_mae', np.mean(error), False

使用方式

import lightgbm as lgb
import numpy as np# 參數配置
params = {'objective': custom_loss_squared_lgb,  # 使用自定義損失函數'boosting_type': 'gbdt','num_leaves': 31,'learning_rate': 0.01,'verbosity': -1
}# 訓練模型
model = lgb.train(params, train_set, valid_sets=[train_set, valid_set],feval=mae_metric_lgb,  # 使用自定義評估指標num_boost_round=1000,callbacks=[lgb.early_stopping(100)]
)

XGBoost實現

自定義損失函數

def custom_loss_squared_xgb(y_pred, train_data):"""XGBoost自定義縮放均方誤差損失函數參數:y_pred: 預測值數組train_data: XGBoost的DMatrix對象返回:tuple: (梯度數組, Hessian數組)"""y_true = train_data.get_label()  # 獲取真實標簽# 計算分母,防止除零denominator = np.maximum(y_true, threshold)# 計算縮放誤差error = (y_true - y_pred) / denominator# 計算梯度和Hessiangrad = -2 * error / denominatorhess = 2 / (denominator ** 2)return grad, hess

自定義評估指標

def mae_metric_xgb(y_pred, train_data):"""XGBoost自定義縮放MAE評估指標參數:y_pred: 預測值數組train_data: XGBoost的DMatrix對象返回:tuple: (指標名稱, 指標值)"""y_true = train_data.get_label()denominator = np.maximum(y_true, threshold)error = np.abs(y_true - y_pred) / denominatorreturn 'custom_mae', np.mean(error)

使用方式

import xgboost as xgb
import numpy as np# 參數配置
params = {'booster': 'gbtree','learning_rate': 0.01,'max_depth': 6,'random_state': 42
}# 訓練模型
model = xgb.train(params,train_matrix,num_boost_round=1000,evals=[(train_matrix, 'train'), (valid_matrix, 'valid')],obj=custom_loss_squared_xgb,  # 自定義損失函數feval=mae_metric_xgb,         # 自定義評估指標early_stopping_rounds=100,verbose_eval=50
)

CatBoost實現

CatBoost的自定義函數需要用類的形式實現。

自定義損失函數

class CustomCatBoostObjective(object):"""CatBoost自定義縮放均方誤差損失函數"""def calc_ders_range(self, approxes, targets, weights):"""計算梯度和Hessian參數:approxes: 當前預測值列表targets: 真實標簽列表weights: 樣本權重列表(可選)返回:list: [(梯度, Hessian), ...] 的列表"""assert len(approxes) == len(targets)if weights is not None:assert len(weights) == len(approxes)result = []for index in range(len(targets)):y_true = targets[index]y_pred = approxes[index]# 計算分母,防止除零denominator = max(y_true, threshold)# 計算縮放誤差error = (y_true - y_pred) / denominator# 計算梯度和Hessiangrad = -2 * error / denominatorhess = 2 / (denominator ** 2)# 應用樣本權重if weights is not None:grad *= weights[index]hess *= weights[index]result.append((grad, hess))return result

自定義評估指標

class CustomCatBoostEval(object):"""CatBoost自定義縮放MAE評估指標"""def is_max_optimal(self):"""指標是否越大越好"""return Falsedef evaluate(self, approxes, targets, weights):"""計算評估指標參數:approxes: 預測值列表的列表 [[pred1, pred2, ...]]targets: 真實標簽列表weights: 樣本權重列表(可選)返回:tuple: (誤差總和, 權重總和)"""assert len(approxes) == 1assert len(targets) == len(approxes[0])error_sum = 0.0weight_sum = 0.0for i in range(len(targets)):y_true = targets[i]y_pred = approxes[0][i]# 計算縮放誤差denominator = max(y_true, threshold)error = abs(y_true - y_pred) / denominator# 應用樣本權重if weights is not None:error *= weights[i]weight_sum += weights[i]else:weight_sum += 1.0error_sum += errorreturn error_sum, weight_sumdef get_final_error(self, error, weight):"""計算最終的評估指標值"""return error / (weight + 1e-38)

使用方式

from catboost import CatBoostRegressor, Pool
import numpy as np# 創建數據池
train_pool = Pool(X_train, y_train)
valid_pool = Pool(X_valid, y_valid)# 參數配置
params = {'objective': CustomCatBoostObjective(),'eval_metric': CustomCatBoostEval(),'iterations': 1000,'learning_rate': 0.01,'depth': 6,'random_state': 42,'verbose': False
}# 訓練模型
model = CatBoostRegressor(**params)
model.fit(train_pool,eval_set=valid_pool,early_stopping_rounds=100,verbose_eval=50,use_best_model=True
)

框架對比

特性LightGBMXGBoostCatBoost
損失函數形式函數函數類方法
參數名稱objectiveobjobjective
數據獲取train_data.get_label()dtrain.get_label()直接傳入 targets
評估指標形式函數函數類方法
評估返回格式(name, value, is_higher_better)(name, value)error_sum, weight_sum
權重支持自動處理自動處理需手動處理
實現復雜度簡單簡單中等

實際應用

適用場景

  1. 新能源功率預測:風電、光伏功率預測范圍從0到滿功率
  2. 金融風險評估:不同規模公司的風險評估
  3. 銷售預測:不同產品類別的銷售額預測
  4. 網絡流量預測:不同時段流量變化很大

常見問題

1. 為什么要設置最小閾值?

問題:直接用真實值作為分母會遇到什么問題?

答案

  • 當真實值為0或接近0時,會導致除零錯誤或梯度爆炸
  • 設置最小閾值可以保證數值穩定性
  • 閾值的選擇應根據數據的實際分布來確定

2. 梯度和Hessian計算錯誤怎么辦?

問題:如何驗證梯度計算的正確性?

答案:可以用數值微分驗證:

def verify_gradients(y_true, y_pred, eps=1e-6):"""驗證梯度計算的正確性"""# 解析梯度denominator = np.maximum(y_true, threshold)error = (y_true - y_pred) / denominatorgrad_analytical = -2 * error / denominator# 數值梯度loss_plus = ((y_true - (y_pred + eps)) / denominator) ** 2loss_minus = ((y_true - (y_pred - eps)) / denominator) ** 2grad_numerical = (loss_plus - loss_minus) / (2 * eps)# 比較diff = np.abs(grad_analytical - grad_numerical)print(f"最大梯度差異: {np.max(diff)}")return np.allclose(grad_analytical, grad_numerical, atol=1e-5)

3. 不同框架的性能差異

問題:三個框架在使用自定義損失函數時的性能如何?

答案

  • LightGBM:通常最快,內存效率高
  • XGBoost:穩定性好,文檔完善
  • CatBoost:對類別特征處理好,但自定義函數實現相對復雜

4. 超參數調優建議

# LightGBM調優示例
from optuna import create_studydef objective(trial):params = {'objective': custom_loss_squared_lgb,'boosting_type': 'gbdt','num_leaves': trial.suggest_int('num_leaves', 10, 100),'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.1),'feature_fraction': trial.suggest_float('feature_fraction', 0.4, 1.0),'bagging_fraction': trial.suggest_float('bagging_fraction', 0.4, 1.0),'verbosity': -1}model = lgb.train(params,train_data,valid_sets=[valid_data],feval=mae_metric_lgb,num_boost_round=1000,callbacks=[lgb.early_stopping(100)],verbose_eval=False)y_pred = model.predict(X_valid)scaled_mae = np.mean(np.abs(y_valid - y_pred) / np.maximum(y_valid, threshold))return scaled_maestudy = create_study(direction='minimize')
study.optimize(objective, n_trials=100)

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/922207.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/922207.shtml
英文地址,請注明出處:http://en.pswp.cn/news/922207.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

2025-09-08升級問題記錄: 升級SDK從Android11到Android12

將 Android 工程的 targetSdkVersion 從 30 (Android 11)升級到 31(Android 12)需要關注一些重要的行為變更和適配點。 主要適配要點: 適配類別關鍵變更點適配緊迫性簡要說明組件導出屬性聲明了 Intent Filter 的組件…

利用OpenCV實現模板與多個對象匹配

代碼實現:import cv2 import numpy as npimg_rgb cv2.imread(mobanpipei.jpg) img_gray cv2.cvtColor(img_rgb, cv2.COLOR_BGR2GRAY) template cv2.imread(jianto.jpg, flags0) h, w template.shape[:2]# 讀取圖像# # 順時針旋轉 90 度(k1&#xff0…

OS28.【Linux】自制簡單的Shell的修bug記錄

目錄 1.問題代碼 2.排查 前期檢查 查找是誰修改了environ[0] 使用gdb下斷點 查看后續的影響 分析出問題的split_commandline函數 3.反思 4.正確代碼 5.結論 6.除此之外...... ★提示: 此bug非常隱蔽,不仔細分析很難查出問題,非常鍛煉調試能力! 1.問題代碼 #includ…

Debian 系統上安裝與配置 MediaMTX

🎯 在 Debian 系統上安裝與配置 MediaMTX(原 rtsp-simple-server):打造輕量級流媒體服務器 作者:遠在太平洋 環境:Debian 10/11/12 | Ubuntu 可參考 關鍵詞:MediaMTX、rtsp-simple-server、RTSP…

分布式專題——10.4 ShardingSphere-Proxy服務端分庫分表

1 為什么要有服務端分庫分表? ShardingSphere-Proxy 是 ShardingSphere 提供的服務端分庫分表工具,定位是“透明化的數據庫代理”。 它模擬 MySQL 或 PostgreSQL 的數據庫服務,應用程序(Application)只需像訪問單個數據…

Mysql相關的面試題1

什么是聚集索引(聚簇索引)?什么是二級索引(非聚簇索引)? 聚集索引就是葉子節點關聯行數據的索引,二級索引就是葉子節點關聯主鍵的索引,聚集索引必須有且僅有一個,二級索引…

電涌保護器:為現代生活筑起一道隱形防雷網

何為電涌保護器?電涌保護器(Surge Protective Device,簡稱SPD)主要用于控制信號系統,保護電氣電子設備信號線路免受雷電電磁脈沖、感應過電壓、操作過電壓的影響,廣泛應用于工控、消防、安防監控、交通、電…

【uniapp微信小程序】掃普通鏈接二維碼打開小程序

需求:用戶A保存自己的邀請碼海報,用戶B掃描該普通連接二維碼,打開微信小程序,并且攜帶用戶A的邀請碼信息,用戶B登錄時,跟用戶A關聯,成為用戶A的下級。 tips:保存海報到手機相冊可以參…

LeetCode 378 - 有序矩陣中第 K 小的元素

文章目錄摘要描述題解答案題解代碼分析代碼解析示例測試及結果輸出結果時間復雜度空間復雜度總結摘要 在開發中,我們經常遇到需要處理大規模有序數據的場景,比如數據庫分頁、排行榜查詢、或者處理排序過的矩陣。LeetCode 第 378 題“有序矩陣中第 K 小的…

【Lua】Windows 下編寫 C 擴展模塊:VS 編譯與 Lua 調用全流程

? 目錄 ?🛫 導讀需求環境1?? 核心原理:Windows下Lua與C的交互邏輯2?? Windows下編寫步驟:以mymath模塊為例2.1 步驟1:準備Windows開發環境方式1:官網下載Lua源碼并編譯(可控性高)方式2&am…

Python快速入門專業版(二十九):函數返回值:多返回值、None與函數嵌套調用

目錄引一、多返回值:一次返回多個結果的優雅方式1. 多返回值的本質:隱式封裝為元組示例1:返回多個值的函數及接收方式2. 多返回值的接收技巧技巧1:用下劃線_忽略不需要的返回值技巧2:用*接收剩余值(Python …

python使用pip安裝的包與卸載

1:基本卸載命令 # 卸載單個包 pip uninstall package_name# 示例:卸載requests包 pip uninstall requests2:卸載多個包 # 一次性卸載多個包 pip uninstall package1 package2 package3# 示例 pip uninstall requests numpy pandas3&#xff1…

超級流水線和標量流水線的原理

一、什么是流水線?要理解這兩個概念,首先要明白流水線(Pipelining) 的基本思想。想象一個汽車裝配工廠:* 沒有流水線:一個工人負責組裝一整輛汽車,裝完一輛再裝下一輛。效率很低。* 有了流水線&…

【Ansible】管理復雜的Play和Playbook知識點

1.什么是主機模式?答:主機模式是Ansible中用于從Inventory中篩選目標主機的規則,通過靈活的模式定義可精準定位需要執行任務的主機。2.主機模式的作用答:篩選目標:從主機清單中選擇一個或多個主機/組,作為P…

FastGPT源碼解析 Agent 智能體應用創建流程和代碼分析

FastGPT對話智能體創建流程和代碼分析 平臺作為agent平臺,平臺所有功能都是圍繞Agent創建和使用為核心的。平臺整合各種基礎能力,如大模型、知識庫、工作流、插件等模塊,通過可視化,在界面上創建智能體,使用全部基礎能…

缺失數據處理全指南:方法、案例與最佳實踐

如何處理缺失數據:方法、案例與最佳實踐 1. 引言 在數據分析和機器學習中,缺失數據是一個普遍存在的問題。如何處理缺失值,往往直接影響到后續分析和建模的效果。處理不當,不僅會浪費數據,還可能導致模型預測結果的不準…

為什么Cesium不使用vue或者react,而是 保留 Knockout

1. Knockout-ES5 插件的語法簡化優勢 自動深度監聽:Cesium 通過集成 Knockout-ES5 插件,允許開發者直接使用普通變量語法(如 viewModel.property newValue)替代繁瑣的 observable() 包裝,無需手動聲明每個可觀察屬性。…

Word怎么設置頁碼總頁數不包含封面和目錄頁

有時候使用頁碼格式是[第x頁/共x頁]或[x/x]時會遇到word總頁數和實際想要的頁數不一致,導致顯示不統一,這里介紹一個簡單的辦法,適用于比較簡單的情況。 一、wps版本 文章分節 首先將目錄頁與正文頁進行分節:在目錄頁后面選擇插入…

突破機器人通訊架構瓶頸,CAN/FD、高速485、EtherCAT,哪種總線才是最優解?

引言: 從協作機械臂到人形機器人,一文拆解主流總線技術選型困局 在機器人技術飛速發展的今天,從工廠流水線上的協作機械臂到科技展會上的人形機器人,它們的“神經系統”——通訊總線,正面臨著前所未有的挑戰。特斯拉O…

Java核心概念詳解:JVM、JRE、JDK、Java SE、Java EE (Jakarta EE)

1. Java是什么? Java首先是一種編程語言。它擁有特定的語法、關鍵字和結構,開發者可以用它來編寫指令,讓計算機執行任務。核心特點: Java最著名的特點是“一次編寫,到處運行”(Write Once, Run Anywhere - …