Python28-11 CatBoost梯度提升算法

圖片

CatBoost(Categorical Boosting)是由Yandex(一家俄羅斯互聯網企業,旗下的搜索引擎曾在俄國內擁有逾60%的市場占有率,同時也提供其他互聯網產品和服務)開發的一種基于梯度提升的機器學習算法。CatBoost特別擅長處理類別特征,并且能夠有效地避免過擬合和數據泄露問題。CatBoost的全稱是“Categorical Boosting”,它的設計初衷是為了在處理包含大量類別特征的數據時表現得更好。

CatBoost的特點

  1. 處理類別特征:CatBoost可以直接處理類別特征而不需要進行額外的編碼(如one-hot編碼)。

  2. 避免過擬合:CatBoost采用了一種新的處理類別特征的方法,有效地減少了過擬合。

  3. 高效性:CatBoost在訓練速度和預測速度方面都表現優異。

  4. 支持CPU和GPU訓練:CatBoost既可以在CPU上運行,也可以利用GPU進行加速訓練。

  5. 自動處理缺失值:CatBoost可以自動處理缺失值,無需額外的預處理步驟。

CatBoost的核心原理

CatBoost的核心原理基于梯度提升決策樹(GBDT),但在處理類別特征和避免過擬合方面進行了創新。以下是一些關鍵的技術點:

  1. 類別特征處理

    • CatBoost引入了一個稱為“均值編碼”的方法,基于類別的均值計算新特征。

    • 使用一種稱為“目標編碼”的技術,將類別特征轉化為數值特征時,通過使用目標值的平均值來減少數據泄露的風險。

    • 在訓練過程中,通過使用統計信息對數據進行處理,避免直接使用目標變量進行編碼。

  2. 有序提升(Ordered Boosting)

    • 為了防止數據泄露和過擬合,CatBoost在訓練時對數據進行了有序處理。

    • 有序提升通過在訓練過程中隨機打亂數據,并確保模型在某一時刻只看到過去的數據,而不會使用未來的信息進行決策。

  3. 計算優化

    • CatBoost通過預計算和緩存的方式加速了特征的計算過程。

    • 支持CPU和GPU訓練,能夠在大規模數據集上表現出色。

CatBoost的基本使用

以下是一個使用CatBoost進行分類任務的基本示例,我們使用Auto MPG(Miles Per Gallon)數據集,它是一個經典的回歸問題數據集,常用于機器學習和統計分析。該數據集記錄了不同型號汽車的燃油效率(即每加侖燃油行駛的英里數)以及其他多個相關特征。

數據集特征:

  • mpg: 每加侖燃油行駛的英里數(目標變量)。

  • cylinders: 氣缸數量,表示發動機的氣缸數。

  • displacement: 發動機排量(立方英寸)。

  • horsepower: 發動機功率(馬力)。

  • weight: 車輛重量(磅)。

  • acceleration: 0到60英里每小時的加速度時間(秒)。

  • model_year: 車輛生產年份。

  • origin: 車輛產地(1=美國,2=歐洲,3=日本)。

數據集前幾行:

????mpg??cylinders??displacement??horsepower??weight??acceleration??model_year??origin
0??18.0??????????8?????????307.0???????130.0??3504.0??????????12.0??????????70???????1
1??15.0??????????8?????????350.0???????165.0??3693.0??????????11.5??????????70???????1
2??18.0??????????8?????????318.0???????150.0??3436.0??????????11.0??????????70???????1
3??16.0??????????8?????????304.0???????150.0??3433.0??????????12.0??????????70???????1
4??17.0??????????8?????????302.0???????140.0??3449.0??????????10.5??????????70???????1

代碼示例:

