2024最新分别用sklearn和NumPy设计k-近邻法对鸢尾花数据集进行分类(包含详细注解与可视化结果)

本文章代码实现以下功能:

利用sklearn设计实现k-近邻法。

利用NumPy设计实现k-近邻法。

将设计的k-近邻法对鸢尾花数据集进行分类,通过准确率来验证所设计算法的正确性,并将分类结果可视化。

评估k取不同值时算法的精度,并通过可视化展示。

sklearn实现

# (1)数据导入,分割数据
# 导入iris数据集
from sklearn.datasets import load_iris# 分割数据模块
from sklearn.model_selection import train_test_split# (2)K最近(KNN,K-Nearest Neighbor)分类算法
from sklearn.neighbors import KNeighborsClassifier# 加载iris数据集
data = load_iris()
# 导入数据和标签
data_X = data.data
data_y = data.target# ———————画图,看第一和第三特征的分布——————————————————
import matplotlib.pyplot as plt
print(data.feature_names)
# print(data.data[:, 0])
# print(data.data[:, 2])
feature_1 = data.data[:, 0] # 设置横坐标标签 代表的是花萼长度
feature_3 = data.data[:, 2]  # 设置纵坐标标签 代表的是花瓣宽度
plt.scatter(feature_1, feature_3)  # 看数据分布
plt.show()# _--------------------150个数据的行索引号0-149------------
plt.scatter(feature_1[:50], feature_3[:50], c='red')  # 第一类
plt.scatter(feature_1[50:100], feature_3[50:100], c='blueviolet')  # 第二类
plt.scatter(feature_1[100:], feature_3[100:], c='darkred')  # 第三类
plt.show()# 分割数据# 将完整数据集的70%作为训练集,30%作为测试集,
# 并使得测试集和训练集中各类别数据的比例与原始数据集比例一致(stratify分层策略),另外可通过设置shuffle=True 提前打乱数据。
X_train, X_test, y_train, y_test = train_test_split(data_X,data_y,random_state=12,stratify=data_y,test_size=0.3)
# 建立模型进行训练和预测# 建立模型
knn = KNeighborsClassifier()
# knn=KNeighborsClassifier(n_neighbors=3)# (3)训练模型
knn.fit(X_train, y_train)
print(knn.score(X_test, y_test))  # 计算模型的准确率# (4)预测模型
y_pred = knn.predict(X_test)
print(y_pred - y_test)# (5)评价— ——用accuracy_score计算准确率— ———————
from sklearn.metrics import accuracy_scoreprint(accuracy_score(y_test, y_pred))  # 也可以算正确率print(accuracy_score(y_test, y_pred, normalize=False))  # 统计测试样本分类的个数# 测试不同的k值
k_range = range(1, 31)  # 测试1到30的k值
accuracy = []  # 用于存储每个k值的准确率for k in k_range:knn = KNeighborsClassifier(n_neighbors=k)knn.fit(X_train, y_train)y_pred = knn.predict(X_test)accuracy.append(accuracy_score(y_test, y_pred))# 绘制k值与准确率的关系图
plt.figure(figsize=(10, 6))
plt.plot(k_range, accuracy, marker='o', linestyle='-', color='b')
plt.title('KNN Varying number of neighbors')
plt.xlabel('Number of neighbors, k')
plt.ylabel('Accuracy')
plt.xticks(k_range)
plt.grid(True)
plt.show()# (6)保存和加载模型
import joblib# 用joblib.dump保存模型
joblib.dump(knn, 'iris_KNN.pkl')
# # 用joblib.load加载已保存的模型
knn1 = joblib.load('iris_KNN.pkl')
# #测试读取后的Model
print(knn1.predict(data_X[0:1]))  # 预测第一个数据的类别
y_pred1 = knn1.predict(X_test)
print(y_pred1 - y_test)# 可视化分类结果
plt.figure(figsize=(8, 6)) #设置图形的大小为8英寸宽和6英寸高。#绘制实际类别
plt.scatter(feature_1[:50], feature_3[:50], c='red', label='Actual Setosa')
plt.scatter(feature_1[50:100], feature_3[50:100], c='blueviolet', label='Actual Versicolor')
plt.scatter(feature_1[100:], feature_3[100:], c='darkred', label='Actual Virginica')#绘制预测类别
plt.scatter(X_test[y_pred == 0][:, 0], X_test[y_pred == 0][:, 2], c='lightcoral', marker='x', label='Predicted Setosa')
plt.scatter(X_test[y_pred == 1][:, 0], X_test[y_pred == 1][:, 2], c='lightblue', marker='^', label='Predicted Versicolor')
plt.scatter(X_test[y_pred == 2][:, 0], X_test[y_pred == 2][:, 2], c='pink', marker='s', label='Predicted Virginica')plt.xlabel('Sepal Length')
plt.ylabel('Petal Width')
plt.title('KNN Classification Result on Iris Dataset')
plt.legend()
plt.grid(True)
plt.show()

