山东大学机器学习实验lab9 决策树

山东大学机器学习实验lab9 决策树

  • 所有公式来源于<<机器学习>>周志华
  • github上有.ipynb源文件

修改:

  • 2024 5.15 添加了一些Node属性,用于标记每个Node使用的划分feature的名称,修改后的版本见 github
Node
  • 构造函数 初始化决策树的节点,包含节点ID、特征数据、特征ID、划分阈值、标签、父节点ID、子节点列表以及节点所属的类别
  • set_sons 设置节点的子节点列表
  • judge_stop 判断是否停止继续生成节点条件包括
    • 所有样本属于同一类别
    • 所有样本在所有属性上取值相同
    • 节点对应的样本集合为空
DecisionTree
  • 构造函数 初始化决策树,包括根节点列表、节点ID计数器
  • ent 计算信息熵,衡量数据集纯度
  • gain_single_feature 对于连续属性,寻找最佳划分点以最大化信息增益
  • gain_rate 计算所有特征的信息增益率,选择信息增益率最高的特征作为分裂依据
  • generate_node 递归生成子节点,根据特征的最佳分裂点划分数据集,直到满足停止条件
  • train 从根节点开始构建决策树
  • predict 给定数据,根据决策树进行分类预测
Data
  • 构造函数 加载数据集,初始化特征和标签
  • get_k_exmaine 实现K折交叉验证的数据切分,返回K个特征和标签组

整体流程

  1. 数据预处理:使用Data类读取数据文件,进行K折交叉验证的数据切分
  2. 模型训练与评估:对每一份测试数据,结合其余K-1份数据进行训练,使用DecisionTree类构建决策树模型,然后对测试集进行预测,计算正确率
  3. 结果展示:收集每一次交叉验证的正确率,最后计算并输出平均正确率

代码执行流程

  • 首先,Data类加载数据并进行K折交叉验证数据分割
  • 接着,对于每个验证折,训练数据被合并以训练决策树模型,然后在对应的测试数据上进行预测
  • 对每个测试样本,通过遍历决策树找到其所属类别,并与实际标签对比,累计正确预测的数量
  • 计算并打印每次验证的正确率,最后计算并输出所有折的平均正确率,以此评估模型的泛化能力

代码以及运行结果展示

