机器学习 | 决策树 Decision Tree

—— 分而治之,逐个击破

                把特征空间划分区域

                每个区域拟合简单模型

                分级分类决策


1、核心思想和原理

  • 举例:
    • 特征选择、节点分类、阈值确定


2、信息嫡

       

        熵本身代表不确定性,是不确定性的一种度量。

        熵越大,不确定性越高,信息量越高。

       

        为什么用log?—— 两种解释,可能性的增长呈指数型;log可以将乘法变为加减法。

        

        联合熵 的物理意义:观察一个多变量系统获得的信息量。

        条件熵 的物理意义:知道其中一个变量的信息后,另一个变量的信息量。

                给定了训练样本 X ,分类标签中包含的信息量是什么。

         

        

        信息增益(互信息)

                代表了一个特征能够为一个系统带来多少信息。

        

        

        熵的分类

        

        

        熵的本质:特殊的衡量分布的混乱程度与分散程度的距离

        

        

        二分类信息熵:

二分类信息熵

import numpy as np
import matplotlib.pyplot as plt
def entropy(p):return -(p * np.log2(p) + (1 - p) * np.log2(1 - p))
plot_x = np.linspace(0.001, 0.999, 100)
plt.plot(plot_x, entropy(plot_x))
plt.show()

        

 

         决策树的本质

        


 3、决策树分类代码实现

 

数据集

from sklearn.datasets import load_irisiris = load_iris()
x = iris.data[:, 1:3]
y = iris.target
plt.scatter(x[:,0], x[:,1], c = y)
plt.show()

 

3.1、sklearn中的决策树

from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(max_depth=2, criterion='entropy')
clf.fit(x, y)

DecisionTreeClassifier

DecisionTreeClassifier(criterion='entropy', max_depth=2)

决策边界绘制的代码: 

def decision_boundary_plot(X, y, clf):axis_x1_min, axis_x1_max = X[:,0].min() - 1, X[:,0].max() + 1axis_x2_min, axis_x2_max = X[:,1].min() - 1, X[:,1].max() + 1x1, x2 = np.meshgrid( np.arange(axis_x1_min,axis_x1_max, 0.01) , np.arange(axis_x2_min,axis_x2_max, 0.01))z = clf.predict(np.c_[x1.ravel(),x2.ravel()])z = z.reshape(x1.shape)from matplotlib.colors import ListedColormapcustom_cmap = ListedColormap(['#F5B9EF','#BBFFBB','#F9F9CB'])plt.contourf(x1, x2, z, cmap=custom_cmap)plt.scatter(X[:,0], X[:,1], c=y)plt.show()
decision_boundary_plot(x, y, clf)

from sklearn.tree import plot_tree
plot_tree(clf)
[Text(0.4, 0.8333333333333334, 'X[1] <= 2.45\nentropy = 1.585\nsamples = 150\nvalue = [50, 50, 50]'),Text(0.2, 0.5, 'entropy = 0.0\nsamples = 50\nvalue = [50, 0, 0]'),Text(0.6, 0.5, 'X[1] <= 4.75\nentropy = 1.0\nsamples = 100\nvalue = [0, 50, 50]'),Text(0.4, 0.16666666666666666, 'entropy = 0.154\nsamples = 45\nvalue = [0, 44, 1]'),Text(0.8, 0.16666666666666666, 'entropy = 0.497\nsamples = 55\nvalue = [0, 6, 49]')]

 


 

3.2、最优划分条件

from collections import Counter
Counter(y)
Counter({0: 50, 1: 50, 2: 50})
def calc_entropy(y):counter = Counter(y)sum_ent = 0for i in counter:p = counter[i] / len(y)sum_ent += (-p * np.log2(p))return sum_ent
calc_entropy(y)
1.584962500721156
def split_dataset(x, y, dim, value):index_left = (x[:, dim] <= value)index_right = (x[:, dim] > value)return x[index_left], y[index_left], x[index_right], y[index_right]
def find_best_split(x, y):best_dim = -1best_value = -1best_entropy = np.infbest_entropy_left, best_entropy_right = -1, -1for dim in range(x.shape[1]):sorted_index = np.argsort(x[:, dim])for i in range(x.shape[0] - 1): # x列数value_left, value_right = x[sorted_index[i], dim], x[sorted_index[i + 1], dim]if value_left != value_right:value = (value_left + value_right) / 2x_left, y_left, x_right, y_right = split_dataset(x, y, dim, value)entropy_left, entropy_right = calc_entropy(y_left), calc_entropy(y_right)entropy = (len(x_left) * entropy_left + len(x_right) * entropy_right) / x.shape[0]if entropy < best_entropy:best_dim = dimbest_value = valuebest_entropy = entropybest_entropy_left, best_entropy_right = entropy_left, entropy_rightreturn best_dim, best_value, best_entropy, best_entropy_left, best_entropy_right
find_best_split(x, y)
(1, 2.45, 0.6666666666666666, 0.0, 1.0)
x_left, y_left, x_right, y_right = split_dataset(x, y, 1, 2.45)
find_best_split(x_right, y_right)
(1, 4.75, 0.34262624992678425, 0.15374218032876188, 0.4971677614160753)


