分类算法系列③:模型选择与调优 (Facebook签到位置预测)

目录

模型选择与调优

1、介绍

模型选择(Model Selection):

调优(Hyperparameter Tuning):

本章重点

2、交叉验证

介绍

为什么需要交叉验证

数据处理

3、⭐超参数搜索-网格搜索(Grid Search)

介绍

API

🔺Facebook签到位置预测K值调优


 🍃作者介绍:准大三网络工程专业在读,努力学习Java,涉猎深度学习,积极输出优质文章

⭐分类算法系列①:初识概念

⭐分类算法系列②:KNN(K-近邻)算法

🍁您的三连支持,是我创作的最大动力🌹

模型选择与调优

1、介绍

在机器学习中,模型选择(Model Selection)和调优(Hyperparameter Tuning)是优化模型性能的关键步骤。模型选择涉及选择合适的算法或模型架构,而调优则涉及调整模型的超参数以达到最佳性能。以下是关于这两个步骤的详细介绍:

模型选择(Model Selection):

模型选择是选择在给定任务中使用哪种算法或模型的过程。不同的算法适用于不同的问题,因此选择适当的算法对于取得良好的性能至关重要。模型选择通常涉及以下步骤:

  1. 问题定义:明确定义要解决的问题,例如分类、回归、聚类等。
  2. 数据准备:对数据进行预处理、清洗和特征工程,以确保数据适用于所选的算法。
  3. 候选模型:根据问题和数据类型,选择几种合适的模型作为候选,例如决策树、支持向量机、神经网络等。
  4. 训练和验证:使用交叉验证等技术,在训练数据上训练候选模型,并在验证集上评估其性能。
  5. 性能比较:比较不同模型在验证集上的性能指标,如准确率、精确度、召回率等。
  6. 选择最佳模型:根据性能比较,选择性能最佳的模型作为最终模型。

调优(Hyperparameter Tuning):

调优是指为机器学习模型选择最佳的超参数,以优化模型的性能。超参数是在训练过程之外设置的参数,它们决定了模型的结构和行为,如学习率、正则化参数、树的深度等。调优的目标是找到使模型在验证集上表现最佳的超参数组合。调优通常包括以下步骤:

  1. 选择超参数空间:选择要调优的超参数和它们的可能取值范围。
  2. 搜索方法:选择超参数搜索方法,如网格搜索、随机搜索、贝叶斯优化等。
  3. 交叉验证:使用交叉验证将数据分为训练集和验证集,以评估不同超参数组合的性能。
  4. 评价指标:选择适当的评价指标来衡量不同超参数组合的性能。
  5. 调优过程:根据选择的搜索方法,不断尝试不同的超参数组合,并记录它们的性能。
  6. 选择最佳组合:从调优过程中选择在验证集上性能最佳的超参数组合作为最终模型的超参数。

模型选择和调优是迭代过程,可能需要多次尝试不同的模型和超参数组合,以找到最适合任务的模型并达到最佳性能。使用交叉验证、可视化工具和自动化调优库(如scikit-learn中的GridSearchCV和RandomizedSearchCV)可以帮助更有效地进行模型选择和调优。

本章重点

本章重点是交叉验证!结合的示例是之前的Facebook签到位置问题,对之前使用KNN算法完成的Facebook签到位置预测进行调优,使其结果更加准确。

2、交叉验证

介绍

交叉验证(Cross-Validation)是一种用于评估机器学习模型性能的技术,它有助于更准确地估计模型在未知数据上的表现。交叉验证通过在不同的数据子集上进行多次训练和验证,提供了对模型泛化性能的更稳定估计。

在传统的训练-测试集划分中,数据被划分为训练集和测试集,然后使用训练集训练模型,使用测试集评估模型性能。然而,这种方法可能因为数据的划分方式而导致评估结果不稳定,特别是在数据量有限的情况下。交叉验证通过将数据划分为多个折(folds),多次进行训练和测试,从而克服了这些问题。

