机器学习中参数优化调试方法

1  超参数优化

图片

调参即超参数优化,是指从超参数空间中选择一组合适的超参数,以权衡好模型的偏差(bias)和方差(variance),从而提高模型效果及性能。常用的调参方法有:

  • 人工手动调参

  • 网格/随机搜索(Grid / Random Search)

  • 贝叶斯优化(Bayesian Optimization)

图片

注:超参数 vs 模型参数差异 超参数是控制模型学习过程的(如网络层数、学习率);模型参数是通过模型训练学习后得到的(如网络最终学习到的权重值)。

2  人工调参

手动调参需要结合数据情况及算法的理解,选择合适调参的优先顺序及参数的经验值。

不同模型手动调参思路会有差异,如随机森林是一种bagging集成的方法,参数主要有n_estimators(子树的数量)、max_depth(树的最大生长深度)、max_leaf_nodes(最大叶节点数)等。(此外其他参数不展开说明) 对于n_estimators:通常越大效果越好。参数越大,则参与决策的子树越多,可以消除子树间的随机误差且增加预测的准度,以此降低方差与偏差。对于max_depth或max_leaf_nodes:通常对效果是先增后减的。取值越大则子树复杂度越高,偏差越低但方差越大。

图片

3 网格/随机搜索

图片

  • 网格搜索(grid search),是超参数优化的传统方法,是对超参数组合的子集进行穷举搜索,找到表现最佳的超参数子集。

  • 随机搜索(random search),是对超参数组合的子集简单地做固定次数的随机搜索,找到表现最佳的超参数子集。对于规模较大的参数空间,采用随机搜索往往效率更高。

import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RandomizedSearchCV
from sklearn.ensemble import RandomForestClassifier# 选择模型 
model = RandomForestClassifier()
# 参数搜索空间
param_grid = {'max_depth': np.arange(1, 20, 1),'n_estimators': np.arange(1, 50, 10),'max_leaf_nodes': np.arange(2, 100, 10)}
# 网格搜索模型参数
grid_search = GridSearchCV(model, param_grid, cv=5, scoring='f1_micro')
grid_search.fit(x, y)
print(grid_search.best_params_)
print(grid_search.best_score_)
print(grid_search.best_estimator_)
# 随机搜索模型参数
rd_search = RandomizedSearchCV(model, param_grid, n_iter=200, cv=5, scoring='f1_micro')
rd_search.fit(x, y)
print(rd_search.best_params_)
print(rd_search.best_score_)
print(rd_search.best_estimator_)

4 贝叶斯优化

贝叶斯优化(Bayesian Optimization) 与网格/随机搜索最大的不同,在于考虑了历史调参的信息,使得调参更有效率。(但在高维参数空间下,贝叶斯优化复杂度较高,效果会近似随机搜索。)

图片

4.1 算法简介

贝叶斯优化思想简单可归纳为两部分:

  • 高斯过程(GP):以历史的调参信息(Observation)去学习目标函数的后验分布(Target)的过程。

  • 采集函数(AC):由学习的目标函数进行采样评估,分为两种过程:1、开采过程:在最可能出现全局最优解的参数区域进行采样评估。2、勘探过程:兼顾不确定性大的参数区域的采样评估,避免陷入局部最优。

4.2 算法流程

for循环n次迭代:采集函数依据学习的目标函数(或初始化)给出下个开采极值点 Xn+1;评估超参数Xn+1得到表现Yn+1;加入新的Xn+1、Yn+1数据样本,并更新高斯过程模型;

图片