import?pandas?as?pd??#?導入Pandas庫,用于數據處理
import?numpy?as?np??#?導入Numpy庫,用于數值計算
from?sklearn.model_selection?import?train_test_split??#?從sklearn庫導入train_test_split,用于劃分數據集
from?sklearn.metrics?import?mean_squared_error,?mean_absolute_error??#?導入均方誤差和平均絕對誤差,用于評估模型性能
from?catboost?import?CatBoostRegressor??#?導入CatBoost庫中的CatBoostRegressor,用于回歸任務
import?matplotlib.pyplot?as?plt??#?導入Matplotlib庫,用于繪圖
import?seaborn?as?sns??#?導入Seaborn庫,用于繪制統計圖#?設置隨機種子以便結果復現
np.random.seed(42)#?從UCI機器學習庫加載Auto?MPG數據集
url?=?"http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data"
column_names?=?['mpg',?'cylinders',?'displacement',?'horsepower',?'weight',?'acceleration',?'model_year',?'origin']
data?=?pd.read_csv(url,?names=column_names,?na_values='?',?comment='\t',?sep='?',?skipinitialspace=True)#?查看數據集的前幾行
print(data.head())#?處理缺失值
data?=?data.dropna()#?特征和目標變量
X?=?data.drop('mpg',?axis=1)??#?特征變量
y?=?data['mpg']??#?目標變量#?將類別特征轉換為字符串類型(CatBoost可以直接處理類別特征)
X['cylinders']?=?X['cylinders'].astype(str)
X['model_year']?=?X['model_year'].astype(str)
X['origin']?=?X['origin'].astype(str)#?將數據集劃分為訓練集和測試集
X_train,?X_test,?y_train,?y_test?=?train_test_split(X,?y,?test_size=0.2,?random_state=42)#?定義CatBoost回歸器
model?=?CatBoostRegressor(iterations=1000,??#?迭代次數learning_rate=0.1,??#?學習率depth=6,??#?決策樹深度loss_function='RMSE',??#?損失函數verbose=100??#?輸出訓練過程信息
)#?訓練模型
model.fit(X_train,?y_train,?eval_set=(X_test,?y_test),?early_stopping_rounds=50)#?進行預測
y_pred?=?model.predict(X_test)#?評估模型性能
mse?=?mean_squared_error(y_test,?y_pred)??#?計算均方誤差
mae?=?mean_absolute_error(y_test,?y_pred)??#?計算平均絕對誤差#?打印模型的評估結果
print(f'Mean?Squared?Error?(MSE):?{mse:.4f}')
print(f'Mean?Absolute?Error?(MAE):?{mae:.4f}')#?繪制真實值與預測值的對比圖
plt.figure(figsize=(10,?6))
plt.scatter(y_test,?y_pred,?alpha=0.5)??#?繪制散點圖
plt.plot([y_test.min(),?y_test.max()],?[y_test.min(),?y_test.max()],?'--k')??#?繪制對角線
plt.xlabel('True?Values')??#?X軸標簽
plt.ylabel('Predictions')??#?Y軸標簽
plt.title('True?Values?vs?Predictions')??#?圖標題
plt.show()#?特征重要性可視化
feature_importances?=?model.get_feature_importance()??#?獲取特征重要性
feature_names?=?X.columns??#?獲取特征名稱plt.figure(figsize=(10,?6))
sns.barplot(x=feature_importances,?y=feature_names)??#?繪制特征重要性條形圖
plt.title('Feature?Importances')??#?圖標題
plt.show()#?輸出
'''
mpg??cylinders??displacement??horsepower??weight??acceleration??\
0??18.0??????????8?????????307.0???????130.0??3504.0??????????12.0???
1??15.0??????????8?????????350.0???????165.0??3693.0??????????11.5???
2??18.0??????????8?????????318.0???????150.0??3436.0??????????11.0???
3??16.0??????????8?????????304.0???????150.0??3433.0??????????12.0???
4??17.0??????????8?????????302.0???????140.0??3449.0??????????10.5???model_year??origin??
0??????????70???????1??
1??????????70???????1??
2??????????70???????1??
3??????????70???????1??
4??????????70???????1??
0:?learn:?7.3598113?test:?6.6405869?best:?6.6405869?(0)?total:?1.7ms?remaining:?1.69s
100:?learn:?1.5990203?test:?2.3207830?best:?2.3207666?(94)?total:?132ms?remaining:?1.17s
200:?learn:?1.0613606?test:?2.2319632?best:?2.2284239?(183)?total:?272ms?remaining:?1.08s
Stopped?by?overfitting?detector??(50?iterations?wait)bestTest?=?2.21453232
bestIteration?=?238Shrink?model?to?first?239?iterations.
Mean?Squared?Error?(MSE):?4.9042
Mean?Absolute?Error?(MAE):?1.6381<Figure?size?1000x600?with?1?Axes>
<Figure?size?1000x600?with?1?Axes>
'''

Mean Squared Error (MSE): 均方誤差,表示預測值與實際值之間的平均平方差異。值越小,模型性能越好,在這里MSE的值是4.9042。

Mean Absolute Error (MAE): 平均絕對誤差,表示預測值與實際值之間的平均絕對差異。值越小,模型性能越好,在這里MAE的值是1.6381。

