决策树实验分析(分类和回归任务,剪枝,数据对决策树影响)

目录

1. 前言

2. 实验分析

        2.1 导入包

        2.2 决策树模型构建及树模型的可视化展示

        2.3 概率估计

        2.4 绘制决策边界

        2.5 决策树的正则化(剪枝)

        2.6 对数据敏感

        2.7 回归任务

        2.8 对比树的深度对结果的影响

        2.9 剪枝


1. 前言

        本文主要分析了决策树的分类和回归任务,对比一系列的剪枝的策略对结果的影响,数据对于决策树结果的影响。

        介绍使用graphaviz这个决策树可视化工具

2. 实验分析

        2.1 导入包

#1.导入包
import os
import numpy as np
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
import warnings
warnings.filterwarnings('ignore')

        2.2 决策树模型构建及树模型的可视化展示

        下载安装包:https://graphviz.gitlab.io/_pages/Download/Download_windows.html

         选择一款安装,注意安装时要配置环境变量

        注意这里使用的是鸢尾花数据集,选择花瓣长和宽两个特征

#2.建立树模型
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
X = iris.data[:,2:] # petal legth and width
y = iris.target
tree_clf = DecisionTreeClassifier(max_depth=2)
tree_clf.fit(X,y)
#3.树模型的可视化展示
from sklearn.tree import export_graphviz
export_graphviz(tree_clf,out_file='iris_tree.dot',feature_names=iris.feature_names[2:],class_names=iris.target_names,rounded=True,filled=True
)

        然后就可以使用graphviz包中的dot.命令工具将此文件转换为各种格式的如pdf,png,如 dot -Tpng iris_tree.png -o iris_tree.png

        可以去文件系统查看,也可以用python展示

from IPython.display import Image
Image(filename='iris_tree.png',width=400,height=400)

        分析:value表示每个节点所有样本中各个类别的样本数,用花瓣宽<=0.8和<=1.75 作为根节点划分,叶子节点表示分类结果,结果执行少数服从多数策略,gini指数随着分类进行在减小。

        2.3 概率估计

        估计类概率 输入数据为:花瓣长5厘米,宽1.5厘米的花。相应节点是深度为2的左节点,因此决策树因输出以下概率:

        iris-Setosa为0%(0/54)

        iris-Versicolor为90.7%(49/54)

        iris-Virginica为9.3%(5/54)

        

#4.概率估计
print(tree_clf.predict_proba([[5,1.5]]))
print(tree_clf.predict([[5,1.5]]))

        2.4 绘制决策边界

        

#5.绘制决策边界
from matplotlib.colors import ListedColormapdef plot_decision_boundary(clf,X,y,axes=[0,7.5,0,3],iris=True,legend=False,plot_training=True):#找两个特征 x1 x2x1s = np.linspace(axes[0],axes[1],100)x2s = np.linspace(axes[2],axes[3],100)#构建棋盘x1,x2 = np.meshgrid(x1s,x2s)#在棋盘中构建待测试数据X_new = np.c_[x1.ravel(),x2.ravel()]#将预测值算出来y_pred = clf.predict(X_new).reshape(x1.shape)#选择颜色custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])#绘制并填充不同的区域plt.contourf(x1,x2,y_pred,alpha=0.3,cmap=custom_cmap)if not iris:custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])plt.contourf(x1,x2,y_pred,alpha=0.8,cmap=custom_cmap2)#可以把训练数据展示出来if plot_training:plt.plot(X[:,0][y==0],X[:,1][y==0],'yo',label='Iris-Setosa')plt.plot(X[:,0][y==1],X[:,1][y==1],'bs',label='Iris-Versicolor')plt.plot(X[:,0][y==2],X[:,1][y==2],'g^',label='Iris-Virginica')if iris:plt.xlabel('Petal length',fontsize = 14)plt.ylabel('Petal width',fontsize = 14)else:plt.xlabel(r'$x_1$',fontsize=18)plt.ylabel(r'$x_2$',fontsize=18)if legend:plt.legend(loc='lower right',fontsize=14)plt.figure(figsize=(8,4))
plot_decision_boundary(tree_clf,X,y)
plt.plot([2.45,2.45],[0,3],'k-',linewidth=2)
plt.plot([2.45,7.5],[1.75,1.75],'k--',linewidth=2)
plt.plot([4.95,4.95],[0,1.75],'k:',linewidth=2)
plt.plot([4.85,4.85],[1.75,3],'k:',linewidth=2)
plt.text(1.40,1.0,'Depth=0',fontsize=15)
plt.text(3.2,1.80,'Depth=1',fontsize=13)
plt.text(4.05,0.5,'(Depth=2)',fontsize=11)
plt.title('Decision Tree decision boundareies')plt.show()

        可以看出三种不同颜色的代表分类结果,Depth=0可看作第一刀切分,Depth=1,2 看作第二刀,三刀,把数据集切分。

        2.5 决策树的正则化(剪枝)

        决策树的正则化

        DecisionTreeClassifier类还具有一些其他的参数类似地限制了决策树的形状

        min-samples_split(节点在分割之前必须具有的样本数)

        min-samples_leaf(叶子节点必须具有的最小样本数)

        max-leaf_nodes(叶子节点的最大数量)

        max_features(在每个节点处评估用于拆分的最大特征数)

        max_depth(树的最大深度)