"""
随机森林分类Iris使用贝叶斯优化调参
"""
import numpy as np
from hyperopt import hp, tpe, Trials, STATUS_OK, Trials, anneal
from functools import partial
from hyperopt.fmin import fmin
from sklearn.metrics import f1_score
from sklearn.ensemble import RandomForestClassifierdef model_metrics(model, x, y):""" 评估指标 """yhat = model.predict(x)return  f1_score(y, yhat,average='micro')def bayes_fmin(train_x, test_x, train_y, test_y, eval_iters=50):"""bayes优化超参数eval_iters:迭代次数"""def factory(params):"""定义优化的目标函数"""fit_params = {'max_depth':int(params['max_depth']),'n_estimators':int(params['n_estimators']),'max_leaf_nodes': int(params['max_leaf_nodes'])}# 选择模型model = RandomForestClassifier(**fit_params)model.fit(train_x, train_y)# 最小化测试集(- f1score)为目标train_metric = model_metrics(model, train_x, train_y)test_metric = model_metrics(model, test_x, test_y)loss = - test_metricreturn {"loss": loss, "status":STATUS_OK}# 参数空间space = {'max_depth': hp.quniform('max_depth', 1, 20, 1),'n_estimators': hp.quniform('n_estimators', 2, 50, 1), 'max_leaf_nodes': hp.quniform('max_leaf_nodes', 2, 100, 1)}# bayes优化搜索参数best_params = fmin(factory, space, algo=partial(anneal.suggest,), max_evals=eval_iters, trials=Trials(),return_argmin=True)# 参数转为整型best_params["max_depth"] = int(best_params["max_depth"])best_params["max_leaf_nodes"] = int(best_params["max_leaf_nodes"])best_params["n_estimators"] = int(best_params["n_estimators"])return best_params#  搜索最优参数
best_params = bayes_fmin(train_x, test_x, train_y, test_y, 100)
print(best_params)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/163740.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[python 刷题] 19 Remove Nth Node From End of List

[python 刷题] 19 Remove Nth Node From End of List 题目: Given the head of a linked list, remove the nth node from the end of the list and return its head. 题目说的是就是移除倒数第 n 个结点,如官方给的案例: 这里提供的 n 就是…

运筹优化 | 分支定界算法(Branch and Bound)Python求解整数规划

from gurobipy import * import copy import numpy as np import matplotlib.pyplot as plt plt.rcParams[font.sans-serif][SimHei]定义了一个线性松弛问题,并用Gurobi求解 initial_LP Model(initial LP) # 定义变量initial_LP,调用Gurobi的Model&…

SVN一直报错Error running context: 由于目标计算机积极拒绝,无法连接。解决办法【杭州多测师_王sir】...

一、发现SVN一直报错Error running context: 由于目标计算机积极拒绝,无法连接。 二、没有启动 VisualSVN Server。cmd--> services.msc打开本地服务。查看VisualSVN的三个服务的启动类型,建议选择“手动”,不能选择“禁用”,选…

电脑办公助手之桌面便签,助力高效率办公

在现代办公的快节奏中,大家有应接不暇的工作,每天面对着复杂的工作任务,总感觉时间不够用,而且工作无厘头。对于这种状态,大家可以选择在电脑上安装一款好用的办公便签软件来辅助日常办公。 敬业签是一款专为办公人士…

【Release】Photoshop ICO file format plug-in 3.0

【Introduction】 The Photoshop ICO plug-in is a file format plug-in developed for Photoshop, which allows Photoshop to directly read and write ICO format files. Because Photoshop has powerful pixel bitmap editing functions, it has many users and a good us…

“构建交互式用户界面的自定义组件应用与界面布局设置“

目录 引言自定义组件应用设置界面布局投票界面布局及实现投票选项界面总结 引言 在软件开发中,用户界面设计是至关重要的一环。良好的界面设计可以提升用户体验、增加用户黏性,并提高软件的易用性。本篇博客将介绍如何利用自定义组件应用和界面布局设置…

了解 AI :了解 AI 方面的一些术语 (中英文对照)

本心、输入输出、结果 文章目录 了解 AI :了解 AI 方面的一些术语 (中英文对照)前言AI 方面的一些术语 (中英文对照)AI 方面的一些术语 (中英文对照) - 文字版弘扬爱国精神 了解 AI &#xff1a…

PowerShell系列(十二):PowerShell Cmdlet高级参数介绍(二)

目录 1、ErrorVariable 错误变量 2、OutVariable 结果输出 3、OutBuffer 输出Buffer定义 4、PipelineVariable管道参数 今天给大家讲解PowerShell Cmdlet高级参数第二部分相关的知识,希望对大家学习PowerShell能有所帮助! 1、ErrorVariable 错误变量…

