XGBoost预测及调参过程(+变量重要性)--血友病计数数据

         所使用的数据是血友病数据,如有需要,可在主页资源处获取,数据信息如下:

ddcdd574478b441d91e491390799e8da.png

读取数据及数据集区分

        数据预处理及区分数据集代码如下(详细预处理说明见上篇文章--随机森林):

import pandas as pd
import numpy as np
hemophilia = pd.read_csv('D:/my_files/data.csv')  #读取数据#数值变量化为分类变量
hemophilia['hiv']=hemophilia['hiv'].astype(object) 
hemophilia['factor']=hemophilia['factor'].astype(object)
new_hemophilia=pd.get_dummies(hemophilia,drop_first=True)   #drop_first=True--删去一列,如hiv,处理后为两列,都是01表示,但只保留一列就足够表示两种状态
new_data=new_hemophilia
from sklearn.model_selection import train_test_split
x = new_data.drop(['deaths'],axis=1)   #删去标签列
X_train, X_test, y_train, y_test = train_test_split(x, new_data.deaths, test_size=0.3, random_state=0)  #区分数据集,70%训练集,30%测试集

默认参数XGBoost

        先使用默认参数XGBoost进行预测,输出预测均方误差为0.334.

from xgboost import XGBRegressor
from sklearn.model_selection import GridSearchCV
xgb_model = XGBRegressor(random_state=0)  #random_state=0是随机种子数
xgb_model.fit(X_train, y_train)
y_pred = xgb_model.predict(X_test)
print('MSE of xgb: %.3f' %metrics.mean_squared_error(y_test, y_pred))
'''MSE of xgb: 0.334
'''

XGBoost调参

        接下来对XGBoost进行调参,XGBoost参数很多,一般对少数参数进行调整就可以得到不错的效果,所以这里只对'max_depth','min_child_weight','gamma'这三个参数进行粗略调参,如果追求更加有效的调参结果,可以对多个参数逐一调参。调参后输出预测均方误差为0.287,已经有所下降,说明模型的预测效果已经得到了提升。

param_grid = {'max_depth':[1,2,3,4,5],'min_child_weight':range(10,70,10),'gamma':[i*0.01 for i in range(0,20,3)]}
GS = GridSearchCV(xgb_model,param_grid,scoring = 'neg_mean_squared_error',cv=5)
GS.fit(X_train, y_train)
GS.best_params_  #最佳参数组合#{'gamma': 0.15, 'max_depth': 3, 'min_child_weight': 68}xgb_model = XGBRegressor(gamma = 0.15, max_depth = 3, min_child_weight = 60, random_state=0)
xgb_model.fit(X_train, y_train)
y_pred = xgb_model.predict(X_test)
print('MSE of xgb: %.3f' %metrics.mean_squared_error(y_test, y_pred))
'''MSE of xgb: 0.287
'''

XGBoost变量重要性

        XGBoost和随机森林都能够输出变量重要性,代码如下:

import matplotlib.pyplot as plt
importances = list(xgb_model.feature_importances_)   #XGBoost
feature_list = list(x.columns)
feature_importances = [(feature, round(importance, 2)) for feature, importance in zip(feature_list, importances)]
feature_importances = sorted(feature_importances, key=lambda x: x[1], reverse=True)
f_list = []
importances_list = []
for i in range(0,8):feature = feature_importances[i][0]importances_r = feature_importances[i][1]f_list.append(feature),importances_list.append(importances_r)
x_values = list(range(len(importances_list)))
plt.figure(figsize=(14, 9))
plt.bar(x_values, importances_list, orientation='vertical')
plt.xticks(x_values, f_list, rotation=25, size =18)
plt.yticks(size =18)
plt.ylabel('Importance',size = 20)
plt.xlabel('Variable',size = 20)
plt.title('XGB Variable Importances',size = 22)
#plt.savefig('D:/files/xgb变量重要性.png', dpi=800)    #保存图片到指定位置 dpi--分辨率
plt.show()

63d57f3d881b494c9c82b321cef4ef92.png

        还可以输出图片对比预测结果和真实值的差异,代码及图片如下:

import matplotlib.pyplot as plt
y_test = y_test.reset_index(drop = True)
plt.plot(y_test,color="b",label = 'True')
plt.plot(y_pred,color="r",label = 'Prediction') 
plt.xlabel("index")  #x轴命名表示
plt.ylabel("deaths")  #y轴命名表示
plt.title("xgb Comparison between real and perdiction") 
plt.legend()  #增加图例
#plt.savefig('D:/my_files/xgb Comparison between real and perdiction.png', dpi = 500) #保存图片
plt.show()  #显示图片

5d065dbe99f747e68fc9b0d063c7a69c.png

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/350949.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQLServer 借助Navcate做定时备份的脚本

首先创建SQLServer链接,然后在Query标签种创建一个查询 查询内容如下 use ChengYuMES declare ls_time varchar(1000) declare ls_dbname varchar(1000) set ls_time convert(varchar, getdate(), 112) _ replace(convert(varchar, getdate(), 108), :, )-- 需…

68. UE5 RPG 优化敌人角色的表现效果

我们现在已经有了四个敌人角色,接下来,处理一下在战斗中遇到的问题。 处理角色死亡后还会攻击的问题 因为我们有角色溶解的效果,角色在死亡以后的5秒钟才会被销毁掉。所以在这五秒钟之内,角色其实还是会攻击。主要时因为AI行为树…

LeetCode 算法:回文链表 c++

原题链接🔗:回文链表 难度:简单⭐️ 题目 给你一个单链表的头节点 head ,请你判断该链表是否为回文链表。如果是,返回 true ;否则,返回 false 。 示例 1: 输入:head…

