XGBoost介绍

      XGBoost:是eXtreme Gradient Boosting(极端梯度提升)的缩写,是一种强大的集成学习(ensemble learning)算法,旨在提高效率、速度和高性能。XGBoost是梯度提升(Gradient Boosting)的优化实现。集成学习将多个弱模型组合起来,形成一个更强大的模型。可用于回归、分类、排序。

      源码地址:https://github.com/dmlc/xgboost,它是经过优化的,可扩展、可移植和分布式梯度提升(GBDT, GBRT or GBM)库,适用于Python、C++等,可在单机、Hadoop、Spark等上运行,支持在Linux、Mac、Windows上的安装,license为Apache-2.0,最新发布版本为2.1.4。

      XGBoost使用决策树(decision trees)作为基础学习器(learners),按顺序组合它们以提高模型的性能。每棵新树都经过训练以纠正前一棵树的错误,这个过程称为提升(boosting)。该过程可以分解如下:

      (1).从基础学习器开始:第一个模型决策树在数据上进行训练。在回归任务中,这个基础模型只是预测目标变量的平均值。

      (2).计算误差:训练第一棵树后,计算预测值和实际值之间的误差。

      (3).训练下一棵树:下一棵树在前一棵树的错误上进行训练。此步骤尝试纠正第一棵树所犯的错误。

      (4).重复该过程:这个过程继续进行,每棵新树都试图纠正前一棵树的错误,直到满足停止标准。

      (5).合并预测:最终预测是所有树的预测总和。

      XGBoost模型主要保存为二进制文件UBJSON、也可选择JSON或文本格式。

      XGBoost优势

      (1).具有高度可扩展性和高效率,适合处理大型数据集。

      (2).实现并行处理技术并利用硬件优化来加快训练过程,在训练期间使用所有CPU内核并行构建树。

      (3).提供了广泛的可自定义参数和正则化技术,允许用户根据自己的特定需求对模型进行微调。

      (4).分布式计算,可使用一组机器训练非常大的模型。

      XGBoost缺点:

      (1).计算量非常大,尤其是在训练复杂模型时,因此不太适合资源受限的系统。

      (2).对噪声数据或异常值很敏感,因此需要数据预处理才能获得最佳性能。

      (3).在小数据集上或在模型中使用过多树时容易过拟合。

      注意

      (1).只有Linux平台支持使用多个GPU进行训练。

      (2).本地安装时,Linux上需要glibc 2.28+;Windows上需要安装Visual C++ Redistributable。pip的版本需要21.3+。

      (3).使用pip的默认安装("pip install xgboost")将安装完整的XGBoost包,包括对GPU算法和联合学习(federated learning)的支持。可安装仅cpu版的,执行"pip install xgboost-cpu",此版本将减少安装包的大小并节省磁盘空间,但不提供某些功能,如GPU算法和联合学习。

      (4).通过Conda安装,执行"conda install -c conda-forge py-xgboost",Conda应该能够检测到您的机器上是否存在GPU,并安装正确的XGBoost变体。也可指定仅安装cpu版本,执行"conda install -c conda-forge py-xgboost-cpu"。

      注:以上整理的内容主要来自:

      1. https://xgboost.readthedocs.io/en/latest/

      2. ttps://machinelearningmastery.com

      3. https://www.geeksforgeeks.org/xgboost/

      以下Python测试代码用于回归,使用波士顿房价数据集,共包含506个样本,每个样本有13个特征和1个目标变量:

