【Sklearn】基于決策樹算法的數據分類預測(Excel可直接替換數據)
- 1.模型原理
- 1.1 模型原理
- 1.2 數學模型
- 2.模型參數
- 3.文件結構
- 4.Excel數據
- 5.下載地址
- 6.完整代碼
- 7.運行結果
1.模型原理
決策樹是一種基于樹狀結構的分類和回歸模型,它通過一系列的決策規則來將數據劃分為不同的類別或預測值。決策樹的模型原理和數學模型如下:
1.1 模型原理
決策樹的基本思想是從根節點開始,通過一系列的節點和分支,根據不同特征的取值將數據集劃分成不同的子集,直到達到葉節點,然后將每個葉節點分配到一個類別或預測值。決策樹的構建過程就是確定如何選擇特征以及如何劃分數據集的過程。
決策樹的主要步驟:
-
選擇特征: 從所有特征中選擇一個最佳特征作為當前節點的劃分特征,這個選擇通常基于某個度量(如信息增益、基尼系數)來評估不同特征的重要性。
-
劃分數據集: 根據選擇的特征,將數據集劃分成多個子集,每個子集對應一個分支。
-
遞歸構建: 對每個子集遞歸地重復步驟1和步驟2,直到滿足停止條件,如達到最大深度、樣本數不足等。
-
葉節點賦值: 在構建過程中,根據訓練數據的真實標簽或均值等方式,為葉節點分配類別或預測值。
1.2 數學模型
在數學上,決策樹可以表示為一個樹狀結構,其中每個節點表示一個特征的劃分,每個分支代表一個特征取值的分支。具體來說,每個節點可以由以下元素定義:
-
劃分特征: 表示選擇哪個特征進行劃分。
-
劃分閾值: 表示在劃分特征上的取值閾值,用于將數據分配到不同的子集。
-
葉節點值: 表示在達到葉節點時所預測的類別或預測值。
在決策樹的訓練過程中,我們尋找最優的劃分特征和劃分閾值,以最大程度地減少不純度(或最大程度地增加信息增益、降低基尼指數等)。
數學模型可以用以下形式表示:
f ( x ) = { C 1 , if? x belongs?to?region? R 1 C 2 , if? x belongs?to?region? R 2 ? ? C k , if? x belongs?to?region? R k f(x) = \begin{cases} C_1, & \text{if } x \text{ belongs to region } R_1 \\ C_2, & \text{if } x \text{ belongs to region } R_2 \\ \vdots & \vdots \\ C_k, & \text{if } x \text{ belongs to region } R_k \end{cases} f(x)=? ? ??C1?,C2?,?Ck?,?if?x?belongs?to?region?R1?if?x?belongs?to?region?R2??if?x?belongs?to?region?Rk??
其中, C i C_i Ci?表示葉節點的類別或預測值, R i R_i Ri?表示根據特征劃分得到的子集。
總之,決策樹通過遞歸地選擇最佳特征和閾值,將數據集劃分為多個子集,最終形成一個樹狀結構的模型,用于分類或回歸預測。
2.模型參數
DecisionTreeClassifier
是scikit-learn
庫中用于構建決策樹分類器的類。它具有多個參數,用于調整決策樹的構建和性能。以下是一些常用的參數及其說明:
-
criterion: 衡量分割質量的標準。可以是"gini"(基尼系數)或"entropy"(信息熵)。默認為"gini"。
-
splitter: 用于選擇節點分割的策略。可以是"best"(選擇最優的分割)或"random"(隨機選擇分割)。默認為"best"。
-
max_depth: 決策樹的最大深度。如果為None,則節點會擴展,直到所有葉節點都是純的,或者包含少于min_samples_split個樣本。默認為None。
-
min_samples_split: 節點分裂所需的最小樣本數。如果一個節點的樣本數少于這個值,就不會再分裂。默認為2。
-
min_samples_leaf: 葉節點所需的最小樣本數。如果一個葉節點的樣本數少于這個值,可以合并到一個葉節點。默認為1。
-
min_weight_fraction_leaf: 葉節點所需的最小權重分數總和。與min_samples_leaf類似,但是使用樣本權重而不是樣本數量。默認為0。
-
max_features: 尋找最佳分割時要考慮的特征數量。可以是整數、浮點數、字符串或None。默認為None。
-
random_state: 隨機數生成器的種子,用于隨機性控制。默認為None。
-
max_leaf_nodes: 最大葉節點數。如果設置,算法會通過去掉最不重要的葉節點來合并其他節點。默認為None。
-
min_impurity_decrease: 分割需要達到的最小不純度減少量。如果分割不會降低不純度超過這個閾值,則節點將被視為葉節點。默認為0。
-
class_weight: 類別權重,用于處理不平衡數據集。
這些是DecisionTreeClassifier
中一些常用的參數。根據你的數據和問題,你可以根據需要調整這些參數的值,以獲得更好的模型性能。在實際應用中,根據數據的特點進行調參非常重要。
3.文件結構
iris.xlsx % 可替換數據集
Main.py % 主函數
4.Excel數據
5.下載地址
- 資源下載地址
6.完整代碼
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as snsdef decision_tree_classification(data_path, test_size=0.2, random_state=42):# 加載數據data = pd.read_excel(data_path)# 分割特征和標簽X = data.iloc[:, :-1] # 所有列除了最后一列y = data.iloc[:, -1] # 最后一列# 劃分訓練集和測試集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)# 創建決策樹分類器# 1. ** criterion: ** 衡量分割質量的標準。可以是"gini"(基尼系數)或"entropy"(信息熵)。默認為"gini"。# 2. ** splitter: ** 用于選擇節點分割的策略。可以是"best"(選擇最優的分割)或"random"(隨機選擇分割)。默認為"best"。# 3. ** max_depth: ** 決策樹的最大深度。如果為None,則節點會擴展,直到所有葉節點都是純的,或者包含少于min_samples_split個樣本。默認為None。# 4. ** min_samples_split: ** 節點分裂所需的最小樣本數。如果一個節點的樣本數少于這個值,就不會再分裂。默認為2。# 5. ** min_samples_leaf: ** 葉節點所需的最小樣本數。如果一個葉節點的樣本數少于這個值,可以合并到一個葉節點。默認為1。# 6. ** min_weight_fraction_leaf: ** 葉節點所需的最小權重分數總和。與min_samples_leaf類似,但是使用樣本權重而不是樣本數量。默認為0。# 7. ** max_features: ** 尋找最佳分割時要考慮的特征數量。可以是整數、浮點數、字符串或None。默認為None。# 8. ** random_state: ** 隨機數生成器的種子,用于隨機性控制。默認為None。# 9. ** max_leaf_nodes: ** 最大葉節點數。如果設置,算法會通過去掉最不重要的葉節點來合并其他節點。默認為None。# 10. ** min_impurity_decrease: ** 分割需要達到的最小不純度減少量。如果分割不會降低不純度超過這個閾值,則節點將被視為葉節點。默認為0。# 11. ** class_weight: ** 類別權重,用于處理不平衡數據集。# 使用gini作為分割標準,設置最大深度為3,最小樣本數為5model = DecisionTreeClassifier(criterion='gini', max_depth=3, min_samples_split=5)# 在訓練集上訓練模型model.fit(X_train, y_train)# 在測試集上進行預測y_pred = model.predict(X_test)# 計算準確率accuracy = accuracy_score(y_test, y_pred)return confusion_matrix(y_test, y_pred), y_test.values, y_pred, accuracyif __name__ == "__main__":# 使用函數進行分類任務data_path = "iris.xlsx"confusion_mat, true_labels, predicted_labels, accuracy = decision_tree_classification(data_path)print("真實值:", true_labels)print("預測值:", predicted_labels)print("準確率:{:.2%}".format(accuracy))# 繪制混淆矩陣plt.figure(figsize=(8, 6))sns.heatmap(confusion_mat, annot=True, fmt="d", cmap="Blues")plt.title("Confusion Matrix")plt.xlabel("Predicted Labels")plt.ylabel("True Labels")plt.show()# 用圓圈表示真實值,用叉叉表示預測值# 繪制真實值與預測值的對比結果plt.figure(figsize=(10, 6))plt.plot(true_labels, 'o', label="True Labels")plt.plot(predicted_labels, 'x', label="Predicted Labels")plt.title("True Labels vs Predicted Labels")plt.xlabel("Sample Index")plt.ylabel("Label")plt.legend()plt.show()
7.運行結果