MachineLearning(7)-決策樹基礎+sklearn.DecisionTreeClassifier簡單實踐

sklearn.DecisionTreeClassifier決策樹簡單使用

  • 1.決策樹算法基礎
  • 2.sklearn.DecisionTreeClassifier簡單實踐
    • 2.1 決策樹類
    • 2.3 決策樹構建
      • 2.3.1全數據集擬合,決策樹可視化
      • 2.3.2交叉驗證實驗
      • 2.3.3超參數搜索
      • 2.3.4模型保存與導入
      • 2.3.5固定隨機數種子
  • 參考資料

1.決策樹算法基礎

決策樹模型可以用來做 回歸/分類 任務。

每次選擇一個屬性/特征,依據特征的閾值,將特征空間劃分為 與 坐標軸平行的一些決策區域。如果是分類問題,每個決策區域的類別為該該區域中多數樣本的類別;如果為回歸問題,每個決策區域的回歸值為該區域中所有樣本值的均值。

決策樹復雜程度 依賴于 特征空間的幾何形狀。根節點->葉子節點的一條路徑產生一條決策規則。

決策樹最大優點:可解釋性強
決策樹最大缺點:不是分類正確率最高的模型

決策樹的學習是一個NP-Complete問題,所以實際中使用啟發性的規則來構建決策樹。
step1:選最好的特征來劃分數據集
step2:對上一步劃分的子集重復步驟1,直至停止條件(節點純度/分裂增益/樹深度)

不同的特征衡量標準,產生了不同的決策樹生成算法:

算法最優特征選擇標準
ID3信息增益:Gain(A)=H(D)?H(D∥A)Gain(A)=H(D)-H(D\|A)Gain(A)=H(D)?H(DA)
C4.5信息增益率:GainRatio(A)=Gain(A)/Split(A)GainRatio(A)=Gain(A)/Split(A)GainRatio(A)=Gain(A)/Split(A)
CARTgini指數增益:Gini(D)?Gini(D∥A)Gini(D)-Gini(D\|A)Gini(D)?Gini(DA)

k個類別,類別分布的gini 指數如下,gini指數越大,樣本的不確定性越大:
Gini(D)=∑k=1Kpk(1?pk)=1?∑k=1Kpk2Gini(D) =\sum_{k=1}^Kp_k(1-p_k)=1-\sum_{k=1}^Kp_k^2Gini(D)=k=1K?pk?(1?pk?)=1?k=1K?pk2?

CART – Classification and Regression Trees 的縮寫1984年提出的一個特征選擇算法,對特征進行是/否判斷,生成一棵二叉樹。且每次選擇完特征后不對特征進行剔除操作,所有同一條決策規則上可能出現重復特征的情況。

2.sklearn.DecisionTreeClassifier簡單實踐

Scikit-learn(sklearn)是機器學習中常用的第三方模塊,其建立在NumPy、Scipy、MatPlotLib之上,包括了回歸,降維,分類,聚類方法。

sklearn 通過以下兩個類實現了 決策分類樹決策回歸樹

sklearn 實現了ID3和Cart 算法,criterion默認為"gini"系數,對應為CART算法。還可設置為"entropy",對應為ID3。(計算機最擅長做的事:規則重復計算,sklearn通過對每個特征的每個切分點計算信息增益/gini增益,得到當前數據集合最優的特征及最優劃分點)

2.1 決策樹類

sklearn.tree.DecisionTreeClassifier(criterion=’gini’*,splitter=’best’, max_depth=None, 
min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0,
max_features=None, random_state=None, max_leaf_nodes=None, 
min_impurity_decrease=0.0, min_impurity_split=None, class_weight=None, presort=False)
DecisionTreeRegressor(criterion=’mse’, splitter=’best’, 
max_depth=None, min_samples_split=2, min_samples_leaf=1, 
min_weight_fraction_leaf=0.0, max_features=None, random_state=None, 
max_leaf_nodes=None, min_impurity_decrease=0.0, 
min_impurity_split=None, presort=False)
Criterion選擇屬性的準則–gini–cart算法
splitter特征劃分點的選擇策略:best 特征的所有劃分點中找最優
random 部分劃分點中找最優
max_depth決策樹的最大深度,none/int 限制/不限制決策樹的深度
min_samples_split節點 繼續劃分需要的最小樣本數,如果少于這個數,節點將不再劃分
min_samples_leaf限制葉子節點的最少樣本數量,如果葉子節點的樣本數量過少會被剪枝
min_weight_fraction_leaf葉子節點的剪枝規則
max_features選取用于分類的特征的數量
random_state隨機數生成的一些規則、
max_leaf_nodes限制葉子節點的數量,防止過擬合
min_impurity_decrease表示結點減少的最小不純度,控制節點的繼續分割規律
min_impurity_split表示結點劃分的最小不純度,控制節點的繼續分割規律
class_weight設置各個類別的權重,針對類別不均衡的數據集使用
不適用于決策樹回歸
presort控制決策樹劃分的速度

