机器学习(三)——决策树(附核心思想、重要算法、概念(信息熵、基尼指数、剪枝处理)及Python源码)

目录

  • 关于
  • 1 基本流程
  • 2 划分属性的选择
    • 2.1 方法一:依据信息增益选择
    • 2.2 方法二:依据增益率选择
    • 2.3 方法三:依据基尼指数选择
  • 3 剪枝处理:防止过拟合
    • 3.1 预剪枝
    • 3.2 后剪枝
  • 4 连续与缺失值
    • 4.1 连续值处理
    • 4.2 缺失值处理
  • 5 多变量决策树
  • X 案例代码
    • X.1 分类任务
      • X.1.1 源码
      • X.1.2 数据集 (鸢尾花数据集)
      • X.1.3 模型效果
    • X.2 回归任务
      • X.2.1 源码
      • X.2.2 数据集(糖尿病数据集)
      • X.2.3 模型效果


关于

  • 本文是基于西瓜书(第四章)的学习记录。讲解决策树的重要概念(划分属性的选择、剪枝处理、连续值和缺失值的处理、多变量决策树等)、核心流程,附Python分类和回归实现代码。
  • 西瓜书电子版:百度网盘分享链接

1 基本流程

决策树是一种模仿人类决策过程的机器学习算法,它通过树状结构来进行决策。

  • 决策树结构
    • 根节点:每个结点包含的样本集合根据属性测试的结果被划分到子结点中,根结点包含样本全集。
    • 内部节点:代表属性测试。
    • 叶节点:对应决策结果。
  • 决策过程:从根节点开始,通过一系列的测试(属性值的判断)到达叶节点,完成决策。
  • 算法流程
    在这里插入图片描述
    决策树的生成是一个递归过程,有三种情形会导致递归返回:
    • 当前结点包含的样本全属于同一类别,无需划分;
    • 当前属性集为空,或是所有样本在所有属性上取值相同,无法划分;
    • 当前结点包含的样本集合为空,不能划分;

2 划分属性的选择

  • 在决策树学习中,选择最优的划分属性是关键,这有助于提高样本集合的纯度,即分支结点所包含的样本尽可能属于同一类别。
  • 下面是三种常见的划分属性选择方法:

2.1 方法一:依据信息增益选择

  • 简述

    • 首先计算数据集D的信息熵;
    • 然后对于每个属性a,考虑将其作为划分属性划分后得到的各个子集,计算各子集的信息熵;
    • 然后利用下面的信息增益公式计算依据属性a划分的信息增益;
    • 最后选择增益最大的属性进行划分然后进行下一轮选择(如果不结束)。
  • 重要概念

    • 信息熵:衡量样本集合D纯度的指标,计算公式为
      E n t ( D ) = − ∑ k = 1 ∣ D ∣ p k log ⁡ 2 p k Ent(D) = -\sum_{k=1}^{|D|} p_k \log_2 p_k Ent(D)=k=1Dpklog2pk其中 p k p_k pk是第k类样本所占的比例。 E n t ( D ) Ent(D) Ent(D)值越小,纯度越高
    • 信息增益:使用属性a进行划分所获得的纯度提升,计算公式为
      G a i n ( D , a ) = E n t ( D ) − ∑ v = 1 V ∣ D v ∣ ∣ D ∣ E n t ( D v ) Gain(D, a) = Ent(D) - \sum_{v=1}^{V} \frac{|D_v|}{|D|} Ent(D_v) Gain(D,a)=Ent(D)v=1VDDvEnt(Dv)其中 D v D_v Dv是在属性a上取值为v的样本子集。
    • ID3算法:以信息增益为准则选择划分属性,是决策树学习的经典算法之一。
  • 案例
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

2.2 方法二:依据增益率选择

  • 简述:方法一的信息增益准则对可取值数目较多的属性有所偏好,所以考虑不直接使用信息增益,而是使用“增益率”来选择最优划分属性。其实就在每个属性的信息增益基础上除以 I V ( a ) IV(a) IV(a),得到的结果就是增益率,基于此选择(但是增益率又会偏好可取值数目少的属性,所以先从候选划分属性中找出信息增益高于平均水平的属性,再从中选择增益率最高的)。
  • 重要概念
    • 增益率:减少信息增益对可取值数目较多属性的偏好,计算公式为 G a i n R a t i o ( D , a ) = G a i n ( D , a ) I V ( a ) GainRatio(D, a) = \frac{Gain(D, a)}{IV(a)} GainRatio(D,a)=IV(a)Gain(D,a)其中 I V ( a ) IV(a) IV(a)是属性a的固有值,计算公式如下。
    • I V ( a ) IV(a) IV(a):这是属性a的固有值,a的可能取值越多,该值越大,起计算公式如下:
      I V ( a ) = − ∑ v = 1 V ∣ D v ∣ ∣ D ∣ log ⁡ 2 ∣ D v ∣ ∣ D ∣ \mathrm{IV}(a)=-\sum_{v=1}^V\frac{|D^v|}{|D|}\log_2\frac{|D^v|}{|D|} IV(a)=v=1VDDvlog2DDv
    • C4.5算法:使用增益率选择最优划分属性,是ID3算法的改进版。

