3 決策樹
3.1 需求規格說明
【問題描述】
ID3算法是一種貪心算法,用來構造決策樹。ID3算法起源于概念學習系統(CLS),以信息熵的下降速度為選取測試屬性的標準,即在每個節點選取還尚未被用來劃分的具有最高信息增益的屬性作為劃分標準,然后繼續這個過程,直到生成的決策樹能完美分類訓練樣例。
具體方法是:從根結點開始,對結點計算所有可能的特征的信息增益,選擇信息增益最大的特征作為結點的特征,由該特征的不同取值建立子結點;再對子結點遞歸的調用以上方法,構建決策樹;知道所以特征的信息增益均很小或沒有特征可以選擇為止。
參考論文:Quinlan J R. Induction of Decision Trees[J]. Machine Learning, 1986, 1(1): 81-106.
參考資料:決策樹—ID3、C4.5、CART_決策樹流程圖-CSDN博客
【數據說明】
DT_data.csv為樣本數據,共14條記錄。每一條記錄共4維特征,分別為Weather(天氣), Temperature(溫度),Humidity(濕度),Wind(風力);其中Date(約會)為標簽列。
【基本要求】(60%)
(1)根據樣本數據,建立決策樹。
(2)輸入測試數據,得到預測是否約會(yes/no)。
【提高要求】(40%)
(1)對決策樹(ID3)中特征節點分類選擇指標(信息增益)進行優化,選擇信息增益率作為決策樹(C4.5)中特征節點分類指標,。
決策樹(C4.5)及信息增益率參考資料:數據挖掘--決策樹C4.5算法(例題)_c4.5算法例題-CSDN博客
3.2 總體分析與設計
(1)設計思想
①存儲結構
在決策樹的實現中,我采用了以下存儲結構:
樣本數據結構(Sample):定義了一個結構體Sample來存儲每個樣本的特征和標簽。特征包括天氣(Weather)、溫度(Temperature)、濕度(Humidity)和風力(Wind),標簽為約會結果(Date),表示為yes或no。
決策樹節點結構(TreeNode):定義了結構體TreeNode來表示決策樹的節點。每個節點包含一個屬性(attribute),用于分裂的屬性名;一個分支映射(branches),存儲該屬性不同取值對應的子節點;以及一個標簽(label),表示葉子節點的分類結果。
決策樹類(DecisionTree):定義了一個類DecisionTree,包含根節點(root)和算法類型(algorithmType)。該類提供構建決策樹(buildTree)、預測(predict)和可視化(visualizeTree)等方法。
②主要算法思想
決策樹的構建主要基于ID3和C4.5算法。這兩種算法都是利用信息增益或信息增益率作為屬性選擇的標準,遞歸地構建決策樹,直到所有樣本都能被正確分類或者沒有更多的屬性可以用來進一步分裂。
ID3算法:以信息熵的下降速度為選取測試屬性的標準,選擇信息增益最大的屬性進行分裂。
C4.5算法:在ID3的基礎上進行了優化,使用信息增益率來選擇屬性,以減少屬性值中某個類別占比過大時的影響。
(2)設計表示
決策樹構建的UML類圖如圖3.2-1所示。
圖3.2-1 程序UML類圖
這里Question_3類主要是為了讀取數據并預生成標簽列,而Decision類則主要進行的是決策樹的構建以及其可視化的工作。
(3)詳細設計表示
決策樹構建的程序流程圖如圖3.2-2所示。
圖3.2-2 決策時實現流程圖
該程序的流程步驟如下所示:
1.初始化:構造函數中初始化根節點和算法類型。
2.屬性選擇:計算每個屬性的信息增益或信息增益率,并選擇最佳屬性。
3.節點分裂:根據選擇的屬性值分裂節點,遞歸構建子樹。
4.葉子節點生成:當所有樣本屬于同一類別或沒有更多屬性時,生成葉子節點。
5.預測:從根節點開始,根據樣本的特征值遞歸查找對應的分支,直到到達葉子節點,得到預測結果。
6.可視化:遞歸遍歷決策樹,使用QT的GraphicsView控件實現繪制節點和分支,展示決策樹結構。
由于本題中,采取了兩種不同的算法,故這里我還繪制了構建決策樹的流程圖,如圖3.2-3所示。
圖3.2-3 構建決策樹流程圖
3.3 編碼
【問題1】:在生成決策樹和剪枝的時候指針容易丟失的問題
【解決方法】:在解決生成決策樹中指針容易丟失的問題時,首先需要確保正確管理內存。動態分配內存時,使用new關鍵字進行分配,而在不再需要時使用delete進行釋放。此外,釋放內存后,將指針設置為nullptr,以避免野指針問題。另一方面,在實際調試過程中,通過手動繪圖,監控指針指向,保證生成決策樹和剪枝時要邏輯合理。反復調試保證代碼正確。
【問題2】:樹的結點只記錄了屬性,無法獲取結點的特征屬性,可視化效果較差
【解決方法】:使用map存儲該結點的特征屬性和特征屬性取值,每個結點存儲上一層分類的值以及下一層分類的最優屬性,通過存儲上一層分類的值,我們可以了解節點的父節點是什么,從而構建出完整的樹形結構。而下一層分類的最優屬性則可以幫助我們了解節點的子節點應該具備哪些特征,以便進一步展開子節點。在此基礎上,完成樹結構的繪制,實現可視化,能清晰明了的告訴用戶在哪種情況下最有可能去約會。
3.4 程序及算法分析
①使用說明
打開程序后,即可出現初始化程序樣式,如圖3.4-1所示。
圖3.4-1 初始化程序
接下來,可以點擊相應的算法,并點擊“選擇CSV文件”將提供的測試CSV輸入并進行訓練,此時程序會調用visualizeTree函數會將訓練結果同步可視化到界面中顯示,如圖3.4-2所示。
圖3.4-2 選擇決策樹訓練集文件
這里我應用了ID3算法為例子,選擇不同的選項后,再點擊“分析結果”即可得到相應的輸出結果,如圖3.4-3所示。
圖3.4-3 輸出結果
同樣,如果選擇“C45”算法則會調用C45算法進行決策樹的構建,這里需要說明的是C4.5算法和ID3算法的一個顯著差異就是應用了信息熵增益率來取代信息熵增益,這樣可以顯著消除眾數分布的影響,并且C45算法所構建的決策樹樣式也是和ID3算法有較大差異的,如圖3.4-4所示。
圖3.4-4 決策樹構建差別
而應用C4.5算法進行預測時顯然會得到更加精確的預測結果,如圖3.4-5所示。
圖3.4-5 C4.5算法預測結果
3.5 小結
決策樹是一種常用的監督學習算法,主要用于分類和回歸問題。它的每個內部節點代表一個屬性或特征,每個分支代表一個決策規則,每個葉節點則代表一個結果或決策。決策樹的優點在于它的模型結果易于理解和解釋,因為決策過程類似于人類的決策過程。此外,它需要的數據預處理較少,不需要進行歸一化或標準化,且可以處理數值和分類數據。
在ID3算法中,我通過計算信息增益來選擇最佳屬性進行節點分裂。信息增益是衡量屬性對于分類結果的信息量的增加程度,它幫助我們確定哪個屬性最能減少分類的不確定性。然而,我發現單一地使用信息增益作為分裂標準有時會導致選擇偏向于那些值較多的屬性。為了改進這一問題,我在C4.5算法中引入了信息增益率。這一改進讓我更加全面地考慮了屬性的選擇。信息增益率不僅考慮了信息增益的大小,還考慮了分裂信息值的大小。這樣,算法在選擇分裂屬性時,能夠更準確地衡量分裂后的信息量變化。而在實現ID3和C4.5決策樹算法的過程中,我遇到了許多困難。決策樹的復雜性不在算法邏輯上,而是涉及到對數據結構的深入理解。指針的正確使用對于決策樹的構建至關重要,一旦出錯,可能會導致樹的結構混亂或內存泄漏。最棘手的是指針丟失問題。在動態構建決策樹時,我經常遇到指針指向無效內存地址或未初始化的情況。為了解決這個問題,我仔細檢查了內存管理代碼,并確保在使用指針之前進行了正確的初始化。同時,我也加強了對new和delete操作符的使用,確保正確地分配和釋放內存。
但是本題程序還有待改進,比如繪制時的標記與結點圖案重合的情況,我還需要進一步詳細的處理。而且未討論如何處理連續數值型數據以構建更精確的決策樹,例如引入CART算法更好地處理連續屬性。當然,算法都是有缺陷的,需要針對不同的問題選擇最適用的算法才能將問題解決,同時優化算法的工作也需要繼續努力完成。
3.6 附錄
//決策樹構建
// 構造函數
DecisionTree::DecisionTree(AlgorithmType algoType) : root(nullptr), algorithmType(algoType) {}
// 析構函數
DecisionTree::~DecisionTree() {freeTree(root);
}
// 釋放內存
void DecisionTree::freeTree(TreeNode* node) {if (!node) return;for (auto& branch : node->branches) {freeTree(branch.second);}delete node;
}
// 計算信息熵
double DecisionTree::calculateEntropy(const std::vector<Sample>& samples) const {std::map<std::string, int> labelCount;for (const auto& sample : samples) {labelCount[sample.date]++;}double entropy = 0.0;int total = samples.size();for (const auto& pair : labelCount) {double p = static_cast<double>(pair.second) / total;entropy -= p * log2(p);}return entropy;
}
// 按屬性分組
std::map<std::string, std::vector<Sample>> DecisionTree::splitByAttribute(const std::vector<Sample>& samples, const std::string& attribute) const {std::map<std::string, std::vector<Sample>> groups;for (const auto& sample : samples) {std::string value;if (attribute == "Weather") value = sample.weather;if (attribute == "Temperature") value = sample.temperature;if (attribute == "Humidity") value = sample.humidity;if (attribute == "Wind") value = sample.wind;groups[value].push_back(sample);}return groups;
}
// 計算信息增益
double DecisionTree::calculateGain(const std::vector<Sample>& samples, const std::string& attribute) const {auto groups = splitByAttribute(samples, attribute);double totalEntropy = calculateEntropy(samples);double weightedEntropy = 0.0;int totalSamples = samples.size();for (const auto& pair : groups) {double p = static_cast<double>(pair.second.size()) / totalSamples;weightedEntropy += p * calculateEntropy(pair.second);}return totalEntropy - weightedEntropy;
}
//計算分裂信息
double DecisionTree::calculateSplitInfo(const std::vector<Sample>& samples, const std::string& attribute) const {auto groups = splitByAttribute(samples, attribute);double splitInfo = 0.0;int totalSamples = samples.size();for (const auto& pair : groups) {double p = static_cast<double>(pair.second.size()) / totalSamples;if (p > 0) {splitInfo -= p * log2(p);}}return splitInfo;
}
//計算信息增益率
double DecisionTree::calculateGainRatio(const std::vector<Sample>& samples, const std::string& attribute) const {double gain = calculateGain(samples, attribute);double splitInfo = calculateSplitInfo(samples, attribute);// 避免分母為0if (splitInfo == 0) return 0.0;return gain / splitInfo;
}
// 構建決策樹
TreeNode* DecisionTree::buildTree(const std::vector<Sample>& samples, const std::set<std::string>& attributes) {// 如果樣本全屬于一個類別std::string firstLabel = samples.front().date;bool allSame = std::all_of(samples.begin(), samples.end(), [&firstLabel](const Sample& s) {return s.date == firstLabel;});if (allSame) {TreeNode* leaf = new TreeNode();leaf->label = firstLabel;return leaf;}// 如果沒有剩余屬性if (attributes.empty()) {TreeNode* leaf = new TreeNode();std::map<std::string, int> labelCount;for (const auto& sample : samples) {labelCount[sample.date]++;}leaf->label = std::max_element(labelCount.begin(), labelCount.end(),[](const auto& a, const auto& b) {return a.second < b.second;})->first;return leaf;}// 選擇最佳屬性double maxMetric = -1;std::string bestAttribute;for (const auto& attr : attributes) {double metric;if (algorithmType == ID3) {metric = calculateGain(samples, attr); // ID3 使用信息增益}else {metric = calculateGainRatio(samples, attr); // C4.5 使用信息增益率}if (metric > maxMetric) {maxMetric = metric;bestAttribute = attr;}}// 創建當前節點TreeNode* node = new TreeNode();node->attribute = bestAttribute;// 按最佳屬性劃分auto groups = splitByAttribute(samples, bestAttribute);std::set<std::string> remainingAttributes = attributes;remainingAttributes.erase(bestAttribute);for (const auto& pair : groups) {node->branches[pair.first] = buildTree(pair.second, remainingAttributes);}return node;
}
// 預測
std::string DecisionTree::predict(const Sample& sample, TreeNode* node) const {if (node->label != "") return node->label;std::string value;if (node->attribute == "Weather") value = sample.weather;if (node->attribute == "Temperature") value = sample.temperature;if (node->attribute == "Humidity") value = sample.humidity;if (node->attribute == "Wind") value = sample.wind;auto it = node->branches.find(value);if (it != node->branches.end()) {return predict(sample, it->second);}return "Unknown";
}
// 可視化決策樹
void DecisionTree::visualizeTree(QGraphicsScene* scene, TreeNode* node, int x, int y, int dx, int dy) const {if (!node) return;// 節點樣式int nodeRadius = 20; // 節點半徑QColor nodeColor = node->label.empty() ? Qt::yellow : Qt::green; // 葉子節點用綠色,非葉子用黃色QGraphicsEllipseItem* ellipse = scene->addEllipse(x - nodeRadius, y - nodeRadius, nodeRadius * 2, nodeRadius * 2, QPen(Qt::black), QBrush(nodeColor));// 節點文字QGraphicsTextItem* text = scene->addText(QString::fromStdString(node->label.empty() ? node->attribute : node->label));text->setDefaultTextColor(Qt::black);text->setFont(QFont("Arial", 10, QFont::Bold));text->setPos(x - text->boundingRect().width() / 2, y - text->boundingRect().height() / 2);// 如果是葉子節點,終止遞歸if (!node->label.empty()) return;// 動態調整分支間距int childCount = static_cast<int>(node->branches.size());if (childCount == 0) return;int totalWidth = (childCount - 1) * dx; // 子節點總寬度int startX = x - totalWidth / 2; // 子節點起始位置// 遍歷子節點,繪制分支線和遞歸調用int index = 0;for (const auto& branch : node->branches) {int childX = startX + index * dx; // 子節點X坐標int childY = y + dy; // 子節點Y坐標// 繪制分支線QGraphicsLineItem* line = scene->addLine(x, y + nodeRadius, childX, childY - nodeRadius, QPen(Qt::black, 2));// 繪制分支文字QGraphicsTextItem* branchText = scene->addText(QString::fromStdString(branch.first));branchText->setDefaultTextColor(Qt::darkGray);branchText->setFont(QFont("Arial", 10));branchText->setPos((x + childX) / 2 - branchText->boundingRect().width() / 2,(y + childY) / 2 - branchText->boundingRect().height() + 30);// 遞歸繪制子節點visualizeTree(scene, branch.second, childX, childY, dx / 2, dy);++index;}
}
項目源代碼:Data-structure-coursework/3/Question_3 at main · CUGLin/Data-structure-courseworkhttps://github.com/CUGLin/Data-structure-coursework/tree/main/3/Question_3https://github.com/CUGLin/Data-structure-coursework/tree/main/3/Question_3