KNN-近邻算法 及 模型的选择与调优(facebook签到地点预测)

什么是K-近邻算法(K Nearest Neighbors)

1、K-近邻算法(KNN)

1.1 定义

如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。

来源:KNN算法最早是由Cover和Hart提出的一种分类算法

1.2 距离公式

两个样本的距离可以通过如下公式计算,又叫欧式距离

距离公式:
在这里插入图片描述
KNN核心思想:
你的“邻居”来推断出你的类别

  • 计算距离:
    • 距离公式
      • 欧氏距离
      • 曼哈顿距离 绝对值距离
      • 明可夫斯基距离

2、电影类型分析

假设我们有现在几部电影
在这里插入图片描述
其中? 7号电影不知道类别,如何去预测?我们可以利用K近邻算法的思想
下方是根据欧氏距离计算结果:
在这里插入图片描述
K值是重要影响元素:
当我们看,如果k=1,假如只有第二行一个”邻居“,那么就计算的距离误差就比较大,样本量过少:
在这里插入图片描述
如果k=7,也就是再多一行数据,假设是“封神榜”,那么也就是说邻近的算法中,动作片的类占多数,那么我们就会将位置的那行数据预测为动作片,但是实际位置那行数据的接吻镜头是比较多的,应该是个爱情片,所以预测也是错误的!

在这里插入图片描述

  • 总结:
    • k 值取得过小,容易受到异常点的影响
    • k 值取得过大,样本不均衡的影响

3、K-近邻算法API

  • sklearn.neighbors.KNeighborsClassifier(n_neighbors=5,algorithm=‘auto’)
    • n_neighbors:int,可选(默认= 5),k_neighbors查询默认使用的邻居数
    • algorithm:{‘auto’,‘ball_tree’,‘kd_tree’,‘brute’},可选用于计算最近邻居的算法:‘ball_tree’将会使用 BallTree,‘kd_tree’将使用 KDTree。‘auto’将尝试根据传递给fit方法的值来决定最合适的算法。 (不同实现方式影响效率)

4、代码:鸢尾花案例

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler"""
用KNN算法对鸢尾花进行分类
:return:
"""
# 1)获取数据
iris = load_iris()
# print(iris)
# print(iris.target_names)
# print(iris.DESCR)# 2)划分数据集
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.3, random_state=22)# 3)特征工程:标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)# 4)KNN算法预估器
estimator = KNeighborsClassifier(n_neighbors=3, algorithm='auto')
estimator.fit(x_train, y_train)# 5)模型评估
# 方法1:直接比对真实值和预测值
y_predict = estimator.predict(x_test)
print("y_predict:\n", y_predict)
print("直接比对真实值和预测值:\n", y_test == y_predict)# 方法2:计算准确率
score = estimator.score(x_test, y_test)
print("准确率为:\n", score)

在这里插入图片描述

那么说来说去,“邻居的数量” K 到底怎么取呢?
这就涉及到模型的选择与调优了;

5、为什么需要交叉验证

  • 交叉验证目的:为了让被评估的模型更加准确可信

6、什么是交叉验证(cross validation)

  • 交叉验证:将拿到的训练数据,分为训练和验证集。以下图为例:将数据分成4份,其中一份作为验证集。然后经过4次(组)的测试,每次都更换不同的验证集。即得到4组模型的结果,取平均值作为最终结果。又称4折交叉验证。
    • 训练集:训练集+验证集
    • 测试集:测试集
      在这里插入图片描述
      问题:那么这个只是对于参数得出更好的结果,那么怎么选择或者调优参数呢?

7、超参数搜索-网格搜索(Grid Search)

通常情况下,有很多参数是需要手动指定的(如k-近邻算法中的K值),这种叫超参数。但是手动过程繁杂,网格搜索帮我们实现了这个调参过程,首先需要对模型预设几种超参数组合,每组超参数都采用交叉验证来进行评估,最后选出最优参数组合建立模型。
在这里插入图片描述

