一、什么是決策樹
1.基本概念
決策樹是一種樹形結構,由結點(node) 和有向邊(directed edge) 組成。其中結點分為兩類:
- 內部結點(internal node):表示一個屬性(特征)
- 葉結點(leaf node):表示一個類別
決策樹是常用的分類機器學習方法。
2.實際舉例說明
以 “相親對象分類系統” 為例構建簡單決策樹:
- 內部結點(長方形):特征 “有無房子”“有無上進心”
- 葉結點(橢圓形):類別 “值得考慮”“備胎”“Say Goodbye”
- 分類邏輯:
- 相親對象有房子→劃分為 “值得認真考慮”
沒有房子但有上進心→劃分為 “備胎”既沒有房子也沒有上進心→劃分為 “Say Goodbye
實際分類中存在多個特征量,可構建多種決策樹,核心問題是如何篩選出最優決策樹。
二、介紹建立決策樹的算法
決策樹算法的核心差異在于特征選擇指標,常見算法對比如下:
算法 | 特征選擇指標 | 核心邏輯 |
ID3 | 信息增益 | 信息增益越大,特征對降低數據不確定性的能力越強,優先作為上層結點 |
C4.5 | 信息增益率 | 解決 ID3 對多值特征的偏好問題,通過 “增益率 = 信息增益 / 特征固有值” 平衡選擇 |
CART | 基尼指數 | 基尼指數越小,數據純度越高,優先選擇使基尼指數下降最多的特征 |
本文重點講解ID3 算法,以下是其核心概念與公式:
1. 某個分類的信息
單個分類的信息表示該分類的不確定性,公式為:
其中,P(x_i) 是選擇該分類的概率。
2. 熵(Entropy)
熵是隨機變量不確定性的度量,定義為信息的期望值,公式為:
其中,n 是分類的數目;熵值越大,數據不確定性越高。
3. 經驗熵(Empirical Entropy)
4. 條件熵(Conditional Entropy)
已知隨機變量 X 的條件下,隨機變量 Y 的不確定性,公式為:
其中,p_i?是 X=x_i?的概率,H(Y|X=x_i)?是 X=x_i?時 Y 的熵。
5. 信息增益(Information Gain)
樣本集 D?的經驗熵 H(D)?與特征 A?給定條件下 D?的經驗條件熵 H(D|A)?之差,公式為:
關鍵結論:特征的信息增益值越大,該特征對分類的貢獻越強,應優先作為決策樹的上層結點。
三、決策樹的一般流程
決策樹構建分為 6 個步驟,適用于各類決策樹算法:
- 收集數據:通過爬蟲、問卷、數據庫查詢等方式獲取原始數據,無固定方法。
- 準備數據:樹構造算法僅支持標稱型數據(離散類別數據),需將數值型數據離散化(如將 “年齡 20-30” 劃分為 “青年”)。
- 分析數據:構建樹后,通過可視化、誤差分析等方式驗證樹結構是否符合預期。
- 訓練算法:根據特征選擇指標(如 ID3 的信息增益),遞歸構建決策樹的數據結構。
- 測試算法:使用測試集計算決策樹的錯誤率,評估模型性能。
- 使用算法:將訓練好的決策樹應用于實際場景(如貸款審批、客戶分類),并持續迭代優化。
四、實際舉例構建決策樹
以 “貸款申請分類” 為例,使用 ID3 算法構建決策樹。
1. 數據集準備
貸款申請樣本數據表(原始)
ID | 年齡 | 有工作 | 有自己的房子 | 信貸情況 | 類別(是否給貸款) |
1 | 青年 | 否 | 否 | 一般 | 否 |
2 | 青年 | 否 | 否 | 好 | 否 |
3 | 青年 | 是 | 否 | 好 | 是 |
4 | 青年 | 是 | 是 | 一般 | 是 |
5 | 青年 | 否 | 否 | 一般 | 否 |
6 | 中年 | 否 | 否 | 一般 | 否 |
7 | 中年 | 否 | 否 | 好 | 否 |
8 | 中年 | 是 | 是 | 好 | 是 |
9 | 中年 | 否 | 是 | 非常好 | 是 |
10 | 中年 | 否 | 是 | 非常好 | 是 |
11 | 老年 | 否 | 是 | 非常好 | 是 |
12 | 老年 | 否 | 是 | 好 | 是 |
13 | 老年 | 是 | 否 | 好 | 是 |
14 | 老年 | 是 | 否 | 非常好 | 是 |
15 | 老年 | 否 | 否 | 一般 | 否 |
數據編碼(標稱化處理)
- 年齡:0 = 青年,1 = 中年,2 = 老年
- 有工作:0 = 否,1 = 是
- 有自己的房子:0 = 否,1 = 是
- 信貸情況:0 = 一般,1 = 好,2 = 非常好
- 類別:no = 否,yes = 是
數據集代碼定義
from math import log
def createDataSet():dataSet = [[0, 0, 0, 0, 'no'], # 樣本1[0, 0, 0, 1, 'no'], # 樣本2[0, 1, 0, 1, 'yes'], # 樣本3[0, 1, 1, 0, 'yes'], # 樣本4[0, 0, 0, 0, 'no'], # 樣本5[1, 0, 0, 0, 'no'], # 樣本6[1, 0, 0, 1, 'no'], # 樣本7[1, 1, 1, 1, 'yes'], # 樣本8[1, 0, 1, 2, 'yes'], # 樣本9[1, 0, 1, 2, 'yes'], # 樣本10[2, 0, 1, 2, 'yes'], # 樣本11[2, 0, 1, 1, 'yes'], # 樣本12[2, 1, 0, 1, 'yes'], # 樣本13[2, 1, 0, 2, 'yes'], # 樣本14[2, 0, 0, 0, 'no'] # 樣本15]labels = ['年齡', '有工作', '有自己的房子', '信貸情況'] # 特征標簽labels1 = ['放貸', '不放貸'] # 分類標簽return dataSet, labels, labels1 # 返回數據集、特征標簽、分類標簽
2. 計算經驗熵 H (D)
數學計算
樣本集 D?共 15 個樣本,其中 “放貸(yes)”9 個,“不放貸(no)”6 個,經驗熵為:
代碼實現
def calcShannonEnt(dataSet):numEntires = len(dataSet) # 數據集行數(樣本數)labelCounts = {} # 存儲每個標簽的出現次數for featVec in dataSet:currentLabel = featVec[-1] # 提取最后一列(分類標簽)if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0labelCounts[currentLabel] += 1 # 標簽計數shannonEnt = 0.0 # 初始化經驗熵for key in labelCounts:prob = float(labelCounts[key]) / numEntires # 標簽出現概率shannonEnt -= prob * log(prob, 2) # 計算經驗熵return shannonEnt# 測試代碼
if __name__ == '__main__':dataSet, features, labels1 = createDataSet()print("數據集:", dataSet)print("經驗熵H(D):", calcShannonEnt(dataSet)) # 輸出:0.9709505944546686
3. 計算信息增益(選擇最優特征)
數學計算(以 “有自己的房子” 為例)
設特征 A_3(有自己的房子),取值為 “是(1)” 和 “否(0)”:
- 子集 D_1(A_3=1):共 9 個樣本,均為 “yes”,經驗熵 H(D_1)=0
- 子集 D_2(A_3=0):共 6 個樣本,“yes” 3 個、“no” 3 個
- 經驗熵?
- 條件熵?
- 信息增益
(注:原文計算結果為 0.420,此處以原文代碼輸出為準)
其他特征的信息增益計算結果:
- 年齡(A_1):0.083
- 有工作(A_2):0.324
- 信貸情況(A_4):0.363
結論:特征 “有自己的房子(A_3)” 信息增益最大,作為決策樹的根節點。
代碼實現
"""
函數:按照給定特征劃分數據集
參數:dataSet - 待劃分數據集axis - 特征索引value - 特征取值
返回:retDataSet - 劃分后的子集
"""
def splitDataSet(dataSet, axis, value):retDataSet = []for featVec in dataSet:if featVec[axis] == value:reducedFeatVec = featVec[:axis] # 去掉當前特征列reducedFeatVec.extend(featVec[axis+1:]) # 拼接剩余列retDataSet.append(reducedFeatVec)return retDataSet"""
函數:選擇最優特征
參數:dataSet - 數據集
返回:bestFeature - 最優特征索引
"""
def chooseBestFeatureToSplit(dataSet):numFeatures = len(dataSet[0]) - 1 # 特征數量(減去分類列)baseEntropy = calcShannonEnt(dataSet) # 基礎經驗熵bestInfoGain = 0.0 # 最優信息增益bestFeature = -1 # 最優特征索引for i in range(numFeatures):featList = [example[i] for example in dataSet] # 提取第i列特征uniqueVals = set(featList) # 特征的唯一取值newEntropy = 0.0 # 條件熵for value in uniqueVals:subDataSet = splitDataSet(dataSet, i, value) # 劃分子集prob = len(subDataSet) / float(len(dataSet)) # 子集概率newEntropy += prob * calcShannonEnt(subDataSet) # 累加條件熵infoGain = baseEntropy - newEntropy # 計算信息增益print(f"第{i}個特征({labels[i]})的增益為:{infoGain:.3f}")if infoGain > bestInfoGain:bestInfoGain = infoGainbestFeature = ireturn bestFeature# 測試代碼
if __name__ == '__main__':dataSet, labels, labels1 = createDataSet()bestFeature = chooseBestFeatureToSplit(dataSet)print(f"最優特征索引值:{bestFeature}(對應特征:{labels[bestFeature]})")# 輸出:最優特征索引值:2(對應特征:有自己的房子)
4. 生成決策樹(遞歸構建)
核心邏輯
- 若樣本集所有樣本屬于同一類別,直接返回該類別(葉節點);
- 若無特征可劃分或樣本特征全相同,返回出現次數最多的類別(葉節點);
- 選擇最優特征作為當前節點,按特征取值劃分子集;
- 對每個子集遞歸執行上述步驟,生成子樹。
代碼實現
import operator"""
函數:統計出現次數最多的類別
參數:classList - 類別列表
返回:sortedClassCount[0][0] - 最多類別
"""
def majorityCnt(classList):classCount = {}for vote in classList:if vote not in classCount.keys():classCount[vote] = 0classCount[vote] += 1# 按類別次數降序排序sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]"""
函數:創建決策樹
參數:dataSet - 訓練集labels - 特征標簽featLabels - 存儲選擇的最優特征
返回:myTree - 決策樹(字典結構)
"""
def createTree(dataSet, labels, featLabels):classList = [example[-1] for example in dataSet] # 提取所有類別# 情況1:所有樣本類別相同if classList.count(classList[0]) == len(classList):return classList[0]# 情況2:無特征可劃分或特征全相同if len(dataSet[0]) == 1 or len(labels) == 0:return majorityCnt(classList)# 情況3:遞歸構建樹bestFeat = chooseBestFeatureToSplit(dataSet) # 最優特征索引bestFeatLabel = labels[bestFeat] # 最優特征標簽featLabels.append(bestFeatLabel)myTree = {bestFeatLabel: {}} # 決策樹字典del(labels[bestFeat]) # 刪除已使用的特征標簽featValues = [example[bestFeat] for example in dataSet] # 最優特征的所有取值uniqueVals = set(featValues) # 唯一取值for value in uniqueVals:subLabels = labels[:] # 復制特征標簽(避免遞歸修改原列表)# 遞歸生成子樹myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels, featLabels)return myTree# 測試代碼
if __name__ == '__main__':dataSet, labels, labels