import numpy as np 
import matplotlib.pyplot as plt 
import math 
class Node():def __init__(self,id_,features_,feature_id_,divide_,labels_,father_,sons_=[]):self.divide=divide_self.feature_id=feature_id_self.id=id_self.feature=features_self.labels=labels_ self.father=father_ self.sons=sons_self.clas='None'def set_sons(self,sons_):self.sons=sons_ def judge_stop(self):labels=np.unique(self.labels)#如果节点样本属于同一类别if(labels.shape[0]==1):self.clas=labels[0]return Truefeatures=np.unique(self.feature)#如果所有样本在所有属性上取值相同if(features.shape[0]==1):unique_values, counts = np.unique(labels, return_counts=True)self.clas = unique_values[counts.argmax()]return True #如果对应的样本集合为空if(self.feature.shape[0]==0 or self.feature.shape[1]==0):self.clas=1return Truereturn False
class DecisionTree():def __init__(self):self.tree=[]self.id=0pass #计算信息熵def ent(self,labels):labels_s=list(set([labels[i,0] for i in range(labels.shape[0])]))ans=0for label in labels_s:num=np.sum(labels==label)p=num/labels.shape[0]ans-=p*math.log(p,2)return ans #计算一个标签对应的最佳分界(连续值)def gain_single_feature(self,feature,labels):origin_ent=self.ent(labels)divide_edge=[]feature=list(set(feature))feature=np.sort(feature)divide_edge=[(feature[i]+feature[i+1])/2.0 for i in range(feature.shape[0]-1)]best_ent=0best_divide=0l1=l2=np.array([[]])for condition in divide_edge:labels1=np.array([labels[i] for i in range(feature.shape[0]) if feature[i]<=condition])labels2=np.array([labels[i] for i in range(feature.shape[0]) if feature[i]>condition])ent1=self.ent(labels1)ent2=self.ent(labels2)ans=origin_ent-((labels1.shape[0]/labels.shape[0])*ent1+(labels2.shape[0]/labels.shape[0])*ent2)if(ans>=best_ent):best_divide=conditionl1=labels1l2=labels2best_ent=ans return best_divide,l1,l2,best_ent#计算信息增益def gain_rate(self,features,labels):origin_ent=self.ent(labels)gain_rate=0feature_id=-1divide=-1l=labels.shape[0]for id in range(features.shape[1]):divide1,labels1,labels2,th_gain=self.gain_single_feature(features[:,id],labels)l1=labels1.shape[0]l2=labels2.shape[0]iv=-1*((l1/l)*math.log(l1/l,2)+(l2/l)*math.log(l2/l,2))if iv!=0:rate=th_gain/ivelse:rate=0if(rate>=gain_rate):gain_rate=ratedivide=divide1feature_id=idreturn feature_id,dividedef generate_node(self,node:Node):a=1features1_id=np.array([i for i in range(node.feature.shape[0]) if node.feature[i,node.feature_id]>=node.divide])features2_id=np.array([i for i in range(node.feature.shape[0]) if node.feature[i,node.feature_id]<node.divide])features1=node.feature[features1_id]features2=node.feature[features2_id]labels1=node.labels[features1_id]labels2=node.labels[features2_id]features1=np.delete(features1,node.feature_id,axis=1)features2=np.delete(features2,node.feature_id,axis=1)features_id1,divide1=self.gain_rate(features1,labels1)features_id2,divide2=self.gain_rate(features2,labels2)tmp=0if(features_id1!=-1):tmp+=1node1=Node(self.id+tmp,features1,features_id1,divide1,labels1,node.id,[])node1.father=node.idself.tree.append(node1)node.sons.append(self.id+tmp)if(features_id2!=-1):tmp+=1node2=Node(self.id+tmp,features2,features_id2,divide2,labels2,node.id,[])node2.father=node.idself.tree.append(node2)node.sons.append(self.id+tmp)self.id+=tmpif(tmp==0):unique_values, counts = np.unique(node.labels, return_counts=True)node.clas = 0 if counts[0]>counts[1] else 1return for n in [self.tree[i] for i in node.sons]:if(n.judge_stop()):continueelse:self.generate_node(n)def train(self,features,labels):feature_id,divide=self.gain_rate(features,labels)root=Node(0,features,feature_id,divide,labels,-1,[])self.tree.append(root)self.generate_node(root)def predict(self,features):re=[]for feature in features:node=self.tree[0]while(node.clas=='None'):th_feature=feature[node.feature_id]feature=np.delete(feature,node.feature_id,axis=0)th_divide=node.divideif(node.clas!='None'):break if(th_feature<th_divide):node=self.tree[node.sons[len(node.sons)-1]]else:node=self.tree[node.sons[0]]re.append(node.clas)return re 
class Data():def __init__(self):self.data=np.loadtxt('/home/wangxv/Files/hw/ml/lab9/data/ex6Data/ex6Data.txt',delimiter=',')self.data_num=self.data.shape[0]self.features=self.data[:,:-1]self.labels=self.data[:,-1:]def get_k_exmaine(self,k:int):num=int(self.data_num/k)data=self.datanp.random.shuffle(self.data)features=data[:,:-1]labels=self.data[:,-1:]feature_groups=[features[i:i+num-1] for i in np.linspace(0,self.data_num,k+1,dtype=int)[:-1]]labels_groups=[labels[i:i+num-1] for i in np.linspace(0,self.data_num,k+1,dtype=int)[:-1]]return feature_groups,labels_groups
data=Data() 
feature_groups,label_groups=data.get_k_exmaine(10)
rate_set=[]
for ind in range(10):dt=DecisionTree()features_=[feature_groups[i] for i in range(10) if i!=ind]labels_=[label_groups[i] for i in range(10) if i!=ind]train_features=features_[0]train_labels=labels_[0]for feature,label in zip(features_[1:],labels_[1:]):train_features=np.vstack((train_features,feature))train_labels=np.vstack((train_labels,label))test_features=feature_groups[ind]test_labels=label_groups[ind]dt.train(train_features,train_labels)pred_re=dt.predict(test_features)right_num=0for i in range(len(pred_re)): if pred_re[i]==test_labels[i][0]:right_num+=1right_rate=right_num/len(pred_re)print(str(ind+1)+'  correct_rate : '+str(right_rate))rate_set.append(right_rate)
print("average_rate : "+str(np.mean(np.array(rate_set))))
1  correct_rate : 0.7930327868852459
2  correct_rate : 0.8032786885245902
3  correct_rate : 0.8012295081967213
4  correct_rate : 0.7848360655737705
5  correct_rate : 0.7889344262295082
6  correct_rate : 0.7909836065573771
7  correct_rate : 0.7827868852459017
8  correct_rate : 0.7766393442622951
9  correct_rate : 0.8012295081967213
10  correct_rate : 0.7725409836065574
average_rate : 0.7895491803278689