2.3 決策樹構建

采用sklearn內置數據集鳶尾花數據集做實驗。

導入第三方庫

from sklearn import tree
from sklearn.tree import DecisionTreeClassifier 
from sklearn.datasets import load_iris
import graphviz
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score
import joblib
plt.switch_backend('agg')

2.3.1全數據集擬合,決策樹可視化

def demo1():# 全數據集擬合,決策樹可視化iris = load_iris()x, y = load_iris(return_X_y = True)                     # x[list]-feature,y[]-label clf = tree.DecisionTreeClassifier()                     # 實例化了一個類,可以指定類參數,定制決策樹模型clf = clf.fit(x,y)                                      # 訓練模型print("feature name ", iris.feature_names)              # 特征列表, 自己的數據可視化時,構建一個特征列表即可print("label name ",iris.target_names)                  # 類別列表dot_data = tree.export_graphviz(clf, out_file = None, feature_names = iris.feature_names, class_names = iris.target_names )    graph = graphviz.Source(dot_data)                        # 能繪制樹節點的一個接口graph.render("iris")                                     # 存成pdf圖
tree.export_graphviz 參數
feature_names特征列表list,和訓練時的特征列表排列順序對其即可
class_names類別l列表ist,和訓練時的label列表排列順序對其即可
filledFalse/True,會依據criterion的純度將節點顯示成不同的顏色

value中的值顯示的是各個類別樣本的數量(二分類就是[負樣本數,正樣本數])

在這里插入圖片描述

2.3.2交叉驗證實驗

def demo2():# n-折實驗iris = load_iris()iris_feature = iris.data                                # 與demo1中的x,y是同樣的數據iris_target = iris.target# 數據集合劃分參數:train_x, test_x, train_y, test_y = train_test_split(iris_feature,iris_target,test_size = 0.2, random_state = 1)dt_model = DecisionTreeClassifier()dt_model.fit(train_x, train_y)                          # 模型訓練predict_y = dt_model.predict(test_x)                    # 模型預測輸出# score = dt_model.score(test_x,test_y)                 # 模型測試性能: 輸入:feature_test,target_test , 輸出acc# print(score)                                          # 性能指標print("label: \n{0}".format(test_y[:5]))                # 輸出前5個labelprint("predict: \n{0}".format(predict_y[:5]))           # 輸出前5個label# sklearn 內置acc, recall, precision統計接口print("test acc: %.3f"%(accuracy_score(test_y, predict_y)))# print("test recall: %.3f"%(recall_score(test_y, predict_y)))  # 多類別統計召回率需要指定平均方式# print("test precision: %.3f"%(precision_score(test_y, predict_y))) # 多類別統計準確率需要指定平均方式

2.3.3超參數搜索

def model_search(feas,labels):# 模型參數選擇,全數據5折交叉驗證,出結果min_impurity_de_entropy = np.linspace(0, 0.01, 10)      # 純度增益下界,劃分后降低量少于這個值,將不進行分裂min_impurity_split_entropy = np.linspace(0, 0.4, 10)    # 當前節點純度小于這個值將不分裂,較高版本中已經取消這個參數max_depth_entropy = np.arange(1,11)                     # 決策樹的深度# param_grid = {"criterion" : ["entropy"], "min_impurity_decrease" : min_impurity_de_entropy,"max_depth" : max_depth_entropy,"min_impurity_split" :  min_impurity_split_entropy }param_grid = {"criterion" : ["entropy"], "max_depth" : max_depth_entropy, "min_impurity_split" :  min_impurity_split_entropy }clf = GridSearchCV(DecisionTreeClassifier(), param_grid, cv = 5)  # 遍歷以上超參, 通過多次五折交叉驗證得出最優的參數選擇clf.fit(feas, label)                                    print("best param:", clf.best_params_)                  # 輸出最優參數選擇print("best score:", clf.best_score_)

