随机森林超参数的网格优化(机器学习的精华--调参)

随机森林超参数的网格优化(机器学习的精华–调参)

随机森林各个参数对算法的影响

影响力参数
⭐⭐⭐⭐⭐
几乎总是具有巨大影响力
n_estimators(整体学习能力)
max_depth(粗剪枝)
max_features(随机性)
⭐⭐⭐⭐
大部分时候具有影响力
max_samples(随机性)
class_weight(样本均衡)
⭐⭐
可能有大影响力
大部分时候影响力不明显
min_samples_split(精剪枝)
min_impurity_decrease(精剪枝)
max_leaf_nodes(精剪枝)
criterion(分枝敏感度)

当数据量足够大时,几乎无影响
random_state
ccp_alpha(结构风险)

探索的第一步(学习曲线)

先看看n_estimators的学习曲线(当然我们也可以看一下其他参数的学习曲线)

#参数潜在取值,由于现在我们只调整一个参数,因此参数的范围可以取大一些、取值也可以更密集
Option = [1,*range(5,101,5)]#生成保存模型结果的arrays
trainRMSE = np.array([])
testRMSE = np.array([])
trainSTD = np.array([])
testSTD = np.array([])#在参数取值中进行循环
for n_estimators in Option:#按照当下的参数,实例化模型reg_f = RFR(n_estimators=n_estimators,random_state=83)#实例化交叉验证方式,输出交叉验证结果cv = KFold(n_splits=5,shuffle=True,random_state=83)result_f = cross_validate(reg_f,X,y,cv=cv,scoring="neg_mean_squared_error",return_train_score=True,n_jobs=-1)#根据输出的MSE进行RMSE计算train = abs(result_f["train_score"])**0.5test = abs(result_f["test_score"])**0.5#将本次交叉验证中RMSE的均值、标准差添加到arrays中进行保存trainRMSE = np.append(trainRMSE,train.mean()) #效果越好testRMSE = np.append(testRMSE,test.mean())trainSTD = np.append(trainSTD,train.std()) #模型越稳定testSTD = np.append(testSTD,test.std())

定义画图函数

def plotCVresult(Option,trainRMSE,testRMSE,trainSTD,testSTD):#一次交叉验证下,RMSE的均值与std的绘图xaxis = Optionplt.figure(figsize=(8,6),dpi=80)#RMSEplt.plot(xaxis,trainRMSE,color="k",label = "RandomForestTrain")plt.plot(xaxis,testRMSE,color="red",label = "RandomForestTest")#标准差 - 围绕在RMSE旁形成一个区间plt.plot(xaxis,trainRMSE+trainSTD,color="k",linestyle="dotted")plt.plot(xaxis,trainRMSE-trainSTD,color="k",linestyle="dotted")plt.plot(xaxis,testRMSE+testSTD,color="red",linestyle="dotted")plt.plot(xaxis,testRMSE-testSTD,color="red",linestyle="dotted")plt.xticks([*xaxis])plt.legend(loc=1)plt.show()plotCVresult(Option,trainRMSE,testRMSE,trainSTD,testSTD)

在这里插入图片描述

当绘制学习曲线时,我们可以很容易找到泛化误差开始上升、或转变为平稳趋势的转折点。因此我们可以选择转折点或转折点附近的n_estimators取值,例如20。然而,n_estimators会受到其他参数的影响,例如:

  • 单棵决策树的结构更简单时(依赖剪枝时),可能需要更多的树
  • 单棵决策树训练的数据更简单时(依赖随机性时),可能需要更多的树

因此n_estimators的参数空间可以被确定为range(20,100,5),如果你比较保守,甚至可以确认为是range(15,25,5)。

探索第二步(随机森林中的每一棵决策树)

属性.estimators_,查看森林中所有的树