#6.决策树正则化
from sklearn.datasets import make_moons
X,y = make_moons(n_samples=100,noise=0.25,random_state=53)
plt.plot(X[:,0],X[:,1],"b.")
tree_clf1 = DecisionTreeClassifier(random_state=42)
tree_clf2 = DecisionTreeClassifier(random_state=42,min_samples_leaf=4)
tree_clf1.fit(X,y)
tree_clf2.fit(X,y)
plt.figure(figsize=(12,4))
plt.subplot(121)
plot_decision_boundary(tree_clf1,X,y,axes=[-1.5,2.5,-1,1.5],iris=False)
plt.title('no restriction')
plt.subplot(122)
plot_decision_boundary(tree_clf2,X,y,axes=[-1.5,2.5,-1,1.5],iris=False)
plt.title('min_samples_leaf={}'.format(tree_clf2.min_samples_leaf))

        可以看出在没有加限制条件之前,分类器要考虑每个点,模型变得复杂,容易过拟合。其他的一些参数读者可以自行尝试。

        2.6 对数据敏感

        决策树对于数据是很敏感的

        

#6.对数据敏感
np.random.seed(6)
Xs = np.random.rand(100,2) - 0.5
ys = (Xs[:,0] > 0).astype(np.float32) * 2angle = np.pi /4
rotation_matrix = np.array([[np.cos(angle),-np.sin(angle)],[np.sin(angle),np.cos(angle)]])
Xsr = Xs.dot(rotation_matrix)tree_clf_s = DecisionTreeClassifier(random_state=42)
tree_clf_sr = DecisionTreeClassifier(random_state=42)
tree_clf_s.fit(Xs,ys)
tree_clf_sr.fit(Xsr,ys)plt.figure(figsize=(11,4))
plt.subplot(121)
plot_decision_boundary(tree_clf_s,Xs,ys,axes=[-0.7,0.7,-0.7,0.7],iris=False)
plt.title('Sensitivity to training set rotation')plt.subplot(122)
plot_decision_boundary(tree_clf_sr,Xsr,ys,axes=[-0.7,0.7,-0.7,0.7],iris=False)
plt.title('Sensitivity to training set rotation')plt.show()

         这里是把数据又旋转了45度,然而决策边界并没有也旋转45度,却是变复杂了。可以看出,对于复杂的数据,决策树是很敏感的。

        2.7 回归任务

#7.回归任务 
np.random.seed(42)
m = 200
X = np.random.rand(m,1)
y = 4 * (X-0.5)**2
y = y + np.random.randn(m,1) /10
plt.plot(X,y,'b.')
from sklearn.tree import DecisionTreeRegressor
tree_reg = DecisionTreeRegressor(max_depth=2)
tree_reg.fit(X,y)
from sklearn.tree import export_graphviz
export_graphviz(tree_reg,out_file='regression_tree.dot',feature_names=['X1'],rounded=True,filled=True
)
from IPython.display import Image
Image(filename='regression_tree.png',width=400,height=400)

 

         回归任务,这里的衡量标准就变成了均方误差。

        2.8 对比树的深度对结果的影响

#8.对比树的深度对结果的影响
from sklearn.tree import DecisionTreeRegressor
tree_reg1 = DecisionTreeRegressor(random_state=42,max_depth=2)
tree_reg2 = DecisionTreeRegressor(random_state=42,max_depth=3)
tree_reg1.fit(X,y)
tree_reg2.fit(X,y)def plot_regression_predictions(tree_reg,X,y,axes=[0,1,-0.2,1],ylabel='$y$'):x1 = np.linspace(axes[0],axes[1],500).reshape(-1,1)y_pred = tree_reg.predict(x1)plt.axis(axes)plt.xlabel('$X_1$',fontsize =18)if ylabel:plt.ylabel(ylabel,fontsize = 18,rotation=0)plt.plot(X,y,'b.')plt.plot(x1,y_pred,'r.-',linewidth=2,label=r'$\hat{y}$')plt.figure(figsize=(11,4))
plt.subplot(121)plot_regression_predictions(tree_reg1,X,y)
for split,style in ((0.1973,'k-'),(0.0917,'k--'),(0.7718,'k--')):plt.plot([split,split],[-0.2,1],style,linewidth = 2)
plt.text(0.21,0.65,'Depth=0',fontsize= 15)
plt.text(0.01,0.2,'Depth=1',fontsize= 13)
plt.text(0.65,0.8,'Depth=0',fontsize= 13)
plt.legend(loc='upper center',fontsize = 18)
plt.title('max_depth=2',fontsize=14)
plt.subplot(122)
plot_regression_predictions(tree_reg2,X,y)
for split,style in ((0.1973,'k-'),(0.0917,'k--'),(0.7718,'k--')):plt.plot([split,split],[-0.2,1],style,linewidth = 2)
for split in (0.0458,0.1298,0.2873,0.9040):plt.plot([split,split],[-0.2,1],linewidth = 1)
plt.text(0.3,0.5,'Depth=2',fontsize= 13)
plt.title('max_depth=3',fontsize=14)plt.show()

        不同的树的深度,对于结果产生极大的影响

        2.9 剪枝

        