2.3.4模型保存與導入

模型保存

joblib.dump(clf,"./dtc_model.pkl")

模型導入

model_path = “./dtc_model.pkl”
clf = joblib.load(model_path)

2.3.5固定隨機數種子

1.五折交叉驗證,數據集劃分隨機數設置 random_state

train_test_split(feas, labels, test_size = 0.2, random_state = 1 )

2.模型隨機數設置 andom_state

DecisionTreeClassifier(random_state = 1)

參考資料

1.官網類接口說明:
https://scikit-learn.org/dev/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier

可視化接口說明https://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html

2.決策樹超參數調參技巧:https://www.jianshu.com/p/230be18b08c2

3.Sklearn.metrics 簡介及應用示例:https://blog.csdn.net/Yqq19950707/article/details/90169913

4.sklearn的train_test_split()各函數參數含義解釋(非常全):https://www.cnblogs.com/Yanjy-OnlyOne/p/11288098.html

5.sklearn.tree.DecisionTreeClassifier 詳細說明:https://www.jianshu.com/p/8f3f1e706f11

6.使用scikit-learn中的metrics以及DecisionTreeClassifier重做《機器學習實戰》中的隱形眼鏡分類問題:http://keyblog.cn/article-235.html

7.決策樹算法:https://www.cnblogs.com/yanqiang/p/11600569.html

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/444969.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/444969.shtml
英文地址,請注明出處:http://en.pswp.cn/news/444969.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

游戲服務器體系結構

本文描述了一個我所設計的游戲服務器體系結構,其目的是實現游戲服務器的動態負載平衡,將對象從繁忙的服務器轉移到相對空閑的服務器中.設計并沒有經過具體的測試與驗證,僅僅是將自己目前的一些想法記錄下來.隨著新構思的出現,可能會有所變化. 以下是服務器的邏輯視圖,其中忽略…

游戲服務器架構探討

要描述一項技術或是一個行業,一般都會從其最古老的歷史開始說起,我本也想按著這個套路走,無奈本人乃一八零后小輩,沒有經歷過那些苦澀的卻令人羨慕的單機游戲開發,也沒有響當當的拿的出手的優秀作品,所以也…

leetcode72 編輯距離

給定兩個單詞 word1 和 word2,計算出將 word1 轉換成 word2 所使用的最少操作數 。 你可以對一個單詞進行如下三種操作: 插入一個字符 刪除一個字符 替換一個字符 示例 1: 輸入: word1 "horse", word2 "ros" 輸出: 3 解釋: ho…

即時通訊系統架構

有過幾款IM系統開發經歷,目前有一款還在線上跑著。準備簡單地介紹一下大型商業應用的IM系統的架構。設計這種架構比較重要的一點是低耦合,把整個系統設計成多個相互分離的子系統。我把整個系統分成下面幾個部分:(1)狀態…

leetcode303 區域和檢索

