机器学习参数调优

手动调参  

分析影响模型的参数,设计步长进行交叉验证

我们以随机森林为例:

本文将使用sklearn自带的乳腺癌数据集,建立随机森林,并基于泛化误差(Genelization Error)与模型复杂度的关系来对模型进行调参,从而使模型获得更高的得分。

泛化误差是机器学习中,用来衡量模型在未知数据上的准确率的指标;

1、导入相关包

from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import GridSearchCV
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

2、导入乳腺癌数据集,建立模型

由于sklearn自带的数据集已经很工整了,所以无需做预处理,直接使用。

# 导入乳腺癌数据集
data = load_breast_cancer()# 建立随机森林
rfc = RandomForestClassifier(n_estimators=100, random_state=90)用交叉验证计算得分
score_pre = cross_val_score(rfc, data.data, data.target, cv=10).mean()
score_pre

3、调参

随机森林主要的参数有n_estimators(子树的数量)、max_depth(树的最大生长深度)、min_samples_leaf(叶子的最小样本数量)、min_samples_split(分支节点的最小样本数量)、max_features(最大选择特征数)。它们对随机森林模型复杂度的影响如下图所示:

n_estimators是影响程度最大的参数,我们先对其进行调整

# 调参,绘制学习曲线来调参n_estimators(对随机森林影响最大)
score_lt = []# 每隔10步建立一个随机森林,获得不同n_estimators的得分
for i in range(0,200,10):rfc = RandomForestClassifier(n_estimators=i+1,random_state=90)score = cross_val_score(rfc, data.data, data.target, cv=10).mean()score_lt.append(score)
score_max = max(score_lt)
print('最大得分:{}'.format(score_max),'子树数量为:{}'.format(score_lt.index(score_max)*10+1))# 绘制学习曲线
x = np.arange(1,201,10)
plt.subplot(111)
plt.plot(x, score_lt, 'r-')
plt.show()

如图所示,当n_estimators从0开始增大至21时,模型准确度有肉眼可见的提升。这也符合随机森林的特点:在一定范围内,子树数量越多,模型效果越好。而当子树数量越来越大时,准确率会发生波动,当取值为41时,获得最大得分。


框架自动调参

Optuna是一个自动化的超参数优化软件框架,专门为机器学习而设计。 这里对其进行简单的入门介绍,详细的学习可以参考:https://github.com/optuna/optuna

 

optuna是一个使用python编写的超参数调节框架。一个极简的 optuna 的优化程序中只有三个最核心的概念,目标函数(objective),单次试验(trial),和研究(study)。其中 objective 负责定义待优化函数并指定参/超参数数范围,trial 对应着 objective 的单次执行,而 study 则负责管理优化,决定优化的方式,总试验的次数、试验结果的记录等功能。

  • objective:根据目标函数的优化Session,由一系列的trail组成。
  • trail:根据目标函数作出一次执行。
  • study:根据多次trail得到的结果发现其中最优的超参数。

随机森林iris数据集调优

