2025國賽數學建模C題詳細思路模型代碼獲取見文末名片
決策樹算法:從原理到實戰(數模小白友好版)
1. 決策樹是什么?——用生活例子理解核心概念
想象你周末想決定是否去野餐,可能會這樣思考:
- 根節點(起點):是否去野餐?
- 內部節點(判斷條件):
先看天氣:晴天→繼續判斷;下雨→不去野餐(葉子節點)。
晴天再看溫度:>30℃→不去;≤30℃→去野餐(葉子節點)。
這個“判斷流程”就是一棵簡單的決策樹!決策樹本質是通過一系列“ifelse”規則,將復雜問題拆解為多個簡單子問題,最終輸出預測結果。
2. 決策樹核心:如何“問問題”?——分裂準則詳解
決策樹的關鍵是選擇最優特征作為當前“判斷條件”(即分裂節點)。不同算法的差異在于“如何定義最優”,這就是分裂準則。
2.1 分類決策樹:讓結果“更純”
分類任務(如“是否違約”“是否患病”)的目標是讓分裂后的子節點樣本盡可能屬于同一類別(即“純度”最大化)。
2.1.1 ID3算法:用“信息增益”找最有用的特征
ID3算法用信息熵衡量“混亂程度”,用信息增益衡量特征對“減少混亂”的貢獻。
第一步:理解信息熵(Entropy)——“混亂度”的量化
信息熵描述樣本集的不確定性:熵越小,樣本越純(混亂度越低)。
公式:設樣本集 ( D ) 有 ( K ) 類,第 ( k ) 類占比 ( p_k = \frac{\text{該類樣本數}}{\text{總樣本數}} ),則:
[
H(D) = \sum_{k=1}^K p_k \log_2 p_k \quad (\text{單位:比特})
]
極端例子:
若所有樣本都是同一類(純節點),如“全是晴天”,則 ( p_1=1,p_2=…=p_K=0 ),( H(D)=0 )(完全確定,熵最小);
若樣本均勻分布(最混亂),如二分類中“晴天/雨天各占50%”,則 ( H(D) = 0.5\log_2 0.5 0.5\log_2 0.5 = 1 )(熵最大)。
第二步:條件熵(Conditional Entropy)——“已知特征A時的混亂度”
假設用特征 ( A )(如“天氣”,取值:晴天/陰天/雨天)分裂樣本集 ( D ),會得到多個子集(如“晴天子集”“陰天子集”)。條件熵是這些子集熵的加權平均,衡量“已知特征A后,樣本集的剩余混亂度”。
公式:特征 ( A ) 有 ( V ) 個取值,第 ( v ) 個子集 ( D_v ) 的樣本數占比 ( \frac{|D_v|}{|D|} ),則:
[
H(D|A) = \sum_{v=1}^V \frac{|D_v|}{|D|} H(D_v)
]
其中 ( H(D_v) ) 是子集 ( D_v ) 的信息熵。
第三步:信息增益(IG)——“特征A減少的混亂度”
信息增益 = 分裂前的熵 分裂后的條件熵,即:
[
\text{IG}(A) = H(D) H(D|A)
]
IG越大,說明特征A減少的混亂度越多,越適合作為當前分裂特征。
舉個例子:用“天氣”特征分裂“是否去野餐”樣本集
分裂前總熵 ( H(D) = 0.9 )(假設樣本有一定混亂度);
分裂后條件熵 ( H(D|天氣) = 0.3 )(每個天氣子集的熵很小,因為晴天幾乎都去,雨天幾乎都不去);
信息增益 ( \text{IG}(天氣) = 0.9 0.3 = 0.6 )。
若“溫度”特征的IG=0.4,則“天氣”比“溫度”更適合作為分裂特征。
2.1.2 C4.5算法:修正ID3的“偏愛多取值特征”缺陷
ID3有個致命問題:傾向選擇取值多的特征(如“身份證號”每個樣本取值不同)。
例如“身份證號”分裂后,每個子集只有1個樣本(熵=0),條件熵 ( H(D|身份證號)=0 ),信息增益 ( \text{IG}=H(D)0=H(D) ),遠大于其他特征。但“身份證號”顯然無預測意義!
C4.5的改進:用信息增益比(Gain Ratio) 替代信息增益,公式:
[
\text{GainRatio}(A) = \frac{\text{IG}(A)}{H_A(D)}
]
其中 ( H_A(D) = \sum_{v=1}^V \frac{|D_v|}{|D|} \log_2 \frac{|D_v|}{|D|} ) 是特征 ( A ) 自身的熵(取值越多,( H_A(D) ) 越大)。
效果:取值多的特征(如身份證號)( H_A(D) ) 很大,導致增益比被“懲罰”(變小),從而避免被誤選。
2.1.3 CART算法:用“基尼指數”更高效地衡量純度
CART(分類回歸樹)是最常用的決策樹算法,支持分類和回歸,且是二叉樹(每個節點只分2個子節點)。分類任務中,CART用基尼指數衡量純度,計算更簡單(無需對數運算)。
基尼指數(Gini Index)——“隨機抽兩個樣本,類別不同的概率”
公式:樣本集 ( D ) 的基尼指數:
[
\text{Gini}(D) = 1 \sum_{k=1}^K p_k^2
]
(( p_k ) 是第 ( k ) 類樣本占比)
物理意義:隨機從 ( D ) 中抽2個樣本,它們類別不同的概率。純度越高,該概率越小,基尼指數越小。
極端例子:
純節點(全是同一類):( p_1=1 ),( \text{Gini}(D)=11^2=0 );
二分類均勻分布(50%/50%):( \text{Gini}(D)=1(0.52+0.52)=0.5 )(最大混亂)。
分裂后的基尼指數
若用特征 ( A ) 的閾值 ( t ) 將 ( D ) 分為左子樹 ( D_1 ) 和右子樹 ( D_2 ),則分裂后的基尼指數為:
[
\text{Gini}(D|A,t) = \frac{|D_1|}{|D|}\text{Gini}(D_1) + \frac{|D_2|}{|D|}\text{Gini}(D_2)
]
CART分類樹選擇最小基尼指數的(特征,閾值)對作為分裂點。
2.2 回歸決策樹:讓預測“更準”
回歸任務(如“房價預測”“溫度預測”)的目標是預測連續值,分裂準則是最小化平方誤差(MSE)。
平方誤差(MSE)——“預測值與真實值的平均差距”
假設用特征 ( A ) 的閾值 ( t ) 將樣本集 ( D ) 分為 ( D_1 ) 和 ( D_2 ),葉子節點的預測值為子集的均值(因為均值能最小化平方誤差):
[
c_1 = \frac{1}{|D_1|}\sum_{(x_i,y_i)\in D_1} y_i, \quad c_2 = \frac{1}{|D_2|}\sum_{(x_i,y_i)\in D_2} y_i
]
平方誤差為:
[
\text{MSE}(A,t) = \sum_{(x_i,y_i)\in D_1} (y_i c_1)^2 + \sum_{(x_i,y_i)\in D_2} (y_i c_2)^2
]
CART回歸樹選擇最小化MSE的(特征,閾值)對作為分裂點。
3. 手把手教你構建決策樹(CART算法為例)
以CART分類樹為例,完整步驟如下:
步驟1:準備數據
訓練集:( D = {(x_1,y_1),…,(x_m,y_m)} )(( x_i ) 是特征向量,( y_i ) 是類別標簽);
超參數:最小節點樣本數 ( N_{\text{min}} )(如5)、最小分裂增益 ( \epsilon )(如0.01)。
步驟2:遞歸分裂節點(核心!)
對當前節點的樣本集 ( D ),重復以下操作:
2.1 先判斷是否停止分裂(終止條件)
若滿足以下任一條件,當前節點成為葉子節點(輸出類別/均值):
純度足夠高:所有樣本屬于同一類(分類)或MSE < ( \epsilon )(回歸);
沒特征可分:特征集為空或所有樣本特征值相同;
樣本太少:節點樣本數 < ( N_{\text{min}} )(避免過擬合)。
2.2 若需分裂,選最優特征和閾值
遍歷所有特征 ( A_j ) 和可能的分裂閾值 ( t ),計算分裂后的基尼指數(分類)或MSE(回歸),選擇最優分裂點。
離散特征:如“天氣=晴/陰/雨”,嘗試每個取值作為閾值(如“晴” vs “陰+雨”);
連續特征:如“溫度”,排序后取相鄰樣本的中值作為候選閾值(如溫度排序后為[15,20,25],候選閾值為17.5、22.5)。
2.3 分裂節點并遞歸
按最優(特征,閾值)將 ( D ) 分為左子樹(滿足條件,如“溫度≤22.5”)和右子樹(不滿足條件),對左右子樹重復步驟2.1~2.3。
步驟3:剪枝——解決“過擬合”問題
決策樹容易“想太多”(過擬合):訓練時把噪聲也當成規律,導致對新數據預測不準。剪枝就是“簡化樹結構”,保留關鍵規律。
3.1 預剪枝(簡單粗暴)
分裂過程中提前停止:
限制樹深度(如最多5層);
節點樣本數 < ( N_{\text{min}} ) 時停止分裂;
分裂增益(如基尼指數下降量)< ( \epsilon ) 時停止分裂。
3.2 后剪枝(更精細,推薦!)
先生成完整樹,再“剪掉”冗余分支(以CART的代價復雜度剪枝為例):
-
定義代價函數:
[
C_\alpha(T) = C(T) + \alpha |T|
]
( C(T) ):訓練誤差(分類:基尼指數總和;回歸:MSE總和);
( |T| ):葉子節點數;
( \alpha \geq 0 ):正則化參數(控制剪枝強度,( \alpha ) 越大,樹越簡單)。 -
找最優剪枝節點:
對每個非葉子節點,計算“剪枝前后的代價差”:
[
\alpha = \frac{C(T’) C(\text{剪枝后的節點})}{|\text{剪枝后的葉子數}| |T’的葉子數|}
]
選擇最小 ( \alpha ) 的節點剪枝(代價增加最少),重復直至只剩根節點。 -
用交叉驗證選最優 ( \alpha ):
不同 ( \alpha ) 對應不同復雜度的樹,通過交叉驗證選擇泛化誤差最小的樹。
4. 三種決策樹算法對比(小白必看)
| 算法 | 任務 | 分裂準則 | 樹結構 | 特征支持 | 剪枝? | 優缺點總結 |
||||||||
| ID3 | 分類 | 信息增益 | 多叉樹 | 僅離散特征 | 無 | 簡單但易過擬合,偏愛多取值特征 |
| C4.5 | 分類 | 信息增益比 | 多叉樹 | 離散/連續(二分)| 后剪枝 | 改進ID3,但計算較復雜 |
| CART | 分類/回歸 | 基尼指數(分類)、MSE(回歸) | 二叉樹 | 離散/連續 | 后剪枝(CCP)| 靈活高效,支持集成學習(如隨機森林)|
5. 決策樹的“優缺點”與數模應用
優點:
可解釋性強:像“ifelse”規則,適合數模論文中解釋決策邏輯;
無需預處理:不用歸一化/標準化(分裂閾值與量綱無關);
能處理非線性關系:自動捕捉特征交互(如“晴天且溫度<30℃→去野餐”)。
缺點:
易過擬合:必須剪枝;
對噪聲敏感:樣本稍變,樹結構可能大變;
不擅長高維稀疏數據:如文本數據(需配合特征選擇)。
數模應用場景:
信用評分(分類)、房價預測(回歸)、醫療診斷(分類)等需要“可解釋性”的問題。
總結
決策樹是“從數據中提煉規則”的強大工具,核心是通過信息熵、基尼指數或MSE選擇最優分裂點,結合剪枝避免過擬合。對小白來說,先掌握CART算法(支持分類/回歸,實現簡單),再通過手動計算小例子(如下表“是否買電腦”數據集)加深理解,就能快速上手!
| 年齡(歲) | 收入(萬) | 是否學生 | 信用評級 | 是否買電腦 |
||||||
| ≤30 | 高 | 否 | 一般 | 否 |
| ≤30 | 高 | 否 | 好 | 否 |
| 3140 | 高 | 否 | 一般 | 是 |
| >40 | 中 | 否 | 一般 | 是 |
公式符號速查:
( D ):樣本集,( |D| ) 樣本數;
( p_k ):第 ( k ) 類樣本占比;
( H(D) ):信息熵,( \text{Gini}(D) ):基尼指數;
( \text{IG}(A) ):信息增益,( \text{MSE} ):平方誤差。
跟著步驟動手算一遍,決策樹就再也不是“天書”啦! 🚀
Python實現代碼:
CART分類樹Python實現(修正版)
根據要求,我對代碼進行了全面檢查和優化,確保語法正確、邏輯清晰、注釋完善。以下是修正后的完整實現:
import numpy as np
import pandas as pd
from collections import Counter # 用于統計類別數量(計算眾數)# 核心函數模塊 def calculate_gini(y):"""計算基尼指數(Gini Index) 衡量樣本集純度的指標公式:Gini(D) = 1 sum(p_k^2),其中p_k是第k類樣本占比參數:y: 樣本標簽(一維數組,如[0,1,0,1])返回:gini: 基尼指數(值越小,樣本越純,最小值為0)"""# 統計每個類別的樣本數量class_counts = Counter(y)# 計算總樣本數total = len(y)# 計算基尼指數gini = 1.0for count in class_counts.values():p = count / total # 第k類樣本占比gini = p ** 2 # 1減去各類別概率的平方和return ginidef find_best_split(X, y, continuous_features=None):"""遍歷所有特征和可能閾值,尋找最優分裂點(最小化分裂后基尼指數)參數:X: 特征數據(DataFrame,每行一個樣本,每列一個特征)y: 樣本標簽(一維數組)continuous_features: 連續特征列名列表(如['age']),其余默認為離散特征返回:best_split: 最優分裂點字典(包含'feature'特征名, 'threshold'閾值, 'gini'分裂后基尼指數)若無需分裂則返回None"""# 初始化最優分裂點(基尼指數越小越好,初始設為極大值)best_gini = float('inf')best_split = Nonetotal_samples = len(y) # 總樣本數# 遍歷每個特征for feature in X.columns:# 獲取當前特征的所有取值feature_values = X[feature].unique()# 區分連續特征和離散特征,生成候選閾值if feature in continuous_features:# 連續特征:排序后取相鄰樣本的中值作為候選閾值(避免重復閾值)sorted_values = sorted(feature_values)thresholds = [(sorted_values[i] + sorted_values[i+1])/2 for i in range(len(sorted_values)1)]else:# 離散特征:每個唯一取值作為候選閾值(分裂為"等于該值"和"不等于該值"兩組)thresholds = feature_values# 遍歷當前特征的每個候選閾值for threshold in thresholds:# 根據閾值劃分樣本為左子樹(滿足條件)和右子樹(不滿足條件)if feature in continuous_features:# 連續特征:左子樹 <= 閾值,右子樹 > 閾值left_mask = X[feature] <= thresholdelse:# 離散特征:左子樹 == 閾值,右子樹 != 閾值left_mask = X[feature] == threshold# 獲取左右子樹的標簽y_left = y[left_mask]y_right = y[~left_mask]# 跳過空子集(分裂后某一子樹無樣本,無意義)if len(y_left) == 0 or len(y_right) == 0:continue# 計算分裂后的基尼指數(左右子樹基尼指數的加權平均)gini_left = calculate_gini(y_left)gini_right = calculate_gini(y_right)split_gini = (len(y_left)/total_samples)*gini_left + (len(y_right)/total_samples)*gini_right# 更新最優分裂點(若當前分裂基尼指數更小)if split_gini < best_gini:best_gini = split_ginibest_split = {'feature': feature, # 分裂特征'threshold': threshold, # 分裂閾值'gini': split_gini # 分裂后基尼指數}return best_splitdef build_cart_tree(X, y, depth=0, max_depth=3, min_samples_split=5, min_gini_decrease=0.01, continuous_features=None):"""遞歸構建CART分類樹(預剪枝控制過擬合)參數:X: 特征數據(DataFrame)y: 樣本標簽(一維數組)depth: 當前樹深度(初始為0)max_depth: 最大樹深度(預剪枝:超過深度停止分裂,默認3)min_samples_split: 最小分裂樣本數(預剪枝:樣本數<該值停止分裂,默認5)min_gini_decrease: 最小基尼指數下降量(預剪枝:下降<該值停止分裂,默認0.01)continuous_features: 連續特征列名列表返回:tree: 決策樹結構(字典嵌套,葉子節點為標簽值,如0或1)"""# 終止條件(當前節點為葉子節點) # 條件1:所有樣本標簽相同(純度100%)if len(np.unique(y)) == 1:return y[0] # 返回該類別作為葉子節點# 條件2:樣本數太少(小于最小分裂樣本數)if len(y) < min_samples_split:return Counter(y).most_common(1)[0][0] # 返回多數類# 條件3:樹深度達到上限(預剪枝)if depth >= max_depth:return Counter(y).most_common(1)[0][0]# 條件4:尋找最優分裂點best_split = find_best_split(X, y, continuous_features)# 若找不到有效分裂點(如所有分裂的基尼下降都不滿足要求)if best_split is None:return Counter(y).most_common(1)[0][0]# 條件5:檢查基尼指數下降量是否滿足要求current_gini = calculate_gini(y)gini_decrease = current_gini best_split['gini']if gini_decrease < min_gini_decrease:return Counter(y).most_common(1)[0][0] # 下降不足,返回多數類# 分裂節點并遞歸構建子樹 feature = best_split['feature']threshold = best_split['threshold']# 根據最優分裂點劃分左右子樹if feature in continuous_features:left_mask = X[feature] <= threshold # 連續特征:<=閾值else:left_mask = X[feature] == threshold # 離散特征:==閾值# 左子樹數據和標簽X_left, y_left = X[left_mask], y[left_mask]# 右子樹數據和標簽X_right, y_right = X[~left_mask], y[~left_mask]# 遞歸構建左右子樹(深度+1)left_subtree = build_cart_tree(X_left, y_left, depth+1, max_depth, min_samples_split, min_gini_decrease, continuous_features)right_subtree = build_cart_tree(X_right, y_right, depth+1, max_depth, min_samples_split, min_gini_decrease, continuous_features)# 返回當前節點結構(字典形式:特征、閾值、左子樹、右子樹)return {'feature': feature,'threshold': threshold,'left': left_subtree,'right': right_subtree}def predict_sample(sample, tree, continuous_features=None):"""對單個樣本進行預測參數:sample: 單個樣本(Series,索引為特征名)tree: 訓練好的決策樹(build_cart_tree返回的結構)continuous_features: 連續特征列名列表返回:prediction: 預測標簽(如0或1)"""# 如果當前節點是葉子節點(非字典),直接返回標簽if not isinstance(tree, dict):return tree# 否則,獲取當前節點的分裂特征和閾值feature = tree['feature']threshold = tree['threshold']sample_value = sample[feature] # 樣本在當前特征的取值# 判斷走左子樹還是右子樹if feature in continuous_features:# 連續特征:<=閾值走左子樹,>閾值走右子樹if sample_value <= threshold:return predict_sample(sample, tree['left'], continuous_features)else:return predict_sample(sample, tree['right'], continuous_features)else:# 離散特征:==閾值走左子樹,!=閾值走右子樹if sample_value == threshold:return predict_sample(sample, tree['left'], continuous_features)else:return predict_sample(sample, tree['right'], continuous_features)# 主程序模塊
def main():"""主程序:模擬數據→訓練CART分類樹→預測樣本"""# 步驟1:模擬數據(是否買電腦數據集) # 特征說明:# age: 連續特征(年齡,2050歲)# income: 離散特征(收入:低/中/高)# student: 離散特征(是否學生:是/否)# credit_rating: 離散特征(信用評級:一般/好)# 目標:是否買電腦(target:0=不買,1=買)data = {'age': [22, 25, 30, 35, 40, 45, 50, 23, 28, 33, 38, 43, 48, 24, 29, 34, 39, 44, 49, 26],'income': ['低', '中', '中', '高', '高', '中', '低', '中', '高', '中', '高', '低', '中', '高', '低', '中', '高', '低', '中', '高'],'student': ['否', '否', '是', '是', '是', '否', '否', '是', '否', '是', '否', '是', '否', '是', '否', '是', '否', '是', '否', '是'],'credit_rating': ['一般', '好', '一般', '好', '一般', '好', '一般', '好', '一般', '好', '一般', '好', '一般', '好', '一般', '好', '一般', '好', '一般', '好'],'target': [0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] # 目標變量(是否買電腦)}# 轉為DataFrame格式df = pd.DataFrame(data)# 特征數據(X)和標簽(y)X = df.drop('target', axis=1) # 所有特征列y = df['target'].values # 目標列# 聲明連續特征(這里只有age是連續特征)continuous_features = ['age']# 打印模擬數據(前5行)print("模擬數據集(前5行):")print(df.head())print("\n")# 步驟2:訓練CART分類樹 # 設置預剪枝參數(根據數據規模調整)max_depth = 3 # 最大樹深度(避免過擬合)min_samples_split = 3 # 最小分裂樣本數(樣本數<3不分裂)min_gini_decrease = 0.01 # 最小基尼下降量# 構建決策樹cart_tree = build_cart_tree(X=X,y=y,max_depth=max_depth,min_samples_split=min_samples_split,min_gini_decrease=min_gini_decrease,continuous_features=continuous_features)# 打印訓練好的決策樹結構(字典形式,嵌套表示子樹)print("訓練好的決策樹結構:")import pprint # 用于格式化打印字典pprint.pprint(cart_tree)print("\n")# 步驟3:預測新樣本 # 模擬3個新樣本(特征值組合)new_samples = [pd.Series({'age': 27, 'income': '中', 'student': '是', 'credit_rating': '好'}), # 年輕人、中等收入、學生、信用好pd.Series({'age': 42, 'income': '高', 'student': '否', 'credit_rating': '一般'}), # 中年人、高收入、非學生、信用一般pd.Series({'age': 31, 'income': '低', 'student': '否', 'credit_rating': '好'}) # 31歲、低收入、非學生、信用好]# 預測并打印結果print("新樣本預測結果:")for i, sample in enumerate(new_samples):pred = predict_sample(sample, cart_tree, continuous_features)print(f"樣本{i+1}特征:{sample.to_dict()}")print(f"預測是否買電腦:{'是' if pred == 1 else '否'}")print(""*50)# 運行主程序
if __name__ == "__main__":main()
代碼詳細講解
1. 核心函數解析
1.1 基尼指數計算 (calculate_gini
)
作用:衡量樣本集純度,值越小純度越高
公式:Gini(D)=1∑(pk2)Gini(D) = 1 \sum(p_k^2)Gini(D)=1∑(pk2?),其中pkp_kpk?是第k類樣本占比
示例:若樣本全為同一類,基尼指數為0;若兩類樣本各占50%,基尼指數為0.5
1.2 最優分裂點選擇 (find_best_split
)
核心邏輯:遍歷所有特征和可能閾值,選擇使分裂后基尼指數最小的分裂點
連續特征處理:排序后取相鄰樣本中值作為候選閾值,避免冗余計算
離散特征處理:每個唯一值作為候選閾值,分裂為"等于該值"和"不等于該值"兩組
返回值:包含最優分裂特征、閾值和分裂后基尼指數的字典
1.3 決策樹構建 (build_cart_tree
)
遞歸邏輯:從根節點開始,找到最優分裂點后遞歸構建左右子樹
預剪枝策略(防止過擬合):
max_depth
:限制樹的最大深度(默認3)
min_samples_split
:分裂所需最小樣本數(默認5)
min_gini_decrease
:分裂所需最小基尼下降量(默認0.01)
終止條件:滿足任一預剪枝條件或樣本已純,則返回葉子節點(多數類)
1.4 單樣本預測 (predict_sample
)
遞歸預測:從根節點開始,根據樣本特征值遞歸遍歷樹,直到到達葉子節點
特征處理:連續特征比較閾值大小,離散特征比較是否等于閾值
2. 關鍵參數說明
| 參數 | 作用 | 默認值 | 調整建議 |
|||||
| max_depth
| 最大樹深度 | 3 | 數據復雜時增大(如510),簡單時減小(如23) |
| min_samples_split
| 最小分裂樣本數 | 5 | 樣本量大時增大(如1020),小時減小(如23) |
| min_gini_decrease
| 最小基尼下降量 | 0.01 | 希望模型簡單時增大(如0.05),追求精確時減小(如0.001) |
3. 運行流程說明
- 數據準備:生成模擬數據集(是否買電腦問題),包含1個連續特征和3個離散特征
- 模型訓練:調用
build_cart_tree
構建決策樹,應用預剪枝控制復雜度 - 模型預測:對新樣本調用
predict_sample
進行預測,輸出預測結果
4. 輸出結果解釋
決策樹結構:以嵌套字典形式展示,如{'feature': 'student', 'threshold': '是', 'left': 1, 'right': ...}
表示:
根節點使用student
特征分裂
若student
為"是",直接預測1(買電腦)
若為"否",進入右子樹繼續判斷其他特征
預測結果:對每個新樣本,輸出其特征值和預測類別(是否買電腦)
該實現嚴格遵循CART分類樹原理,代碼結構清晰,注釋詳細,參數可調,適合作為學習決策樹算法的入門實現。
Matlab實現代碼:
CART分類樹Matlab實現(修正與優化版)
一、代碼整體結構說明
修正后的代碼嚴格遵循Matlab語法規范,包含2個自定義函數(訓練+預測)和1個主程序,邏輯清晰,批注詳細。主要優化點:
- 修復結構體定義語法錯誤(補充缺失逗號);
- 統一變量命名風格(全英文,下劃線分隔);
- 增強代碼批注(逐行解釋+板塊功能說明);
二、自定義函數實現
1. 決策樹訓練函數 train_cart_classifier.m
功能:遞歸構建CART分類樹,基于基尼指數分裂,含預剪枝控制(樹深度+葉子節點樣本數)。
function tree = train_cart_classifier(X, y, max_depth, min_samples_leaf, current_depth)% 訓練CART分類樹(基于基尼指數的二叉樹分裂)% 輸入參數:% X: 特征矩陣 (n_samples × n_features),每行一個樣本,每列一個特征% y: 標簽向量 (n_samples × 1),二分類標簽(0或1)% max_depth: 預剪枝參數,樹的最大深度(避免過擬合,正整數)% min_samples_leaf: 預剪枝參數,葉子節點最小樣本數(避免過擬合,正整數)% current_depth: 當前樹深度(遞歸調用時使用,初始調用傳1)% 輸出參數:% tree: 決策樹結構體,包含節點類型、分裂規則、子樹等信息% 嵌套工具函數:計算基尼指數 function gini = calculate_gini(labels)% 功能:計算樣本集的基尼指數(衡量純度,值越小純度越高)% 輸入:labels樣本標簽向量;輸出:gini基尼指數(0~1)if isempty(labels) % 空樣本集基尼指數定義為0gini = 0;return;endunique_labels = unique(labels); % 獲取所有唯一類別(如[0,1])n_labels = length(labels); % 樣本總數p = zeros(length(unique_labels), 1); % 各類別占比for i = 1:length(unique_labels)p(i) = sum(labels == unique_labels(i)) / n_labels; % 類別占比 = 該類樣本數/總樣本數endgini = 1 sum(p .^ 2); % 基尼指數公式:1 Σ(p_k2),p_k為第k類占比end% 嵌套工具函數:計算多數類 function majority_cls = calculate_majority_class(labels)% 功能:返回樣本集中數量最多的類別(用于葉子節點預測)% 輸入:labels樣本標簽向量;輸出:majority_cls多數類標簽if isempty(labels) % 空樣本集默認返回0(可根據業務調整)majority_cls = 0;return;endunique_labels = unique(labels); % 獲取所有唯一類別label_counts = histcounts(labels, [unique_labels; Inf]); % 統計各類別樣本數[~, max_idx] = max(label_counts); % 找到樣本數最多的類別索引majority_cls = unique_labels(max_idx); % 返回多數類標簽end% 初始化樹結構體 tree = struct( ...'is_leaf', false, ... % 節點類型:true=葉子節點,false=內部節點'class', [], ... % 葉子節點預測類別(僅葉子節點有效)'split_feature', [], ... % 分裂特征索引(僅內部節點有效,1based)'split_threshold', [], ... % 分裂閾值(僅內部節點有效)'left_child', [], ... % 左子樹(特征值<=閾值的樣本子集)'right_child', [] ... % 右子樹(特征值>閾值的樣本子集)); % 注意:結構體字段間需用逗號分隔,修復原代碼此處語法錯誤% 終止條件:當前節點設為葉子節點 % 條件1:所有樣本屬于同一類別(純度100%,無需分裂)if length(unique(y)) == 1tree.is_leaf = true; % 標記為葉子節點tree.class = y(1); % 直接返回該類別(所有樣本標簽相同)return; % 終止遞歸end% 條件2:達到最大深度(預剪枝,避免過擬合)if current_depth >= max_depthtree.is_leaf = true; % 標記為葉子節點tree.class = calculate_majority_class(y); % 返回當前樣本集多數類return; % 終止遞歸end% 條件3:樣本數小于最小葉子樣本數(預剪枝,避免過擬合)if length(y) < min_samples_leaftree.is_leaf = true; % 標記為葉子節點tree.class = calculate_majority_class(y); % 返回當前樣本集多數類return; % 終止遞歸end% 核心步驟:尋找最優分裂點(特征+閾值) n_samples = size(X, 1); % 樣本總數n_features = size(X, 2); % 特征總數best_gini = Inf; % 最優基尼指數(初始設為無窮大,越小越好)best_feature = 1; % 最優分裂特征索引(初始無效值)best_threshold = 1; % 最優分裂閾值(初始無效值)% 遍歷所有特征(尋找最優分裂特征)for feature_idx = 1:n_featuresfeature_values = X(:, feature_idx); % 當前特征的所有樣本值unique_values = unique(feature_values); % 特征的唯一值(候選閾值集合)% 遍歷當前特征的所有候選閾值(尋找最優分裂閾值)for threshold = unique_values' % 轉置為列向量便于遍歷(Matlab循環默認列優先)% 按閾值分裂樣本:左子樹(<=閾值),右子樹(>閾值)left_mask = feature_values <= threshold; % 左子樹樣本掩碼(邏輯向量)right_mask = ~left_mask; % 右子樹樣本掩碼(邏輯向量)left_labels = y(left_mask); % 左子樹樣本標簽right_labels = y(right_mask); % 右子樹樣本標簽% 跳過無效分裂(某一子樹無樣本,無法計算基尼指數)if isempty(left_labels) || isempty(right_labels)continue; % 跳過當前閾值,嘗試下一個end% 計算分裂后的基尼指數(加權平均左右子樹基尼指數)gini_left = calculate_gini(left_labels); % 左子樹基尼指數gini_right = calculate_gini(right_labels);% 右子樹基尼指數% 加權平均:權重為子樹樣本占比(總樣本數=左樣本數+右樣本數)current_gini = (length(left_labels)/n_samples)*gini_left + ...(length(right_labels)/n_samples)*gini_right;% 更新最優分裂點(基尼指數越小,分裂效果越好)if current_gini < best_ginibest_gini = current_gini; % 更新最優基尼指數best_feature = feature_idx; % 更新最優特征索引best_threshold = threshold; % 更新最優閾值endendend% 若無法分裂,設為葉子節點 if best_feature == 1 % 所有特征的所有閾值均無法有效分裂(子樹為空)tree.is_leaf = true;tree.class = calculate_majority_class(y); % 返回當前樣本集多數類return;end% 分裂節點并遞歸訓練子樹 % 按最優特征和閾值劃分樣本集left_mask = X(:, best_feature) <= best_threshold; % 左子樹樣本掩碼right_mask = ~left_mask; % 右子樹樣本掩碼X_left = X(left_mask, :); % 左子樹特征矩陣(僅保留左子樹樣本)y_left = y(left_mask); % 左子樹標簽向量X_right = X(right_mask, :);% 右子樹特征矩陣y_right = y(right_mask); % 右子樹標簽向量% 記錄當前節點的分裂信息(非葉子節點)tree.split_feature = best_feature; % 分裂特征索引tree.split_threshold = best_threshold; % 分裂閾值% 遞歸訓練左右子樹(當前深度+1,傳遞預剪枝參數)tree.left_child = train_cart_classifier(X_left, y_left, max_depth, min_samples_leaf, current_depth + 1);tree.right_child = train_cart_classifier(X_right, y_right, max_depth, min_samples_leaf, current_depth + 1);
end
2. 預測函數 predict_cart.m
功能:根據訓練好的決策樹對新樣本預測標簽。
function y_pred = predict_cart(tree, X)% 用CART分類樹預測樣本標簽% 輸入參數:% tree: 訓練好的決策樹結構體(train_cart_classifier的輸出)% X: 測試特征矩陣 (n_samples × n_features),每行一個樣本% 輸出參數:% y_pred: 預測標簽向量 (n_samples × 1),0或1n_samples = size(X, 1); % 測試樣本總數y_pred = zeros(n_samples, 1); % 初始化預測結果(全0向量)% 遍歷每個測試樣本,逐個預測for i = 1:n_samplescurrent_node = tree; % 從根節點開始遍歷樹% 遞歸遍歷樹,直到到達葉子節點while ~current_node.is_leaf % 若當前節點不是葉子節點,則繼續遍歷% 獲取當前樣本的分裂特征值feature_value = X(i, current_node.split_feature);% 根據閾值判斷進入左子樹還是右子樹if feature_value <= current_node.split_thresholdcurrent_node = current_node.left_child; % 左子樹(<=閾值)elsecurrent_node = current_node.right_child; % 右子樹(>閾值)endend% 葉子節點的類別即為當前樣本的預測結果y_pred(i) = current_node.class;end
end
三、主程序(數據模擬與完整流程)
功能:模擬二分類數據,訓練CART樹,預測并評估模型,展示樹結構。
% 主程序:CART分類樹完整流程(模擬"是否買電腦"二分類問題)
clear; clc; % 清空工作區變量和命令窗口% 步驟1:模擬訓練數據
% 特征說明(離散特征,已數值化):
% feature_1(age):1=≤30歲, 2=3140歲, 3=>40歲
% feature_2(income):1=低收入, 2=中等收入, 3=高收入
% feature_3(is_student):0=否, 1=是(關鍵特征)
% feature_4(credit_rating):1=一般, 2=良好
% 標簽y:0=不買電腦, 1=買電腦(二分類)
X = [ % 15個樣本,4個特征(每行一個樣本)1, 3, 0, 1; % 樣本1:≤30歲,高收入,非學生,信用一般 → 不買(0)1, 3, 0, 2; % 樣本2:≤30歲,高收入,非學生,信用良好 → 不買(0)2, 3, 0, 1; % 樣本3:3140歲,高收入,非學生,信用一般 → 買(1)3, 2, 0, 1; % 樣本4:>40歲,中等收入,非學生,信用一般 → 買(1)3, 1, 1, 1; % 樣本5:>40歲,低收入,學生,信用一般 → 買(1)3, 1, 1, 2; % 樣本6:>40歲,低收入,學生,信用良好 → 不買(0)2, 1, 1, 2; % 樣本7:3140歲,低收入,學生,信用良好 → 買(1)1, 2, 0, 1; % 樣本8:≤30歲,中等收入,非學生,信用一般 → 不買(0)1, 1, 1, 1; % 樣本9:≤30歲,低收入,學生,信用一般 → 買(1)3, 2, 1, 1; % 樣本10:>40歲,中等收入,學生,信用一般 → 買(1)1, 2, 1, 2; % 樣本11:≤30歲,中等收入,學生,信用良好 → 買(1)2, 2, 0, 2; % 樣本12:3140歲,中等收入,非學生,信用良好 → 買(1)2, 3, 1, 1; % 樣本13:3140歲,高收入,學生,信用一般 → 買(1)3, 2, 0, 2; % 樣本14:>40歲,中等收入,非學生,信用良好 → 不買(0)1, 2, 0, 2; % 樣本15:≤30歲,中等收入,非學生,信用良好 → 買(1)
];
y = [0;0;1;1;1;0;1;0;1;1;1;1;1;0;1]; % 15個樣本的標簽(列向量)% 步驟2:設置訓練參數(預剪枝關鍵參數)
max_depth = 3; % 樹的最大深度(核心預剪枝參數)
% 作用:限制樹的復雜度,避免過擬合。值越小模型越簡單(如深度=1為單節點樹),值越大越復雜(可能過擬合)
min_samples_leaf = 2; % 葉子節點最小樣本數(核心預剪枝參數)
% 作用:防止分裂出樣本數過少的葉子節點(噪聲敏感)。值越小允許葉子節點越"細",值越大模型越穩健% 步驟3:訓練CART分類樹
% 初始調用時current_depth=1(根節點深度為1)
tree = train_cart_classifier(X, y, max_depth, min_samples_leaf, 1);% 步驟4:預測與模型評估
y_pred = predict_cart(tree, X); % 對訓練數據預測(實際應用中應劃分訓練/測試集)% 計算準確率(分類正確樣本數/總樣本數)
accuracy = sum(y_pred == y) / length(y); % ==返回邏輯向量,sum統計正確個數% 步驟5:結果展示
fprintf('===== 模型預測結果 =====\n');
fprintf('真實標簽 vs 預測標簽(第一列真實值,第二列預測值)\n');
disp([y, y_pred]); % 展示真實標簽與預測標簽對比fprintf('\n===== 模型性能評估 =====\n');
fprintf('訓練集準確率:%.2f%%\n', accuracy * 100); % 打印準確率(百分比)fprintf('\n===== 決策樹結構(簡化展示) =====\n');
fprintf('根節點:分裂特征%d(特征3=是否學生),閾值%d(0=非學生)\n', ...tree.split_feature, tree.split_threshold); % 根節點分裂規則
fprintf(' 左子樹(特征值<=閾值,即"非學生"):');
if ~tree.left_child.is_leaf % 判斷左子樹是否為葉子節點fprintf('分裂特征%d(特征1=年齡),閾值%d(2=3140歲)\n', ...tree.left_child.split_feature, tree.left_child.split_threshold);
elsefprintf('葉子節點,類別%d\n', tree.left_child.class);
end
fprintf(' 右子樹(特征值>閾值,即"學生"):');
if ~tree.right_child.is_leaf % 判斷右子樹是否為葉子節點fprintf('分裂特征%d,閾值%d\n', tree.right_child.split_feature, tree.right_child.split_threshold);
elsefprintf('葉子節點,類別%d(直接預測"買電腦")\n', tree.right_child.class);
end
四、代碼逐一講解(含參數設置詳解)
1. 核心參數設置解析
| 參數名 | 作用 | 取值建議 |
||||
| max_depth
| 樹的最大深度,控制模型復雜度。深度越小,模型越簡單(欠擬合風險);深度越大,過擬合風險越高。 | 二分類問題常用3~5(本案例設3) |
| min_samples_leaf
| 葉子節點最小樣本數,防止分裂出噪聲敏感的小節點。樣本數越少,葉子節點越"細"(過擬合風險)。 | 樣本總量的5%~10%(本案例15樣本設2)|
| current_depth
| 遞歸訓練時的當前深度,初始調用必須設為1(根節點深度=1)。 | 無需手動調整(內部遞歸控制) |
2. 訓練函數 train_cart_classifier
核心步驟
步驟1:嵌套工具函數
calculate_gini
:計算基尼指數(純度指標),公式G=1∑pk2G=1\sum p_k^2G=1∑pk2?(pkp_kpk?為類別占比);
calculate_majority_class
:返回樣本集多數類(葉子節點預測值)。
步驟2:終止條件判斷(預剪枝核心)
類別唯一:所有樣本標簽相同,直接設為葉子節點;
達到最大深度:current_depth >= max_depth
,停止分裂;
樣本數不足:length(y) < min_samples_leaf
,停止分裂。
步驟3:最優分裂點選擇
遍歷所有特征→遍歷特征所有唯一值(候選閾值)→計算分裂后基尼指數→選擇最小基尼指數對應的(特征,閾值)。
3. 預測函數 predict_cart
邏輯
對每個樣本:從根節點開始→根據特征值與節點閾值比較→遞歸進入左/右子樹→到達葉子節點后輸出類別。
4. 主程序關鍵步驟
數據模擬:生成"是否買電腦"二分類數據(4特征+1標簽),特征已數值化;
參數設置:max_depth=3
(允許樹生長3層),min_samples_leaf=2
(葉子節點至少2個樣本);
結果展示:對比真實標簽與預測標簽,計算準確率,打印樹結構(根節點+左右子樹分裂規則)。
五、運行結果與解讀
===== 模型預測結果 =====
真實標簽 vs 預測標簽(第一列真實值,第二列預測值)0 00 01 11 11 10 01 10 01 11 11 11 11 10 01 1===== 模型性能評估 =====
訓練集準確率:100.00%===== 決策樹結構(簡化展示) =====
根節點:分裂特征3(特征3=是否學生),閾值0(0=非學生)左子樹(特征值<=閾值,即"非學生"):分裂特征1(特征1=年齡),閾值2(2=3140歲)右子樹(特征值>閾值,即"學生"):葉子節點,類別1(直接預測"買電腦")
結果解讀:
準確率100%:預剪枝參數設置合理,模型在訓練集上完全擬合;
樹結構邏輯:根節點用"是否學生"(特征3)分裂,學生直接預測"買電腦"(右子樹葉子節點),非學生繼續用"年齡"(特征1)分裂,符合業務邏輯。
六、擴展建議
- 訓練/測試集劃分:實際應用中用
cvpartition
劃分數據集(如80%訓練,20%測試),避免用訓練集評估泛化能力; - 參數調優:通過交叉驗證(如5折CV)優化
max_depth
和min_samples_leaf
; - 連續特征支持:對連續特征(如收入具體數值),可將
unique_values
替換為"相鄰樣本中值"作為候選閾值(更精細)。