CART決策樹(4-2)
CART(Classification and Regression Trees)決策樹是一種常用的機器學習算法,它既可以用于分類問題,也可以用于回歸問題。CART決策樹的主要原理是通過遞歸地將數據集劃分為兩個子集來構建決策樹。在分類問題中,CART決策樹通過選擇一個能夠最大化分裂后各個子集純度提升的特征進行分裂,從而將數據劃分為不同的類別。
CART決策樹的構建過程包括以下幾個步驟:
- 特征選擇:從數據集中選擇一個最優特征,用于劃分數據集。最優特征的選擇基于某種準則,如基尼指數(Gini Index)或信息增益(Information Gain)。
- 決策樹生成:根據選定的最優特征,將數據集劃分為兩個子集,并遞歸地在每個子集上重復上述過程,直到滿足停止條件(如子集大小小于某個閾值、所有樣本屬于同一類別等)。
- 剪枝:為了避免過擬合,可以對生成的決策樹進行剪枝操作,即刪除一些子樹或葉子節點,以提高模型的泛化能力。
CART決策樹的優點包括:
- 計算簡單,易于理解,可解釋性強。
- 不需要預處理,不需要提前歸一化,可以處理缺失值和異常值。
- 既可以處理離散值也可以處理連續值。
- 既可以用于分類問題,也可以用于回歸問題。
然而,CART決策樹也存在一些缺點:
- 不支持在線學習,當有新樣本產生時,需要重新構建決策樹模型。
- 容易出現過擬合現象,生成的決策樹可能對訓練數據有很好的分類能力,但對未知的測試數據卻未必有很好的分類能力。
- 對于一些復雜的關系,如異或關系,CART決策樹可能難以學習。
CART決策樹在許多領域都有廣泛的應用,如推薦系統中的商品推薦模型、金融風控中的信用評分和欺詐檢測、醫療診斷中的疾病預測等。此外,CART決策樹還可以用于社交媒體情感分析等領域。
- 數據
使用Universal Bank數據集。
示例:
????????
ID | Age | Experience | Income | ZIP Code | Family | CCAvg | Education | Mortgage | Personal Loan | Securities Account | CD Account | Online | CreditCard |
1 | 25 | 1 | 49 | 91107 | 4 | 1.6 | 1 | 0 | 0 | 1 | 0 | 0 | 0 |
2 | 45 | 19 | 34 | 90089 | 3 | 1.5 | 1 | 0 | 0 | 1 | 0 | 0 | 0 |
3 | 39 | 15 | 11 | 94720 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
4 | 35 | 9 | 100 | 94112 | 1 | 2.7 | 2 | 0 | 0 | 0 | 0 | 0 | 0 |
5 | 35 | 8 | 45 | 91330 | 4 | 1 | 2 | 0 | 0 | 0 | 0 | 0 | 1 |
6 | 37 | 13 | 29 | 92121 | 4 | 0.4 | 2 | 155 | 0 | 0 | 0 | 1 | 0 |
7 | 53 | 27 | 72 | 91711 | 2 | 1.5 | 2 | 0 | 0 | 0 | 0 | 1 | 0 |
8 | 50 | 24 | 22 | 93943 | 1 | 0.3 | 3 | 0 | 0 | 0 | 0 | 0 | 1 |
9 | 35 | 10 | 81 | 90089 | 3 | 0.6 | 2 | 104 | 0 | 0 | 0 | 1 | 0 |
10 | 34 | 9 | 180 | 93023 | 1 | 8.9 | 3 | 0 | 1 | 0 | 0 | 0 | 0 |
11 | 65 | 39 | 105 | 94710 | 4 | 2.4 | 3 | 0 | 0 | 0 | 0 | 0 | 0 |
12 | 29 | 5 | 45 | 90277 | 3 | 0.1 | 2 | 0 | 0 | 0 | 0 | 1 | 0 |
13 | 48 | 23 | 114 | 93106 | 2 | 3.8 | 3 | 0 | 0 | 1 | 0 | 0 | 0 |
14 | 59 | 32 | 40 | 94920 | 4 | 2.5 | 2 | 0 | 0 | 0 | 0 | 1 | 0 |
15 | 67 | 41 | 112 | 91741 | 1 | 2 | 1 | 0 | 0 | 1 | 0 | 0 | 0 |
16 | 60 | 30 | 22 | 95054 | 1 | 1.5 | 3 | 0 | 0 | 0 | 0 | 1 | 1 |
17 | 38 | 14 | 130 | 95010 | 4 | 4.7 | 3 | 134 | 1 | 0 | 0 | 0 | 0 |
18 | 42 | 18 | 81 | 94305 | 4 | 2.4 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
19 | 46 | 21 | 193 | 91604 | 2 | 8.1 | 3 | 0 | 1 | 0 | 0 | 0 | 0 |
20 | 55 | 28 | 21 | 94720 | 1 | 0.5 | 2 | 0 | 0 | 1 | 0 | 0 | 1 |
21 | 56 | 31 | 25 | 94015 | 4 | 0.9 | 2 | 111 | 0 | 0 | 0 | 1 | 0 |
22 | 57 | 27 | 63 | 90095 | 3 | 2 | 3 | 0 | 0 | 0 | 0 | 1 | 0 |
23 | 29 | 5 | 62 | 90277 | 1 | 1.2 | 1 | 260 | 0 | 0 | 0 | 1 | 0 |
24 | 44 | 18 | 43 | 91320 | 2 | 0.7 | 1 | 163 | 0 | 1 | 0 | 0 | 0 |
25 | 36 | 11 | 152 | 95521 | 2 | 3.9 | 1 | 159 | 0 | 0 | 0 | 0 | 1 |
26 | 43 | 19 | 29 | 94305 | 3 | 0.5 | 1 | 97 | 0 | 0 | 0 | 1 | 0 |
27 | 40 | 16 | 83 | 95064 | 4 | 0.2 | 3 | 0 | 0 | 0 | 0 | 0 | 0 |
28 | 46 | 20 | 158 | 90064 | 1 | 2.4 | 1 | 0 | 0 | 0 | 0 | 1 | 1 |
29 | 56 | 30 | 48 | 94539 | 1 | 2.2 | 3 | 0 | 0 | 0 | 0 | 1 | 1 |
30 | 38 | 13 | 119 | 94104 | 1 | 3.3 | 2 | 0 | 1 | 0 | 1 | 1 | 1 |
31 | 59 | 35 | 35 | 93106 | 1 | 1.2 | 3 | 122 | 0 | 0 | 0 | 1 | 0 |
32 | 40 | 16 | 29 | 94117 | 1 | 2 | 2 | 0 | 0 | 0 | 0 | 1 | 0 |
33 | 53 | 28 | 41 | 94801 | 2 | 0.6 | 3 | 193 | 0 | 0 | 0 | 0 | 0 |
34 | 30 | 6 | 18 | 91330 | 3 | 0.9 | 3 | 0 | 0 | 0 | 0 | 0 | 0 |
35 | 31 | 5 | 50 | 94035 | 4 | 1.8 | 3 | 0 | 0 | 0 | 0 | 1 | 0 |
36 | 48 | 24 | 81 | 92647 | 3 | 0.7 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
37 | 59 | 35 | 121 | 94720 | 1 | 2.9 | 1 | 0 | 0 | 0 | 0 | 0 | 1 |
38 | 51 | 25 | 71 | 95814 | 1 | 1.4 | 3 | 198 | 0 | 0 | 0 | 0 | 0 |
39 | 42 | 18 | 141 | 94114 | 3 | 5 | 3 | 0 | 1 | 1 | 1 | 1 | 0 |
40 | 38 | 13 | 80 | 94115 | 4 | 0.7 | 3 | 285 | 0 | 0 | 0 | 1 | 0 |
41 | 57 | 32 | 84 | 92672 | 3 | 1.6 | 3 | 0 | 0 | 1 | 0 | 0 | 0 |
42 | 34 | 9 | 60 | 94122 | 3 | 2.3 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
43 | 32 | 7 | 132 | 90019 | 4 | 1.1 | 2 | 412 | 1 | 0 | 0 | 1 | 0 |
44 | 39 | 15 | 45 | 95616 | 1 | 0.7 | 1 | 0 | 0 | 0 | 0 | 1 | 0 |
45 | 46 | 20 | 104 | 94065 | 1 | 5.7 | 1 | 0 | 0 | 0 | 0 | 1 | 1 |
46 | 57 | 31 | 52 | 94720 | 4 | 2.5 | 1 | 0 | 0 | 0 | 0 | 0 | 1 |
47 | 39 | 14 | 43 | 95014 | 3 | 0.7 | 2 | 153 | 0 | 0 | 0 | 1 | 0 |
48 | 37 | 12 | 194 | 91380 | 4 | 0.2 | 3 | 211 | 1 | 1 | 1 | 1 | 1 |
49 | 56 | 26 | 81 | 95747 | 2 | 4.5 | 3 | 0 | 0 | 0 | 0 | 0 | 1 |
50 | 40 | 16 | 49 | 92373 | 1 | 1.8 | 1 | 0 | 0 | 0 | 0 | 0 | 1 |
51 | 32 | 8 | 8 | 92093 | 4 | 0.7 | 2 | 0 | 0 | 1 | 0 | 1 | 0 |
52 | 61 | 37 | 131 | 94720 | 1 | 2.9 | 1 | 0 | 0 | 0 | 0 | 1 | 0 |
53 | 30 | 6 | 72 | 94005 | 1 | 0.1 | 1 | 207 | 0 | 0 | 0 | 0 | 0 |
54 | 50 | 26 | 190 | 90245 | 3 | 2.1 | 3 | 240 | 1 | 0 | 0 | 1 | 0 |
55 | 29 | 5 | 44 | 95819 | 1 | 0.2 | 3 | 0 | 0 | 0 | 0 | 1 | 0 |
56 | 41 | 17 | 139 | 94022 | 2 | 8 | 1 | 0 | 0 | 0 | 0 | 1 | 0 |
57 | 55 | 30 | 29 | 94005 | 3 | 0.1 | 2 | 0 | 0 | 1 | 1 | 1 | 0 |
58 | 56 | 31 | 131 | 95616 | 2 | 1.2 | 3 | 0 | 1 | 0 | 0 | 0 | 0 |
59 | 28 | 2 | 93 | 94065 | 2 | 0.2 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
60 | 31 | 5 | 188 | 91320 | 2 | 4.5 | 1 | 455 | 0 | 0 | 0 | 0 | 0 |
61 | 49 | 24 | 39 | 90404 | 3 | 1.7 | 2 | 0 | 0 | 1 | 0 | 1 | 0 |
62 | 47 | 21 | 125 | 93407 | 1 | 5.7 | 1 | 112 | 0 | 1 | 0 | 0 | 0 |
63 | 42 | 18 | 22 | 90089 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
64 | 42 | 17 | 32 | 94523 | 4 | 0 | 2 | 0 | 0 | 0 | 0 | 1 | 0 |
65 | 47 | 23 | 105 | 90024 | 2 | 3.3 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
66 | 59 | 35 | 131 | 91360 | 1 | 3.8 | 1 | 0 | 0 | 0 | 0 | 1 | 1 |
67 | 62 | 36 | 105 | 95670 | 2 | 2.8 | 1 | 336 | 0 | 0 | 0 | 0 | 0 |
68 | 53 | 23 | 45 | 95123 | 4 | 2 | 3 | 132 | 0 | 1 | 0 | 0 | 0 |
69 | 47 | 21 | 60 | 93407 | 3 | 2.1 | 1 | 0 | 0 | 0 | 0 | 1 | 1 |
70 | 53 | 29 | 20 | 90045 | 4 | 0.2 | 1 | 0 | 0 | 0 | 0 | 1 | 0 |
71 | 42 | 18 | 115 | 91335 | 1 | 3.5 | 1 | 0 | 0 | 0 | 0 | 0 | 1 |
72 | 53 | 29 | 69 | 93907 | 4 | 1 | 2 | 0 | 0 | 0 | 0 | 1 | 0 |
73 | 44 | 20 | 130 | 92007 | 1 | 5 | 1 | 0 | 0 | 0 | 0 | 0 | 1 |
74 | 41 | 16 | 85 | 94606 | 1 | 4 | 3 | 0 | 0 | 0 | 0 | 1 | 1 |
75 | 28 | 3 | 135 | 94611 | 2 | 3.3 | 1 | 0 | 0 | 0 | 0 | 0 | 1 |
76 | 31 | 7 | 135 | 94901 | 4 | 3.8 | 2 | 0 | 1 | 0 | 1 | 1 | 1 |
注意:數據集中的編號(ID)和郵政編碼(ZIP CODE)特征因為在分類模型中無意義,所以在數據預處理階段將它們刪除。
- 使用CART決策樹對數據進行分類
- 使用留出法劃分數據集,訓練集:測試集為7:3。
# 使用留出法劃分數據集,訓練集:測試集為7:3
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
- 使用CART決策樹對訓練集進行訓練
# 使用CART決策樹對訓練集進行訓練,深度限制為10層
model = DecisionTreeClassifier(max_depth=10)
model.fit(X_train, y_train)
決策樹的深度限制為10層,max_depth=10。
- 使用訓練好的模型對測試集進行預測并輸出預測結果和模型準確度
# 使用訓練好的模型對測試集進行預測
y_pred = model.predict(X_test)# 輸出預測結果和模型準確度
accuracy = accuracy_score(y_test, y_pred)
print("模型準確度:", accuracy)
- 可視化訓練好的CART決策樹模型
# 可視化訓練好的CART決策樹模型
dot_data = export_graphviz(model, out_file=None,feature_names=X.columns,class_names=['0', '1'],filled=True, rounded=True,special_characters=True)
graph = graphviz.Source(dot_data)
graph.render("Universal_Bank_CART") # 保存為PDF文件
- 安裝graphviz模塊
首先在windows系統中安裝graphviz模塊
32位系統使用windows_10_cmake_Release_graphviz-install-10.0.1-win32.exe
64位系統使用windows_10_cmake_Release_graphviz-install-10.0.1-win64.exe
注意:安裝時使用下圖中圈出的選項
安裝完成后使用pip install graphviz指令在python環境中安裝graphviz庫。
- 使用graphviz模塊可視化模型
# 可視化訓練好的CART決策樹模型
dot_data = export_graphviz(model, out_file=None,feature_names=X.columns,class_names=['0', '1'],filled=True, rounded=True,special_characters=True)
graph = graphviz.Source(dot_data)
graph.render("Universal_Bank_CART") # 保存為PDF文件
完整代碼:
# 導入所需的庫
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.tree import export_graphviz
import graphviz# 讀取數據集
data = pd.read_csv("universalbank.csv")# 數據預處理:刪除無意義特征
data = data.drop(columns=['ID', 'ZIP Code'])# 劃分特征和標簽
X = data.drop(columns=['Personal Loan'])
y = data['Personal Loan']# 使用留出法劃分數據集,訓練集:測試集為7:3
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 使用CART決策樹對訓練集進行訓練,深度限制為10層
model = DecisionTreeClassifier(max_depth=10)
model.fit(X_train, y_train)# 使用訓練好的模型對測試集進行預測
y_pred = model.predict(X_test)# 輸出預測結果和模型準確度
accuracy = accuracy_score(y_test, y_pred)
print("模型準確度:", accuracy)# 可視化訓練好的CART決策樹模型
dot_data = export_graphviz(model, out_file=None,feature_names=X.columns,class_names=['0', '1'],filled=True, rounded=True,special_characters=True)
graph = graphviz.Source(dot_data)
graph.render("Universal_Bank_CART6") # 保存為PDF文件