2.3 方法三:依据基尼指数选择

  • 简述:流程和方法一一致,但是选择的指标是基尼指数
  • 重要概念
    • 基尼指数:度量数据集纯度的指标(反映了从数据集D 中随机抽取两个样本,其类别标记不一致的概率.因此,Gini(D)越小,则数据集D 的纯度越高),计算公式为 G i n i ( D ) = 1 − ∑ k = 1 ∣ D ∣ p k 2 Gini(D) = 1 - \sum_{k=1}^{|D|} p_k^2 Gini(D)=1k=1Dpk2
      属性a划分后的基尼指数计算公式为: G i n i I n d e x ( D , a ) = ∑ v = 1 V ∣ D v ∣ ∣ D ∣ G i n i ( D v ) . \mathrm{GiniIndex}(D,a)=\sum_{v=1}^V\frac{|D^v|}{|D|}\mathrm{Gini}(D^v) . GiniIndex(D,a)=v=1VDDvGini(Dv).
    • CART算法:使用基尼指数选择划分属性,适用于分类和回归任务

3 剪枝处理:防止过拟合

  • 剪枝是决策树学习中对付过拟合的主要手段,通过剪枝可以降低过拟合的风险。剪枝策略有预剪枝后剪枝两种

3.1 预剪枝

在这里插入图片描述

  • 预剪枝:在生成过程中,对每个结点在划分前先进行估计,若当前结点的划分不能带来决策树泛化性能提升,则停止划分
  • 优点:这不仅降低了过拟合的风险,还显著减少了决策树的训练时间开销和测试时间开销
  • 缺点:预剪枝基于“贪心”本质禁止这些分支展开,给预剪枝决策树带来了欠拟合的风险
  • 决策树桩:只有一层划分的决策树

3.2 后剪枝

在这里插入图片描述

  • 后剪枝:先从训练集生成一棵完整的决策树,然后自底向上地对非叶结点进行考察,若将该结点对应的子树替换为叶结点能带来决策树泛化性能提升,则将该子树替换为叶结点。
  • 后剪枝决策树通常比预剪枝决策树保留了更多的分支.
  • 优点:后剪枝决策树的欠拟合风险很小,泛化性能往往优于预剪枝决策树
  • 缺点:但后剪枝过程是在生成完全决策树之后进行的,并且要自底向上地对树中的所有非叶结点进行逐一考察,因此其训练时间开销比未剪枝决策树和预剪枝决策树都要大得多.

4 连续与缺失值

4.1 连续值处理

  • 必要性:现实学习任务中常会遇到连续属性,有必要讨论如何在决策树学习中使用连续属性.
  • 连续属性离散化:使用二分法处理连续属性,按照某个划分点的值将连续属性的值划分为两个子集,选择最优的划分点以最大化信息增益。
  • 划分点的选择:对于包含n个不同值的某个属性,可以得到n-1个划分点,根据划分点可以将取值范围划分为两个区间
  • 核心逻辑:得到某个属性的划分点集合之后就可以按照信息增益的大小选择出最合适的划分点构造一个划分节点。
  • 连续数据划分点和离散数据的区别:与离散属性不同,若当前结点划分属性为连续属性,该属性还可作为其后代结点的划分属性
  • 示例:
    在这里插入图片描述

4.2 缺失值处理

  • 必要性:一部分样本得属性缺少值,直接丢弃会造成信息丢失
  • 要解决的问题:
    • 如何在属性值缺失的情况下进行划分属性选择给定划分属性?——将前文的信息增益公式推广后执行类似流程。
    • 若样本在该属性上的值缺失,如何对样本进行划分?——同时划分到子节点,不过进行权重调整,直观地看就是将同一个样本以不同的概率划分到不同的子节点去
  • 权重调整:对缺失值样本进行权重调整后进行划分,确保模型能够从不完整样本中学习。