圖片

  1. 散點圖:圖中的每個點表示一個測試樣本。橫坐標表示該樣本的真實值(MPG),縱坐標表示模型的預測值(MPG)。

  2. 對角線:圖中的黑色虛線是45度對角線,表示理想情況下的預測結果,即預測值等于真實值。

  3. 點的分布:

    • 靠近對角線:表示模型的預測值與真實值非常接近,預測準確。

    • 遠離對角線:表示預測值與真實值有較大差距,預測不準確。

通過圖中的點可以看到大部分點都集中在對角線附近,這表明模型的預測性能良好,但也有一些點離對角線較遠,表示這些樣本的預測值與真實值存在一定的差距。

圖片

  1. 條形圖:每個條形表示一個特征在模型中的重要性。條形越長,表示該特征對模型預測的貢獻越大。

  2. 特征名稱:在Y軸上列出了所有特征的名稱。

  3. 特征重要性值:在X軸上顯示了每個特征的相對重要性值。

從圖中可以看到:

  1. model_year:在所有特征中最重要,表示汽車的生產年份對預測燃油效率有很大的影響。

  2. weight:汽車的重量是第二重要的特征,對燃油效率也有顯著影響。

  3. displacement?和?horsepower:發動機的排量和功率對燃油效率也有較大的貢獻。

在實例中,我們使用CatBoost處理Auto MPG數據集,其主要目的是構建一個回歸模型,以預測汽車的燃油效率(即每加侖燃油行駛的英里數,MPG)。

以上內容總結自網絡,如有幫助歡迎轉發,我們下次再見!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/43423.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/43423.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/43423.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

什么是ThingsKit物聯網平臺?

在信息化時代的浪潮中&#xff0c;物聯網&#xff08;IoT&#xff09;作為新一代信息技術的核心&#xff0c;已經逐漸滲透到我們生活的方方面面。而在這個大背景下&#xff0c;Thingskit物聯網平臺以其獨特的技術優勢和應用場景&#xff0c;成為了物聯網領域的一顆璀璨明星。本…

AI和人工智能是啥關系?

AI&#xff08;人工智能&#xff09;與通用人工智能&#xff08;AGI&#xff09;是人工智能領域中的兩個重要概念&#xff0c;它們在定義、技術基礎以及應用領域等方面有所區別。人工智能&#xff08;AI&#xff09;&#xff0c;是指使計算機和其他機器模擬人類智能的技術&…

3.flink架構

目錄 概述 概述 Flink是一個分布式的帶有狀態管理的計算框架&#xff0c;為了執行流應用程序&#xff0c;可以和 Hadoop YARN 、K8s 進行整合&#xff0c;當然也可以是一個 standalone 。 官方地址&#xff1a;速遞 k8s 是未來的一種趨勢&#xff0c;對資源管控能力強。

Windows 控制中心在哪里打開,七種方法教會你

在 Windows 操作系統中&#xff0c;控制中心的概念可能稍有些混淆&#xff0c;因為 Windows 通常使用“控制面板”這一術語來指代用于配置系統設置和更改硬件及軟件設置的中心區域。 不過&#xff0c;隨著 Windows 的更新&#xff0c;微軟也在逐步將一些設置遷移到“設置”應用…

關于Linux的操作作業!24道題

&#x1f3c6;本文收錄于「Bug調優」專欄&#xff0c;主要記錄項目實戰過程中的Bug之前因后果及提供真實有效的解決方案&#xff0c;希望能夠助你一臂之力&#xff0c;幫你早日登頂實現財富自由&#x1f680;&#xff1b;同時&#xff0c;歡迎大家關注&&收藏&&…

js如何要讓一個對象繼承另一個對象的原型屬性和方法

js如何要讓一個對象繼承另一個對象的原型屬性和方法 1、使用 Object.create() const parent {greet: function() {console.log("Hello from parent!");} };const child Object.create(parent); child.greet(); // 輸出: Hello from parent!2、使用 proto 屬性 …

【算法】貪婪算法介紹及實現方法

貪婪算法簡介 貪婪算法&#xff08;Greedy Algorithm&#xff09;是一種在每一步選擇中都采取當前狀態下最好或最優&#xff08;即最有利&#xff09;的選擇&#xff0c;從而希望導致結果是全局最好或最優的算法。貪婪算法通常用于解決優化問題&#xff0c;如最小化成本、最大…

Tomcat打破雙親委派模型的方式

文章目錄 1、前言2、標準的雙親委派模型3、Tomcat的類加載器架構4、Tomcat打破雙親委派模型的方式5、總結 1、前言 雙親委派模型是一種類加載機制&#xff0c;它確保了類加載器層次結構中的父加載器先于子加載器嘗試加載類。這種機制有助于防止類的重復加載和類之間的不兼容。…

