【机器学习】回归树

回归树是一种用于数值型目标变量的监督学习算法,通过将特征空间划分为多个区域,并在每个区域内使用简单的预测模型(如区域均值)来进行回归。回归树以“递归划分-计算区域均值”的方式逐层生成树节点,最终形成叶节点预测值。相比于线性回归,回归树更适合处理非线性和复杂数据结构。

回归树的基本原理

在回归树中,每个节点执行以下操作:

  • 选择最优特征及分割点:通过最小化均方误差(Mean Squared Error, MSE)等标准选择最佳分割特征和分割点。
  • 分割数据:根据选择的分割特征将数据划分成两部分,形成左子节点和右子节点。
  • 递归分割:对子节点进行递归分割,直至满足停止条件(如最大深度或最小样本数)。

分割准则

均方误差(MSE)

在回归树中,常用均方误差(MSE)作为分割准则:
MSE = 1 N ∑ i = 1 N ( y i − y ˉ ) 2 \text{MSE} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \bar{y})^2 MSE=N1i=1N(yiyˉ)2
其中,( y_i ) 是样本 ( i ) 的实际值,( \bar{y} ) 是区域内样本的平均值。分割点选择通过最小化分割前后数据的 MSE 来完成。

回归树的构建步骤

  1. 选择最佳分割特征与分割点:遍历每个特征和可能的分割点,计算分割后的MSE,选择使MSE最小的分割特征和点。
  2. 递归分割数据:在左、右子节点递归执行上述过程,形成新的分支节点。
  3. 生成叶节点:一旦满足停止条件,将当前节点的预测值设为该区域中所有样本的均值。

用 Numpy 实现回归树

以下代码展示了如何用 Numpy 实现一个基本的回归树,并通过均方误差来确定分割点。

import numpy as np
import matplotlib.pyplot as plt# 计算均方误差(MSE)
def mean_squared_error(y):return np.var(y) * len(y)# 数据集分割
def split_dataset(X, y, feature, threshold):left_mask = X[:, feature] <= thresholdright_mask = ~left_maskreturn X[left_mask], y[left_mask], X[right_mask], y[right_mask]# 查找最佳分割特征和分割点
def best_split(X, y):best_mse = float("inf")best_feature, best_threshold = None, Nonefor feature in range(X.shape[1]):thresholds = np.unique(X[:, feature])for threshold in thresholds:_, y_left, _, y_right = split_dataset(X, y, feature, threshold)if len(y_left) == 0 or len(y_right) == 0:continuemse_split = mean_squared_error(y_left) + mean_squared_error(y_right)if mse_split < best_mse:best_mse = mse_splitbest_feature = featurebest_threshold = thresholdreturn best_feature, best_threshold# 回归树类
class RegressionTree:def __init__(self, max_depth=3, min_samples_split=2):self.max_depth = max_depthself.min_samples_split = min_samples_splitself.tree = Nonedef fit(self, X, y, depth=0):if len(y) < self.min_samples_split or depth >= self.max_depth:return np.mean(y)feature, threshold = best_split(X, y)if feature is None:return np.mean(y)left_X, left_y, right_X, right_y = split_dataset(X, y, feature, threshold)left_node = self.fit(left_X, left_y, depth + 1)right_node = self.fit(right_X, right_y, depth + 1)self.tree = {"feature": feature, "threshold": threshold, "left": left_node, "right": right_node}return self.treedef predict_sample(self, x, tree):if not isinstance(tree, dict):return treeif x[tree["feature"]] <= tree["threshold"]:return self.predict_sample(x, tree["left"])else:return self.predict_sample(x, tree["right"])def predict(self, X):return np.array([self.predict_sample(x, self.tree) for x in X])# 生成示例数据
np.random.seed(0)
X = np.random.rand(100, 1) * 10  # 特征数据
y = 2 * X.flatten() + np.random.randn(100) * 2  # 标签数据# 训练回归树
tree = RegressionTree(max_depth=4, min_samples_split=5)
tree.fit(X, y)# 预测并可视化
X_test = np.linspace(0, 10, 100).reshape(-1, 1)
y_pred = tree.predict(X_test)plt.scatter(X, y, color="blue", label="训练数据")
plt.plot(X_test, y_pred, color="red", label="回归树预测")
plt.xlabel("特征")
plt.ylabel("目标值")
plt.title("回归树预测示意图")
plt.legend()
plt.show()

在代码中,我们首先通过遍历各个特征和分割点来选择最优分割点,使得均方误差最小。然后在每个节点递归进行分割,直至达到设定的深度或最小样本数。最终通过构建的树结构进行预测。

使用 Sklearn 的回归树