5 多变量决策树

  • 背景:

    • 分类意味着找到分类边界
    • 决策树分类边界的特点:由若干个与坐标轴平行的分段组成
      • 优点:可解释性强
      • 缺点:需要很多段,学习代价大
  • 多变量决策树

    • 优点:能够实现斜划分(使用属性的线性组合进行测试,允许分类边界不是轴平行的),简化模型复杂度。
    • 改变:在非叶节点上实现斜划分,不再是寻找一个最优划分属性,而是试图建立一个合适的线性分类器
      在这里插入图片描述

X 案例代码

X.1 分类任务

X.1.1 源码

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
import seaborn as sns# 1. 加载数据集
iris = load_iris()
X, y = iris.data, iris.target
print("此时X,y的数据类型为:", type(X), type(y), '\n')# 2. 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("此时X_train,y_train的数据类型为:", type(X_train), type(y_train), '\n')
print("X_train的前10条数据展示:")
print(pd.DataFrame(X_train).head(10).to_string(index=False, justify='left'), '\n')# 3. 构建并训练决策树分类模型
model = DecisionTreeClassifier(random_state=42)
model.fit(X_train, y_train)# 4. 预测测试集上的目标变量
y_pred = model.predict(X_test)# 5. 评估模型性能
accuracy = accuracy_score(y_test, y_pred)
print("模型准确率:", accuracy)print("分类报告:")
print(classification_report(y_test, y_pred))# 6. 绘制混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=iris.target_names, yticklabels=iris.target_names)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix for Iris Dataset')
plt.tight_layout()
plt.show()# 可选:将结果保存到DataFrame中以便进一步分析
results = pd.DataFrame({'Actual': y_test,'Predicted': y_pred
})
print("模型预测结果:")
print(results.head())

X.1.2 数据集 (鸢尾花数据集)

  • 鸢尾花数据集是机器学习领域中最著名的数据集之一,常被用于分类算法的测试和演示。

  • 概览

    • 样本数量:150个样本
    • 特征数量:4个特征
    • 标签种类数量:3个类别,每个类别有50个样本
  • 特征描述

    • 萼片长度 (sepal length):花萼的长度,单位为厘米。
    • 萼片宽度 (sepal width):花萼的宽度,单位为厘米。
    • 花瓣长度 (petal length):花瓣的长度,单位为厘米。
    • 花瓣宽度 (petal width):花瓣的宽度,单位为厘米。
  • 目标变量是鸢尾花的种类,共有三种:

    1. Iris setosa
    2. Iris versicolor
    3. Iris virginica
  • 使用

    • 可以使用 sklearn.datasets.load_iris() 函数来加载这个数据集,并查看其详细信息。

X.1.3 模型效果

在这里插入图片描述
在这里插入图片描述

X.2 回归任务

X.2.1 源码

import pandas as pd
from matplotlib import pyplot as plt
from sklearn import datasets
from sklearn.tree import DecisionTreeRegressor, export_text
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score# 加载糖尿病数据集
diabetes = datasets.load_diabetes()# 提取特征和标签
X = diabetes.data
y = diabetes.target
print("此时X,y的数据类型为:", type(X), type(y), '\n')# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("此时X_train,y_train的数据类型为:", type(X_train), type(y_train), '\n')
print("X_train的前10条数据展示:")
print(pd.DataFrame(X_train).head(10).to_string(index=False, justify='left'), '\n')# 创建决策树回归模型
regressor = DecisionTreeRegressor(random_state=42)# 训练模型
regressor.fit(X_train, y_train)# 进行预测
y_pred = regressor.predict(X_test)# 计算均方误差
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)print(f"均方误差(MSE): {mse}")
print(f"决定系数(R^2): {r2}", '\n')# 查看决策树的结构
tree_rules = export_text(regressor, feature_names=diabetes.feature_names)
print(tree_rules)# 绘制实际值和预测值的折线图
plt.figure(figsize=(12, 6))
plt.plot(y_test, label='Actual', marker='o', color='blue')
plt.plot(y_pred, label='Predicted', marker='x', color='red', linestyle='--')
plt.title('Actual vs Predicted Values')
plt.xlabel('Sample Index')
plt.ylabel('Target Value')
plt.legend()
plt.tight_layout()
plt.show()