参数参数含义对应属性属性含义
n_estimators树的数量reg.estimators_森林中所有树对象
max_depth允许的最大深度.tree_.max_depth0号树实际的深度
max_leaf_nodes允许的最大
叶子节点量
.tree_.node_count0号树实际的总节点量
min_sample_split分枝所需最小
样本量
.tree_.n_node_samples0号树每片叶子上实际的样本量
min_weight_fraction_leaf分枝所需最小
样本权重
tree_.weighted_n_node_samples0号树每片叶子上实际的样本权重
min_impurity_decrease分枝所需最小
不纯度下降量
.tree_.impurity
.tree_.threshold
0号树每片叶子上的实际不纯度
0号树每个节点分枝后不纯度下降量

可以通过对上述属性的调用查看当前模型每一棵树的各个属性,对我们对于参数范围的选择给予帮助。

正戏开始(网格搜索)

import numpy as np
import pandas as pd
import sklearn
import matplotlib as mlp
import matplotlib.pyplot as plt
import time #计时模块time
from sklearn.ensemble import RandomForestRegressor as RFR
from sklearn.model_selection import cross_validate, KFold, GridSearchCV
# 定义RMSE函数
def RMSE(cvresult,key):return (abs(cvresult[key])**0.5).mean()
#导入波士顿房价数据
data = pd.read_csv(r"D:\Pythonwork\datasets\House Price\train_encode.csv",index_col=0)
#查看数据
data.head()
X = data.iloc[:,:-1]
y = data.iloc[:,-1]

在这里插入图片描述

Step 1.建立benchmark

# 定义回归器
reg = RFR(random_state=83)
# 进行5折交叉验证
cv = KFold(n_splits=5,shuffle=True,random_state=83)result_pre_adjusted = cross_validate(reg,X,y,cv=cv,scoring="neg_mean_squared_error",return_train_score=True,verbose=True,n_jobs=-1)
#分别查看训练集和测试集在调参之前的RMSE
RMSE(result_pre_adjusted,"train_score")
RMSE(result_pre_adjusted,"test_score")

结果分别是
11177.272008319653
30571.26665524217

Step 2.创建参数空间

param_grid_simple = {"criterion": ["squared_error","poisson"], 'n_estimators': [*range(20,100,5)], 'max_depth': [*range(10,25,2)], "max_features": ["log2","sqrt",16,32,64,"auto"], "min_impurity_decrease": [*np.arange(0,5,10)]}

Step 3.实例化用于搜索的评估器、交叉验证评估器与网格搜索评估器

#n_jobs=4/8,verbose=True
reg = RFR(random_state=1412,verbose=True,n_jobs=-1)
cv = KFold(n_splits=5,shuffle=True,random_state=83)
search = GridSearchCV(estimator=reg,param_grid=param_grid_simple,scoring = "neg_mean_squared_error",verbose = True,cv = cv,n_jobs=-1)

Step 4.训练网格搜索评估器

#=====【TIME WARNING: 15mins】=====#   当然博主的电脑比较慢
start = time.time()
search.fit(X,y)
print(time.time() - start)

Step 5.查看结果

search.best_estimator_# 直接使用最优参数进行建模
ad_reg = RFR(n_estimators=85, max_depth=23, max_features=16, random_state=83)cv = KFold(n_splits=5,shuffle=True,random_state=83)
result_post_adjusted = cross_validate(ad_reg,X,y,cv=cv,scoring="neg_mean_squared_error",return_train_score=True,verbose=True,n_jobs=-1)
#查看调参后的结果
RMSE(result_post_adjusted,"train_score")
RMSE(result_post_adjusted,"test_score")

得出结果
11000.81099038192
28572.070208366855

调参前后对比

#默认值下随机森林的RMSE
xaxis = range(1,6)
plt.figure(figsize=(8,6),dpi=80)
#RMSE
plt.plot(xaxis,abs(result_pre_adjusted["train_score"])**0.5,color="green",label = "RF_pre_ad_Train")
plt.plot(xaxis,abs(result_pre_adjusted["test_score"])**0.5,color="green",linestyle="--",label = "RF_pre_ad_Test")
plt.plot(xaxis,abs(result_post_adjusted["train_score"])**0.5,color="orange",label = "RF_post_ad_Train")
plt.plot(xaxis,abs(result_post_adjusted["test_score"])**0.5,color="orange",linestyle="--",label = "RF_post_ad_Test")
plt.xticks([1,2,3,4,5])
plt.xlabel("CVcounts",fontsize=16)
plt.ylabel("RMSE",fontsize=16)
plt.legend()
plt.show()