Scikit-Learn 提供了 DecisionTreeRegressor 来实现回归树模型,可以大大简化建模过程。

from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error# 训练回归树
regressor = DecisionTreeRegressor(max_depth=4, min_samples_split=5)
regressor.fit(X, y)# 预测
y_pred_sklearn = regressor.predict(X_test)# 计算均方误差
mse = mean_squared_error(y, regressor.predict(X))
print("均方误差:", mse)# 可视化
plt.scatter(X, y, color="blue", label="训练数据")
plt.plot(X_test, y_pred_sklearn, color="red", label="Sklearn 回归树预测")
plt.xlabel("特征")
plt.ylabel("目标值")
plt.title("Sklearn 回归树预测示意图")
plt.legend()
plt.show()

总结

本文介绍了回归树的基本概念与实现,包括回归树的分割准则、MSE 计算、最佳分割点选择等细节。通过 Numpy 手动实现了一个简单的回归树模型,并展示了如何在 Scikit-Learn 中快速实现和使用回归树。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/462518.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ChatGPT新体验:AI搜索功能与订阅支付指南

就在凌晨&#xff0c;在ChatGPT迎来两周岁生日之际&#xff0c;OpenAI重磅发布了ChatGPT的全新人工智能搜索体验。 期待已久的时刻终于到来&#xff0c; ChatGPT正式转型成为一款革命性的AI搜索引擎! 先来看看ChatGPT搜索&#xff1a;这次不是简单的加个搜索框&#xff0c;而…

JS | 如何更好地优化 JavaScript 的内存回收?

目录 一、理解JavaScript内存生命周期 ● 创建对象和分配内存 ● 内存的使用 ● 内存回收 二、减少内存泄露 ● 避免全局变量 ● 正确使用闭包 三、合理管理内存 ● 局部变量和即时函数 ● 解绑事件监听器 四、使用现代JavaScript特性辅助内存回收 ● 使用WeakMap和…

群控系统服务端开发模式-应用开发-上传配置功能开发

下面直接进入上传配置功能开发&#xff0c;废话不多说。 一、创建表 1、语句 CREATE TABLE cluster_control.nc_param_upload (id int(11) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 编号,upload_type tinyint(1) UNSIGNED NOT NULL COMMENT 上传类型 1&#xff1a;本站 2&a…

Cisco Packet Tracer 8.0 路由器的基本配置和Telnet设置

文章目录 构建拓扑图配置IP地址配置路由器命令说明测试效果 构建拓扑图 1&#xff0c;添加2811路由器。 2&#xff0c;添加pc0。 3&#xff0c;使用交叉线连接路由器和pc&#xff08;注意线路端口&#xff09;。 4&#xff0c;使用配置线连接路由器和pc&#xff08;注意线路…

从气象中心采集cma台风路径数据

在自然灾害监测与预警领域&#xff0c;台风作为一种极具破坏力的自然现象&#xff0c;其路径预测和强度评估对于减少潜在损失至关重要。随着互联网技术的发展&#xff0c;国家气象中心等专业机构提供了详尽的台风历史数据和实时跟踪服务&#xff0c;通过网络接口可便捷地访问这…

ssm+vue665基于Java的壁纸网站设计与实现

博主介绍&#xff1a;专注于Java&#xff08;springboot ssm 等开发框架&#xff09; vue .net php phython node.js uniapp 微信小程序 等诸多技术领域和毕业项目实战、企业信息化系统建设&#xff0c;从业十五余年开发设计教学工作 ☆☆☆ 精彩专栏推荐订阅☆☆☆☆☆不…

Applied Artificial Intelligence

文章目录 一、征稿简介二、重要信息三、服务简述四、投稿须知五、联系咨询 一、征稿简介 二、重要信息 期刊官网&#xff1a;https://ais.cn/u/3eEJNv 三、服务简述 四、投稿须知 1.在线投稿&#xff1a;由艾思科蓝支持在线投稿&#xff0c;请将文章全文投稿至艾思科蓝投稿…

oracle-函数-NULLIF (expr1, expr2)的妙用

【语法】NULLIF (expr1, expr2) 【功能】expr1和expr2相等返回NULL&#xff0c;不相等返回expr1经典的使用场景&#xff1a; 1. 数据清洗与转换 在数据清洗过程中&#xff0c;NULLIF 函数可以用于将某些特定值&#xff08;通常是无效或不需要的值&#xff09;替换为 NULL&…

pycharm 安装

双击pycharm-community-2024.2.0.1.exe安装包 可以保持默认&#xff0c;也可以改成D&#xff0c;如果你有D 盘 全选&#xff0c;下一步 安装完成 在桌面创建一个文件夹任意名字 拖动到pycharm 图标打开 如果出现这个勾选信任即可 下面准备汉化&#xff08;喜欢英语界面的…