以下是交叉验证的常见方法:

  1. k折交叉验证(k-Fold Cross-Validation)
    • 将数据分为k个大小相似的折(folds)。
    • 每次将其中一个折作为验证集,其他k-1个折作为训练集。
    • 重复这个过程k次,每次选择不同的折作为验证集,其他折作为训练集。
    • 计算k次验证的平均性能作为最终性能评估。
  2. 留一交叉验证(Leave-One-Out Cross-Validation,LOOCV)
    • 将每个样本单独作为一个折,其他样本作为训练集。
    • 执行n次训练和验证,n为样本数量。
    • 计算n次验证的平均性能作为最终性能评估。
    • 适用于小样本数据集,但计算开销较大。
  3. 随机折交叉验证(Stratified k-Fold Cross-Validation)
    • 类似于k折交叉验证,但在划分折时会保持各个类别的比例相同。
    • 对于不均衡的数据集,这种方法可以更好地保持类别分布。

交叉验证的优势在于它能够提供更可靠的模型性能估计,因为每个样本都会被用于训练和验证,减少了数据划分可能引发的偶然性影响。交叉验证还有助于选择合适的模型和调整超参数,从而提高模型的泛化性能。在实际应用中,k折交叉验证是最常用的方法之一,但根据问题的特点和数据集的大小,选择适当的交叉验证方法非常重要。

为什么需要交叉验证

交叉验证目的:为了让被评估的模型更加准确可信

数据处理

一般情况下,数据分为训练集和测试集,但是为了让从训练得到模型结果更加准确。

做以下处理:

  1. 训练集:训练集+验证集
  2. 测试集:测试集

那么对于之前的Facebook签到位置预测问题的k值,如何取得一个合理的值?下面使用超参数搜索-网格搜索(Grid Search)

3超参数搜索-网格搜索(Grid Search)

介绍

超参数网格搜索(Grid Search)是一种常用的超参数调优方法,用于寻找最佳的超参数组合,从而优化机器学习模型的性能。它通过在预定义的超参数空间中搜索所有可能的组合,然后评估每个组合的性能,最终选择性能最佳的组合作为最终的超参数设置。

以下是超参数网格搜索的步骤和原理:

  1. 超参数空间定义: 首先,为模型选择要调优的超参数,并为每个超参数指定可能的取值范围。例如,对于支持向量机,可以选择C(正则化参数)和kernel(核函数类型)作为需要调优的超参数,为它们指定一组候选取值。
  2. 生成网格: 将每个超参数的可能取值组合成一个网格,生成所有可能的超参数组合。这个网格中的每个点都代表一组超参数设置。
  3. 交叉验证: 对于每个超参数组合,使用交叉验证来评估模型在验证集上的性能。通常使用k折交叉验证,对于每个超参数组合,训练模型k次,并计算平均性能指标。
  4. 选择最佳组合: 根据交叉验证的结果,选择性能最佳的超参数组合作为最终的选择。通常根据准确率、F1得分、均方误差等评价指标来衡量性能。
  5. 应用最佳超参数: 使用在步骤4中选择的最佳超参数组合来训练模型,然后在独立的测试集上评估其性能。

超参数网格搜索的优点在于它是一种简单而有效的方法,可以在有限的计算资源下尝试多种超参数组合。然而,网格搜索的缺点是它可能会对计算资源造成较大的负担,特别是在超参数空间较大时。为了提高效率,可以使用随机搜索等方法来在超参数空间中采样,以更快地找到性能较好的超参数组合。

API

sklearn.model_selection.GridSearchCV(estimator, param_grid=None,cv=None)

        对估计器的指定参数值进行详尽搜索

        estimator:估计器对象

        param_grid:估计器参数(dict){“n_neighbors”:[1,3,5]}

        cv:指定几折交叉验证

        fit:输入训练数据

        score:准确率

结果分析:

        bestscore:在交叉验证中验证的最好结果

        bestestimator:最好的参数模型

        cvresults:每次交叉验证后的验证集准确率结果和训练集准确率结果

🔺Facebook签到位置预测K值调优

使用网格搜索估计器,在原来的KNN算法实现Facebook签到位置预测的代码基础上,新的修改如下代码:

