3 决策树
3.1 需求规格说明
【问题描述】
ID3算法是一种贪心算法,用来构造决策树。ID3算法起源于概念学习系统(CLS),以信息熵的下降速度为选取测试属性的标准,即在每个节点选取还尚未被用来划分的具有最高信息增益的属性作为划分标准,然后继续这个过程,直到生成的决策树能完美分类训练样例。
具体方法是:从根结点开始,对结点计算所有可能的特征的信息增益,选择信息增益最大的特征作为结点的特征,由该特征的不同取值建立子结点;再对子结点递归的调用以上方法,构建决策树;知道所以特征的信息增益均很小或没有特征可以选择为止。
参考论文:Quinlan J R. Induction of Decision Trees[J]. Machine Learning, 1986, 1(1): 81-106.
参考资料:决策树—ID3、C4.5、CART_决策树流程图-CSDN博客
【数据说明】
DT_data.csv为样本数据,共14条记录。每一条记录共4维特征,分别为Weather(天气), Temperature(温度),Humidity(湿度),Wind(风力);其中Date(约会)为标签列。
【基本要求】(60%)
(1)根据样本数据,建立决策树。
(2)输入测试数据,得到预测是否约会(yes/no)。
【提高要求】(40%)
(1)对决策树(ID3)中特征节点分类选择指标(信息增益)进行优化,选择信息增益率作为决策树(C4.5)中特征节点分类指标,。
决策树(C4.5)及信息增益率参考资料:数据挖掘--决策树C4.5算法(例题)_c4.5算法例题-CSDN博客
3.2 总体分析与设计
(1)设计思想
①存储结构
在决策树的实现中,我采用了以下存储结构:
样本数据结构(Sample):定义了一个结构体Sample来存储每个样本的特征和标签。特征包括天气(Weather)、温度(Temperature)、湿度(Humidity)和风力(Wind),标签为约会结果(Date),表示为yes或no。
决策树节点结构(TreeNode):定义了结构体TreeNode来表示决策树的节点。每个节点包含一个属性(attribute),用于分裂的属性名;一个分支映射(branches),存储该属性不同取值对应的子节点;以及一个标签(label),表示叶子节点的分类结果。
决策树类(DecisionTree):定义了一个类DecisionTree,包含根节点(root)和算法类型(algorithmType)。该类提供构建决策树(buildTree)、预测(predict)和可视化(visualizeTree)等方法。
②主要算法思想
决策树的构建主要基于ID3和C4.5算法。这两种算法都是利用信息增益或信息增益率作为属性选择的标准,递归地构建决策树,直到所有样本都能被正确分类或者没有更多的属性可以用来进一步分裂。
ID3算法:以信息熵的下降速度为选取测试属性的标准,选择信息增益最大的属性进行分裂。
C4.5算法:在ID3的基础上进行了优化,使用信息增益率来选择属性,以减少属性值中某个类别占比过大时的影响。
(2)设计表示
决策树构建的UML类图如图3.2-1所示。
图3.2-1 程序UML类图
这里Question_3类主要是为了读取数据并预生成标签列,而Decision类则主要进行的是决策树的构建以及其可视化的工作。
(3)详细设计表示
决策树构建的程序流程图如图3.2-2所示。
图3.2-2 决策时实现流程图
该程序的流程步骤如下所示:
1.初始化:构造函数中初始化根节点和算法类型。
2.属性选择:计算每个属性的信息增益或信息增益率,并选择最佳属性。
3.节点分裂:根据选择的属性值分裂节点,递归构建子树。
4.叶子节点生成:当所有样本属于同一类别或没有更多属性时,生成叶子节点。
5.预测:从根节点开始,根据样本的特征值递归查找对应的分支,直到到达叶子节点,得到预测结果。
6.可视化:递归遍历决策树,使用QT的GraphicsView控件实现绘制节点和分支,展示决策树结构。
由于本题中,采取了两种不同的算法,故这里我还绘制了构建决策树的流程图,如图3.2-3所示。
图3.2-3 构建决策树流程图
3.3 编码
【问题1】:在生成决策树和剪枝的时候指针容易丢失的问题
【解决方法】:在解决生成决策树中指针容易丢失的问题时,首先需要确保正确管理内存。动态分配内存时,使用new关键字进行分配,而在不再需要时使用delete进行释放。此外,释放内存后,将指针设置为nullptr,以避免野指针问题。另一方面,在实际调试过程中,通过手动绘图,监控指针指向,保证生成决策树和剪枝时要逻辑合理。反复调试保证代码正确。
【问题2】:树的结点只记录了属性,无法获取结点的特征属性,可视化效果较差
【解决方法】:使用map存储该结点的特征属性和特征属性取值,每个结点存储上一层分类的值以及下一层分类的最优属性,通过存储上一层分类的值,我们可以了解节点的父节点是什么,从而构建出完整的树形结构。而下一层分类的最优属性则可以帮助我们了解节点的子节点应该具备哪些特征,以便进一步展开子节点。在此基础上,完成树结构的绘制,实现可视化,能清晰明了的告诉用户在哪种情况下最有可能去约会。
3.4 程序及算法分析
①使用说明
打开程序后,即可出现初始化程序样式,如图3.4-1所示。
图3.4-1 初始化程序
接下来,可以点击相应的算法,并点击“选择CSV文件”将提供的测试CSV输入并进行训练,此时程序会调用visualizeTree函数会将训练结果同步可视化到界面中显示,如图3.4-2所示。
图3.4-2 选择决策树训练集文件
这里我应用了ID3算法为例子,选择不同的选项后,再点击“分析结果”即可得到相应的输出结果,如图3.4-3所示。
图3.4-3 输出结果
同样,如果选择“C45”算法则会调用C45算法进行决策树的构建,这里需要说明的是C4.5算法和ID3算法的一个显著差异就是应用了信息熵增益率来取代信息熵增益,这样可以显著消除众数分布的影响,并且C45算法所构建的决策树样式也是和ID3算法有较大差异的,如图3.4-4所示。
图3.4-4 决策树构建差别
而应用C4.5算法进行预测时显然会得到更加精确的预测结果,如图3.4-5所示。
图3.4-5 C4.5算法预测结果
3.5 小结
决策树是一种常用的监督学习算法,主要用于分类和回归问题。它的每个内部节点代表一个属性或特征,每个分支代表一个决策规则,每个叶节点则代表一个结果或决策。决策树的优点在于它的模型结果易于理解和解释,因为决策过程类似于人类的决策过程。此外,它需要的数据预处理较少,不需要进行归一化或标准化,且可以处理数值和分类数据。
在ID3算法中,我通过计算信息增益来选择最佳属性进行节点分裂。信息增益是衡量属性对于分类结果的信息量的增加程度,它帮助我们确定哪个属性最能减少分类的不确定性。然而,我发现单一地使用信息增益作为分裂标准有时会导致选择偏向于那些值较多的属性。为了改进这一问题,我在C4.5算法中引入了信息增益率。这一改进让我更加全面地考虑了属性的选择。信息增益率不仅考虑了信息增益的大小,还考虑了分裂信息值的大小。这样,算法在选择分裂属性时,能够更准确地衡量分裂后的信息量变化。而在实现ID3和C4.5决策树算法的过程中,我遇到了许多困难。决策树的复杂性不在算法逻辑上,而是涉及到对数据结构的深入理解。指针的正确使用对于决策树的构建至关重要,一旦出错,可能会导致树的结构混乱或内存泄漏。最棘手的是指针丢失问题。在动态构建决策树时,我经常遇到指针指向无效内存地址或未初始化的情况。为了解决这个问题,我仔细检查了内存管理代码,并确保在使用指针之前进行了正确的初始化。同时,我也加强了对new和delete操作符的使用,确保正确地分配和释放内存。
但是本题程序还有待改进,比如绘制时的标记与结点图案重合的情况,我还需要进一步详细的处理。而且未讨论如何处理连续数值型数据以构建更精确的决策树,例如引入CART算法更好地处理连续属性。当然,算法都是有缺陷的,需要针对不同的问题选择最适用的算法才能将问题解决,同时优化算法的工作也需要继续努力完成。
3.6 附录
//决策树构建
// 构造函数
DecisionTree::DecisionTree(AlgorithmType algoType) : root(nullptr), algorithmType(algoType) {}
// 析构函数
DecisionTree::~DecisionTree() {freeTree(root);
}
// 释放内存
void DecisionTree::freeTree(TreeNode* node) {if (!node) return;for (auto& branch : node->branches) {freeTree(branch.second);}delete node;
}
// 计算信息熵
double DecisionTree::calculateEntropy(const std::vector<Sample>& samples) const {std::map<std::string, int> labelCount;for (const auto& sample : samples) {labelCount[sample.date]++;}double entropy = 0.0;int total = samples.size();for (const auto& pair : labelCount) {double p = static_cast<double>(pair.second) / total;entropy -= p * log2(p);}return entropy;
}
// 按属性分组
std::map<std::string, std::vector<Sample>> DecisionTree::splitByAttribute(const std::vector<Sample>& samples, const std::string& attribute) const {std::map<std::string, std::vector<Sample>> groups;for (const auto& sample : samples) {std::string value;if (attribute == "Weather") value = sample.weather;if (attribute == "Temperature") value = sample.temperature;if (attribute == "Humidity") value = sample.humidity;if (attribute == "Wind") value = sample.wind;groups[value].push_back(sample);}return groups;
}
// 计算信息增益
double DecisionTree::calculateGain(const std::vector<Sample>& samples, const std::string& attribute) const {auto groups = splitByAttribute(samples, attribute);double totalEntropy = calculateEntropy(samples);double weightedEntropy = 0.0;int totalSamples = samples.size();for (const auto& pair : groups) {double p = static_cast<double>(pair.second.size()) / totalSamples;weightedEntropy += p * calculateEntropy(pair.second);}return totalEntropy - weightedEntropy;
}
//计算分裂信息
double DecisionTree::calculateSplitInfo(const std::vector<Sample>& samples, const std::string& attribute) const {auto groups = splitByAttribute(samples, attribute);double splitInfo = 0.0;int totalSamples = samples.size();for (const auto& pair : groups) {double p = static_cast<double>(pair.second.size()) / totalSamples;if (p > 0) {splitInfo -= p * log2(p);}}return splitInfo;
}
//计算信息增益率
double DecisionTree::calculateGainRatio(const std::vector<Sample>& samples, const std::string& attribute) const {double gain = calculateGain(samples, attribute);double splitInfo = calculateSplitInfo(samples, attribute);// 避免分母为0if (splitInfo == 0) return 0.0;return gain / splitInfo;
}
// 构建决策树
TreeNode* DecisionTree::buildTree(const std::vector<Sample>& samples, const std::set<std::string>& attributes) {// 如果样本全属于一个类别std::string firstLabel = samples.front().date;bool allSame = std::all_of(samples.begin(), samples.end(), [&firstLabel](const Sample& s) {return s.date == firstLabel;});if (allSame) {TreeNode* leaf = new TreeNode();leaf->label = firstLabel;return leaf;}// 如果没有剩余属性if (attributes.empty()) {TreeNode* leaf = new TreeNode();std::map<std::string, int> labelCount;for (const auto& sample : samples) {labelCount[sample.date]++;}leaf->label = std::max_element(labelCount.begin(), labelCount.end(),[](const auto& a, const auto& b) {return a.second < b.second;})->first;return leaf;}// 选择最佳属性double maxMetric = -1;std::string bestAttribute;for (const auto& attr : attributes) {double metric;if (algorithmType == ID3) {metric = calculateGain(samples, attr); // ID3 使用信息增益}else {metric = calculateGainRatio(samples, attr); // C4.5 使用信息增益率}if (metric > maxMetric) {maxMetric = metric;bestAttribute = attr;}}// 创建当前节点TreeNode* node = new TreeNode();node->attribute = bestAttribute;// 按最佳属性划分auto groups = splitByAttribute(samples, bestAttribute);std::set<std::string> remainingAttributes = attributes;remainingAttributes.erase(bestAttribute);for (const auto& pair : groups) {node->branches[pair.first] = buildTree(pair.second, remainingAttributes);}return node;
}
// 预测
std::string DecisionTree::predict(const Sample& sample, TreeNode* node) const {if (node->label != "") return node->label;std::string value;if (node->attribute == "Weather") value = sample.weather;if (node->attribute == "Temperature") value = sample.temperature;if (node->attribute == "Humidity") value = sample.humidity;if (node->attribute == "Wind") value = sample.wind;auto it = node->branches.find(value);if (it != node->branches.end()) {return predict(sample, it->second);}return "Unknown";
}
// 可视化决策树
void DecisionTree::visualizeTree(QGraphicsScene* scene, TreeNode* node, int x, int y, int dx, int dy) const {if (!node) return;// 节点样式int nodeRadius = 20; // 节点半径QColor nodeColor = node->label.empty() ? Qt::yellow : Qt::green; // 叶子节点用绿色,非叶子用黄色QGraphicsEllipseItem* ellipse = scene->addEllipse(x - nodeRadius, y - nodeRadius, nodeRadius * 2, nodeRadius * 2, QPen(Qt::black), QBrush(nodeColor));// 节点文字QGraphicsTextItem* text = scene->addText(QString::fromStdString(node->label.empty() ? node->attribute : node->label));text->setDefaultTextColor(Qt::black);text->setFont(QFont("Arial", 10, QFont::Bold));text->setPos(x - text->boundingRect().width() / 2, y - text->boundingRect().height() / 2);// 如果是叶子节点,终止递归if (!node->label.empty()) return;// 动态调整分支间距int childCount = static_cast<int>(node->branches.size());if (childCount == 0) return;int totalWidth = (childCount - 1) * dx; // 子节点总宽度int startX = x - totalWidth / 2; // 子节点起始位置// 遍历子节点,绘制分支线和递归调用int index = 0;for (const auto& branch : node->branches) {int childX = startX + index * dx; // 子节点X坐标int childY = y + dy; // 子节点Y坐标// 绘制分支线QGraphicsLineItem* line = scene->addLine(x, y + nodeRadius, childX, childY - nodeRadius, QPen(Qt::black, 2));// 绘制分支文字QGraphicsTextItem* branchText = scene->addText(QString::fromStdString(branch.first));branchText->setDefaultTextColor(Qt::darkGray);branchText->setFont(QFont("Arial", 10));branchText->setPos((x + childX) / 2 - branchText->boundingRect().width() / 2,(y + childY) / 2 - branchText->boundingRect().height() + 30);// 递归绘制子节点visualizeTree(scene, branch.second, childX, childY, dx / 2, dy);++index;}
}
项目源代码:Data-structure-coursework/3/Question_3 at main · CUGLin/Data-structure-courseworkhttps://github.com/CUGLin/Data-structure-coursework/tree/main/3/Question_3