最终得到的平均正确率为0.7895>0.78符合实验书要求

可视化

  • 需要将自定义的Node类转换为可识别的数据格式
from graphviz import Digraphdef show_tree(root_node, dot=None):if dot is None:dot = Digraph(comment='decision_tree')node_label = f"{root_node.id} [label=\"feature {root_node.feature_id}: {root_node.divide} | class: {root_node.clas}\"]"dot.node(str(root_node.id), node_label)if root_node.sons:for son_id in root_node.sons:dot.edge(str(root_node.id), str(son_id))visualize_tree(dt.tree[son_id], dot)return dot
root_node=dt.tree[0]
dot = show_tree(root_node)
dot.render('decision_tree', view=True,format='png')
from PIL import Image
Image.open('./decision_tree.png')
  • 这个决策树图有点过于大了,凑合看吧
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/328719.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ubuntu升级python

添加Python官方PPA源 sudo add-apt-repository ppa:deadsnakes/ppa 执行会显示各个版本ubuntu可以安装哪些python版本 更新软件包索引 sudo apt update 安装需要版本Python sudo apt install python3.11 检查Python版本: which python11 /usr/bin/python3.11 设置为系统默认Pyt…

命令行中,Python 想使用本地环境,但总是显示为Anaconda的虚拟环境

电脑环境 Python 本地环境&#xff08;Python3.9.5&#xff09;Anaconda 虚拟环境&#xff08;Python3.8.8&#xff09; 遇到的问题 在cmd 中&#xff0c;我想在本地环境使用 Python、pip &#xff0c;但它却是一直识别成Anaconda的虚拟环境。 解决方法 环境变量配置中&am…

Python | Leetcode Python题解之第91题解码方法

题目&#xff1a; 题解&#xff1a; class Solution:def numDecodings(self, s: str) -> int:n len(s)# a f[i-2], b f[i-1], c f[i]a, b, c 0, 1, 0for i in range(1, n 1):c 0if s[i - 1] ! 0:c bif i > 1 and s[i - 2] ! 0 and int(s[i-2:i]) < 26:c aa,…

汇中 SCL-61D2超声水表汇中通讯协议

RS-485串行通讯接口设置表 通用代码注释 读取正向仪表数据 DD的内容为 通讯示例 主机命令&#xff1a;2A 41 4A 仪表响应&#xff1a;26 41 4A 00 00 13 63 00 00 07 72 00 00 10 34 00 33 读取负向仪表数据&#xff1a;&#xff08;单向型仪表无此命令&#xff09; DD的内容…

redis-stack部署概要

第一步&#xff0c;下载redis-stack 下载链接&#xff1a;Downloads - Redis 第二步&#xff0c;redis安装包解压缩 gzip -d redis-stack-server-7.2.0-v10.rhel8.x86_64.tar.gz tar -xvf redis-stack-server-7.2.0-v10.rhel8.x86_64.tar 第三步&#xff0c;编辑etc下的redis…

【Spring security】Note01-pig登录验证过程

&#x1f338;&#x1f338; pig 登录验证 &#x1f338;&#x1f338; 一、大概执行顺序&#xff0c;便于理解 pig spring-security 二、执行过程分析 请求拦截&#xff1a; 当客户端发送请求时&#xff0c;Spring Security 的过滤器链会首先拦截该请求。过滤器链中的每个…

数据结构与算法—顺序表

目录 一、线性表 二、顺序表概念 三、实现顺序表 1、声明结构体 2、初始化 3、打印数据 4、销毁 5、尾插&头插 尾插 判断是否扩容 头插 6、尾删&头删 尾删 头删 7、 指定位置插入元素 8、 删除指定位置元素 9、 查找指定元素位置 10、修改指定位置元…

angular13 自定义组件全项目都可用 自存

1.定义自定义组件 使用命令创建一个组件 但删除它在你的module里的声明&#xff0c;因为会报错只能引用一次 在本组件中创建一个module文件&#xff0c;引入刚才的组件component.ts import { NgModule } from angular/core; import { CommonModule } from angular/common; im…

简化路径[中等]

优质博文&#xff1a;IT-BLOG-CN 一、题目 给你一个字符串path&#xff0c;表示指向某一文件或目录的Unix风格 绝对路径 &#xff08;以/开头&#xff09;&#xff0c;请你将其转化为更加简洁的规范路径。在Unix风格的文件系统中&#xff0c;一个点.表示当前目录本身&#x…