实现结果

用Numpy实现

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt# 加载iris数据集
data = load_iris()
X = data.data
y = data.target# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=12, stratify=y)# 定义欧氏距离函数
def euclidean_distance(x1, x2):return np.sqrt(np.sum((x1 - x2) ** 2))# 实现KNN算法
class KNN:def fit(self, X, y):self.X_train = Xself.y_train = ydef predict(self, X, k=3):y_pred = [self._predict(x, k) for x in X]return np.array(y_pred)def _predict(self, x, k):# 计算x与训练集中每个点的距离distances = [euclidean_distance(x, x_train) for x_train in self.X_train]# 获取k个最近邻的索引k_indices = np.argsort(distances)[:k]# 获取这些最近邻对应的标签k_nearest_labels = [self.y_train[i] for i in k_indices]# 通过多数投票确定预测类别most_common = np.bincount(k_nearest_labels).argmax()return most_common# 实例化KNN
knn = KNN()# 训练模型
knn.fit(X_train, y_train)# 进行预测
y_pred = knn.predict(X_test)# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")# 评估不同k值的准确率
k_values = range(1, 31)
accuracies = []for k in k_values:knn = KNN()knn.fit(X_train, y_train)y_pred = knn.predict(X_test, k=k)accuracy = accuracy_score(y_test, y_pred)accuracies.append(accuracy)# 可视化k值与准确率的关系
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(k_values, accuracies, marker='o')
plt.xlabel('Number of Neighbors')
plt.ylabel('Accuracy')
plt.title('KNN Varying number of neighbors')
plt.grid(True)# 可视化分类结果
plt.subplot(1, 2, 2)
plt.scatter(X_test[:, 2], X_test[:, 3], c=y_pred, cmap=plt.cm.Set1, edgecolor='k')
plt.title('KNN Classification Result')
plt.xlabel('Petal Length')
plt.ylabel('Petal Width')
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, loc='upper left')
plt.grid(True)plt.tight_layout()
plt.show()# 找出最高精度和对应的k值
max_accuracy = max(accuracies)
best_k = k_values[accuracies.index(max_accuracy)]
print(f"The best accuracy is {max_accuracy:.2f} with k = {best_k}")

实现结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/447469.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HTML CSS 基础

HTML & CSS 基础 HTML一、HTML简介1、网页1.1 什么是网页1.2 什么是HTML1.3 网页的形成1.4总结 2、web标准2.1 为什么需要web标准2.2 Web 标准的构成 二、HTML 标签1、HTML 语法规范1.1基本语法概述1.2 标签关系 2、 HTML 基本结构标签2.1 第一个 HTML 网页2.2 基本结构标签…

uniapp 游戏 - 使用 uniapp 实现的扫雷游戏

