一、概述
??CatBoost 是在傳統GBDT基礎上改進和優化的一種算法,由俄羅斯 Yandex 公司開發,于2017 年開源,在處理類別型特征和防止過擬合方面有獨特優勢。
??在實際數據中,存在大量的類別型特征,如性別、顏色、類別等,傳統的算法通常需要在預處理中對這些特征進行獨熱編碼(One-Hot Encoding)或標簽編碼(Label Encoding)。但這些方法存在一些問題,獨熱編碼會增加數據的維度,導致模型訓練時間變長;標簽編碼可能會引入不必要的順序關系,影響模型的準確性。CatBoost 采用了一種獨特的處理方式,稱為 “Ordered Target Statistics”(有序目標統計),它通過對數據進行排序,利用數據的順序信息來計算類別型特征的統計量,從而將特征有效地融入到模型中,避免了傳統編碼方式的弊端。
??另外,在構建決策樹時,CatBoost 采用了對稱樹的結構,與傳統的非對稱決策樹相比,對稱樹在生長過程中,每層的節點數量相同,結構更加規整。這種結構使得模型在訓練過程中更加穩定,能夠減少過擬合的風險,同時也有助于提高訓練速度。
二、算法原理
1.對稱樹結構
??對稱樹結構在形式上是完全二叉樹結構,是指在構建決策樹時,對于每個節點的分裂,都考慮所有可能的特征和閾值組合,并且在樹的同一層中,所有節點的分裂方式是對稱的。具體可描述為
??特征選擇:在構建對稱樹時,CatBoost 會對所有可用的特征進行評估,計算每個特征對于目標變量的重要性。通過一些統計指標,如信息增益、基尼系數等,來衡量特征對數據劃分的有效性,選擇具有最高重要性的特征作為當前節點的分裂特征。
??閾值確定:對于選定的分裂特征,CatBoost 會遍歷該特征的所有可能取值,尋找一個最優的分裂閾值,使得分裂后的兩個子節點能夠最大程度地分離不同類別的數據,或者使目標變量在兩個子節點上的分布具有最大的差異。
??對稱分裂:一旦確定了分裂特征和閾值,就在當前節點上按照這個特征和閾值進行分裂,將數據集分為左右兩個子節點。在樹的同一層中,所有節點都按照相同的特征選擇和閾值確定方法進行分裂,形成對稱的樹結構。
2.訓練過程
(1) 初始化弱學習器
??首先,初始化一個弱學習器,通常是一個決策樹(是否對稱樹結構均可),記為 f 0 ( X ) f_0(X) f0?(X),其預測結果為初始的預測值 y ^ 0 \hat y_0 y^?0?。此時,初始預測值與真實值之間存在誤差。
(2) 計算殘差或負梯度
??在回歸任務中,計算每個樣本的殘差,即真實值 y i y_i yi?與當前模型預測值 y ^ i , t ? 1 \hat y_{i,t-1} y^?i,t?1?的差值 r i , t = y i ? y ^ i , t ? 1 r_{i,t}=y_i-\hat y_{i,t-1} ri,t?=yi??y^?i,t?1?,其中表示迭代的輪數。在分類任務中,計算損失函數關于當前模型預測值的負梯度
g i , t = ? ? L ( y i , y ^ i , t ? 1 ) ? y ^ i , t ? 1 g_{i,t}=-\frac{\vartheta L(y_i,\hat y_{i,t-1})}{\vartheta \hat y_{i,t-1}} gi,t?=??y^?i,t?1??L(yi?,y^?i,t?1?)?
(3) 構建決策樹
??使用計算得到的殘差(回歸任務)或負梯度(分類任務)作為新的目標值,使用“對稱樹結構” 的方式來構建一棵新的決策樹 f t ( X ) f_t(X) ft?(X)。同時采用一些限制決策樹深度、控制葉子節點數量的正則化技術。
(4) 更新模型
??根據新訓練的決策樹,更新當前模型。更新公式為 y ^ i , t = y ^ i , t ? 1 + α f t ( x i ) \hat y_{i,t}=\hat y_{i,t-1}+\alpha f_t(x_i) y^?i,t?=y^?i,t?1?+αft?(xi?),其中是學習率(也稱為步長),用于控制每棵樹對模型更新的貢獻程度。學習率較小可以使模型訓練更加穩定,但需要更多的迭代次數;學習率較大則可能導致模型收斂過快,甚至無法收斂。
(5) 重復迭代
??重復步驟 (2)–(4)步,不斷訓練新的決策樹并更新模型,直到達到預設的迭代次數、損失函數收斂到一定程度或滿足其他停止條件為止。最終,CatBoost模型由多棵決策樹組成,其預測結果是所有決策樹預測結果的累加。
過程示意圖
三、應用場景
1. 結構化數據預測
??在金融領域,CatBoost 可以用于信用評估、風險預測等任務。通過分析客戶的各種屬性(如年齡、收入、信用記錄等分類和數值特征),預測客戶的信用等級和違約風險,幫助金融機構做出更準確的決策。在電商領域,它可以用于商品推薦、銷售預測等。根據用戶的購買歷史、瀏覽行為等特征,預測用戶對不同商品的興趣,為用戶提供個性化的推薦服務,同時也可以幫助商家預測商品的銷量,合理安排庫存。
2.時間序列分析
??CatBoost 在時間序列預測方面也有一定的應用。它可以處理具有復雜模式和趨勢的時間序列數據,如股票價格預測、能源消耗預測等。通過提取時間序列中的各種特征(如趨勢、季節性、周期性等),結合其他相關的影響因素,構建預測模型,為決策提供支持。
3.圖像和文本數據的輔助分析
??雖然 CatBoost 主要適用于結構化數據,但在一些情況下,它也可以與其他深度學習算法結合,用于圖像和文本數據的輔助分析。例如,在圖像分類任務中,可以先使用深度學習模型提取圖像的特征,然后將這些特征與其他相關的結構化數據(如拍攝時間、地點等)一起輸入到 CatBoost 模型中,進行進一步的分類和預測。
四、Python實現
(環境:Python 3.11,scikit-learn 1.6.1)
分類情形
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import catboost as cb
from sklearn import metrics# 生成數據集
X, y = make_classification(n_samples = 1000, n_features = 6, random_state = 42)
# 將數據集劃分為訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 42)# 創建CatBoost分類模型
model = cb.CatBoostClassifier()
# 訓練模型
model.fit(X_train, y_train)# 預測
y_pre = model.predict(X_test)
# 性能評價
accuracy = metrics.accuracy_score(y_test,y_pre)print('預測結果為:',y_pre)
print('準確率為:',accuracy)
回歸情形
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
import catboost as cb
from sklearn.metrics import mean_squared_error# 生成數據集
X, y = make_regression(n_samples = 1000, n_features = 6, random_state = 42)
# 將數據集劃分為訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 42)# 創建CatBoost回歸模型
model = cb.CatBoostRegressor()
# 訓練模型
model.fit(X_train, y_train)# 進行預測
y_pred = model.predict(X_test)# 計算均方誤差評估模型性能
mse = mean_squared_error(y_test, y_pred)print(f"均方誤差: {mse}")
五、小結
??CatBoost 算法憑借其獨特的算法原理和核心特點,在機器學習領域中占據了一席之地。它在處理類別型特征、防止過擬合、訓練速度和易用性等方面都表現出色,適用于多種應用場景。無論是在結構化數據預測、時間序列分析還是與其他類型數據的結合應用中,CatBoost 都展現出了強大的能力。隨著數據科學的發展,CatBoost 可逐漸在更多領域得到應用,為解決實際問題提供更多有效的幫助。
End.
下載