机器学习:回归决策树(Python)

一、平方误差的计算

square_error_utils.py

import numpy as npclass SquareErrorUtils:"""平方误差最小化准则,选择其中最优的一个作为切分点对特征属性进行分箱处理"""@staticmethoddef _set_sample_weight(sample_weight, n_samples):"""扩展到集成学习,此处为样本权重的设置:param sample_weight: 各样本的权重:param n_samples: 样本量:return:"""if sample_weight is None:sample_weight = np.asarray([1.0] * n_samples)return sample_weight@staticmethoddef square_error(y, sample_weight):"""平方误差:param y: 当前划分区域的目标值集合:param sample_weight: 当前样本的权重:return:"""y = np.asarray(y)return np.sum((y - y.mean()) ** 2 * sample_weight)def cond_square_error(self, x, y, sample_weight):"""计算根据某个特征x划分的区域中y的误差值:param x: 某个特征划分区域所包含的样本:param y: x对应的目标值:param sample_weight: 当前x的权重:return:"""x, y = np.asarray(x), np.asarray(y)error = 0.0for x_val in set(x):x_idx = np.where(x == x_val)  # 按区域计算误差new_y = y[x_idx]  # 对应区域的目标值new_sample_weight = sample_weight[x_idx]error += self.square_error(new_y, new_sample_weight)return errordef square_error_gain(self, x, y, sample_weight=None):"""平方误差带来的增益值:param x: 某个特征变量:param y: 对应的目标值:param sample_weight: 样本权重:return:"""sample_weight = self._set_sample_weight(sample_weight, len(x))return self.square_error(y, sample_weight) - self.cond_square_error(x, y, sample_weight)

 二、树的结点信息封装


class TreeNode_R:"""决策树回归算法,树的结点信息封装,实体类:setXXX()、getXXX()"""def __init__(self, feature_idx: int = None, feature_val=None, y_hat=None, square_error: float = None,criterion_val=None, n_samples: int = None, left_child_Node=None, right_child_Node=None):"""决策树结点信息封装:param feature_idx: 特征索引,如果指定特征属性的名称,可以按照索引取值:param feature_val: 特征取值:param square_error: 划分结点的标准:当前结点的平方误差:param n_samples: 当前结点所包含的样本量:param y_hat: 当前结点的预测值:Ci:param left_child_Node: 左子树:param right_child_Node: 右子树"""self.feature_idx = feature_idxself.feature_val = feature_valself.criterion_val = criterion_valself.square_error = square_errorself.n_samples = n_samplesself.y_hat = y_hatself.left_child_Node = left_child_Node  # 递归self.right_child_Node = right_child_Node  # 递归def level_order(self):"""按层次遍历树...:return:"""pass# def get_feature_idx(self):#     return self.get_feature_idx()## def set_feature_idx(self, feature_idx):#     self.feature_idx = feature_idx

三、回归决策树CART算法实现