在这里插入图片描述
不难发现,网格搜索之后的模型过拟合程度减轻,且在训练集与测试集上的结果都有提高,可以说从根本上提升了模型的基础能力。我们还可以根据网格的结果继续尝试进行其他调整,来进一步降低模型在测试集上的RMSE。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/251787.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

系统架构设计师考试大纲2023

一、 考试方式(机考) 考试采取科目连考、 分批次考试的方式, 连考的第一个科目作答结束交卷完成后自动进 入第二个科目, 第一个科目节余的时长可为第二个科目使用。 高级资格: 综合知识科目考试时长 150 分钟&#xff…

使用goland IDE编写go windows ui

最近突发奇想,想实现一款工作节奏的提示安排小闹钟。那首先解决的就是UI。本人擅长go语言。那go在windows ui的探索肯定有人做过了吧。一查还真有,通过知乎,csdn等查到目前支持最好的就是walk库了。那走起试试。 一、拷贝go代码 将官网例子…

idea中找到所有的TODO

idea中找到所有的TODO (1)快捷键 Alt6 (2)View -> Tool Windows -> TODO

kafka客户端生产者消费者kafka可视化工具(可生产和消费消息)

点击下载《kafka客户端生产者消费者kafka可视化工具(可生产和消费消息)》 1. 前言 因在工作中经常有用到kafka做消息的收发,每次调试过程中,经常需要查看接收的消息内容以及人为发送消息,从网上搜寻了一下&#xff0…

css1基础选择器

大纲 一.标签选择器 比较简单,前面直接写目标标签 二.类选择器 应用 例子 三.多类名选择器(调用时中间用空格隔开) 四.id选择器 应用 五.通配符选择器 应用 六.总结

OpenResty 安装

安装OpenResty 1.安装 首先你的Linux虚拟机必须联网 1)安装开发库 首先要安装OpenResty的依赖开发库,执行命令: yum install -y pcre-devel openssl-devel gcc --skip-broken2)安装OpenResty仓库 你可以在你的 CentOS 系统中…

VUE开发记录

1、VUE模板传递参数到JS方法 <select-language :value"item.language" change"selectLanguage($event, key)"></select-language>selectLanguage(value, key){console.log(value, key) }, 2、Element框架el-form-item自定义label和内容 <…

2024 年 5 款适用于免费 iPhone 数据恢复的工具软件

搜索一下&#xff0c;你会发现许多付费或免费的iPhone数据恢复工具声称它们可以帮助你以很高的成功率找回所有丢失的数据。然而&#xff0c;这正是问题所在。真的很难做出选择。为了进一步帮助您解决数据丢失问题&#xff0c;我们在此列出了 5 款最好的免费 iPhone 恢复软件供您…

JS第二天、原型、原型链、正则

☆☆☆☆ 什么是原型&#xff1f; 构造函数的prototype 就是原型 专门保存所有子对象共有属性和方法的对象一个对象的原型就是它的构造函数的prototype属性的值。prototype是哪来的&#xff1f;所有的函数都有一个prototype属性当函数被创建的时候&#xff0c;prototype属性…

Palworld幻兽帕鲁自建服务器32人联机开黑!

玩转幻兽帕鲁服务器&#xff0c;阿里云推出新手0基础一键部署幻兽帕鲁服务器教程&#xff0c;傻瓜式一键部署&#xff0c;3分钟即可成功创建一台Palworld专属服务器&#xff0c;成本仅需26元&#xff0c;阿里云服务器网aliyunfuwuqi.com分享2024年新版基于阿里云搭建幻兽帕鲁服…

ele-h5项目使用vue3+vite+vant4开发:第四节、业务组件-SearchView组件开发