4、基尼系数

        

        基尼系数运算稍快;

        物理意义略有不同,信息熵表示的是随机变量的不确定度;

                基尼系数表示在样本集合中一个随机选中的样本被分错的概率,也就是纯度。

                基尼系数越小,纯度越高。

        模型效果上差异不大。

        

二分类信息熵和基尼系数代码实现:

import numpy as np
import matplotlib.pyplot as plt
def entropy(p):return -(p * np.log2(p) + (1 - p) * np.log2(1 - p))
def gini(p):return 1 - p ** 2 - (1 - p) ** 2
plot_x = np.linspace(0.001, 0.999, 100)
plt.plot(plot_x, entropy(plot_x), color = 'blue')
plt.plot(plot_x, gini(plot_x), color = 'red')
plt.show()


5、决策树剪枝

Chapter-07/7-6 决策树剪枝.ipynb · 梗直哥/Machine-Learning - Gitee.com

为什么要剪枝?

                复杂度过高。

                        预测复杂度:O(logm)

                        训练复杂度:O(n x m x logm)

                        logm为数的深度,n为数据的维度。

                容易过拟合

                        为非参数学习方法。

 目标:

                降低复杂度

                解决过拟合

 手段:

                限制深度(结点层数)

                限制广度(叶子结点个数)

   —— 设置超参数

                        


6、决策树回归

        基于一种思想:相似输入必会产生相似输出。

        取节点平均值。

        

6.1、决策树回归代码实现

import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')boston = datasets.load_boston()
x = boston.data
y = boston.target
x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=233)
from sklearn.tree import DecisionTreeRegressorreg = DecisionTreeRegressor()
reg.fit(x_train,y_train)

DecisionTreeRegressor

DecisionTreeRegressor()
reg.score(x_test,y_test)
0.7410680140563546
reg.score(x_train,y_train)
1.0

6.2、绘制学习曲线

from sklearn.metrics import r2_scoreplt.rcParams["figure.figsize"] = (12, 8)
max_depth = [2, 5, 10, 20]for i, depth in enumerate(max_depth):reg = DecisionTreeRegressor(max_depth=depth)train_error, test_error = [], []for k in range(len(x_train)):reg.fit(x_train[:k+1], y_train[:k+1])y_train_pred = reg.predict(x_train[:k + 1])train_error.append(r2_score(y_train[:k + 1], y_train_pred))y_test_pred = reg.predict(x_test)test_error.append(r2_score(y_test, y_test_pred))plt.subplot(2, 2, i + 1)plt.ylim(0, 1.1)plt.title("Depth: {0}".format(depth))plt.plot([k + 1 for k in range(len(x_train))], train_error, color = "red", label = 'train')plt.plot([k + 1 for k in range(len(x_train))], test_error, color = "blue", label = 'test')plt.legend()plt.show()

6.3、网格搜索

from sklearn.model_selection import GridSearchCVparams = {'max_depth': [n for n in range(2, 15)],'min_samples_leaf': [sn for sn in range(3, 20)],
}grid = GridSearchCV(estimator = DecisionTreeRegressor(), param_grid = params, n_jobs = -1
)
grid.fit(x_train,y_train)

GridSearchCV

GridSearchCV(estimator=DecisionTreeRegressor(), n_jobs=-1,param_grid={'max_depth': [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,14],'min_samples_leaf': [3, 4, 5, 6, 7, 8, 9, 10, 11, 12,13, 14, 15, 16, 17, 18, 19]})

estimator: DecisionTreeRegressor

DecisionTreeRegressor()

DecisionTreeRegressor