import colorama
import argparse
import pandas as pd
import xgboost as xgbdef parse_args():parser = argparse.ArgumentParser(description="test XGBoost")parser.add_argument("--task", required=True, type=str, choices=["regress", "classify", "rank"], help="specify what kind of task")parser.add_argument("--csv", required=True, type=str, help="source csv file")parser.add_argument("--model", required=True, type=str, help="model file, save or load")args = parser.parse_args()return argsdef split_train_test(X, y):X = X.sample(frac=1, random_state=42).reset_index(drop=True) # random_state=42: make the results consistent each timey = y.sample(frac=1, random_state=42).reset_index(drop=True)index = int(len(X) * 0.8)X_train, X_test = X[:index], X[index:]y_train, y_test = y[:index], y[index:]return X_train, X_test, y_train, y_testdef calculate_rmse(input, target): # Root Mean Squared Errorreturn (sum((input - target) ** 2) / len(input)) ** 0.5def regress(csv_file, model_file):# 1. load datadata = pd.read_csv(csv_file)# 2. split into training set and test seX = data.drop('MEDV', axis=1)y = data['MEDV']print(f"X: type: {type(X)}, shape: {X.shape}; y: type: {type(X)}, shape: {y.shape}")X_train, X_test, y_train, y_test = split_train_test(X, y)train_dmatrix = xgb.DMatrix(X_train, label=y_train)test_dmatrix = xgb.DMatrix(X_test, label=y_test)print(f"train_dmatrix type: {type(train_dmatrix)}, shape(h,w): {train_dmatrix.num_row()}, {train_dmatrix.num_col()}")# 3. set XGBoost paramsparams = {'objective': 'reg:squarederror', # specify the learning task: classify: binary:logistic or multi:softmax or multi:softprob; rank: rank:ndcg'max_depth': 5, # maximum tree depth'eta': 0.1, # learning rate'subsample': 0.8, # subsample ratio of the training instance'colsample_bytree': 0.8, # subsample ratio of columns when constructing each tree'seed': 42, # random number seed'eval_metric': 'rmse' # metric used for monitoring the training result and early stopping}# 4. train modelbest = xgb.train(params, train_dmatrix, num_boost_round=1000) # num_boost_round: epochs# 5. predicty_pred = best.predict(test_dmatrix)# print(f"y_pred: {y_pred}")# 6. evaluate the modelrmse = calculate_rmse(y_test, y_pred)print(f"rmse: {rmse}")# 7. save modelbest.save_model(model_file)# 8. load mode and predictmodel = xgb.Booster()model.load_model(model_file)result = model.predict(test_dmatrix)test_label = test_dmatrix.get_label()for idx in range(len(result)):print(f"ground truth: {test_label[idx]:.1f}, \tpredict: {result[idx]:.1f}")if __name__ == "__main__":print("xgboost version:", xgb.__version__)colorama.init(autoreset=True)args = parse_args()if args.task == "regress":regress(args.csv, args.model)print(colorama.Fore.GREEN + "====== execution completed ======")

      执行结果如下图所示:

      GitHub:https://github.com/fengbingchun/NN_Test

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/31841.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Aliyun CTF 2025 web ezoj