浏览器缓存

浏览器的缓存是性能优化中最高效的方法看,他可以显著减少网络传输带来的损耗。 浏览器缓存可以帮助以下两种情况下进行优化: 发起请求:使用缓存不发起的请求浏览器响应:后端与前端数据是一致的,那么没有必要再将数据传…

网络安全内网渗透之信息收集--systeminfo查看电脑有无加域

systeminfo输出的内容很多,包括主机名、OS名称、OS版本、域信息、打的补丁程序等。 其中,查看电脑有无加域可以快速搜索: systeminfo|findstr "域:" 输出结果为WORKGROUP,可见该机器没有加域: systeminfo…

Docker 安装zookeeper

一、安装单机版 1、拉取镜像 docker pull zookeeper2、创建挂载目录 mkdir -p /mydata/zookeeper/{conf,data,logs}3、新建配置文件 cd /mydata/zookeeper/conf vi zoo.cfgdataDir/data dataLogDir/logs tickTime2000 initLimit10 syncLimit5 clientPort21814、单机主机启…

elasticsearch常用命令

Elasticsearch概念 ElasticsearchmysqlIndex(索引)数据库Type(类型)表Documents(文档)行Fields列 常用命令 索引 # 索引初始化,number_of_shards:分片数,不可修改;number_of_replicas:副本数,可修改 PUT lagou {"settings…

基于深度学习的目标检测模型综述

基于深度学习的目标检测模型综述 一 概论目标检测主要挑战评估指标 二 展望 一 概论 目标检测是目标分类的自然延伸,目标分类仅旨在识别图像中的目标。目标检测的目标是检测预定义类的所有实例并通过轴对齐的框提供其在图像中的初略定位。检测器应能够识别所有目标…

jmeter监听每秒点击数(Hits per Second)

jmeter监听每秒点击数(Hits per Second) 下载插件添加监听器执行压测,监听结果 下载插件 点击选项,点击Plugins Manager (has upgrades),点击Available Plugins,搜索5 Additional Graphs安装。 添加监听…

什么是热阻?

电流流过导体时,在导体两端会产生电压差,这个电压差除以流过导体的电流就是这个导体的电阻,单位是欧姆。这就是欧姆定律,大家都知道的东西。 当热源的热量在物体中传递时,在物体上也会产生温度差,这个温度差…

通过SVN拉取项目 步骤

方法一:文件夹方式 首先新建一个空的文件夹,例如,名为“demo”的空文件夹 在这个空的文件夹中鼠标右键,点击SVN Checkout 会出现下图所示的页面,第一个输入框是svn的项目地址,第二个输入框是拉取项目所放的…

安防监控系统EasyCVR视频汇聚平台设备树收藏按钮的细节优化

视频监控TSINGSEE青犀视频平台EasyCVR能在复杂的网络环境中,将分散的各类视频资源进行统一汇聚、整合、集中管理,在视频监控播放上,TSINGSEE青犀视频安防监控汇聚平台可支持1、4、9、16个画面窗口播放,可同时播放多路视频流&#…

Java学习

目录 一、变量 二、运算 三、判断和循环语句 四、数组 五、方法 六、类 七、字符串 八、static 九、继承 十、多态 十一、包 十二、final 十三、抽象类 十四、接口 十五、嵌套类 一、变量 1、byte范围【-128,127】 2、long变量后面要写l,float变量…

单链表算法经典OJ题

目录 1、移除链表元素 2、翻转链表 3、合并两个有序链表 4、获取链表的中间结点 5、环形链表解决约瑟夫问题 6、分割链表 1、移除链表元素 203. 移除链表元素 - 力扣(LeetCode) typedef struct ListNode LSNode; struct ListNode* remove…

C#冒泡排序算法

冒泡排序实现原理 冒泡排序是一种简单的排序算法,其原理如下: 从待排序的数组的第一个元素开始,依次比较相邻的两个元素。 如果前面的元素大于后面的元素(升序排序),则交换这两个元素的位置,使…