#9.加一些限制
tree_reg1 = DecisionTreeRegressor(random_state=42)
tree_reg2 = DecisionTreeRegressor(random_state=42,min_samples_leaf=10)
tree_reg1.fit(X,y)
tree_reg2.fit(X,y)x1 = np.linspace(0,1,500).reshape(-1,1)
y_pred1 = tree_reg1.predict(x1)
y_pred2 = tree_reg2.predict(x1)plt.figure(figsize=(11,4))plt.subplot(121)
plt.plot(X,y,'b.')
plt.plot(x1,y_pred1,'r.-',linewidth=2,label=r'$\hat{y}$')
plt.axis([0,1,-0.2,1.1])
plt.xlabel('$x_1$',fontsize=18)
plt.ylabel('$y$',fontsize=18,rotation=0)
plt.legend(loc='upper center',fontsize =18)
plt.title('No restrctions',fontsize =14)plt.subplot(122)
plt.plot(X,y,'b.')
plt.plot(x1,y_pred2,'r.-',linewidth=2,label=r'$\hat{y}$')
plt.axis([0,1,-0.2,1.1])
plt.xlabel('$x_1$',fontsize=18)
plt.ylabel('$y$',fontsize=18,rotation=0)
plt.legend(loc='upper center',fontsize =18)
plt.title('min_samples_leaf={}'.format(tree_reg2.min_samples_leaf),fontsize =14)plt.show()

        一目了然。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/270224.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第十三届蓝桥杯大赛软件赛省赛C/C++ 大学 B 组 统计子矩阵

#include<iostream> #include<algorithm> #include<cstring> #include<string> #include<vector> #include<queue>using namespace std;int cnt,temp; int n,m,K; int a[505][505]; int pre[505][505];//二维前缀和void sol() {cin>>…

回溯算法01-组合(Java)

1.组合 题目描述 给定两个整数 n 和 k&#xff0c;返回范围 [1, n] 中所有可能的 k 个数的组合。 你可以按 任何顺序 返回答案。 示例 1&#xff1a; 输入&#xff1a;n 4, k 2 输出&#xff1a; [[2,4],[3,4],[2,3],[1,2],[1,3],[1,4]]示例 2&#xff1a; 输入&#x…

智慧城市中的公共服务创新:让城市生活更便捷

目录 一、引言 二、智慧城市公共服务创新的实践 1、智慧交通系统 2、智慧医疗服务 3、智慧教育系统 4、智慧能源管理 三、智慧城市公共服务创新的挑战 四、智慧城市公共服务创新的前景 五、结论 一、引言 随着信息技术的迅猛发展&#xff0c;智慧城市已成为现代城市发…

Python光速入门 - Flask轻量级框架

FlASK是一个轻量级的WSGI Web应用程序框架&#xff0c;Flask的核心包括Werkzeug工具箱和Jinja2模板引擎&#xff0c;它没有默认使用的数据库或窗体验证工具&#xff0c;这意味着用户可以根据自己的需求选择不同的数据库和验证工具。Flask的设计理念是保持核心简单&#xff0c…

【C语言】还有柔性数组?

前言 也许你从来没有听说过柔性数组&#xff08;flexible array&#xff09;这个概念&#xff0c;但是它确实是存在的。C99中&#xff0c;结构中的最后⼀个元素允许是未知⼤⼩的数组&#xff0c;这就叫做『柔性数组』成员。 欢迎关注个人主页&#xff1a;逸狼 创造不易&#xf…

Java基于微信小程序的旅游出行必备小程序,附源码

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;…

RabbitMQ如何保证消息不丢

如何保证Queue消息能不丢呢&#xff1f; RabbitMQ在接收到消息后&#xff0c;默认并不会立即进行持久化&#xff0c;而是先把消息暂存在内存中&#xff0c;这时候如果MQ挂了&#xff0c;那么消息就会丢失。所以需要通过持久化机制来保证消息可以被持久化下来。 队列和交换机的…