給定一個整數數組 nums,求出數組從索引 i 到 j (i ≤ j) 范圍內元素的總和,包含 i, j 兩點。 示例: 給定 nums [-2, 0, 3, -5, 2, -1],求和函數為 sumRange() sumRange(0, 2) -> 1 sumRange(2, 5) -> -1 sumRange(0,…

算法(24)-股票買賣

股票買賣1.動態規劃框架LeetCode-121 一次買賣LeetCode-122 不限次數LeetCode-309 不限次數冷凍期LeetCode-714 不限次數手續費LeetCode-123 兩次買賣LeetCode-188 k次買賣2.貪心特解LeetCode-121 一次買賣LeetCode-122 不限次數解題思路參考buladong解題,詳細信息可…

網絡游戲的客戶端同步問題 .

有關位置同步的方案實際上已經比較成熟,網上也有比較多的資料可供參考。在《帶寬限制下的視覺實體屬性傳播》一文中,作者也簡單提到了位置同步方案的構造過程,但涉及到細節的地方沒有深入,這里專門針對這一主題做些回顧。 最直接的…

leetcode319 燈泡的開關

初始時有 n 個燈泡關閉。 第 1 輪,你打開所有的燈泡。 第 2 輪,每兩個燈泡你關閉一次。 第 3 輪,每三個燈泡切換一次開關(如果關閉則開啟,如果開啟則關閉)。第 i 輪,每 i 個燈泡切換一次開關。 …

網游服務器端設計思考:心跳設計

網絡游戲服務器的主要作用是模擬整個游戲世界,客戶端用過網絡連接把一些信息數據發給服務器,在操作合法的情況下,更新服務器上該客戶端對應的player實體、所在場景等,并把這些操作及其影響廣播出去。讓別的客戶端能顯示這些操作。…

算法(25)-括號

各種括號1.LeetCode-22 括號生成--各種括號排列組合2.LeetCode-20 有效括號(是否)--堆棧3.LeetCode-32 最長有效括號(長度)--dp4.LeetCode-301刪除無效括號 --多種刪除方式1.LeetCode-22 括號生成–各種括號排列組合 數字 n 代表生成括號的對數,請你設計一個函數&a…

(二十)深入淺出TCPIP之epoll的一些思考

Epoll基本介紹 在linux的網絡編程中,很長的時間都在使用select來做事件觸發。在linux新的內核中,有了一種替換它的機制,就是epoll。相比于 select,epoll最大的好處在于它不會隨著監聽fd數目的增長而降低效率。因為在內核中的select實現中,它是采用輪詢來處理的,輪詢的fd…

leetcode542 01矩陣

給定一個由 0 和 1 組成的矩陣,找出每個元素到最近的 0 的距離。 兩個相鄰元素間的距離為 1 。 示例 1: 輸入: 0 0 0 0 1 0 0 0 0 輸出: 0 0 0 0 1 0 0 0 0 示例 2: 輸入: 0 0 0 0 1 0 1 1 1 輸出: 0 0 0 0 1 0 1 2 1 注意: 給定矩陣的元素個數不超過 10000。…

RPC、RMI與MOM與組播 通信原理 .

遠程過程調用(RPC): 即對遠程站點機上的過程進行調用。當站點機A上的一個進程調用另一個站點機上的過程時,A上的調用進程掛起,B上的被調用過程執行,并將結果返回給調用進程,使調用進程繼續執行【…

網關服務器 .

之前想著要把什么什么給寫一下,每次都太懶了,都是想起了才來寫一下。今天只討論游戲服務器的網關服務器。 1.轉發 轉發客戶端和服務器間的消息,網關將場景、會話、數據、名字、平臺等服務器的數據轉發給客戶端,接收客戶端的數據&a…

算法(26)-最長系列

最長系列1.LeetCode-32 最長有效括號--子串2.LeetCode-300 最長上升子序列--長度3.LeetCode-32 最長回文子串--是什么5.LeetCode-512 最長回文子序列--長度6.LeetCode-1143 最長公共子序列--長度6.LeetCode-128 最長連續序列--長度7.LeetCode-14 最長公共前綴-字符串8.劍指offe…

一個簡單的游戲服務器框架 .

最近一段時間不是很忙,就寫了一個自己的游戲服務器框架雛形,很多地方還不夠完善,但是基本上也算是能夠跑起來了。我先從上層結構說起,一直到實現細節吧,想起什么就寫什么。 第一部分 服務器邏輯 服務器這邊簡單的分為三…

游戲登陸流程 .

當公司有很多游戲的時候,那么公司往往會有一個統一的賬號管理平臺,就就像盛大通行證、網易通行證,戰網平臺,這些平臺統一管理游戲的賬號數據。 打個比方,現在我們玩星辰變,那么玩家登陸游戲的時候…

leetcode97 交錯字符串

給定三個字符串 s1, s2, s3, 驗證 s3 是否是由 s1 和 s2 交錯組成的。 示例 1: 輸入: s1 "aabcc", s2 "dbbca", s3 "aadbbcbcac" 輸出: true 示例 2: 輸入: s1 "aabcc", s2 "dbbca", s3 "aadbbbaccc" 輸…

算法(27)-最大系列

最大系列1.LeetCode-239 滑動窗口的最大值2.LeetCode-53 連續子數組的最大和3.LeetCode-152 乘積最大的子數組。4.劍指 Offer 14- I. 剪繩子為k個整數段,使各個段成績最大1.dp數學推導1.LeetCode-239 滑動窗口的最大值 窗口由左往右最大值數組Left,和由…

mysql數據庫表的導入導出

MySQL寫入數據通常用insert語句,如 復制代碼 代碼如下: insert into person values(張三,20),(李四,21),(王五,70)…; 但有時為了更快速地插入大批量數據或…