7.1、模型选择与调优 API

  • sklearn.model_selection.GridSearchCV(estimator, param_grid=None,cv=None)
    • 对估计器的指定参数值进行详尽搜索
    • estimator:估计器对象
    • param_grid:估计器参数(dict){“n_neighbors”:[1,3,5]}
    • cv:指定几折交叉验证
    • fit:输入训练数据
    • score:准确率
  • 结果分析:
    • bestscore:在交叉验证中验证的最好结果_
    • bestestimator:最好的参数模型
    • cvresults:每次交叉验证后的验证集准确率结果和训练集准确率结果

7.1、网格搜索与交叉验证代码

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler"""
用KNN算法对鸢尾花进行分类,添加网格搜索和交叉验证
:return:
"""
# 1)获取数据
iris = load_iris()# 2)划分数据集
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.3, random_state=22)# 3)特征工程:标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)# 4)KNN算法预估器
estimator = KNeighborsClassifier()# 加入网格搜索与交叉验证
# 参数准备
param_dict = {"n_neighbors": [1, 2, 3, 4, 5, 6, 7, 8, 9, 11]}
estimator = GridSearchCV(estimator, param_grid=param_dict, cv=10)
estimator.fit(x_train, y_train)# 5)模型评估
# 方法1:直接比对真实值和预测值
y_predict = estimator.predict(x_test)
print("y_predict:\n", y_predict)
print("直接比对真实值和预测值:\n", y_test == y_predict)# 方法2:计算准确率
score = estimator.score(x_test, y_test)
print("准确率为:\n", score)# 最佳参数:best_params_
print("最佳参数:\n", estimator.best_params_)
# 最佳结果:best_score_
print("最佳结果:\n", estimator.best_score_)
# 最佳估计器:best_estimator_
print("最佳估计器:\n", estimator.best_estimator_)
# 交叉验证结果:cv_results_
print("交叉验证结果:\n", estimator.cv_results_)

在这里插入图片描述

8、facebook 签到位置预测

在这里插入图片描述
在这里插入图片描述

  • 数据介绍:将根据用户的位置,准确性和时间戳预测用户正在查看的业务。
  • train.csv
    • row_id:登记事件的ID
    • xy:坐标
    • 准确性:定位准确性
    • 时间:时间戳
    • place_id:业务的ID,这是您预测的目标

官网:https://www.kaggle.com/navoshta/grid-knn/data

8.1、流程分析

对于数据做一些基本处理(这里所做的一些处理不一定达到很好的效果,我们只是简单尝试,有些特征我们可以根据一些特征选择的方式去做处理)

1、缩小数据集范围 DataFrame.query()(选择性处理!)
2、删除没用的日期数据 DataFrame.drop(可以选择保留)
3、将签到位置少于n个用户的删除

place_count = data.groupby('place_id').count()
tf = place_count[place_count.row_id > 3].reset_index()
data = data[data['place_id'].isin(tf.place_id)]

4、分割数据集
5、标准化处理
6、k-近邻预测

8.2、代码

import pandas as pd
# 1、获取数据
data = pd.read_csv("train.csv")
data.head()

在这里插入图片描述

# 1)处理时间特征
time_value = pd.to_datetime(data["time"], unit="s")
date = pd.DatetimeIndex(time_value)
data["day"] = date.day
data["weekday"] = date.weekday
data["hour"] = date.hour
data.head()

在这里插入图片描述

# 2)过滤签到次数少的地点
place_count = data.groupby("place_id").count()["row_id"]
data_final = data[data["place_id"].isin(place_count[place_count > 3].index.values)]
data_final.head()

在这里插入图片描述

# 筛选特征值和目标值
x = data_final[["x", "y", "accuracy", "day", "weekday", "hour"]]
y = data_final["place_id"]

在这里插入图片描述