DecisionTreeRegressor()
grid.best_params_
{'max_depth': 5, 'min_samples_leaf': 3}
grid.best_score_
0.7327442904059717
reg = grid.best_estimator_
reg.score(x_test, y_test)
0.781690085676063

7、优缺点和适用条件

优点:

        符合人类直观思维

        可解释性强

        能够处理数值型数据和分类型数据

        能够处理多输出问题

缺点:

        容易产生过拟合

        决策边界只能是水平或竖直方向

                

        不稳定,数据的微小变化可能生成完全不同的树


参考于

Chapter-07/7-4 决策树分类.ipynb · 梗直哥/Machine-Learning - Gitee.com

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/219910.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenAI 偷偷在训练 GPT-4.5!?

最近看到有人已经套路出 ChatGPT 当前的版本&#xff0c;回答居然是 gpt-4.5-turbo&#xff1a; 实际试验下&#xff0c;用 starflow.tech&#xff0c;切换到小星 4 全能版&#xff08;同等官网最新 GPT-4&#xff09;&#xff0c;复制下面这段话问它&#xff1a; What is the…

人工智能在金融与商业领域的智能化变革

导言 随着人工智能技术的不断发展&#xff0c;金融和商业领域正迎来一场智能化的变革。随着人工智能的不断发展&#xff0c;其在金融和商业领域的应用正成为业界瞩目的焦点。本文将深入探讨人工智能在金融和商业应用中的关键技术、应用场景以及对未来的影响。 1. 关键技术与算…

数据结构:队列

数据结构&#xff1a;队列 文章目录 数据结构&#xff1a;队列1.队列常用操作&#xff1a;2.队列的实现3.队列典型应用 ***「队列 queue」是一种遵循先入先出规则的线性数据结构。***队列模拟了排队现象&#xff0c;即新来的人不断加入队列尾部&#xff0c;而位于队列头部的人逐…

ceph的osd盘删除操作和iscsi扩展

ceph的osd盘删除操作 拓展:osd磁盘的删除(这里以删除node1上的osd.0磁盘为例) 1, 查看osd磁盘状态 [rootnode1 ceph]# ceph osd tree ID CLASS WEIGHT TYPE NAME STATUS REWEIGHT PRI-AFF -1 0.00298 root default -3 0.00099 host node10 hdd 0.00…

微服务实战系列之ZooKeeper(下)

前言 通过前序两篇关于ZooKeeper的介绍和总结&#xff0c;我们可以大致理解了它是什么&#xff0c;它有哪些重要组成部分。 今天&#xff0c;博主特别介绍一下ZooKeeper的一个核心应用场景&#xff1a;分布式锁。 应用ZooKeeper Q&#xff1a;什么是分布式锁 首先了解一下&…

云原生向量计算引擎 PieCloudVector:为大模型提供独特记忆

拓数派大模型数据计算系统&#xff08;PieDataComputingSystem&#xff0c;缩写&#xff1a;πDataCS&#xff09;在10月24日程序员节「大模型数据计算系统」2023拓数派年度技术论坛正式发布。πDataCS 以云原生技术重构数据存储和计算&#xff0c;「一份存储&#xff0c;多引擎…

C# 基本桌面编程(二)

一、前言 本章为C# 基本桌面编程技术的第二节也是最后一节。前一节在下面这个链接 C# 基本桌面编程&#xff08;一&#xff09;https://blog.csdn.net/qq_71897293/article/details/135024535?spm1001.2014.3001.5502 二、控件布局 1 叠放顺序 在WPF当中布局&#xff0c;通…

华为配置OSPF与BFD联动示例

组网需求 如图1所示&#xff0c;SwitchA、SwitchB和SwitchC之间运行OSPF&#xff0c;SwitchA和SwitchB之间的交换机仅作透传功能。现在需要SwitchA和SwitchB能快速感应它们之间的链路状态&#xff0c;当链路SwitchA-SwitchB发生故障时&#xff0c;业务能快速切换到备份链路Swi…

极狐GitLab DevSecOps 之容器镜像安全扫描

容器镜像安全 现状 最近某银行遭受供应链攻击的事件传的沸沸扬扬&#xff0c;安全又双叒叕进入了人们的视野。安全确实是一个非常重要&#xff0c;但是又最容易被忽略的话题。但是现在到了一个不得不人人重视安全&#xff0c;人人为安全负责的时代。尤其以现在非常火爆的云原…