import numpy as np
from utils.square_error_utils import SquareErrorUtils
from utils.tree_node_R import TreeNode_R
from utils.data_bin_wrapper import DataBinsWrapperclass DecisionTreeRegression:"""回归决策树CART算法实现:按照二叉树构造1. 划分标准:平方误差最小化2. 创建决策树fit(),递归算法实现,注意出口条件3. 预测predict_proba()、predict() --> 对树的搜索4. 数据的预处理操作,尤其是连续数据的离散化,分箱5. 剪枝处理"""def __init__(self, criterion="mse", max_depth=None, min_sample_split=2, min_sample_leaf=1,min_target_std=1e-3, min_impurity_decrease=0, max_bins=10):self.utils = SquareErrorUtils()  # 结点划分类self.criterion = criterion  # 结点的划分标准if criterion.lower() == "mse":self.criterion_func = self.utils.square_error_gain  # 平方误差增益else:raise ValueError("参数criterion仅限mse...")self.min_target_std = min_target_std  # 最小的样本目标值方差,小于阈值不划分self.max_depth = max_depth  # 树的最大深度,不传参,则一直划分下去self.min_sample_split = min_sample_split  # 最小的划分结点的样本量,小于则不划分self.min_sample_leaf = min_sample_leaf  # 叶子结点所包含的最小样本量,剩余的样本小于这个值,标记叶子结点self.min_impurity_decrease = min_impurity_decrease  # 最小结点不纯度减少值,小于这个值,不足以划分self.max_bins = max_bins  # 连续数据的分箱数,越大,则划分越细self.root_node: TreeNode_R() = None  # 回归决策树的根节点self.dbw = DataBinsWrapper(max_bins=max_bins)  # 连续数据离散化对象self.dbw_XrangeMap = {}  # 存储训练样本连续特征分箱的端点def fit(self, x_train, y_train, sample_weight=None):"""回归决策树的创建,递归操作前的必要信息处理(分箱):param x_train: 训练样本:ndarray,n * k:param y_train: 目标集:ndarray,(n, ):param sample_weight: 各样本的权重,(n, ):return:"""x_train, y_train = np.asarray(x_train), np.asarray(y_train)self.class_values = np.unique(y_train)  # 样本的类别取值n_samples, n_features = x_train.shape  # 训练样本的样本量和特征属性数目if sample_weight is None:sample_weight = np.asarray([1.0] * n_samples)self.root_node = TreeNode_R()  # 创建一个空树self.dbw.fit(x_train)x_train = self.dbw.transform(x_train)self._build_tree(1, self.root_node, x_train, y_train, sample_weight)def _build_tree(self, cur_depth, cur_node: TreeNode_R, x_train, y_train, sample_weight):"""递归创建回归决策树算法,核心算法。按先序(中序、后序)创建的:param cur_depth: 递归划分后的树的深度:param cur_node: 递归划分后的当前根结点:param x_train: 递归划分后的训练样本:param y_train: 递归划分后的目标集合:param sample_weight: 递归划分后的各样本权重:return:"""n_samples, n_features = x_train.shape  # 当前样本子集中的样本量和特征属性数目# 计算当前数结点的预测值,即加权平均值,cur_node.y_hat = np.dot(sample_weight / np.sum(sample_weight), y_train)cur_node.n_samples = n_samples# 递归出口判断cur_node.square_error = ((y_train - y_train.mean()) ** 2).sum()# 所有的样本目标值较为集中,样本方差非常小,不足以划分if cur_node.square_error <= self.min_target_std:# 如果为0,则表示当前样本集合为空,递归出口3returnif n_samples < self.min_sample_split:  # 当前结点所包含的样本量不足以划分returnif self.max_depth is not None and cur_depth > self.max_depth:  # 树的深度达到最大深度return# 划分标准,选择最佳的划分特征及其取值best_idx, best_val, best_criterion_val = None, None, 0.0for k in range(n_features):  # 对当前样本集合中每个特征计算划分标准for f_val in sorted(np.unique(x_train[:, k])):  # 当前特征的不同取值region_x = (x_train[:, k] <= f_val).astype(int)  # 是当前取值f_val就是1,否则就是0criterion_val = self.criterion_func(region_x, y_train, sample_weight)if criterion_val > best_criterion_val:best_criterion_val = criterion_val  # 最佳的划分标准值best_idx, best_val = k, f_val  # 当前最佳特征索引以及取值# 递归出口的判断if best_idx is None:  # 当前属性为空,或者所有样本在所有属性上取值相同,无法划分returnif best_criterion_val <= self.min_impurity_decrease:  # 小于最小不纯度阈值,不划分returncur_node.criterion_val = best_criterion_valcur_node.feature_idx = best_idxcur_node.feature_val = best_val# print("当前划分的特征索引:", best_idx, "取值:", best_val, "最佳标准值:", best_criterion_val)# print("当前结点的类别分布:", target_dist)# 创建左子树,并递归创建以当前结点为子树根节点的左子树left_idx = np.where(x_train[:, best_idx] <= best_val)  # 左子树所包含的样本子集索引if len(left_idx) >= self.min_sample_leaf:  # 小于叶子结点所包含的最少样本量,则标记为叶子结点left_child_node = TreeNode_R()  # 创建左子树空结点# 以当前结点为子树根结点,递归创建cur_node.left_child_Node = left_child_nodeself._build_tree(cur_depth + 1, left_child_node, x_train[left_idx],y_train[left_idx], sample_weight[left_idx])right_idx = np.where(x_train[:, best_idx] > best_val)  # 右子树所包含的样本子集索引if len(right_idx) >= self.min_sample_leaf:  # 小于叶子结点所包含的最少样本量,则标记为叶子结点right_child_node = TreeNode_R()  # 创建右子树空结点# 以当前结点为子树根结点,递归创建cur_node.right_child_Node = right_child_nodeself._build_tree(cur_depth + 1, right_child_node, x_train[right_idx],y_train[right_idx], sample_weight[right_idx])def _search_tree_predict(self, cur_node: TreeNode_R, x_test):"""根据测试样本从根结点到叶子结点搜索路径,判定所属区域(叶子结点)搜索:按照后续遍历:param x_test: 单个测试样本:return:"""if cur_node.left_child_Node and x_test[cur_node.feature_idx] <= cur_node.feature_val:return self._search_tree_predict(cur_node.left_child_Node, x_test)elif cur_node.right_child_Node and x_test[cur_node.feature_idx] > cur_node.feature_val:return self._search_tree_predict(cur_node.right_child_Node, x_test)else:# 叶子结点,类别,包含有类别分布return cur_node.y_hatdef predict(self, x_test):"""预测测试样本x_test的预测值:param x_test: 测试样本ndarray、numpy数值运算:return:"""x_test = np.asarray(x_test)  # 避免传递DataFrame、list...if self.dbw.XrangeMap is None:raise ValueError("请先进行回归决策树的创建,然后预测...")x_test = self.dbw.transform(x_test)y_test_pred = []  # 用于存储测试样本的预测值for i in range(x_test.shape[0]):y_test_pred.append(self._search_tree_predict(self.root_node, x_test[i]))return np.asarray(y_test_pred)@staticmethoddef cal_mse_r2(y_test, y_pred):"""模型预测的均方误差MSE和判决系数R2:param y_test: 测试样本的真值:param y_pred: 测试样本的预测值:return:"""y_test, y_pred = y_test.reshape(-1), y_pred.reshape(-1)mse = ((y_pred - y_test) ** 2).mean()  # 均方误差r2 = 1 - ((y_pred - y_test) ** 2).sum() / ((y_test - y_test.mean()) ** 2).sum()return mse, r2def _prune_node(self, cur_node: TreeNode_R, alpha):"""递归剪枝,针对决策树中的内部结点,自底向上,逐个考察方法:后序遍历:param cur_node: 当前递归的决策树的内部结点:param alpha: 剪枝阈值:return:"""# 若左子树存在,递归左子树进行剪枝if cur_node.left_child_Node:self._prune_node(cur_node.left_child_Node, alpha)# 若右子树存在,递归右子树进行剪枝if cur_node.right_child_Node:self._prune_node(cur_node.right_child_Node, alpha)# 针对决策树的内部结点剪枝,非叶结点if cur_node.left_child_Node is not None or cur_node.right_child_Node is not None:for child_node in [cur_node.left_child_Node, cur_node.right_child_Node]:if child_node is None:# 可能存在左右子树之一为空的情况,当左右子树划分的样本子集数小于min_samples_leafcontinueif child_node.left_child_Node is not None or child_node.right_child_Node is not None:return# 计算剪枝前的损失值(平方误差),2表示当前结点包含两个叶子结点pre_prune_value = 2 * alphaif cur_node and cur_node.left_child_Node is not None:pre_prune_value += (0.0 if cur_node.left_child_Node.square_error is Noneelse cur_node.left_child_Node.square_error)if cur_node and cur_node.right_child_Node is not None:pre_prune_value += (0.0 if cur_node.right_child_Node.square_error is Noneelse cur_node.right_child_Node.square_error)# 计算剪枝后的损失值,当前结点即是叶子结点after_prune_value = alpha + cur_node.square_errorif after_prune_value <= pre_prune_value:  # 进行剪枝操作cur_node.left_child_Node = Nonecur_node.right_child_Node = Nonecur_node.feature_idx, cur_node.feature_val = None, Nonecur_node.square_error = Nonedef prune(self, alpha=0.01):"""决策树后剪枝算法(李航)C(T) + alpha * |T|:param alpha: 剪枝阈值,权衡模型对训练数据的拟合程度与模型的复杂度:return:"""self._prune_node(self.root_node, alpha)return self.root_node

 四、回归决策树算法的测试