# 数据集划分
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y)
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV# 3)特征工程:标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)# 4)KNN算法预估器
estimator = KNeighborsClassifier()# 加入网格搜索与交叉验证
# 参数准备
param_dict = {"n_neighbors": [3, 5, 7, 9]}
estimator = GridSearchCV(estimator, param_grid=param_dict, cv=3)
estimator.fit(x_train, y_train)# 5)模型评估
# 方法1:直接比对真实值和预测值
y_predict = estimator.predict(x_test)
print("y_predict:\n", y_predict)
print("直接比对真实值和预测值:\n", y_test == y_predict)# 方法2:计算准确率
score = estimator.score(x_test, y_test)
print("准确率为:\n", score)# 最佳参数:best_params_
print("最佳参数:\n", estimator.best_params_)
# 最佳结果:best_score_
print("最佳结果:\n", estimator.best_score_)
# 最佳估计器:best_estimator_
print("最佳估计器:\n", estimator.best_estimator_)
# 交叉验证结果:cv_results_
print("交叉验证结果:\n", estimator.cv_results_)

这个结果数据量比较大,毕竟两千万训练数据了,各位可自行试验及调参;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/161208.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JVM面试】从JDK7 到 JDK8, JVM为啥用元空间替换永久代?

系列文章目录 【JVM系列】第一章 运行时数据区 【面试】第二章 从JDK7 到 JDK8, JVM为啥用元空间替换永久代? 大家好,我是青花。拥有多项发明专利(都是关于商品、广告等推荐产品)。对广告、Web全栈以及Java生态微服务拥有自己独到…

VSS、VDD、VBAT、VSSA

引言 在学习设计TM32时,发现芯片除了GPIO引脚外还会引出许多引脚,以STM32F407ZGT6为例除了GPIO引脚还会有以下引脚 如VSS、VDD、VBAT、VSSA、NRST、VREF、VDDA、VCAP_1、VCAP_2、PDR_ON这些引脚。他们有何作用,电路设计中应如何连接&#x…

迅为RK3588开发板使用RKNN-Toolkit-lite2运行测试程序

1 首先也需要部署运行环境,将库文件放入 RK3588 开发板上,我们将网盘资料“iTOP-3588 开发 板 \02_ 【 iTOP-RK3588 开 发 板 】 开 发 资 料 \12_NPU 使 用 配 套 资 料 \05_Linux_librknn_api\librknn_api\aarch64”路径下的文件通过U盘拷贝到开发板的…

LLM-RAG-WEB 大模型+文件+可视化界面

注意:这里只是简单实现了功能和界面,文件对话也暂时只支持一个文件,如果跳到模型对话再切换回文件对话会文件会删除重置会话,但模型对话切换回来时保留之前会话的。 1、代码(使用步骤说明在链接里) 参考下…

SQLite4Unity3d安卓 在手机上创建sqlite失败解决

总结 要在Unity上运行一次出现库,再打包进APK内 问题 使用示例代码的创建库 var dbPath string.Format("Assets/StreamingAssets/{0}", DatabaseName); #else// check if file exists in Application.persistentDataPathvar filepath string.Format…

Wireshark新手小白基础使用方法