java设计模式-工厂方法模式

1.工厂方法(FactoryMethod)模式的定义 定义一个创建产品对象的工厂接口&#xff0c;将产品对象的实际创建工作推迟到具体子工厂类当中。这满足创建型模式中所要求的“创建与使用相分离”的特点。 2.工厂方法模式的主要优缺点 优点&#xff1a; 用户只需要知道具体工厂的名称…

智能优化算法应用:基于乌燕鸥算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于乌燕鸥算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于乌燕鸥算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.乌燕鸥算法4.实验参数设定5.算法结果6.参考文…

低代码企业级PMO项目管理系统,360度全景透视企业管理视角

在一个崇高的目标支持下&#xff0c;不停地工作&#xff0c;即使慢&#xff0c;也一定会获得成功。 爱因斯坦 ★ 前情概要&#xff1a; 企业级PMO项目管理业务是行业里相对成熟和规范的业务&#xff0c;拥有众多商业套件和标准产品。 然而随着企业数字化建设进入深水区&#…

《Global illumination with radiance regression functions》

总结一下最近看的这篇结合神经网络的全局光照论文 这是一篇2013年TOG的论文。 介绍 论文的主要思想是利用了神经网络的非线性特性去拟合全局光照中的间接光照部分&#xff0c;采用了基础的2层MLP去训练&#xff0c;最终能实现一些点光源、glossy材质的光照渲染。为了更好的理…

解决App Store上架提示您必须上传 12.9 英寸 iPad Pro(第 2 代)显示屏的截屏

出错场景 在App Store Connect中&#xff0c;上架App时&#xff0c;出现以下错误提示. 要开始审核流程&#xff0c;必须提供以下项目&#xff1a;您必须上传 12.9 英寸 iPad Pro&#xff08;第 2 代&#xff09;显示屏的截屏。&#xff08;2048&#xff0c;2732&#xff09;您…

overleaf 加载pdf格式的矢量图时,visio 图片保存为pdf格式,如何确保pdf页面大小和图片一致

Overleaf支持多种矢量图形格式&#xff0c;其中一些常见的包括&#xff1a; PDF&#xff08;Portable Document Format&#xff09;&#xff1a; PDF是一种常见的矢量图形格式&#xff0c;Overleaf可以直接加载和显示PDF文件。许多绘图工具和LaTeX生成的图形都可以导出为PDF格式…

ShenYu网关Http服务探活解析

文章目录 网关端服务探活admin端服务探活 Shenyu HTTP服务探活是一种用于检测HTTP服务是否正常运行的机制。它通过建立Socket连接来判断服务是否可用。当服务不可用时&#xff0c;将服务从可用列表中移除。 网关端服务探活 以divide插件为例&#xff0c;看下divide插件是如何获…

21、同济、微软亚研院、西安电子科技大提出HPT:层次化提示调优,独属于提示学习的[安妮海瑟薇]

前言&#xff1a; 本论文由同济大学、微软亚洲研究院、西安电子科技大学&#xff0c;于2023年12月11日中了AAAI2024 论文&#xff1a; 《Learning Hierarchical Prompt with Structured Linguistic Knowledge for Vision-Language Models》 地址&#xff1a; [2312.06323]…

网络(十)ACL和NAT

前言 网络管理在生产环境和生活中&#xff0c;如何实现拒绝不希望的访问连接&#xff0c;同时又要允许正常的访问连接&#xff1f;当下公网地址消耗殆尽&#xff0c;且公网IP地址费用昂贵&#xff0c;企业访问Internet全部使用公网IP地址不够现实&#xff0c;如何让私网地址也…

机器翻译:跨越语言边界的智能大使

导言 机器翻译作为人工智能领域的瑰宝&#xff0c;正在以前所未有的速度和精度&#xff0c;为全球沟通拓展新的可能性。本文将深入研究机器翻译的技术原理、应用场景以及对语言交流未来的影响。 1. 简介 机器翻译是一项致力于通过计算机自动将一种语言的文本翻译成另一种语言的…

android studio 快捷输入模板提示

在Android开发中&#xff0c;我们经常会遇到一些重复性的代码&#xff0c;例如创建一个新的Activity、定义一个Getter方法等。为了提高开发效率&#xff0c;Android Studio提供了Live Templates功能&#xff0c;可以通过简化输入来快速生成这些重复性代码。 按下图提示设置&am…