0. 思路 1. 效果图 2. 游戏规则 扫雷的规则很简单。盘面上有许多方格,方格中随机分布着一些雷。你的目标是避开雷,打开其他所有格子。一个非雷格中的数字表示其相邻 8 格子中的雷数,你可以利用这个信息推导出安全格和雷的位置。你可以用右键在你认为是雷的地方插旗(称为标…

中华春节符号·世界品牌——粤港澳企(实)业协会商会经济合作座谈会成功举办

日前,一场旨在推动粤港澳三地经济深度合作的盛会——《粤港澳企(实)业协会商会经济合作座谈会》在广州市天河区时代TIT广场2栋801车陂社区文化中心隆重举行。此次活动由泰康之家粤园与广东经贸文化促进会联合主办,吸引了全球华人企…

Dubbo SpringBoot应用创建和K8S部署

推荐阅读:Dubbo 快速入门-CSDN博客 创建基于Spring Boot的微服务应用 以下文档将引导您从头创建一个基于 Spring Boot 的 Dubbo 应用,并为应用配置 Triple 通信协议、服务发现等微服务基础能力。 快速创建应用 以下是我们为您提前准备好的示例项目&am…

AI大模型开发架构设计(12)——以真实场景案例驱动深度剖析 AIGC 新时代 IT 人的能力模型

文章目录 以真实场景案例驱动深度剖析 AIGC 新时代 IT 人的能力模型1 AI 工具以及大模型会给我们带来哪些实际收益?业务研发流程环节为什么 LLM 大模型不是万能的?LLM 大模型带来实际收益 2 新时代IT人的能力模型会发生哪些变化?古典互联网架构师能力模型IT人能力模型变化 以…

这都能封!开发者行为导致Google账号关联?

从去年10月开始,在AI加持下,Google Play不断更新和迭代审查机制,Google Play在最近一年的时间真是杀疯了,封号的声音响彻整个行业,尤其是一些敏感品类行业。账号关联,恶意软件,欺骗行为&#xf…

小红书新ID保持项目StoryMaker,面部特征、服装、发型和身体特征都能保持一致!(已开源)

继之前和大家介绍的小红书在ID保持以及风格转换方面相关的优秀工作,感兴趣的小伙伴可以点击以下链接阅读~ 近期,小红书又新开源了一款文生图身份保持项目:StoryMaker,是一种个性化解决方案,它不仅保留了面部的一致性&…

贪吃蛇游戏(代码篇)

我们并不是为了满足别人的期待而活着。 前言 这是我自己做的第五个小项目---贪吃蛇游戏(代码篇)。后期我会继续制作其他小项目并开源至博客上。 上一小项目是贪吃蛇游戏(必备知识篇),没看过的同学可以去看看&#xf…

Angular Count-To 项目教程

Angular Count-To 项目教程 angular-count-to Angular directive to animate counting to a number 项目地址: https://gitcode.com/gh_mirrors/an/angular-count-to 1. 项目介绍 Angular Count-To 是一个用于 AngularJS 的动画计数器指令。该指令可以在指定的时间内从…

Lfsr32

首先分析 Lfsr5 首先要理解什么是抽头点(tap),注意到图中有两个触发器的输入为前级输出与q[0]的异或,这些位置被称为 tap position.通过观察上图,所谓抽头点指的就是第5个,第3个寄存器的输入经过了异或逻辑…

利用C++封装鼠标轨迹算法为DLL:游戏行为检测的利器

在现代软件开发中,鼠标轨迹模拟技术因其在自动化测试、游戏脚本编写等领域的广泛应用而备受青睐。本文将介绍如何使用C语言将鼠标轨迹算法封装为DLL(动态链接库),以便在多种编程环境中实现高效调用,同时探讨其在游戏行…

cudnn8编译caffe过程(保姆级图文全过程,涵盖各种报错及解决办法)

众所周知,caffe是个较老的框架,而且只支持到cudnn7,但是笔者在复现ds-slam过程中又必须编译caffe,我的cuda版本是11.4,最低只支持到8.2.4,故没办法,只能编译了 在此记录过程、报错及解决办法如下; 首先安装依赖: sudo apt-get install git sudo apt-get install lib…

【IEEE独立出版 | 厦门大学主办】第四届人工智能、机器人和通信国际会议(ICAIRC 2024)

【IEEE独立出版 | 厦门大学主办】 第四届人工智能、机器人和通信国际会议(ICAIRC 2024) 2024 4th International Conference on Artificial Intelligence, Robotics, and Communication 2024年12月27-29日 | 中国厦门 >>往届均已成功见刊检索…

【Kubernetes① 基础】一、容器基础

目录 一、进程二、隔离与限制三、容器镜像总结参考书籍 一、进程 容器技术的兴起源于PaaS技术(平台即服务)的普及;Docker公司发布的Docker项目具有里程碑式的意义;Docker项目通过“容器镜像”解决了应用打包这个根本性难题(CloudFoundry)。 容器本身的价…

【QAMISRA】解决导入commands.json时报错问题

1、 文档目标 解决导入commands.json时报错“Could not obtain system-wide includes and defines”的问题。 2、 问题场景 客户导入commands.json时报错“Could not obtain system-wide includes and defines”。 3、软硬件环境 1、软件版本: QA-MISRA23.04 2、…

【电路笔记】-运算放大器多谐振荡器

运算放大器多谐振荡器 文章目录 运算放大器多谐振荡器1、概述2、施密特触发器3、运算放大器稳态多谐振荡器4、运算放大器单稳态多谐振荡器5、运算放大器双稳态多谐振荡器6、总结1、概述 本文将重点介绍通常称为多谐振荡器的配置,特别是基于运算放大器的电路。 事实上,多谐振…

AWS账号与邮箱的关系解析

在当今数字化时代,云计算服务的普及使得越来越多的企业和个人用户开始使用亚马逊网络服务(AWS)。作为全球领先的云服务平台,AWS为用户提供了丰富的计算、存储和数据库服务。然而,对于许多新用户来说,关于AW…

VLOG视频制作解决方案,开发者可自行定制包装模板

无论是旅行见闻、美食探店,还是日常琐事、创意挑战,每一个镜头背后都蕴含着创作者无限的热情和创意。然而,面对纷繁复杂的视频编辑工具,美摄科技凭借其前沿的视频制作技术和创新的解决方案,为每一位视频创作者提供了开…

LLaMA-Factory 让大模型微调变得更简单!!

背景 如果只需要构建一份任务相关的数据,就可以轻松通过网页界面的形式进行 Fine-tuning 微调操作, 那么必将大大减轻微调工作量。 今年的 ACL 2024见证了北航和北大合作的突破—论文《LLAMAFACTORY: 统一高效微调超百种语言模型》。他们打造的 LLaMA-…

三菱FX3UPLC机械原点回归- DSZR/ZRN指令

机械原点回归用指令的种类 产生正转脉冲或者反转脉冲后,增减当前值寄存器的内容。可编程控制器的定位指令,可编程控制器的电源0FF后,当前值寄存器清零,因此上电后,请务必使机械位置和当前值寄存器的位置相吻合…