test_decision_tree_R.py

import numpy as np
import matplotlib.pyplot as plt
from decision_tree_R import DecisionTreeRegression
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressorobj_fun = lambda x: np.sin(x)
np.random.seed(0)
n = 100
x = np.linspace(0, 10, n)
target = obj_fun(x) + 0.3 * np.random.randn(n)
data = x[:, np.newaxis]  # 二维数组tree = DecisionTreeRegression(max_bins=50, max_depth=10)
tree.fit(data, target)
x_test = np.linspace(0, 10, 200)
y_test_pred = tree.predict(x_test[:, np.newaxis])
mse, r2 = tree.cal_mse_r2(obj_fun(x_test), y_test_pred)plt.figure(figsize=(14, 5))
plt.subplot(121)
plt.scatter(data, target, s=15, c="k", label="Raw Data")
plt.plot(x_test, y_test_pred, "r-", lw=1.5, label="Fit Model")
plt.xlabel("x", fontdict={"fontsize": 12, "color": "b"})
plt.ylabel("y", fontdict={"fontsize": 12, "color": "b"})
plt.grid(ls=":")
plt.legend(frameon=False)
plt.title("Regression Decision Tree(UnPrune) and MSE = %.5f R2 = %.5f" % (mse, r2))plt.subplot(122)
tree.prune(0.5)
y_test_pred = tree.predict(x_test[:, np.newaxis])
mse, r2 = tree.cal_mse_r2(obj_fun(x_test), y_test_pred)
plt.scatter(data, target, s=15, c="k", label="Raw Data")
plt.plot(x_test, y_test_pred, "r-", lw=1.5, label="Fit Model")
plt.xlabel("x", fontdict={"fontsize": 12, "color": "b"})
plt.ylabel("y", fontdict={"fontsize": 12, "color": "b"})
plt.grid(ls=":")
plt.legend(frameon=False)
plt.title("Regression Decision Tree(Prune) and MSE = %.5f R2 = %.5f" % (mse, r2))plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/254301.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FlinkSql 窗口函数