Matlab实现蚁群算法求解旅行商优化问题(TSP)(理论+例子+程序)

一、蚁群算法 蚁群算法由意大利学者Dorigo M等根据自然界蚂蚁觅食行为提岀。蚂蚁觅食行为表示大量蚂蚁组成的群体构成一个信息正反馈机制&#xff0c;在同一时间内路径越短蚂蚁分泌的信息就越多&#xff0c;蚂蚁选择该路径的概率就更大。 蚁群算法的思想来源于自然界蚂蚁觅食&a…

计算机毕业设计Hadoop+大模型高考推荐系统 高考分数线预测 知识图谱 高考数据分析可视化 高考大数据 大数据毕业设计 Hadoop 深度学习

温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 开题报告 题目&#xff1a…

【qwen2-1.5-instruct 好于Gemma2-2b-instruct\Llama3.2-1B-instruct】

最新的qwen Llama Gemma小参数模型比较&#xff0c;移动端 qwen2-1.5-instruct 好于Gemma2-2b-instruct\Llama3.2-1B-instruct 从 Qwen2–1.5B-instruct 到 Gemma2–2B-instruct&#xff0c;再到 Llama3.2–1B-instruct&#xff0c;最后是新的 Qwen2.5–1.5B-instruct。虽然我…

C++之位算法

位算法 常见位运算总结 位1的个数 给定一个正整数 n&#xff0c;编写一个函数&#xff0c;获取一个正整数的二进制形式并返回其二进制表达式中 设置位 的个数&#xff08;也被称为汉明重量&#xff09;。 示例 1&#xff1a; 输入&#xff1a;n 11 输出&#xff1a;3 解释…

JAVA利用方法实现四道题

目录 1.给定一个字符串 s &#xff0c;找到 它的第一个不重复的字符&#xff0c;并返回它的索引 。如果不存在&#xff0c;则返回-1 2.计算字符串最后一个单词的长度&#xff0c;单词以空格隔开。&#xff08;注&#xff1a;字符串末尾不以空格为结尾&#xff09; 3.如果在将所…

【教程】Git 标准工作流

前言 Git 是日常开发中常用的版本控制工具&#xff0c;配合代码托管仓库&#xff08;如&#xff0c;Github&#xff0c;GitLab&#xff0c;Gitee 等&#xff09;用来实现多人多版本的协作开发。 但是 Git 的命令纷繁复杂&#xff0c;多如累卵&#xff0c;不可能也不需要全部搞…

基于AI深度学习的中医针灸实训室腹针穴位智能辅助定位系统开发

在中医针灸的传统治疗中&#xff0c;穴位取穴的精确度对于治疗效果至关重要。然而&#xff0c;传统的定位方法&#xff0c;如体表标志法、骨度折量法和指寸法&#xff0c;由于观察角度、个体差异&#xff08;如人体姿态和皮肤纹理&#xff09;以及环境因素的干扰&#xff0c;往…

金融标准体系

目录 基本原则 标准体系结构图 标准明细表 金融标准体系下载地址 基本原则 需求引领、顶层设计。 坚持目标导向、问题导向、结果 导向有机统一&#xff0c;构建支撑适用、体系完善、科学合理的金融 标准体系。 全面系统、重点突出。 以金融业运用有效、保护有力、 管理高…

.NET 8 Web API 中的身份验证和授权

本次介绍分为3篇文章&#xff1a; 1&#xff1a;.Net 8 Web API CRUD 操作.Net 8 Web API CRUD 操作-CSDN博客 2&#xff1a;在 .Net 8 API 中实现 Entity Framework 的 Code First 方法https://blog.csdn.net/hefeng_aspnet/article/details/143229912 3&#xff1a;.NET …

Spring Boot 与 Vue 共铸卓越采购管理新平台

作者介绍&#xff1a;✌️大厂全栈码农|毕设实战开发&#xff0c;专注于大学生项目实战开发、讲解和毕业答疑辅导。 &#x1f345;获取源码联系方式请查看文末&#x1f345; 推荐订阅精彩专栏 &#x1f447;&#x1f3fb; 避免错过下次更新 Springboot项目精选实战案例 更多项目…

字符串统计(Python)

接收键盘任意录入&#xff0c;分别统计大小写字母、数字及其它字符数量&#xff0c;打印输出。 (笔记模板由python脚本于2024年11月02日 08:23:31创建&#xff0c;本篇笔记适合熟悉python字符串并懂得基本编程技法的coder翻阅) 【学习的细节是欢悦的历程】 Python 官网&#xf…