文章目錄
- 1. 決策樹基本原理
- 1.1. 什么是決策樹?
- 1.2. 決策樹的基本構成:
- 1.3. 核心思想
- 2. 決策樹的構建過程
- 2.1. 特征選擇
- 2.1.1. 信息增益(ID3)
- 2.1.2. 基尼不純度(CART)
- 2.1.3. 均方誤差(MSE)
- 2.2. 節點劃分
- 2.3. 停止條件:
- 3. 決策樹的剪枝(防止過擬合)
- 4. 決策樹的優缺點
- 5. 常見決策樹算法
- 6. 樣例代碼:
- 7. 歸納
1. 決策樹基本原理
1.1. 什么是決策樹?
決策樹(Decision Tree)是一種非參數的監督學習算法,適用于分類和回歸任務。其核心思想是通過一系列規則(if-then結構)對數據進行遞歸劃分,最終形成一棵樹形結構,實現預測或分類。
1.2. 決策樹的基本構成:
- 根節點(Root Node):代表整個數據集,選擇第一個最優特征進行分裂。
- 內部節點(Internal Nodes):代表對某個特征的判斷,用來決定如何分裂數據。
- 葉子節點(Leaf Nodes):存放最終的預測結果,表示分類或回歸結果。
1.3. 核心思想
- 目標:構建一棵樹,使得每個分支節點代表一個特征判斷,每個葉子節點代表一個預測結果。
- 關鍵問題:
- 如何選擇劃分特征?(特征選擇準則)
- 何時停止劃分?(防止過擬合)
2. 決策樹的構建過程
決策樹的構建是一個遞歸分割(Recursive Partitioning)的過程
2.1. 特征選擇
選擇最佳特征:在每一步分裂中,算法會選擇一個最優的特征來進行數據劃分。
常用的準則:
- 信息增益(Information Gain, ID3算法)
- 信息增益比(Gain Ratio, C4.5算法)
- 基尼不純度(Gini Impurity, CART算法)
- 均方誤差(MSE, 回歸樹)
2.1.1. 信息增益(ID3)
-
衡量使用某特征劃分后信息不確定性減少的程度。
-
計算公式: 信息增益 = H ( D ) ? H ( D ∣ A ) 信息增益=H(D)?H(D∣A) 信息增益=H(D)?H(D∣A)
- H(D):數據集的熵(不確定性)。
- H(D∣A):在特征 A劃分后的條件熵。
2.1.2. 基尼不純度(CART)
-
衡量數據集的不純度,越小越好,表示數據集越純。
-
計算公式:
Gini ( D ) = 1 ? ∑ k = 1 K p k 2 \text{Gini}(D) = 1 - \sum_{k=1}^{K} p_k^2 Gini(D)=1?k=1∑K?pk2?- p k p_k pk? :數據集中第 k k k 類樣本的比例。
2.1.3. 均方誤差(MSE)
-
用于回歸問題,計算預測值與真實值的差異。
-
計算公式: M S E = 1 n ∑ ( y i ? y ^ i ) 2 MSE= \frac {1}{n}\sum(y_i ? \hat y_i) ^2 MSE=n1?∑(yi??y^?i?)2
- y i y_i yi?是實際值, y ^ i \hat y_i y^?i? 是預測值。
2.2. 節點劃分
- 分類任務:選擇使信息增益最大(或基尼不純度最小)的特征進行劃分。
- 回歸任務:選擇使均方誤差(MSE)最小的特征進行劃分。
2.3. 停止條件:
- 當前節點所有樣本屬于同一類別(純度100%)。
- 所有特征已用完,或繼續劃分無法顯著降低不純度。
- 達到預設的最大深度(max_depth)或最小樣本數(min_samples_split)。
3. 決策樹的剪枝(防止過擬合)
決策樹容易過擬合(訓練集表現好,測試集差)。為了防止過擬合,我們通常會使用剪枝技術。
-
預剪枝(Pre-pruning):在訓練時提前停止(如限制樹深度)。
-
后剪枝(Post-pruning):先訓練完整樹,再剪掉不重要的分支(如C4.5的REP方法)。
4. 決策樹的優缺點
- ? 優點
- 可解釋性強:規則清晰,易于可視化(if-then結構)。
- 無需數據標準化:對數據分布無嚴格要求。
- 可處理混合類型數據(數值型+類別型)。
- 適用于小規模數據。
- ? 缺點
- 容易過擬合(需剪枝或限制樹深度)。
- 對噪聲敏感(異常值可能導致樹結構不穩定)。
- 不穩定性:數據微小變化可能導致完全不同的樹。
- 不適合高維稀疏數據(如文本數據)。
5. 常見決策樹算法
算法 | 適用任務 | 特征選擇準則 | 特點 |
---|---|---|---|
ID3 | 分類 | 信息增益 | 只能處理離散特征,容易過擬合 |
C4.5 | 分類 | 信息增益比 | 可處理連續特征,支持剪枝 |
CART | 分類/回歸 | 基尼不純度(分類) 均方誤差(回歸) | 二叉樹結構,Scikit-learn默認實現 |
CHAID | 分類 | 卡方檢驗 | 適用于類別型數據 |
6. 樣例代碼:
# 導入必要的庫
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree# 加載鳶尾花數據集
data = load_iris()
X = data.data
y = data.target# 將數據分為訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 創建決策樹分類器
clf = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42)# 訓練決策樹
clf.fit(X_train, y_train)# 預測測試集
y_pred = clf.predict(X_test)# 輸出準確率
print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}")# 可視化決策樹
plt.figure(figsize=(12, 8))
plot_tree(clf, filled=True, feature_names=data.feature_names, class_names=data.target_names)
plt.show()
7. 歸納
決策樹的核心:遞歸劃分數據,選擇最優特征,構建樹結構。
-
關鍵問題:
- 如何選擇劃分特征?(信息增益、基尼不純度)
- 如何防止過擬合?(剪枝、限制樹深度)
-
適用場景:
-
需要可解釋性的任務(如金融風控)。
-
小規模、低維數據分類/回歸
-