Windowing TVF 以前用的是Grouped Window Functions&#xff08;分组窗口函数&#xff09;&#xff0c;但是分组窗口函数只支持窗口聚合 现在FlinkSql统一都是用的是Windowing TVFs&#xff08;窗口表值函数&#xff09;&#xff0c;Windowing TVFs更符合 SQL 标准且更加强大…

第5章 数据库操作

学习目标 了解数据库&#xff0c;能够说出数据库的概念、特点和分类 熟悉Flask-SQLAlchemy的安装&#xff0c;能够在Flask程序中独立安装扩展包Flask-SQLAlchemy 掌握数据库的连接方式&#xff0c;能够通过设置配置项SQLALCHEMY_DATABASE_URI的方式连接数据库 掌握模型的定义…

rust语言tokio库底层原理解析

目录 1 rust版本及tokio版本说明1 tokio简介2 tokio::main2.1 tokio::main使用多线程模式2.2 tokio::main使用单线程模式 3 builder.build()函数3.1 build_threaded_runtime()函数新的改变功能快捷键合理的创建标题&#xff0c;有助于目录的生成如何改变文本的样式插入链接与图…

【开源】JAVA+Vue.js实现高校实验室管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、研究内容2.1 实验室类型模块2.2 实验室模块2.3 实验管理模块2.4 实验设备模块2.5 实验订单模块 三、系统设计3.1 用例设计3.2 数据库设计 四、系统展示五、样例代码5.1 查询实验室设备5.2 实验放号5.3 实验预定 六、免责说明 一、摘…

使用 Docker 镜像预热提升容器启动效率详解

概要 在容器化部署中,Docker 镜像的加载速度直接影响到服务的启动时间和扩展效率。本文将深入探讨 Docker 镜像预热的概念、必要性以及实现方法。通过详细的操作示例和实践建议,读者将了解如何有效地实现镜像预热,以加快容器启动速度,提高服务的响应能力。 Docker 镜像预热…

【数据结构】堆(创建,调整,插入,删除,运用)

目录 堆的概念&#xff1a; 堆的性质&#xff1a; 堆的存储方式&#xff1a; 堆的创建 &#xff1a; 堆的调整&#xff1a; 向下调整&#xff1a; 向上调整&#xff1a; 堆的创建&#xff1a; 建堆的时间复杂度&#xff1a; 向下调整&#xff1a; 向上调整&#xff…

红队打靶练习:GLASGOW SMILE: 1.1

目录 信息收集 1、arp 2、nmap 3、nikto 4、whatweb 目录探测 1、gobuster 2、dirsearch WEB web信息收集 /how_to.txt /joomla CMS利用 1、爆破后台 2、登录 3、反弹shell 提权 系统信息收集 rob用户登录 abner用户 penguin用户 get root flag 信息收集…

HARRYPOTTER: FAWKES

攻击机 192.168.223.128 目标机192.168.223.143 主机发现 nmap -sP 192.168.223.0/24 端口扫描 nmap -sV -p- -A 192.168.223.143 开启了21 22 80 2222 9898 五个端口&#xff0c;其中21端口可以匿名FTP登录&#xff0c;好像有点说法,百度搜索一下发现可以用anonymous登录…

网络安全产品之认识准入控制系统

文章目录 一、什么是准入控制系统二、准入控制系统的主要功能1. 接入设备的身份认证2. 接入设备的安全性检查 三、准入控制系统的工作原理四、准入控制系统的特点五、准入控制系统的部署方式1. 网关模式2. 控制旁路模式 六、准入控制系统的应用场景七、企业如何利用准入控制系统…