【论文阅读】Usenix Security 2023 你看不见我:对基于激光雷达的自动驾驶汽车驾驶框架的物理移除攻击

文章目录 一.论文信息二.论文内容1.摘要2.引言3.作者贡献4.主要图表5.结论 一.论文信息 论文题目&#xff1a; You Can’t See Me: Physical Removal Attacks on LiDAR-based Autonomous Vehicles Driving Frameworks&#xff08;你看不见我:对基于激光雷达的自动驾驶汽车驾驶…

如何搭建Nacos集群

1.搭建Nacos集群 众所周知&#xff0c;在实际的工作中&#xff0c;Nacos的生成环境下一定要部署为集群状态 其中包含3个nacos节点&#xff0c;然后一个负载均衡器代理3个Nacos。这里负载均衡器可以使用nginx。 我们计划的集群结构&#xff1a; 我就直接在本机上开三个Nacos来搭…

Linux网络编程(五-IO模型)

目录 一、同步I/O 二、异步I/O 三、阻塞I/O 四、非阻塞I/O 五、信号驱动I/O ​六、多路复用IO 一、同步I/O 在操作系统中&#xff0c;程序运行的空间分为内核空间和用户空间&#xff0c; 用户空间所有对io操作的代码(如文件的读写、socket的收发等)都会通过系统调用进…

纯css+html实现拟态开关按钮面板

适合智能家居的开关面板UI 参考&#xff1a;https://drams.framer.website/

【Flutter 】get-cli init报错处理

报错内容 get init 命令跳出,报错内如下 Select which type of project you want to creat Synchronous waiting using dart:cli waitFor Unhandled exceotion . Dart WaitforEvent is deprecated and disabled by default. This feature will be fully removed in Dart 3.4 …

springboot+vue+mysql项目使用的常用注解

实体类常用注解 Data Data 是一个 Lombok 提供的注解&#xff0c;使用 Data 注解可以简化代码&#xff0c;使代码更加简洁易读。 作用&#xff1a;自动为类生成常用的方法&#xff0c;包括 getter、setter、equals、hashCode 和 toString 等需要加Lombok的依赖 <depende…

Linux高级编程:网络

回顾&#xff1a; 进程间的通信&#xff1a; 同一主机内通信&#xff1a; 传统的进程间通信方式&#xff08;管道、信号&#xff09;&#xff1b; IPC对象&#xff08;共享内存&#xff0c;消息队列&#xff0c;信号量集&#xff09;&#xff1b; 不同主机间进程的通信&#…

C++ 链表OJ

目录 1、2. 两数相加 2、24. 两两交换链表中的节点 3、143. 重排链表 4、23. 合并 K 个升序链表 5、25. K 个一组翻转链表 解决链表题目常用方法&#xff1a; 1、画图 2、引入虚拟"头”结点 便于处理边界情况方便我们对链表操作 3、大胆定义变量&#xff0c;减少连接…

STM32CubeMX软件界面花屏,混乱的解决方案。

添加系统环境变量:J2D_D3D 值:false 即可解决。

MySQL:一行记录如何

1、表空间文件结构 表空间由段「segment」、区「extent」、页「page」、行「row」组成&#xff0c;InnoDB存储引擎的逻辑存储结构大致如下图&#xff1a; 行 数据库表中的记录都是按「行」进行存放的&#xff0c;每行记录根据不同的行格式&#xff0c;有不同的存储结构。 页…

机器学习中类别不平衡问题的解决方案

类别不平衡问题 解决方案简单方法收集数据调整权重阈值移动 数据层面欠采样过采样采样方法的优劣 算法层面代价敏感集成学习&#xff1a;EasyEnsemble 总结 类别不平衡&#xff08;class-imbalance&#xff09;就是指分类任务中不同类别的训练样例数目差别很大的情况 解决方案…

信呼OA普通用户权限getshell方法

0x01 前言 信呼OA是一款开源的OA系统&#xff0c;面向社会免费提供学习研究使用&#xff0c;采用PHP语言编写&#xff0c;搭建简单方便&#xff0c;在中小企业中具有较大的客户使用量。从公开的资产治理平台中匹配到目前互联中有超过1W的客户使用案例。 信呼OA目前最新的版本是…

码住!微信自动回复的实现方式

对于很多企业和个人来说&#xff0c;如何高效地回复微信客户消息是一个非常重要的问题。在这样的需求下&#xff0c;使用微信管理系统的机器人自动回复功能成为了一种不错的选择。 1、自动通过好友 当有大量的好友请求时&#xff0c;手动处理好友申请会变得非常繁琐。而微信管…