sklearn.DecisionTreeClassifier決策樹簡單使用
- 1.決策樹算法基礎
- 2.sklearn.DecisionTreeClassifier簡單實踐
- 2.1 決策樹類
- 2.3 決策樹構建
- 2.3.1全數據集擬合,決策樹可視化
- 2.3.2交叉驗證實驗
- 2.3.3超參數搜索
- 2.3.4模型保存與導入
- 2.3.5固定隨機數種子
- 參考資料
1.決策樹算法基礎
決策樹模型可以用來做 回歸/分類 任務。
每次選擇一個屬性/特征,依據特征的閾值,將特征空間劃分為 與 坐標軸平行的一些決策區域。如果是分類問題,每個決策區域的類別為該該區域中多數樣本的類別;如果為回歸問題,每個決策區域的回歸值為該區域中所有樣本值的均值。
決策樹復雜程度 依賴于 特征空間的幾何形狀。根節點->葉子節點的一條路徑產生一條決策規則。
決策樹最大優點:可解釋性強
決策樹最大缺點:不是分類正確率最高的模型
決策樹的學習是一個NP-Complete問題,所以實際中使用啟發性的規則來構建決策樹。
step1:選最好的特征來劃分數據集
step2:對上一步劃分的子集重復步驟1,直至停止條件(節點純度/分裂增益/樹深度)
不同的特征衡量標準,產生了不同的決策樹生成算法:
算法 | 最優特征選擇標準 |
---|---|
ID3 | 信息增益:Gain(A)=H(D)?H(D∥A)Gain(A)=H(D)-H(D\|A)Gain(A)=H(D)?H(D∥A) |
C4.5 | 信息增益率:GainRatio(A)=Gain(A)/Split(A)GainRatio(A)=Gain(A)/Split(A)GainRatio(A)=Gain(A)/Split(A) |
CART | gini指數增益:Gini(D)?Gini(D∥A)Gini(D)-Gini(D\|A)Gini(D)?Gini(D∥A) |
k個類別,類別分布的gini 指數如下,gini指數越大,樣本的不確定性越大:
Gini(D)=∑k=1Kpk(1?pk)=1?∑k=1Kpk2Gini(D) =\sum_{k=1}^Kp_k(1-p_k)=1-\sum_{k=1}^Kp_k^2Gini(D)=k=1∑K?pk?(1?pk?)=1?k=1∑K?pk2?
CART – Classification and Regression Trees 的縮寫1984年提出的一個特征選擇算法,對特征進行是/否判斷,生成一棵二叉樹。且每次選擇完特征后不對特征進行剔除操作,所有同一條決策規則上可能出現重復特征的情況。
2.sklearn.DecisionTreeClassifier簡單實踐
Scikit-learn(sklearn)是機器學習中常用的第三方模塊,其建立在NumPy、Scipy、MatPlotLib之上,包括了回歸,降維,分類,聚類方法。
sklearn 通過以下兩個類實現了 決策分類樹 和 決策回歸樹
sklearn 實現了ID3和Cart 算法,criterion默認為"gini"系數,對應為CART算法。還可設置為"entropy",對應為ID3。(計算機最擅長做的事:規則重復計算,sklearn通過對每個特征的每個切分點計算信息增益/gini增益,得到當前數據集合最優的特征及最優劃分點)
2.1 決策樹類
sklearn.tree.DecisionTreeClassifier(criterion=’gini’*,splitter=’best’, max_depth=None,
min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0,
max_features=None, random_state=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None, class_weight=None, presort=False)
DecisionTreeRegressor(criterion=’mse’, splitter=’best’,
max_depth=None, min_samples_split=2, min_samples_leaf=1,
min_weight_fraction_leaf=0.0, max_features=None, random_state=None,
max_leaf_nodes=None, min_impurity_decrease=0.0,
min_impurity_split=None, presort=False)
Criterion | 選擇屬性的準則–gini–cart算法 |
---|---|
splitter | 特征劃分點的選擇策略:best 特征的所有劃分點中找最優 |
random 部分劃分點中找最優 | |
max_depth | 決策樹的最大深度,none/int 限制/不限制決策樹的深度 |
min_samples_split | 節點 繼續劃分需要的最小樣本數,如果少于這個數,節點將不再劃分 |
min_samples_leaf | 限制葉子節點的最少樣本數量,如果葉子節點的樣本數量過少會被剪枝 |
min_weight_fraction_leaf | 葉子節點的剪枝規則 |
max_features | 選取用于分類的特征的數量 |
random_state | 隨機數生成的一些規則、 |
max_leaf_nodes | 限制葉子節點的數量,防止過擬合 |
min_impurity_decrease | 表示結點減少的最小不純度,控制節點的繼續分割規律 |
min_impurity_split | 表示結點劃分的最小不純度,控制節點的繼續分割規律 |
class_weight | 設置各個類別的權重,針對類別不均衡的數據集使用 |
不適用于決策樹回歸 | |
presort | 控制決策樹劃分的速度 |
2.3 決策樹構建
采用sklearn內置數據集鳶尾花數據集做實驗。
導入第三方庫
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
import graphviz
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score
import joblib
plt.switch_backend('agg')
2.3.1全數據集擬合,決策樹可視化
def demo1():# 全數據集擬合,決策樹可視化iris = load_iris()x, y = load_iris(return_X_y = True) # x[list]-feature,y[]-label clf = tree.DecisionTreeClassifier() # 實例化了一個類,可以指定類參數,定制決策樹模型clf = clf.fit(x,y) # 訓練模型print("feature name ", iris.feature_names) # 特征列表, 自己的數據可視化時,構建一個特征列表即可print("label name ",iris.target_names) # 類別列表dot_data = tree.export_graphviz(clf, out_file = None, feature_names = iris.feature_names, class_names = iris.target_names ) graph = graphviz.Source(dot_data) # 能繪制樹節點的一個接口graph.render("iris") # 存成pdf圖
tree.export_graphviz 參數 | |
---|---|
feature_names | 特征列表list,和訓練時的特征列表排列順序對其即可 |
class_names | 類別l列表ist,和訓練時的label列表排列順序對其即可 |
filled | False/True,會依據criterion的純度將節點顯示成不同的顏色 |
value中的值顯示的是各個類別樣本的數量(二分類就是[負樣本數,正樣本數])
2.3.2交叉驗證實驗
def demo2():# n-折實驗iris = load_iris()iris_feature = iris.data # 與demo1中的x,y是同樣的數據iris_target = iris.target# 數據集合劃分參數:train_x, test_x, train_y, test_y = train_test_split(iris_feature,iris_target,test_size = 0.2, random_state = 1)dt_model = DecisionTreeClassifier()dt_model.fit(train_x, train_y) # 模型訓練predict_y = dt_model.predict(test_x) # 模型預測輸出# score = dt_model.score(test_x,test_y) # 模型測試性能: 輸入:feature_test,target_test , 輸出acc# print(score) # 性能指標print("label: \n{0}".format(test_y[:5])) # 輸出前5個labelprint("predict: \n{0}".format(predict_y[:5])) # 輸出前5個label# sklearn 內置acc, recall, precision統計接口print("test acc: %.3f"%(accuracy_score(test_y, predict_y)))# print("test recall: %.3f"%(recall_score(test_y, predict_y))) # 多類別統計召回率需要指定平均方式# print("test precision: %.3f"%(precision_score(test_y, predict_y))) # 多類別統計準確率需要指定平均方式
2.3.3超參數搜索
def model_search(feas,labels):# 模型參數選擇,全數據5折交叉驗證,出結果min_impurity_de_entropy = np.linspace(0, 0.01, 10) # 純度增益下界,劃分后降低量少于這個值,將不進行分裂min_impurity_split_entropy = np.linspace(0, 0.4, 10) # 當前節點純度小于這個值將不分裂,較高版本中已經取消這個參數max_depth_entropy = np.arange(1,11) # 決策樹的深度# param_grid = {"criterion" : ["entropy"], "min_impurity_decrease" : min_impurity_de_entropy,"max_depth" : max_depth_entropy,"min_impurity_split" : min_impurity_split_entropy }param_grid = {"criterion" : ["entropy"], "max_depth" : max_depth_entropy, "min_impurity_split" : min_impurity_split_entropy }clf = GridSearchCV(DecisionTreeClassifier(), param_grid, cv = 5) # 遍歷以上超參, 通過多次五折交叉驗證得出最優的參數選擇clf.fit(feas, label) print("best param:", clf.best_params_) # 輸出最優參數選擇print("best score:", clf.best_score_)
2.3.4模型保存與導入
模型保存
joblib.dump(clf,"./dtc_model.pkl")
模型導入
model_path = “./dtc_model.pkl”
clf = joblib.load(model_path)
2.3.5固定隨機數種子
1.五折交叉驗證,數據集劃分隨機數設置 random_state
train_test_split(feas, labels, test_size = 0.2, random_state = 1 )
2.模型隨機數設置 andom_state
DecisionTreeClassifier(random_state = 1)
參考資料
1.官網類接口說明:
https://scikit-learn.org/dev/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier
可視化接口說明https://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html
2.決策樹超參數調參技巧:https://www.jianshu.com/p/230be18b08c2
3.Sklearn.metrics 簡介及應用示例:https://blog.csdn.net/Yqq19950707/article/details/90169913
4.sklearn的train_test_split()各函數參數含義解釋(非常全):https://www.cnblogs.com/Yanjy-OnlyOne/p/11288098.html
5.sklearn.tree.DecisionTreeClassifier 詳細說明:https://www.jianshu.com/p/8f3f1e706f11
6.使用scikit-learn中的metrics以及DecisionTreeClassifier重做《機器學習實戰》中的隱形眼鏡分類問題:http://keyblog.cn/article-235.html
7.決策樹算法:https://www.cnblogs.com/yanqiang/p/11600569.html