一、针对IP抓取 1、过滤格式: (1)、ip.src eq x.x.x.x (2)、ip.dst eq x.x.x.x (3)ip.src eq x.x.x.x or ip.dst eq x.x.x.x 二、针对端口过滤 1、过滤格式: (1&a…

百度智能云千帆大模型平台 2.0 产品技术解析

本文整理自 2023 年 9 月 5 日百度云智大会 - 智能计算&大模型技术分论坛,百度智能云 AI &大数据平台总经理忻舟的主题演讲《百度智能云千帆大模型平台 2.0 产品技术解析》。 这是关于技术主题的论坛,我首先问大家三个开发者的小问题。 第一个问…

如何使用RockPlus MES系统帮助SMT行业实现降本增效

SMT(Surface Mount Technology)是现代电子行业中主要的组装技术,广泛应用于电子产品的生产。SMT工艺涵盖了锡膏印刷、元器件贴装和回流焊接。经过这些关键工序,元器件被精确固定在电路板上,完成一个电子产品组装。 SM…

SpringCloud: feign整合sentinel实现降级

一、加依赖&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache…

excel+requests管理测试用例接口自动化框架

背景&#xff1a; 某项目有多个接口&#xff0c;之前使用的unittest框架来管理测试用例&#xff0c;将每个接口的用例封装成一个py文件&#xff0c;接口有数据或者字段变动后&#xff0c;需要去每个py文件中找出变动的接口测试用例&#xff0c;维护起来不方便&#xff0c;为了…

Spring AOP归纳与总结

前言 AOP的核心思想是面向切面编程。AOP规范定义了多种概念&#xff0c;常用的aop框架有spring aop和AspectJ&#xff0c;两者功能和性能差异较大&#xff0c;现在默认的AOP框架是AspectJ&#xff0c;下面逐渐归纳其相关概念、功能及实现原理。 1. 概念 1. 切面&#xff1a;…

从零开始学习 Java:简单易懂的入门指南之线程池(三十六)

线程池 1.1 线程状态介绍1.2 线程池-基本原理1.3 线程池-Executors默认线程池1.4 线程池-Executors创建指定上限的线程池1.5 线程池-ThreadPoolExecutor1.6 线程池-参数详解1.7 线程池-非默认任务拒绝策略 1.1 线程状态介绍 当线程被创建并启动以后&#xff0c;它既不是一启动…

linux文件权限与目录配置

用户与用户组 linux一般将文件可读写的身份分为三个类别&#xff1a;拥有者&#xff08;owner&#xff09;、所属群组&#xff08;group&#xff09;、其他人&#xff08;other&#xff09; 三种身份都有读、写、执行等权限 文件拥有者 linux是个多人多任务的系统&#xff0c…

论文解析-moETM

论文解析-moETM 参考亮点动机发展现状现存问题 功能方法Encoder改进Decoder改进 评价指标生物保守性批次效应移除 实验设置结果多组学数据整合cell-topic mixture可解释性组学翻译性能评估RNA转录本、表面蛋白、染色质可及域调控关系研究1. 验证同一主题下&#xff0c;top gene…

什么是NetApp的DQP和如何安装DQP?

首先看看什么是DQP&#xff0c;DQPDisk Qualification Package&#xff0c;文字翻译就是磁盘验证包。按照NetApp的最佳实践&#xff0c;要定期升级DQP包&#xff0c;保证对最新磁盘和磁盘扩展柜的兼容。 本文主要介绍7-mode下如何升级DQP&#xff0c;至于cluster mode另外文章…

NewStarCTF2023week2-Upload again!

尝试传修改后缀的普通一句话木马&#xff0c;被检测 尝试传配置文件 .htaccess 和 .user.ini 两个都传成功了 接下来继续传入经过修改的木马 GIF89a <script language"php"> eval($_POST[cmd]); </script> 没有被检测&#xff0c;成功绕过 直接上蚁剑…

圣树唤歌最强阵容2023,圣树唤歌阵容推荐

无疑圣树唤歌作为一款备受欢迎的手机游戏&#xff0c;其深刻的战斗系统一直以来都受到大家的追捧。在这个虚拟世界中胜利的关键在于组建一支无懈可击的强大队伍&#xff0c;要想成为强者&#xff0c;就必须拥有最强阵容。 关注【娱乐天梯】&#xff0c;获取内部福利号 在本篇攻…

C++项目实战——基于多设计模式下的同步异步日志系统(总集篇)

文章目录 专栏导读项目介绍开发环境核心技术环境搭建日志系统介绍1.为什么需要日志系统2.日志系统技术实现2.1同步写日志2.2异步写日志 前置知识补充不定参函数C风格不定参函数不定参宏函数设计模式六大原则单例模式饿汉模式懒汉模式 工厂模式简单工厂模式工厂方法模式抽象工厂…

Linux:mongodb数据逻辑备份与恢复(3.4.5版本)

我在数据库aaa的里创建了一个名为tarro的集合&#xff0c;其中有三条数据 备份语法 mongodump –h server_ip –d database_name –o dbdirectory 恢复语法 mongorestore -d database_name --dirdbdirectory 备份 现在我要将aaa.tarro进行备份 mongodump --host 192.168.254…

攻防演练蓝队|Windows应急响应入侵排查

文章目录 日志分析web日志windows系统日志 文件排查进程排查新增、隐藏账号排查启动项/服务/计划任务排查工具 日志分析 web日志 dirpro扫描目录&#xff0c;sqlmap扫描dvwa Python dirpro -u http://192.168.52.129 -b sqlmap -u "http://192.168.52.129/dvwa/vulnera…