文章目录 ezoj ezoj 进来一看是算法题,先做了试试看,gpt写了一个高效代码通过了 通过后没看见啥,根据页面底部提示去/source看到源代码,没啥思路,直接看wp吧,跟算法题没啥关系,关键是去看源码 def audit_checker(even…

大数据hadoop课程笔记

1.课程导入 柯洁 Alpha Go是人工智能领域的里程碑。 深度学习 大模型deepseek chatgpt 大模型 和 大数据 之间有着非常紧密的关系。可以说,大数据是大模型发展的基石,而大模型是大数据价值挖掘的重要工具。 https://youtu.be/nN-VacxHUH8?sifj7Ltk…

Pandas数据清洗实战之清洗猫眼电影

本次案例所需要用到的模块 pandas(文件读取保存 操作表格的模块) 将上次Scrapy爬取下来的文件 做个数据清洗 变成我们想要的数据 确定目的:将此文件中的duration字段中的分钟 和publisher_time上映去掉 只保留纯数值 数据清洗题目如下: 修复 publish_time列中的错…

UDP-网络编程/socket编程

一,socket相关接口 1,socket 我们来介绍socket编程的第一个接口:socket,它需要用到的头文件如图: 其中domain表示域或者协议家族: 本次我就用AF_INET(ipv4)来做演示 type参数表示…

《人月神话》:软件工程的成本寓言与生存法则

1975年,Fred Brooks在《人月神话》中写下那句振聋发聩的断言——“向进度落后的项目增加人力,只会让进度更加落后”——时,他或许未曾料到,这一观点会在半个世纪后的人工智能与云原生时代,依然如达摩克利斯之剑般悬在每…

ROS云课基础题库-01C++案例-甜甜圈

效率是核心,但效率高的教程会忽略掉非常多的细节。 解决问题的思路和细节对于一个问题的有效求解至关重要。 资料 云课五分钟-02第一个代码复现-终端甜甜圈C-CSDN博客 从云课五分钟到五秒钟焦虑的甜甜圈向前冲-CSDN博客 说明 复现重要性没有那么大,…

Oracle RHEL 5.8 安装 - 呆瓜式

前言 Red Hat Enterprise Linux Server release 5.8 为企业级 SO 镜像。绝大部分企业如果使用Oracle数据库均会使用其企业版 OS ,能够很好的支持数据库的运行 文档目的 当前文档仅针对 VMware Workstation Pro 进行 OS 介质安装。 镜像下载地址 注意&#xff1…

【数据分析大屏】基于Django+Vue汽车销售数据分析可视化大屏(完整系统源码+数据库+开发笔记+详细部署教程+虚拟机分布式启动教程)✅

目录 一、项目背景 二、项目创新点 三、项目功能 四、开发技术介绍 五、项目功能展示 六、权威视频链接 一、项目背景 汽车行业数字化转型加速,销售数据多维分析需求激增。本项目针对传统报表系统交互性弱、实时性差等痛点,基于DjangoVue架构构建…

软件IIC和硬件IIC的主要区别,用标准库举例!

学习交流792125321,欢迎一起加入讨论! 在学习iic的时候,我们经常会遇到软件 IC和硬件 IC,它两到底有什么区别呢? 软件 IC(模拟 IC)和硬件 IC(外设 IC)是两种实现 IC 总线通信的方式…

CSS-三大特性,盒子模型,圆角边框,盒子阴影,文字阴影

一、 CSS 的三大特性 CSS 有三个非常重要的三个特性:层叠性、继承性、优先级。 1.层叠性 相同选择器给设置相同的样式,此时一个样式就会覆盖(层叠)另一个冲突的样式。层叠性主要解决样式冲突 的问题 层叠性原则: 样式冲突,遵循的原…

基于 Qwen2.5-14B + Elasticsearch RAG 的大数据知识库智能问答系统

AI 时代,如何从海量私有文档(非公开)中快速提取精准信息成为了许多企业和个人的迫切需求。 本文介绍了一款基于 Qwen2.5-14B 大语言模型(换成 DeepSeek 原理一致)与 Elasticsearch 搜索引擎构建的大数据知识库智能问答…

算法手记1

🦄个人主页:修修修也 🎏所属专栏:数据结构 ⚙️操作环境:Visual Studio 2022 目录 一.NC313 两个数组的交集 题目详情: 题目思路: 解题代码: 二.AB5 点击消除 题目详情: 题目思路: 解题代码: 结语 一.NC313 两个数组的交集 牛客网题目链接(点击即可跳转)…

JMeter使用BeanShell断言

BeanShell简介 BeanShell是使用Java语法的一套脚本语言,在JMeter的多种组件中都有BeanShell的身影,如: 定时器:BeanShell Timer前置处理器:BeanShell PreProcessor采样器:BeanShell Sampler后置处理器&am…

【技海登峰】Kafka漫谈系列(五)Java客户端之生产者Producer核心组件与实现原理剖析

【技海登峰】Kafka漫谈系列(五)Java客户端之生产者Producer核心组件与实现原理剖析 向Kafka Broker服务节点中发送主题消息数据的应用程序被称为生产者,生产者与消费者均属于Kafka客户端,几乎所有主流语言都支持调用客户端API。官方提供了基于Java实现的kafka-clients,用于…

【eNSP实战】配置交换机端口安全

拓扑图 目的:让交换机端口与主机mac绑定,防止私接主机。 主机PC配置不展示,按照图中配置即可。 开始配置之前,使用PC1 ping 一遍PC2、PC3、PC4、PC5,让交换机mac地址表刷新一下记录。 LSW1查看mac地址表 LSW1配置端…

AWS Bedrock 正式接入 DeepSeek-R1 模型:安全托管的生成式 AI 解决方案

亚马逊云科技(AWS)于 2024 年 1 月 30 日 宣布,DeepSeek-R1 模型 正式通过 Amazon Bedrock 平台提供服务,用户可通过 Bedrock Marketplace 或自定义模型导入功能使用该模型。 DeepSeek-R1,其安全防护机制与全面的 AI 部…

数据结构之线性表

目录 1 简介 2 线性表的基本概念 3 顺序存储的线性表 3.1 定义线性表结构 3.2 初始化线性表 3.3 插入元素 3.4 删除元素 3.5 查找元素 3.6 扩容操作 3.7 打印线性表 4 线性表的应用 5 总结 1 简介 线性表是数据结构中最基础且常用的一种结构,它是由一…

c#面试题12

1.ApplicationPool介绍一下 c#里没有 2.XML 可扩展标记语言,一般以.xml文件格式的形式存在。可用于存储结构化的数据 3.ASP.NET的用户控件 将原始的控件,用户根据需要进行整合成一个新的控件 4.介绍一下code-Behind 即代码后置技术,就是…

英语学习(GitHub学到的分享)

【英语语法:https://github.com/hzpt-inet-club/english-note】 【离谱的英语学习指南:https://github.com/byoungd/English-level-up-tips/tree/master】 【很喜欢文中的一句话:如果我轻轻松松的学习,生活的幸福指数会提高很多…

C++蓝桥杯基础篇(十一)

片头 嗨~小伙伴们,大家好!今天我们来学习C蓝桥杯基础篇(十一),学习类,结构体,指针相关知识,准备好了吗?咱们开始咯~ 一、类与结构体 类的定义:在C中&#x…