vue3 自定义组件

在项目中&#xff0c;我们会遇到一些没有现成的组件&#xff0c;那这个时候我们就需要自己去写一个满足我们需求的组件。 比如&#xff0c;我需要一个上下排布&#xff0c;上面显示标题&#xff0c;下面显示内容的组件。封装完成后方便复用。 1、布局组件 我定义一个上下结构的…

2024生日快乐祝福HTML源码

源码介绍 2024生日快乐祝福HTML源码&#xff0c;源码由HTMLCSSJS组成&#xff0c;记事本打开源码文件可以进行内容文字之类的修改&#xff0c;双击html文件可以本地运行效果&#xff0c;也可以上传到服务器里面&#xff0c; 源码截图 源码下载 2024生日快乐祝福HTML源码

Gradio 案例——将 dicom 文件转为 nii文件

文章目录 Gradio 案例——将 dicom 文件转为 nii文件界面截图依赖安装项目目录结构代码 Gradio 案例——将 dicom 文件转为 nii文件 利用 SimpleITK 库&#xff0c;将 dicom 文件转为 nii文件更完整、丰富的示例项目见 GitHub - AlionSSS/dcm2niix-webui: The web UI for dcm2…

一种请求头引起的跨域问题记录(statusCode = 400/CORS)

问题表象 问题描述 当我们需要在接口的headers中添加一个自定义的变量的时候&#xff0c;前端的处理是直接在拦截器或者是接口配置的地方直接进行写&#xff0c;比如下面的这段比较基础的写法&#xff1a; $http({method: "post",url:constants.backend.SERVER_LOGIN…

selenium发展史

Selenium Core 2004 年&#xff0c;Thoughtworks 的工程师 Jason Huggins 正在负责一个 Web 应用的测试工作&#xff0c;由于这个项目需要频繁回归&#xff0c;这导致他不得不每天做着重复且低效的工作。为了解决这个困境&#xff0c;Jason 开发了一个运行在 JavaScript 沙箱中…

React框架-Next 学习-1

创建一个 Next.js 应用,node版本要高&#xff0c;16.5以上 npm淘宝镜像切为https://registry.npmmirror.com npm config set registry https://registry.npmmirror.com npx create-next-applatest//安装后 使用npm run dev 启动 Next.js 是围绕着 页面&#xff08;pages&am…

我21岁玩“撸货”,被骗1000多万

最近&#xff0c;撸货业界内发生了一些颇受瞩目的事件。 在郑州&#xff0c;数码档口下面抢手团长跑路失联&#xff0c;涉及金额几百万&#xff0c;在南京&#xff0c;一家知名的电商平台下的收货站点突然失联&#xff0c;涉及金额高达一千多万&#xff0c;令众多交易者震惊不已…

回归预测 | Matlab实现GA-LSSVM遗传算法优化最小二乘支持向量机多输入单输出回归预测

回归预测 | Matlab实现GA-LSSVM遗传算法优化最小二乘支持向量机多输入单输出回归预测 目录 回归预测 | Matlab实现GA-LSSVM遗传算法优化最小二乘支持向量机多输入单输出回归预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本介绍 Matlab实现GA-LSSVM遗传算法优化最小…

基于 Spring Boot 博客系统开发(十)

基于 Spring Boot 博客系统开发&#xff08;十&#xff09; 本系统是简易的个人博客系统开发&#xff0c;为了更加熟练地掌握 SprIng Boot 框架及相关技术的使用。&#x1f33f;&#x1f33f;&#x1f33f; 基于 Spring Boot 博客系统开发&#xff08;九&#xff09;&#x1f…

【考研数学】张宇《1000题》强化阶段正确率多少算合格?

张宇1000题真的很练人心态.... 基础不好&#xff0c;建议别碰1000题 基础好&#xff0c;1000题建议在两个月以内刷完 如果自己本身在基础阶段学的比较水&#xff0c;自己的薄弱点刷了一小部分题没有针对性完全解决&#xff0c;转身去刷1000题就会发现&#xff0c;会的题目刷…

算术平均数

算术平均数&#xff08;average&#xff09;是一组数据相加后除以数据的个数而得到的结果&#xff0c;是度量数据水平的常用统计量&#xff0c;在参数估计和假设检验中经常用到。比如&#xff1a;用职工平均工资来衡量职工工资的一般水平&#xff0c;用平均体重来观察某一人群体…