# 使用网格搜索和交叉验证找到合适的参数
knn = KNeighborsClassifier()param = {"n_neighbors": [3, 5, 10]}gc = GridSearchCV(knn, param_grid=param, cv=2)gc.fit(x_train, y_train)print("选择了某个模型测试集当中预测的准确率为:", gc.score(x_test, y_test))# 训练验证集的结果
print("在交叉验证当中验证的最好结果:", gc.best_score_)
print("gc选择了的模型K值是:", gc.best_estimator_)
print("每次交叉验证的结果为:", gc.cv_results_)

代码解释:

创建KNN分类器:knn = KNeighborsClassifier()

定义超参数空间:param = {"n_neighbors": [3, 5, 10]}

创建GridSearchCV实例,传入KNN分类器实例和定义好的超参数空间,cv=2表示使用2折交叉验证gc = GridSearchCV(knn, param_grid=param, cv=2)

执行网格搜索和交叉验证:gc.fit(x_train, y_train)

评估测试集性能,使用训练好的网格搜索模型在测试集上进行预测并输出准确率:print("选择了某个模型测试集当中预测的准确率为:", gc.score(x_test, y_test))

打印交叉验证结果:

  • gc.best_score_:输出在交叉验证中获得的最佳性能指标。
  • gc.best_estimator_:输出最佳性能对应的模型,包括超参数设置。
  • gc.cv_results_:输出每次交叉验证的结果,包括参数设置和性能指标。

回顾k值交叉验证:

KNN调优全部代码:

# -*- coding: utf-8 -*-
# @Author:︶ㄣ释然
# @Time: 2023/8/30 23:48
import pandas as pd
from sklearn.model_selection import train_test_split  # 将数据集分割为训练集和测试集。
from sklearn.neighbors import KNeighborsClassifier  # 实现KNN分类器
from sklearn.preprocessing import StandardScaler  # 特征标准化
from sklearn.model_selection import GridSearchCV  # 网格搜索'''
sklearn.model_selection.GridSearchCV(estimator, param_grid=None,cv=None)对估计器的指定参数值进行详尽搜索estimator:估计器对象param_grid:估计器参数(dict){“n_neighbors”:[1,3,5]}cv:指定几折交叉验证fit:输入训练数据score:准确率
结果分析:bestscore:在交叉验证中验证的最好结果_bestestimator:最好的参数模型cvresults:每次交叉验证后的验证集准确率结果和训练集准确率结果
'''
def knn_GridSearch():"""K近邻算法预测入住位置类别:return:"""# 一、处理数据以及特征工程# 1、读取收,缩小数据的范围data = pd.read_csv("./data/FBlocation/train.csv")# 数据逻辑筛选操作 df.query()data = data.query("x > 1.0 & x < 1.25 & y > 2.5 & y < 2.75")# 删除time这一列特征data = data.drop(['time'], axis=1)print(data)# 删除入住次数少于三次位置place_count = data.groupby('place_id').count()tf = place_count[place_count.row_id > 3].reset_index()data = data[data['place_id'].isin(tf.place_id)]# 3、取出特征值和目标值y = data['place_id']# y = data[['place_id']]x = data.drop(['place_id', 'row_id'], axis=1)# 4、数据分割与特征工程?# (1)、数据分割x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3)# (2)、标准化std = StandardScaler()# 队训练集进行标准化操作x_train = std.fit_transform(x_train)print(x_train)# 进行测试集的标准化操作x_test = std.fit_transform(x_test)# 二、算法的输入训练预测# K值:算法传入参数不定的值    理论上:k = 根号(样本数)# K值:后面会使用参数调优方法,去轮流试出最好的参数[1,3,5,10,20,100,200]# 使用网格搜索和交叉验证找到合适的参数knn = KNeighborsClassifier()param = {"n_neighbors": [3, 5, 10]}gc = GridSearchCV(knn, param_grid=param, cv=2)gc.fit(x_train, y_train)print("选择了某个模型测试集当中预测的准确率为:", gc.score(x_test, y_test))# 训练验证集的结果print("在交叉验证当中验证的最好结果:", gc.best_score_)print("gc选择了的模型K值是:", gc.best_estimator_)print("每次交叉验证的结果为:", gc.cv_results_)if __name__ == '__main__':knn_GridSearch()

执行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/114192.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Node.js 中间件是怎样工作的?

express自带路由功能&#xff0c;可以侦听指定路径的请求&#xff0c;除此之外&#xff0c;express最大的优点就是【中间件】概念的灵活运用&#xff0c;使得各个模块得以解耦&#xff0c;像搭积木一样串起来就可以实现复杂的后端逻辑。除此之外&#xff0c;还可以利用别人写好…

文件包含漏洞

文章目录 渗透测试漏洞原理文件包含漏洞1. 文件包含概述1.1 文件包含语句1.1.1 相关配置 1.2 动态包含1.2.1 示例代码1.2.2 本地文件包含1.2.3 远程文件包含 1.3 漏洞原理1.4 文件包含和文件读取的区别 2. 文件包含攻防2.1 利用方法2.1.1 包含图片木马2.1.2 读取敏感文件2.1.3 …

Java实现根据关键词搜索京东商品列表数据方法,当当API接口(jd.item_search)申请指南

要通过京东网的API获取商品列表数据&#xff0c;您可以使用京东开放平台提供的接口来实现。以下是一种使用Java编程语言实现的示例&#xff0c;展示如何通过京东开放平台API获取商品列表&#xff1a; 首先&#xff0c;确保您已注册成为当当开放平台的开发者&#xff0c;并创建…

提高工作效率的一键查询和保存大量快递物流信息的技巧

在如今快速发展的电商行业中&#xff0c;物流服务的准确与便捷是保证顺利交付商品的重要环节。为了方便用户追踪物流&#xff0c;固乔快递查询助手应运而生。这款软件不仅能够快速查询快递单号的物流信息&#xff0c;还具备保存查询结果的功能&#xff0c;方便用户随时查看。 首…

QT(8.31)加载资源文件,信号与槽机制

作业&#xff1a; 实现登录界面&#xff0c;设置账号为admin&#xff0c;密码为123456&#xff0c;登陆成功则退出当前界面&#xff0c;切换到其他界面&#xff0c;密码错误或者账号不匹配则清空账号密码输入框中的内容&#xff0c;并输出登录失败&#xff0c;点击取消则退出当…

C语言memcpy与memmove

C语言memcpy与memmove memcpy memcpy函数原型 void* memcpy(void* dst,const void* src,size_t size);//const修饰src,表示不应该修改src指向的数据memcpy用于实现数据的拷贝操作&#xff0c;将src往后的size字节数据拷贝到dst指向的空间 memcpy的实现&#xff1a; void*…

Leetcode每日一题:1267. 统计参与通信的服务器

原题 这里有一幅服务器分布图&#xff0c;服务器的位置标识在 m * n 的整数矩阵网格 grid 中&#xff0c;1 表示单元格上有服务器&#xff0c;0 表示没有。 如果两台服务器位于同一行或者同一列&#xff0c;我们就认为它们之间可以进行通信。 请你统计并返回能够与至少一台其…

gdb 快速上手(附带测试案例)

在终端使用 gdb 对程序进行调试比较复杂&#xff0c;本文旨在帮助小白快速上手 gdb &#xff0c;所以只介绍了一些比较重要的命令&#xff01; 案例代码在文末&#xff01; 一、gdb 调试 1、编译源文件 gcc -g test.c -o test 2、启动程序 gdb ./test 结果如下&#xff1a;…

Ansible学习笔记3

ansible模块&#xff1a; ansible是基于模块来工作的&#xff0c;本身没有批量部署的能力&#xff0c;真正具有批量部署的是ansible所运行的模块&#xff0c;ansible只是提供一个框架。 ansible支持的模块非常多&#xff0c;我们并不需要把每个模块记住&#xff0c;而只需要熟…

高基数类别特征预处理:平均数编码 | 京东云技术团队

一 前言 对于一个类别特征&#xff0c;如果这个特征的取值非常多&#xff0c;则称它为高基数&#xff08;high-cardinality&#xff09;类别特征。在深度学习场景中&#xff0c;对于类别特征我们一般采用Embedding的方式&#xff0c;通过预训练或直接训练的方式将类别特征值编…

day 31 面向对象 成员方法

class 类名称&#xff1a; 类的属类(定义在类中的变量&#xff0c;成员变量) 类的行为(定义在类中的函数&#xff0c;成员方法) # 设计一个类&#xff08;类比生活中&#xff1a;设计一张等级表&#xff09; class Student:name Nonegender Nonenatio…

DWA算法学习

一、DWA概念  DWA(动态窗口法)属于局部路径规划方法&#xff0c;为ROS中主要采用的方法。其原理主要是在速度空间&#xff08;v,w&#xff09;中采样多组速度&#xff0c;并模拟这些速度在一定时间内的运动轨迹&#xff0c;再通过一个评价函数对这些轨迹打分&#xff0c;最优的…

15.坐标添加带箭头的线

ol的官网示例中有绘制带箭头的线的demo&#xff0c;那个是交互式绘制&#xff0c;而不是根据经纬度坐标添加&#xff0c;在其基础上稍作修改&#xff0c;即可转为通过经纬度添加带箭头的线的功能&#xff0c;线和箭头的粗细大小样式都可以自定义 代码如下 <!DOCTYPE HTML P…

Java 多线程系列Ⅰ(创建线程+查看线程+Thread方法+线程状态)

多线程基础 一、创建线程的五种方法前置知识1、方法一&#xff1a;使用继承Thread类&#xff0c;重写run方法2、方法二&#xff1a;实现Runnable接口&#xff0c;重写run方法3、方法三&#xff1a;继承Thread&#xff0c;使用匿名内部类4、方法四&#xff1a;实现Runnable&…

5G工业网关赋能救护车远程监控,助力高效救援

智慧医疗是传统医疗业发展进步的必要趋势&#xff0c;医疗设备通过物联网技术的应用实现智能化转型。通过5G工业网关将医疗器械等设备的数据采集再经过专网传输到医疗系统中&#xff0c;实现医疗设备间的数据共享和远程监控&#xff0c;能够帮助医疗行业大大提高服务质量和管理…

Weblogic漏洞(四)之 CVE-2018-2894 任意文件上传漏洞

CVE-2018-2894 任意文件上传漏洞 漏洞影响 Weblogic受影响的版本&#xff1a; 10.3.6.012.1.3.012.2.1.212.2.1.3 漏洞环境 此次我们使用的是vnlhub靶场搭建的环境&#xff0c;是vnlhub中的Weblogic漏洞中的CVE-2018-2894靶场&#xff0c;我们 cd 到 CVE-2018-2894&#x…

基于KNN算法的鸢尾花种类预测

导入数据 iris_data load_iris() iris_data.data[0:5, :]array([[5.1, 3.5, 1.4, 0.2],[4.9, 3. , 1.4, 0.2],[4.7, 3.2, 1.3, 0.2],[4.6, 3.1, 1.5, 0.2],[5. , 3.6, 1.4, 0.2]])# 特征值名称 iris_data.feature_names[sepal length (cm),sepal width (cm),petal length (cm…

12、监测数据采集物联网应用开发步骤(9.1)

监测数据采集物联网应用开发步骤(8.2) TCP/IP Server开发 在com.zxy.common.Com_Para.py中添加如下内容 #锁机制 lock threading.Lock() #本机服务端端口已被连接客户端socket list dServThreadList {} #作为服务端接收数据拦截器 ServerREFLECT_IN_CLASS "com.plug…

PMP - 敏捷 3355

三个核心 产品负责人 负责最大化投资回报&#xff08;ROI&#xff09;&#xff0c;通过确定产品特性&#xff0c;把他们翻译成一个有优先级的列表 为下一个 sprint 决定在这个列表中哪些应该优先级最高&#xff0c;并且不断调整优先级以及调整这个列表 职责是定义需求、定义…

SSL核心概念 SSL类型级别

SSL&#xff1a;SSL&#xff08;Secure Sockets Layer&#xff09;即安全套接层&#xff0c;及其继任者传输层安全&#xff08;Transport Layer Security&#xff0c;TLS&#xff09;是为网络通信提供安全及数据完整性的一种安全协议。TLS与SSL在传输层对网络连接进行加密。 H…