C++11移动语义

前言 之前我们已经知道了在类里开辟数组后,每一次传值返回和拷贝是,都会生成一个临时变量 class Arr { public://构造Arr() {/*具体实现*/ };//拷贝Arr(const Arr& ar) {/*具体实现*/ };//重载Arr operator(const Arr& ar) { /*具体实现*/Arr …

Python高级编程:Functools模块的8个高级用法,强烈建议添加到你的开发工具箱中!

目录 1. functools.partial 2. functools.lru_cache lru_cache的特点 cache的特点 性能比较与选择 3. functools.reduce functools.reduce的作用 工作原理 示例 累加序列中的所有元素 计算阶乘 initializer的使用 应用场景 示例:计算平均销售额 小结 4. funct…

微前端乾坤方案

微前端乾坤方案 了解乾坤 官方文档 介绍 qiankun 是一个基于 single-spa 的微前端实现库,旨在帮助大家能更简单、无痛的构建一个生产可用微前端架构系统。 qiankun 的核心设计理念 🥄 简单 由于主应用微应用都能做到技术栈无关,qiankun 对…

Java 桥接模式(Bridge Pattern)是设计模式中的一种结构型设计模式,桥接模式的核心思想是将抽象与实现解耦

桥接模式(Bridge Pattern)是一种结构型设计模式,它将抽象部分与它的实现部分分离,使它们都可以独立地变化。桥接模式的核心思想是将抽象与实现解耦,使得它们可以独立扩展。 在桥接模式中,通常包含以下四个…

DAY3-力扣刷题

1.罗马数字转整数 13. 罗马数字转整数 - 力扣(LeetCode) 罗马数字包含以下七种字符: I, V, X, L,C,D 和 M。 字符 数值 I 1 V 5 X 10 L …

最长不下降子序列LIS详解

最长不下降子序列指的是在一个数字序列中,找到一个最长的子序列(可以不连续),使得这个子序列是不下降(非递减)的。 假如,现有序列A[1,2,3,-1,-2&…

16.大模型分布式训练框架 Microsoft DeepSpeed

微调、预训练显存对比占用 预训练LLaMA2-7B模型需要多少显存? 假设以bf16混合精度预训练 LLaMA2-7B模型,需要近120GB显存。即使A100/H100(80GB)单卡也无法支持。 为何比 QLoRA多了100GB?不妨展开计算下显存占用&…

誉天教育近期开班计划(6月15日更新)

云计算HCIP 周末班 2024/6/15 田老师 售前IP-L3 周末班 2024/6/15 陈老师 RHCA442 晚班 2024/6/17邹老师 数通HCIE 晚班 2024/6/24阮老师 云计算HCIE直通车晚班 2024/6/25 曾老师 售前IT-L3 周末班 2024/6/29 伍老师 数通HCIP 晚班 2024/7/1杨老师 存储直通车 晚班 2024/7/1 高…

从ES的JVM配置起步思考JVM常见参数优化

目录 一、真实查看参数 (一)-XX:PrintCommandLineFlags (二)-XX:PrintFlagsFinal 二、堆空间的配置 (一)默认配置 (二)配置Elasticsearch堆内存时,将初始大小设置为…

Windows Server 2008 r2 IIS .NET

Windows Server 2008 r2 IIS .NET

CleanMyMacX4.15.4如何优化苹果电脑系统缓存,告别MacBook卡顿,提升mac电脑性能

你是否曾为苹果电脑存储空间不够而烦恼?是否曾因系统运行缓慢而苦恼?别担心,今天我要给大家种草一个神器——CleanMyMac!这款软件可以帮助你轻松解决苹果电脑的种种问题,让你的电脑焕然一新! 让我来给大家介…

springboot原理篇-配置优先级

springboot原理篇-配置优先级(一) springboot项目一个支持三种配置文件 application.propertiesapplication.ymlapplication.yaml 其中,优先级的顺序是: application.properties > application.yml > application.yaml 也…

基于Nios-II实现流水灯

基于Nios-II实现流水灯的主要原理 涉及到FPGA(现场可编程门阵列)上的嵌入式软核处理器Nios II与LED控制逻辑的结合。以下是详细的实现原理,分点表示并归纳: Nios II软核处理器介绍: Nios II是Altera公司推出的一种应用…

Vue笔记(二)

Vue(一):Vue笔记(一)-CSDN博客 目录 综合案例:水果购物车 生命周期 1.生命周期&生命周期四个阶段 2.生命周期函数(钩子函数【8个】) 3.生命周期两个案例 初始化渲染 自动获…

Node.js和npm的安装及配置

Node.js 是一个基于 Chrome V8 引擎的 JavaScript 运行环境。Node.js 使用了一个事件驱动、非阻塞 I/O 的模型。 npm(node package manager)是一个 Node.js 包管理和分发工具,也是整个 Node.js 社区最流行、支持第三方模块最多的包管理器。使…

2024.6.14 作业 xyt

使用手动连接,将登录框中的取消按钮使用第二中连接方式,右击转到槽,在该槽函数中,调用关闭函数 将登录按钮使用qt4版本的连接到自定义的槽函数中,在槽函数中判断ui界面上输入的账号是否为"admin"&#xff0c…

Unity2D计算两个物体的距离

1.首先新建一个场景并添加2个物体 2.创建一个脚本并编写代码 using UnityEngine;public class text2: MonoBehaviour {public GameObject gameObject1; // 第一个物体public GameObject gameObject2; // 第二个物体void Update(){// 计算两个物体之间的距离float distance Vec…