from sklearn.datasets import load_iris
x, y = load_iris().data, load_iris().target
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
def objective(trial):global x, yX_train, X_test, y_train, y_test=train_test_split(x, y, train_size=0.3)# 数据集划分param = {"n_estimators": trial.suggest_int('n_estimators', 5, 20),"criterion": trial.suggest_categorical('criterion', ['gini','entropy'])}dt_clf = RandomForestClassifier(**param)dt_clf.fit(X_train, y_train)pred_dt = dt_clf.predict(X_test)score = (y_test==pred_dt).sum() / len(y_test)return score
study=optuna.create_study(direction='maximize')
n_trials=20 # try50次
study.optimize(objective, n_trials=n_trials)
print(study.best_value)
print(study.best_params)
#######################################结果######################################
[32m[I 2021-04-12 16:20:13,627][0m A new study created in memory with name: no-name-47fe20d7-e9c0-4bed-bc6d-8113edae0bec[0m
[32m[I 2021-04-12 16:20:13,652][0m Trial 0 finished with value: 0.9523809523809523 and parameters: {'n_estimators': 15, 'criterion': 'gini'}. Best is trial 0 with value: 0.9523809523809523.[0m
[32m[I 2021-04-12 16:20:13,662][0m Trial 1 finished with value: 0.9523809523809523 and parameters: {'n_estimators': 5, 'criterion': 'gini'}. Best is trial 0 with value: 0.9523809523809523.[0m
[32m[I 2021-04-12 16:20:13,680][0m Trial 2 finished with value: 0.9428571428571428 and parameters: {'n_estimators': 15, 'criterion': 'entropy'}. Best is trial 0 with value: 0.9523809523809523.[0m
[32m[I 2021-04-12 16:20:13,689][0m Trial 3 finished with value: 0.9523809523809523 and parameters: {'n_estimators': 7, 'criterion': 'gini'}. Best is trial 0 with value: 0.9523809523809523.[0m
[32m[I 2021-04-12 16:20:13,704][0m Trial 4 finished with value: 0.9428571428571428 and parameters: {'n_estimators': 14, 'criterion': 'gini'}. Best is trial 0 with value: 0.9523809523809523.[0m
[32m[I 2021-04-12 16:20:13,721][0m Trial 5 finished with value: 0.9714285714285714 and parameters: {'n_estimators': 14, 'criterion': 'gini'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,733][0m Trial 6 finished with value: 0.9619047619047619 and parameters: {'n_estimators': 10, 'criterion': 'gini'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,753][0m Trial 7 finished with value: 0.9619047619047619 and parameters: {'n_estimators': 18, 'criterion': 'gini'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,764][0m Trial 8 finished with value: 0.9714285714285714 and parameters: {'n_estimators': 8, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,771][0m Trial 9 finished with value: 0.9333333333333333 and parameters: {'n_estimators': 5, 'criterion': 'gini'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,795][0m Trial 10 finished with value: 0.9333333333333333 and parameters: {'n_estimators': 20, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,809][0m Trial 11 finished with value: 0.9333333333333333 and parameters: {'n_estimators': 9, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,827][0m Trial 12 finished with value: 0.9428571428571428 and parameters: {'n_estimators': 12, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,842][0m Trial 13 finished with value: 0.9238095238095239 and parameters: {'n_estimators': 11, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,855][0m Trial 14 finished with value: 0.9428571428571428 and parameters: {'n_estimators': 8, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,880][0m Trial 15 finished with value: 0.9428571428571428 and parameters: {'n_estimators': 18, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,899][0m Trial 16 finished with value: 0.9428571428571428 and parameters: {'n_estimators': 13, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,911][0m Trial 17 finished with value: 0.9714285714285714 and parameters: {'n_estimators': 7, 'criterion': 'gini'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,933][0m Trial 18 finished with value: 0.9428571428571428 and parameters: {'n_estimators': 17, 'criterion': 'entropy'}. Best is trial 5 with value: 0.9714285714285714.[0m
[32m[I 2021-04-12 16:20:13,948][0m Trial 19 finished with value: 0.9523809523809523 and parameters: {'n_estimators': 11, 'criterion': 'gini'}. Best is trial 5 with value: 0.9714285714285714.[0m0.9714285714285714
{'n_estimators': 14, 'criterion': 'gini'}
##################################################################################

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/81620.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TOPIAM 社区版 1.0.0 发布,开源 IAM/IDaaS 企业身份管理平台

文章目录 产品概述系统架构功能列表管理端门户端 技术架构后续规划相关地址 ​Hi,亲爱的朋友们,今天是传统 24 节气中的立秋,秋天是禾谷成熟、收获的季节。经过长时间优化和迭代,TOPIAM 企业身份管控平台也迎来了当下的成长和收获…

PWNlab靶机渗透

安装靶机 下载地址:https://www/vulnhub.com/entry/pwnlab-init,158/ 信息收集: 收集靶机ip地址,由于搭建在本地使用kali自带命令 arp-scan -l nmap 扫描端口,服务 nmap -sV -p 1-65535 -A 靶机ip地址 漏洞探测 访问80端口地…

kubernetes 集群命令行工具 kubectl

1、kubectl 概述 kubectl是一种命令行工具,用于管理Kubernetes集群和与其相关的资源。通过kubectl,您可以查看和管理Pod、Deployment、Service、Volume、ConfigMap等资源,也可以创建、删除和更新它们。 kubectl还提供了许多其他功能&#x…

监控Kafka的关键指标

Kafka 架构 上面绿色部分 PRODUCER(生产者)和下面紫色部分 CONSUMER(消费者)是业务程序,通常由研发人员埋点解决监控问题,如果是 Java 客户端也会暴露 JMX 指标。组件运维监控层面着重关注蓝色部分的 BROKE…

LabVIEW开发高压配电设备振动信号特征提取与模式识别

LabVIEW开发高压配电设备振动信号特征提取与模式识别 矿用高压配电设备是井下供电系统中的关键设备之一,肩负着井下供配电和供电安全的双重任务,其工作状态直接影响着井下供电系统的安全性和可靠性。机械故障占配电总故障的70%。因此,机械故…

24届近5年江南大学自动化考研院校分析

今天给大家带来的是江南大学控制考研分析 满满干货~还不快快点赞收藏 一、江南大学 学校简介 江南大学(Jiangnan University)是国家“双一流”建设高校,“211工程”、“985工程优势学科创新平台”重点建设高校,入选…

Django架构图

1. Django 简介 基本介绍 Django 是一个由 Python 编写的一个开放源代码的 Web 应用框架 使用 Django,只要很少的代码,Python 的程序开发人员就可以轻松地完成一个正式网站所需要的大部分内容,并进一步开发出全功能的 Web 服务 Django 本身…

Vue2源码分析-day2

实现数组的响应式原理 首先我们在index.html中定义一个数组,并且打印实例 const vm new MVue({data() {return {name: "zhangsan",age: "16",hobby:[zhangsan,lisi]}} }) console.log(vm);我们会发现定义的数组每一项都有get和set方法虽然数…

线程池-手写线程池C++11版本(生产者-消费者模型)

本项目是基于C11的线程池。使用了许多C的新特性,包含不限于模板函数泛型编程、std::future、std::packaged_task、std::bind、std::forward完美转发、std::make_shared智能指针、decltype类型推断、std::unique_lock锁等C11新特性功能。 本项目有一定的上手难度。推…

计算机网络—TCP

这里写目录标题 TCP头格式有哪些为什么需要TCP,TCP工作在哪什么是TCP什么是TCP连接如何确定一个TCP连接TCP和UDP的区别,以及场景TCP和UDP能共用一个端口?TCP的建立TCP三次握手过程为什么是三次握手、不是两次、四次why每次建立连接&#xff0…

3.1 计算机网络和网络设备

数据参考:CISP官方 目录 计算机网络基础网络互联设备网络传输介质 一、计算机网络基础 1、ENIAC:世界上第一台计算机的诞生 1946年2月14日,宾夕法尼亚大学诞生了世界上第一台计算机,名为电子数字积分计算机(ENIAC…

Netty框架自带类DefaultEventExecutorGroup的作用,用来做业务的并发

一、DefaultEventExecutorGroup的用途 DefaultEventExecutorGroup 是 Netty 框架中的一个类,用于管理和调度事件处理器(EventExecutor)的组。在 Netty 中,事件处理是通过多线程来完成的,EventExecutor 是处理事件的基…

2023年中期奶粉行业分析报告(京东数据开放平台)

根据国家统计局和民政部数据公布,2022年中国结婚登记数创造了1980年(有数据公布)以来的历史新低,共计683.3万对。相较于2013年巅峰时期的数据,2022年全国结婚登记对数已接近“腰斩”。 2023年“520”期间的结婚登记数…

【MySQL】

这里写目录标题 MySQL架构一条sql执行流程MySQL数据存放电脑位置ibd文件结构行溢出是什么MySQL行记录存储格式索引为什么InnoDB选择B树作为索引数据结构什么时候需要创建索引优化索引方法InnoDB内部怎么存储数据B 树如何进行查询聚簇索引和二级索引为什么MySQL要采用B树作为索引…

【Spring Boot】Spring Boot项目的创建和文件配置

目录 一、为什么要学Spring Boot 1、Spring Boot的优点 二、创建Spring Boot项目 1、创建项目之前的准备工作 2、创建Spring Boot项目 3、项目目录的介绍 4、安装Spring Boot快速添加依赖的插件 5、在项目中写一个helloworld 三、Spring Boot的配置文件 1、配置文件的…

ELK 企业级日志分析系统(二)

目录 ELK Kiabana 部署(在 Node1 节点上操作) 1.安装 Kiabana 2.设置 Kibana 的主配置文件 3.启动 Kibana 服务 4.验证 Kibana 5.将 Apache 服务器的日志(访问的、错误的&#x…

知识图谱实战应用23-【知识图谱的高级用法】Neo4j图算法的Cypher查询语句实例

大家好,我是微学AI,今天给大家介绍一下知识图谱实战应用23-【知识图谱的高级用法】Neo4j图算法的Cypher查询语句实例,Neo4j图算法是一套在Neo4j图数据库上运行的算法集合。这些算法专门针对图数据结构进行设计,用于分析、查询和处理图数据。图算法可以帮助我们发现图中的模…

【Azure】office365邮箱测试的邮箱账号因频繁连接邮箱服务器而被限制连接 引起邮箱显示异常

azure微软office365邮箱会对频繁连接自身邮箱服务器的IP地址进行,连接邮箱服务器IP限制,也就是黑名单,释放时间不确定,但至少一天及以上。 解决办法,换一个IP,或者新注册一个office365邮箱再重试。 以下是…

使用 API Gateway Integrator 在 Quarkus 中实施适用于 AWS Lambda 的 OpenAPI

AWS API Gateway 集成使得使用符合 OpenAPI 标准的 Lambda Function 轻松实现 REST API。 关于开放API 它是一个 允许以标准方式描述 REST API 的规范。 OpenAPI规范 (OAS) 为 REST API 定义了与编程语言无关的标准接口描述。这使得人类和计算机都可以发现和理解服务的功能&am…

中介者模式(C++)

定义 用一个中介对象来封装(封装变化)一系列的对象交互。中介者使各对象不需要显式的相互引用(编译时依赖->运行时依赖),从而使其耦合松散(管理变化),而且可以独立地改变它们之间的交互。 应用场景 在软件构建过程中,经常会出现多个对象…