XGB-16:自定義目標和評估指標

概述

XGBoost被設計為一個可擴展的庫。通過提供自定義的訓練目標函數和相應的性能監控指標,可以擴展它。本文介紹了如何為XGBoost實現自定義的逐元評估指標和目標。

注意:

排序不能自定義

在接下來的兩個部分中,將逐步介紹如何實現平方對數誤差(Squared Log Error,SLE)目標函數

1 2 [ log ? ( p r e d + 1 ) ? log ? ( l a b e l + 1 ) ] 2 \frac{1}{2}[\log(pred + 1) - \log(label + 1)]^2 21?[log(pred+1)?log(label+1)]2

以及它的默認評估指標均方根對數誤差(Root Mean Squared Log Error,RMSLE

1 N [ log ? ( p r e d + 1 ) ? log ? ( l a b e l + 1 ) ] 2 \sqrt{\frac{1}{N}[\log(pred + 1) - \log(label + 1)]^2} N1?[log(pred+1)?log(label+1)]2 ?

定制目標函數

盡管XGBoost本身已經原生支持這些功能,但為了演示的目的,使用它來比較自己實現的結果和XGBoost內部實現的結果。完成本教程后,應該能夠為快速實驗提供自己的函數。最后,將提供一些關于非恒等鏈接函數的注釋,以及在scikit-learn接口中使用自定義度量和目標的示例。

如果計算所述目標函數的梯度:

g = ? o b j e c t i v e ? p r e d = log ? ( p r e d + 1 ) ? log ? ( l a b e l + 1 ) p r e d + 1 g = \frac{\partial{objective}}{\partial{pred}} = \frac{\log(pred + 1) - \log(label + 1)}{pred + 1} g=?pred?objective?=pred+1log(pred+1)?log(label+1)?

以及 hessian(目標的二階導數):

h = ? 2 o b j e c t i v e ? p r e d = ? log ? ( p r e d + 1 ) + log ? ( l a b e l + 1 ) + 1 ( p r e d + 1 ) 2 h = \frac{\partial^2{objective}}{\partial{pred}} = \frac{ - \log(pred + 1) + \log(label + 1) + 1}{(pred + 1)^2} h=?pred?2objective?=(pred+1)2?log(pred+1)+log(label+1)+1?

在模型訓練過程中,目標函數起著重要的作用:基于模型預測和觀察到的數據標簽(或目標),提供梯度信息,包括一階和二階梯度。因此,有效的目標函數應接受兩個輸入,即預測值和標簽。對于實現SLE,定義:

import numpy as np
import xgboost as xgb
from typing import Tupledef gradient(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:'''Compute the gradient squared log error.'''y = dtrain.get_label()return (np.log1p(predt)-np.log1p(y)) / (predt+1)def hessian(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:'''Compute the hessian for squared log error.'''y = dtrain.get_label()return ((-np.log1p(predt)+np.log1p(y)+1) /np.power(predt+1, 2))def squared_log(predt: np.ndarray,dtrain: xgb.DMatrix) -> Tuple[np.ndarray, np.ndarray]:'''Squared Log Error objective. A simplified version for RMSLE used asobjective function.'''predt[predt < -1] = -1 + 1e-6grad = gradient(predt, dtrain)hess = hessian(predt, dtrain)return grad, hess

在上面的代碼片段中,squared_log是想要的目標函數。它接受一個numpy數組predt作為模型預測值,以及用于獲取所需信息的訓練DMatrix,包括標簽和權重(此處未使用)。然后,在訓練過程中,通過將其作為參數傳遞給xgb.train,將此目標函數用作XGBoost的回調函數:

xgb.train({'tree_method': 'hist', 'seed': 1994},   # any other tree method is fine.dtrain=dtrain,num_boost_round=10,obj=squared_log)

注意,在定義目標函數時,從預測值中減去標簽或從標簽中減去預測值,這是很重要的。如果發現訓練錯誤上升而不是下降,這可能是原因。

定制度量函數

因此,在擁有自定義目標函數之后,還需要一個相應的度量標準來監控模型的性能。如上所述,SLE 的默認度量標準是 RMSLE。同樣,定義另一個類似的回調函數作為新的度量標準:

def rmsle(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:''' Root mean squared log error metric.'''y = dtrain.get_label()predt[predt < -1] = -1 + 1e-6elements = np.power(np.log1p(y) - np.log1p(predt), 2)return 'PyRMSLE', float(np.sqrt(np.sum(elements) / len(y)))

與目標函數類似,度量也接受 predtdtrain 作為輸入,但返回度量本身的名稱和一個浮點值作為結果。將其作為 custom_metric 參數傳遞給 XGBoost:

xgb.train({'tree_method': 'hist', 'seed': 1994,'disable_default_eval_metric': 1},dtrain=dtrain,num_boost_round=10,obj=squared_log,custom_metric=rmsle,evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],evals_result=results)

能夠看到 XGBoost 打印如下內容:

[0] dtrain-PyRMSLE:1.37153  dtest-PyRMSLE:1.31487
[1] dtrain-PyRMSLE:1.26619  dtest-PyRMSLE:1.20899
[2] dtrain-PyRMSLE:1.17508  dtest-PyRMSLE:1.11629
[3] dtrain-PyRMSLE:1.09836  dtest-PyRMSLE:1.03871
[4] dtrain-PyRMSLE:1.03557  dtest-PyRMSLE:0.977186
[5] dtrain-PyRMSLE:0.985783 dtest-PyRMSLE:0.93057
...

注意,參數 disable_default_eval_metric 用于禁用 XGBoost 中的默認度量。

完整可復制的源代碼參閱定義自定義回歸目標和度量的演示。

轉換鏈接函數

在使用內置目標函數時,原始預測值會根據目標函數進行轉換。當提供自定義目標函數時,XGBoost 不知道其鏈接函數,因此用戶需要對目標和自定義評估度量進行轉換。對于具有身份鏈接的目標,如平方誤差squared error,這很簡單,但對于其他鏈接函數,如對數鏈接或反鏈接,差異很大。

在 Python 包中,可以通過 predict 函數中的 output_margin 參數來控制預測的行為。當使用 custom_metric 參數而沒有自定義目標函數時,度量函數將接收經過轉換的預測,因為目標是由 XGBoost 定義的。然而,當同時提供自定義目標和度量時,目標和自定義度量都將接收原始預測。以下示例比較了多類分類模型中兩種不同的行為。首先,我們定義了兩個不同的 Python 度量函數,實現了相同的底層度量以進行比較。其中 merror_with_transform 在同時使用自定義目標時使用,否則會使用更簡單的 merror,因為 XGBoost 可以自行執行轉換。

import xgboost as xgb
import numpy as npdef merror_with_transform(predt: np.ndarray, dtrain: xgb.DMatrix):"""Used when custom objective is supplied."""y = dtrain.get_label()n_classes = predt.size // y.shape[0]# Like custom objective, the predt is untransformed leaf weight when custom objective# is provided.# With the use of `custom_metric` parameter in train function, custom metric receives# raw input only when custom objective is also being used.  Otherwise custom metric# will receive transformed prediction.assert predt.shape == (d_train.num_row(), n_classes)out = np.zeros(dtrain.num_row())for r in range(predt.shape[0]):i = np.argmax(predt[r])out[r] = iassert y.shape == out.shapeerrors = np.zeros(dtrain.num_row())errors[y != out] = 1.0return 'PyMError', np.sum(errors) / dtrain.num_row()

僅當想要使用自定義目標并且 XGBoost 不知道如何轉換預測時才需要上述函數。多類誤差函數的正常實現是:

def merror(predt: np.ndarray, dtrain: xgb.DMatrix):"""Used when there's no custom objective."""# No need to do transform, XGBoost handles it internally.errors = np.zeros(dtrain.num_row())errors[y != out] = 1.0return 'PyMError', np.sum(errors) / dtrain.num_row()

接下來需要自定義 softprob 目標:

def softprob_obj(predt: np.ndarray, data: xgb.DMatrix):"""Loss function.  Computing the gradient and approximated hessian (diagonal).Reimplements the `multi:softprob` inside XGBoost."""# Full implementation is available in the Python demo script linked below...return grad, hess

最后可以使用 objcustom_metric 參數訓練模型:

Xy = xgb.DMatrix(X, y)
booster = xgb.train({"num_class": kClasses, "disable_default_eval_metric": True},m,num_boost_round=kRounds,obj=softprob_obj,custom_metric=merror_with_transform,evals_result=custom_results,evals=[(m, "train")],
)

如果不需要自定義目標,只是想提供一個XGBoost中不可用的指標:

booster = xgb.train({"num_class": kClasses,"disable_default_eval_metric": True,"objective": "multi:softmax",},m,num_boost_round=kRounds,# Use a simpler metric implementation.custom_metric=merror,evals_result=custom_results,evals=[(m, "train")],
)

使用multi:softmax來說明轉換后預測的差異。使用softprob時,輸出預測數組的形狀是(n_samples, n_classes),而對于softmax,它是(n_samples, )。關于多類目標函數的示例也可以在創建自定義多類目標函數的示例中找到。此外,更多解釋請參見Intercept。

Scikit-Learn 接口

XGBoost的scikit-learn接口提供了一些工具,以改善與標準的scikit-learn函數的集成,用戶可以直接使用scikit-learn的成本函數(而不是評分函數):

from sklearn.datasets import load_diabetes
from sklearn.metrics import mean_absolute_errorX, y = load_diabetes(return_X_y=True)
reg = xgb.XGBRegressor(tree_method="hist",eval_metric=mean_absolute_error,
)
reg.fit(X, y, eval_set=[(X, y)])

對于自定義目標函數,用戶可以在不訪問DMatrix的情況下定義目標函數:

def softprob_obj(labels: np.ndarray, predt: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:rows = labels.shape[0]classes = predt.shape[1]grad = np.zeros((rows, classes), dtype=float)hess = np.zeros((rows, classes), dtype=float)eps = 1e-6for r in range(predt.shape[0]):target = labels[r]p = softmax(predt[r, :])for c in range(predt.shape[1]):g = p[c] - 1.0 if c == target else p[c]h = max((2.0 * p[c] * (1.0 - p[c])).item(), eps)grad[r, c] = ghess[r, c] = hgrad = grad.reshape((rows * classes, 1))hess = hess.reshape((rows * classes, 1))return grad, hessclf = xgb.XGBClassifier(tree_method="hist", objective=softprob_obj)

參考

  • https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html
  • https://xgboost.readthedocs.io/en/latest/python/examples/custom_rmsle.html#sphx-glr-python-examples-custom-rmsle-py

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/719681.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/719681.shtml
英文地址,請注明出處:http://en.pswp.cn/news/719681.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【EAI 027】Learning Interactive Real-World Simulators

Paper Card 論文標題&#xff1a;Learning Interactive Real-World Simulators 論文作者&#xff1a;Mengjiao Yang, Yilun Du, Kamyar Ghasemipour, Jonathan Tompson, Leslie Kaelbling, Dale Schuurmans, Pieter Abbeel 作者單位&#xff1a;UC Berkeley, Google DeepMind, …

【 Docker 容器詳細介紹和說明】

Docker 容器詳細介紹和說明 Docker 容器詳細介紹和說明Docker 安裝步驟&#xff08;以Ubuntu為例&#xff09;&#xff1a;使用Docker創建并運行容器&#xff1a;VSCode遠程連接Docker容器&#xff1a;步驟1&#xff1a;配置Docker環境步驟2&#xff1a;配置PyCharm步驟3&#…

日本發動全面侵華戰爭他們在怕什么?為何不敢動陜西,

日本全面侵華戰爭之謎&#xff1a;恐懼與野心的交織 在二十世紀三十年代&#xff0c;日本帝國主義以令人發指的暴行和殘忍手段&#xff0c;對中國發動了全面侵華戰爭。然而&#xff0c;在這場戰爭中&#xff0c;有一個引人關注的現象&#xff1a;日本侵略者在進攻過程中&#…

python和nodejs一鍵安裝當前項目所有依賴

python和nodejs一鍵安裝當前項目所有依賴。群里有人問怎么快速安裝網上下載的源碼里面的依賴。所以在這里分享一下。更多問題可以自己加群917400262問我。 目錄導航 1.0 python一鍵安裝當前項目所有依賴2.0 nodejs一鍵安裝當前項目所有依賴 1.0 python一鍵安裝當前項目所有依賴…

snakemake: 基礎知識

為了有效地學習和使用 Snakemake&#xff0c;你需要具備一定的基礎知識。這些基礎知識將幫助你更好地理解 Snakemake 的工作原理和如何在你的項目中應用它。以下是學習 Snakemake 所需的一些基礎知識&#xff1a; 1. Python 編程 Snakemake 是用 Python 編寫的&#xff0c;并…

聊聊國內「類Sora模型」發展現狀,和 Sora 的差距到底有多大?

2024 年 2 月 16 日。 就在谷歌發布他新一代的多模態大模型 Gemini 1.5 Pro 的同一天&#xff0c;OpenAI 帶著新一代的文生視頻模型 Sora 再次抓住了全世界人們的眼球。 “顛覆”、“炸裂”、“變天”、“瘋狂”&#xff0c;類似的形容詞一夜之間簇擁在 Sora 周圍&#xff0c;…

網絡傳輸基本流程(封裝,解包)+圖解(同層直接通信的證明),報頭分離問題,協議定位問題,協議多路復用

目錄 網絡傳輸基本流程 引入 封裝 過程梳理 圖解 報文 解包 過程梳理 圖解 -- 同層直接通信的證明 總結 解包時的報頭分離問題 舉例 -- 倒水 介紹 自底向上傳輸時的協議定位問題 介紹 解決方法 協議多路復用 介紹 優勢 網絡傳輸基本流程 引入 首先,我們明確…

VS查看C++頭文件(.h文件)的函數列表

這里使用的是VS2019舉例 如下圖查看Actor.h文件中的函數列表 設置步驟如下圖

【d35】【Java】【力扣】28. 找出字符串中第一個匹配項的下標

題目 給你兩個字符串 haystack 和 needle &#xff0c;請你在 haystack 字符串中找出 needle 字符串的第一個匹配項的下標&#xff08;下標從 0 開始&#xff09;。如果 needle 不是 haystack 的一部分&#xff0c;則返回 -1 。 示例 1&#xff1a; 輸入&#xff1a;haystac…

【大數據】通過 docker-compose 快速部署 MinIO 保姆級教程

文章目錄 一、概述二、MinIO 與 Ceph 對比1&#xff09;架構設計對比2&#xff09;數據一致性對比3&#xff09;部署和管理對比4&#xff09;生態系統和兼容性對比 三、前期準備1&#xff09;部署 docker2&#xff09;部署 docker-compose 四、創建網絡五、MinIO 編排部署1&…

【SQL】608. 樹節點(流控制語句 CASE + IF語句)

前述 知識點推薦學習&#xff1a; sql中的 IF 條件語句的用法 MySQL&#xff1a;if語句、if…else語句、case語句&#xff0c;使用方法解析 題目描述 leetcode 題目&#xff1a;608. 樹節點 思路 關鍵點&#xff1a;如何確定有沒有子節點 根節點&#xff1a;父節點為空內節…

基于Redo log Undo log的MySQL的崩潰恢復

基于Redo log & Undo log的MySQL的崩潰恢復 Redo log Undo log Redo log 重做日志,記錄,修改過的數據 Undo log 回滾日志,記錄修改之前的數據 兩個我不做詳細的介紹了,redo log就是記錄哪些地方被修改了 undo log是記錄修改之前我們的數據長什么樣 更新流程 我們來捋一…

python封裝,繼承,復寫詳解

目錄 1.封裝 2.繼承 復寫和使用父類成員 1.封裝 class phone:__voltage 0.5def __keepsinglecore(self):print("單核運行")def callby5g(self):if self.__voltage > 1:print("5g通話開啟")else:self.__keepsinglecore()print("不能開啟5g通…

Redis集群(主從)

1.主從集群 集群結構: 一.單機安裝redis 1.上傳壓縮包并解壓&#xff0c;編譯 tar -xzf redis-6.2.4.tar.gz cd redis-6.2.4 make && make install 2.修改redis.config的配置并啟動redis # 綁定地址&#xff0c;默認是127.0.0.1&#xff0c;會導致只能在本地訪問。…

Tomcat布署及優化-----JDK和Tomcat

1.Tomcat簡介 Tomcat 是 Java 語言開發的&#xff0c;Tomcat 服務器是一個免費的開放源代碼的 Web 應用服務器&#xff0c;Tomcat 屬于輕量級應用服務器&#xff0c;在中小型系統和并發訪問用戶不是很多的場合下被普遍使用&#xff0c;是開發和調試 JSP 程序的首選。一般來說&…

C++ //練習 10.2 重做上一題,但讀取string序列存入list中。

C Primer&#xff08;第5版&#xff09; 練習 10.2 練習 10.2 重做上一題&#xff0c;但讀取string序列存入list中。 環境&#xff1a;Linux Ubuntu&#xff08;云服務器&#xff09; 工具&#xff1a;vim 代碼塊 /******************************************************…

Vue前端加密后的數據發送到服務器端

首先&#xff0c;定義了一個名為 PUBLIC_KEY 的公鑰和一個名為 PRIVATE_KEY 的私鑰。然后&#xff0c;通過 JSEncrypt 創建了兩個實例 encrypt 和 decrypt&#xff0c;分別用于加密和解密操作。 對于加密操作&#xff0c;調用了 encrypt.setPublicKey() 方法設置公鑰&#xff…

升級Centos7的openssh到openssh-9.6p1版本 shell腳本 漏掃整改

升級Centos7的openssh到openssh-9.6p1版本 shell腳本 漏掃整改 #!/bin/bash# 聲明: 該腳本適用于升級Centos7的openssh到openssh-9.6p1版本# 定義源碼包版本號 OPENSSH_VERSIONopenssh-9.6p1 OPENSSL_VERSIONopenssl-3.2.1 ZILB_VERSIONzlib-1.3.1# 安裝編譯環境 yum -y insta…

【前端面試題5】利用 border 屬性畫一個三角形

舉例1&#xff1a;利用 border 屬性畫一個三角形&#xff08;小技巧&#xff09; 完整代碼如下&#xff1a; div{width: 0;height: 0;border: 50px solid transparent;border-top-color: red;border-bottom: none; }步驟如下&#xff1a; &#xff08;1&#xff09;當我們設…

【QT+QGIS跨平臺編譯】之五十六:【QGIS_CORE跨平臺編譯】—【qgsmeshcalclexer.cpp生成】

文章目錄 一、Flex二、生成來源三、構建過程一、Flex Flex (fast lexical analyser generator) 是 Lex 的另一個替代品。它經常和自由軟件 Bison 語法分析器生成器 一起使用。Flex 最初由 Vern Paxson 于 1987 年用 C 語言寫成。 “flex 是一個生成掃描器的工具,能夠識別文本中…