使用PDFBox实现pdf转其他图片格式

最近在做一个小项目&#xff0c;项目中有一个功能要把pdf格式的图片转换为其它格式&#xff0c;接下来看看用pdfbox来如何实现吧。 首先导入pdfbox相关依赖&#xff1a; <dependency> <groupId>org.apache.pdfbox</groupId> <artifactId>pdfbox</a…

【高阶数据结构】位图布隆过滤器

文章目录 1. 位图1.1什么是位图1.2为什么会有位图1.3 实现位图1.4 位图的应用 2. 布隆过滤器2.1 什么是布隆过滤器2.2 为什么会有布隆过滤器2.3 布隆过滤器的插入2.4 布隆过滤器的查找2.5 布隆过滤器的模拟实现2.6 布隆过滤器的优点2.7 布隆过滤器缺陷 3. 海量数据面试题3.1 哈…

CTFshow web(命令执行29-36)

?ceval($_GET[shy]);&shypassthru(cat flag.php); #逃逸过滤 ?cinclude%09$_GET[shy]?>&shyphp://filter/readconvert.base64-encode/resourceflag.php #文件包含 ?cinclude%0a$_GET[cmd]?>&cmdphp://filter/readconvert.base64-encode/…

Kubernetes实战(二十七)-HPA实战

1 HPA简介 HPA 全称是 Horizontal Pod Autoscaler&#xff0c;用于POD 水平自动伸缩&#xff0c; HPA 可以 基于 POD CPU 利用率对 deployment 中的 pod 数量进行自动扩缩容&#xff08;除了 CPU 也可以基于自定义的指标进行自动扩缩容&#xff09;。pod 自动缩放不适用于无法…

ubuntu22.04@laptop OpenCV Get Started: 001_reading_displaying_write_image

ubuntu22.04laptop OpenCV Get Started: 001_reading_displaying_write_image 1. 源由2. Read/Display/Write应用Demo2.1 C应用Demo2.2 Python应用Demo 3. 过程分析3.1 导入OpenCV库3.2 读取图像文件3.3 显示图像3.4 保存图像文件 4. 总结5. 参考资料 1. 源由 读、写、显示图像…

Windows - URL Scheme - 在Windows上无管理员权限为你的程序添加URL Scheme

Windows - URL Scheme - 在Windows上无管理员权限为你的程序添加URL Scheme What 想不想在浏览器打开/控制你的电脑应用&#xff1f; 比如我在浏览器地址栏输入wegame://后回车会提示是否打开URL:wegame Portocol。 若出现了始终允许选项&#xff0c;你甚至可以写一个Web界面…

【AIGC核心技术剖析】DreamCraft3D一种层次化的3D内容生成方法

DreamCraft3D是一种用于生成高保真、连贯3D对象的层次化3D内容生成方法。它利用2D参考图像引导几何塑造和纹理增强阶段&#xff0c;通过视角相关扩散模型执行得分蒸馏采样&#xff0c;解决了现有方法中存在的一致性问题。使用Bootstrapped Score Distillation来提高纹理&#x…

React 实现表单组件

表单是html的基础元素&#xff0c;接下来我会用React实现一个表单组件。支持包括输入状态管理&#xff0c;表单验证&#xff0c;错误信息展示&#xff0c;表单提交&#xff0c;动态表单元素等功能。 数据状态 表单元素的输入状态管理&#xff0c;可以基于react state 实现。 …

计算机网络——04接入网和物理媒体

接入网和物理媒体 接入网络和物理媒体 怎样将端系统和边缘路由器连接&#xff1f; 住宅接入网络单位接入网络&#xff08;学校、公司&#xff09;无线接入网络 住宅接入&#xff1a;modem 将上网数据调制加载到音频信号上&#xff0c;在电话线上传输&#xff0c;在局端将其…

Ubuntu22.04 gnome-builder gnome C 应用程序习练笔记(一)

一、序言 gnome-builder构建器是gnome程序开发的集成环境&#xff0c;支持主力语言C, C, Vala, jscript, python等&#xff0c;界面以最新的 gtk 4.12 为主力&#xff0c;将其下版本的gtk直接压入了depreciated&#xff0c;但gtk4.12与普遍使用的gtk3有很大区别&#xff0c;原…

Java stream 流的基本使用

Java stream 的基本使用 package com.zhong.streamdemo.usestreamdemo;import jdk.jfr.DataAmount; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor;import java.util.ArrayList; import java.util.Comparator; import java.util.Li…