需求分析 展示切换动画搜索框输入文字&#xff0c;自动发送请求搜索结果展示搜索状态维护历史搜索展示&#xff0c;点击历史搜索后发送请求历史搜索更多切换动画效果 <script setup lang"ts"> import OpSearch from /components/OpSearch.vue import { ref } f…

视频播放器nPlayer

nPlayer是一款功能强大的媒体播放器&#xff0c;适用于各种操作系统&#xff0c;包括iOS、Android、Windows和Mac等。它支持多种音频和视频格式&#xff0c;包括MP4、AVI、FLV、MKV等&#xff0c;并提供了高质量的音视频播放和流畅的播放速度。 nPlayer具有多种功能和特点&…

基于PSO-BP神经网络的风电功率MATLAB预测程序

微❤关注“电气仔推送”获得资料&#xff08;专享优惠&#xff09; 参考文献 基于风电场运行特性的风电功率预测及应用分析——倪巡天 资源简介 由于自然风具有一定的随机性、不确定性与波动性&#xff0c;这将会使风电场的功率预测受到一定程度的影响&#xff0c;它们之间…

【Linux】EXT2文件系统 | 磁盘分区块组 | inode

文章目录 一、前言二、EXT2文件系统 - 逻辑存储结构&#x1f4be;分区&#xff08;Partition&#xff09;分区的概念每个分区的内容Linux下查询磁盘分区 &#x1f4be;块组&#xff08;Block Group&#xff09;磁盘格式化每个块组的内容1. Superblock&#xff08;超级块&#x…

Log360,引入全新安全与风险管理功能,助力企业积极抵御网络威胁

ManageEngine在其SIEM解决方案中推出了安全与风险管理新功能&#xff0c;企业现在能够更主动地减轻内部攻击和防范入侵。 SIEM 这项新功能为Log360引入了安全与风险管理仪表板&#xff0c;Log360是ManageEngine的统一安全信息与事件管理&#xff08;SIEM&#xff09;解决方案…

算法学习——华为机考题库2(HJ11 - HJ20)

算法学习——华为机考题库2&#xff08;HJ11 - HJ20&#xff09; HJ11 数字颠倒 描述 输入一个整数&#xff0c;将这个整数以字符串的形式逆序输出 程序不考虑负数的情况&#xff0c;若数字含有0&#xff0c;则逆序形式也含有0&#xff0c;如输入为100&#xff0c;则输出为0…

如何在 Microsoft Azure 上部署和管理 Elastic Stack

作者&#xff1a;来自 Elastic Osman Ishaq Elastic 用户可以从 Azure 门户中查找、部署和管理 Elasticsearch。 此集成提供了简化的入门体验&#xff0c;所有这些都使用你已知的 Azure 门户和工具&#xff0c;因此你可以轻松部署 Elastic&#xff0c;而无需注册外部服务或配置…

JSR303参数校验-SpringMVC

文章目录 JSR303技术标准简介JSR303标准几个具体实现框架validation-apijakarta.validation-apihibernate-validatorspring-boot-starter-validation Spring Validationjavax.validation.constraints包下提供的注解org.hibernate.validator.constraints包扩展的注解校验注解默认…

《QDebug 2024年1月》

一、Qt Widgets 问题交流 1. 二、Qt Quick 问题交流 1.Repeator 的 delegate 在 remove 移除时的注意事项 Qt Bug Tracker&#xff1a;https://bugreports.qt.io/browse/QTBUG-47500 Repeator 在调用 remove 函数之后&#xff0c;对应的 Item 会立即释放&#xff0c;后续就…

Apache Doris 整合 FLINK CDC + Iceberg 构建实时湖仓一体的联邦查询

1概况 本文展示如何使用 Flink CDC Iceberg Doris 构建实时湖仓一体的联邦查询分析&#xff0c;Doris 1.1版本提供了Iceberg的支持&#xff0c;本文主要展示Doris和Iceberg怎么使用&#xff0c;大家按照步骤可以一步步完成。完整体验整个搭建操作的过程。 2系统架构 我们整…