X.2.2 数据集(糖尿病数据集)

  • 糖尿病数据集包含442名患者的10项生理特征,目标是预测一年后疾病水平的定量测量值。这些特征经过了标准化处理,使得每个特征的平均值为零,标准差为1。

  • 概览

    • 样本数量:442个样本
    • 特征数量:10个特征
    • 目标变量:1个目标变量(一年后疾病水平的定量测量值)
  • 特征描述

    1. 年龄 (age):患者年龄(已标准化)
    2. 性别 (sex):患者性别(已标准化)
    3. 体质指数 (bmi):身体质量指数(已标准化)
    4. 血压 (bp):平均动脉压(已标准化)
    5. S1:血清测量值1(已标准化)
    6. S2:血清测量值2(已标准化)
    7. S3:血清测量值3(已标准化)
    8. S4:血清测量值4(已标准化)
    9. S5:血清测量值5(已标准化)
    10. S6:血清测量值6(已标准化)
  • 目标变量

    • 一年后疾病水平的定量测量值:这是模型需要预测的目标变量。
  • 使用

    • 可以使用 sklearn.datasets.load_diabetes() 函数来加载这个数据集,并查看其详细信息。

X.2.3 模型效果

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/465651.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu和Debian系列的Release默认shell解释器变更

Debian 12 Bookworm 和 Ubuntu 24.04 中默认的 shell 解释器已经由 bash 变更为了 dash 。 这个变化对于我们直接在 CLI 上执行 Linux command 无影响,但对于执行shell解释性程序有影响,已知 bash 中的 变量正规表达式 (如 ${GIT_COMMIT:0:8…

ReLU6替换ReLU为什么可以增强硬件效率?

ReLU6(Rectified Linear Unit 6)是ReLU的一种变体,它在ReLU的基础上增加了一个上限值6,即输出范围被限制在[0, 6]之间。 这种变化在硬件实现中可以带来以下几个方面的效率提升: 1. 数据表示的简化 ReLU的输出范围是[…

vscode在windows和linux如何使用cmake构建项目并make生成可执行文件,两者有什么区别

vscode在windows和linux如何使用cmake构建项目并make生成可执行文件,两者有什么区别 windows默认使用的是最新的visual studio,而linux默认就是cmake 文章目录 vscode在windows和linux如何使用cmake构建项目并make生成可执行文件,两者有什么…

Spirngboot集成Knife4j spirngboot版本2.7.17 Knife4j版本4.0.0

Knife4j是什么?有什么作用? ‌Knife4j‌是一个基于Swagger的Java RESTful API文档工具,旨在帮助开发者轻松生成和维护API文档。它继承并增强了Swagger的功能,简化了使用流程,并提供了一系列增强功能,如接口…

ROS2humble版本使用colcon构建包

colcon与与catkin相比,没有 devel 目录。 创建工作空间 首先,创建一个目录 ( ros2_example_ws ) 来包含我们的工作区: mkdir -p ~/ros2_example_ws/src cd ~/ros2_example_ws 此时,工作区包含一个空目录 src : . └── src1 directory, …

GY-56 (VL53L0X) 激光测距

文章目录 一、GY-56 简介二、引脚功能三、通信协议1.串口协议: 当 GY-56 PS 焊点开放时候使用(默认)(1)串口通信参数(默认波特率值 9600bps)(2)模块输出格式,每帧包含 8-13 个字节&a…

C语言 | Leetcode C语言题解之第541题反转字符串II

题目&#xff1a; 题解&#xff1a; void swap(char* a, char* b) {char tmp *a;*a *b, *b tmp; }void reverse(char* l, char* r) {while (l < r) {swap(l, --r);} }int min(int a, int b) {return a < b ? a : b; }char* reverseStr(char* s, int k) {int n strl…

提升网站安全性 HTTPS的重要性与应用指南

内容概要 在如今数字化快速发展的时代&#xff0c;网站安全显得尤为重要。许多用户在访问网站时&#xff0c;尤其是涉及个人信息或金融交易时&#xff0c;对数据传输的安全性有着高度的关注。HTTPS&#xff08;超文本传输安全协议&#xff09;正是为了满足这种需求而诞生的。通…

Transformer究竟是什么?预训练又指什么?BERT

目录 Transformer究竟是什么? 预训练又指什么? BERT的影响力 Transformer究竟是什么? Transformer是一种基于自注意力机制(Self-Attention Mechanism)的神经网络架构,它最初是为解决机器翻译等序列到序列(Seq2Seq)任务而设计的。与传统的循环神经网络(RNN)或卷…

OpenDroneMap Webodm

OpenDroneMap & Webodm OpenDroneMap Webodm 开源无人机航拍系列图像及其它系列图像三维重建软件。很棒的开源无人机测绘软件OpenDroneMap,从航拍图像生成精确的地图、高程模型、3D 模型和点云。 应用领域 Mapping & Surveying 测绘和测量 从图像测量获得高精度的可…

Java+Swing可视化图像处理软件

JavaSwing可视化图像处理软件 一、系统介绍二、功能展示1.图片裁剪2.图片缩放3.图片旋转4.图像灰度处理5.图像变形6.图像扭曲7.图像移动 三、系统实现1.ImageProcessing.java 四、其它1.其他系统实现2.获取源码 一、系统介绍 该系统实现了图片裁剪、缩放、旋转、图像灰度处理、…

迈入国际舞台,AORO M8防爆手机获国际IECEx、欧盟ATEX防爆认证

近日&#xff0c;深圳市遨游通讯设备有限公司&#xff08;以下简称“遨游通讯”&#xff09;旗下5G防爆手机——AORO M8&#xff0c;通过了CSA集团的严格测试和评估&#xff0c;荣获国际IECEx及欧盟ATEX防爆认证证书。2024年11月5日&#xff0c;CSA集团和遨游通讯双方领导在遨游…

string模拟实现插入+删除

个人主页&#xff1a;Jason_from_China-CSDN博客 所属栏目&#xff1a;C系统性学习_Jason_from_China的博客-CSDN博客 所属栏目&#xff1a;C知识点的补充_Jason_from_China的博客-CSDN博客 string模拟实现reserve 这里实现的是扩容 扩容这里是可以实现缩容&#xff0c;可以实现…

如何实现KIS私有云数据到聚水潭的高效集成

KIS私有云数据集成到聚水潭&#xff1a;KIS-供应商——>空操作案例分享 在企业信息化建设中&#xff0c;数据的高效流动和准确对接是提升业务效率的关键。本文将重点介绍如何通过轻易云数据集成平台&#xff0c;将KIS私有云中的供应商数据无缝集成到聚水潭系统&#xff0c;…

GESP4级考试语法知识(算法概论(三))

爱因斯坦的阶梯代码&#xff1a; //算法1-12 #include<iostream> using namespace std; int main() {int n1; //n为所设的阶梯数while(!((n%21)&&(n%32)&&(n%54)&&(n%65)&&(n%70)))n; //判别是否满足一组同余式cout<<n<…

【无标题】123

软件包管理器yum yum类似应用商店客户端&#xff0c;有人已经把软件写好放在服务器上了&#xff0c;通过yum找到服务器上的软件下载 软件操作 yum list 可以显示所有可下载软件&#xff0c;我们要找lrzsz软件 yum install 下载 yum remove 卸载 yum源 yum下载软件是通过下载…

【Golang】sql.Null* 类型使用(处理空值和零值)

sql.NullString 和 sql.NullInt64 类型&#xff08;以及其他类似的 sql.Null* 类型&#xff09;在处理数据库操作时非常有用&#xff0c;尤其是在 Go 语言的 database/sql 包中。它们的主要用途包括&#xff1a; 表示 NULL 值&#xff1a; 在数据库中&#xff0c;NULL 表示“没…

【昇腾】从单机单卡到单机多卡训练

昇腾&#xff1a;单机单卡训练->单机多卡训练 分布式训练 &#xff08;1&#xff09;单机单卡的训练流程 硬盘读取数据CPU处理数据&#xff0c;将数据组成一个batch传入GPU网络前向传播计算loss网络反向传播计算梯度 &#xff08;2&#xff09;PyTorch中最早的数据并行框…

【动手学电机驱动】STM32-FOC(3)STM32 三路互补 PWM 输出

STM32-FOC&#xff08;1&#xff09;STM32 电机控制的软件开发环境 STM32-FOC&#xff08;2&#xff09;STM32 导入和创建项目 STM32-FOC&#xff08;3&#xff09;STM32 三路互补 PWM 输出 STM32-FOC&#xff08;4&#xff09;IHM03 电机控制套件介绍 STM32-FOC&#xff08;5&…

docker+nacos

安装数据库 以docker安装为例&#xff08;实际建议实体&#xff09; 初始化数据库 /******************************************/ /* 数据库全名 nacos_config */ /* 表名称 config_info */ /******************************************/ CREATE TABLE config_i…