MySQL數據庫基本操作-DDL和DML

1. DDL解釋 DDL(Data Definition Language)&#xff0c;數據定義語言&#xff0c;該語言部分包括以下內容&#xff1a; 對數據庫的常用操作對表結構的常用操作修改表結構 2. 對數據庫的常用操作 功能SQL查看所有的數據庫show databases&#xff1b;查看有印象的數據庫show d…

16 - Python語言進階

Python語言進階 數據結構和算法 算法&#xff1a;解決問題的方法和步驟 評價算法的好壞&#xff1a;漸近時間復雜度和漸近空間復雜度。 漸近時間復雜度的大O標記&#xff1a; - 常量時間復雜度 - 布隆過濾器 / 哈希存儲 - 對數時間復雜度 - 折半查找&#xff08;二分查找&am…

關于TCP的三次握手流程

三次握手流程 第一次握手&#xff1a;客戶端向服務端發起建立連接請求&#xff0c;客戶端會隨機生成一個起始序列號x&#xff0c;客戶端向服務端發送的字段包含標志位SYN1&#xff0c;序列號segx。第一次握手后客戶端的狀態為SYN-SENT。此時服務端的狀態為LISTEN 第二次握手&…

使用耳機殼UV樹脂制作私模定制耳塞的價格如何呢?

使用耳機殼UV樹脂制作私模定制耳塞的價格如何呢&#xff1f; 耳機殼UV樹脂制作私模定制耳塞的價格因多個因素而異&#xff0c;如材料、工藝、設計、定制復雜度等。 根據我目前所了解到的信息&#xff0c;使用UV樹脂制作私模定制耳塞的價格可能在數百元至數千元不等。具體價格…

LVS+Nginx高可用集群---Nginx進階與實戰

1.Nginx中解決跨域問題 兩個站點的域名不一樣&#xff0c;就會有一個跨域問題。 跨域問題&#xff1a;了解同源策略&#xff1a;協議&#xff0c;域名&#xff0c;端口號都相同&#xff0c;只要有一個不相同那么就是非同源。 CORS全稱Cross-Origin Resource Sharing&#xff…

大模型知識大全1-基礎知識【大模型】

文章目錄 大模型簡介以后的介紹流程基礎知識訓練流程介紹pre-train對齊和指令微調規模拓展涌現能力 系統學習大模型的記錄https://github.com/LLMBook-zh/LLMBook-zh.github.io 大模型簡介 歷史我就不寫了&#xff0c;簡單說說大模型的應用和特點。人類使用大模型其實分為兩個…

linux高級編程(OSI/UDP(用戶數據報))

OSI七層模型&#xff1a; OSI 模型 --> 開放系統互聯模型 --> 分為7層&#xff1a; 理想模型 --> 尚未實現 1.應用層 QQ 應用程序的接口 2.表示層 加密解密 gzip 將接收的數據進行解釋&#xff…

【shell】—雙引號引用變量

文章目錄 一、舉例—單、雙引號引用變量的結果差異二、使用雙引號引用變量的場景1、使用雙引號—可以防止字符串被分割2、使用雙引號—特殊字符變為普通字符3、使用雙引號—保存原始命令的輸出格式4、使用雙引號—具有強約束的單引號變為普通單引號字符5、注意 一、舉例—單、雙…

挑戰杯 opencv python 深度學習垃圾圖像分類系統

0 前言 &#x1f525; 優質競賽項目系列&#xff0c;今天要分享的是 &#x1f6a9; opencv python 深度學習垃圾分類系統 &#x1f947;學長這里給一個題目綜合評分(每項滿分5分) 難度系數&#xff1a;3分工作量&#xff1a;3分創新點&#xff1a;4分 這是一個較為新穎的競…

昇思25天學習打卡營第13天|應用實踐之ResNet50遷移學習

基本介紹 今日的應用實踐的模型是計算機實踐領域中十分出名的模型----ResNet模型。ResNet是一種殘差網絡結構&#xff0c;它通過引入“殘差學習”的概念來解決隨著網絡深度增加時訓練困難的問題&#xff0c;從而能夠訓練更深的網絡結構。現很多網絡極深的模型或多或少都受此影響…

數據鏈路層(超詳細)

引言 數據鏈路層是計算機網絡協議棧中的第二層&#xff0c;位于物理層之上&#xff0c;負責在相鄰節點之間的可靠數據傳輸。數據鏈路層使用的信道主要有兩種類型&#xff1a;點對點信道和廣播信道。點對點信道是指一對一的通信方式&#xff0c;而廣播信道則是一對多的通信方式…