机器学习实验3——支持向量机分类鸢尾花

文章目录

    • 🧡🧡实验内容🧡🧡
    • 🧡🧡数据预处理🧡🧡
      • 代码
      • 认识数据
      • 相关性分析
      • 径向可视化
      • 各个特征之间的关系图
    • 🧡🧡支持向量机SVM求解🧡🧡
      • 直觉理解:
      • 数学推导
      • 代码
      • 运行结果
    • 🧡🧡总结🧡🧡

🧡🧡实验内容🧡🧡

基于鸢尾花数据集,完成关于支持向量机的分类模型训练、测试与评估。

🧡🧡数据预处理🧡🧡

代码

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler# ==================特征探索====================# ===认识数据===
iris = datasets.load_iris()
print("Feature names: {}".format(iris['feature_names']))
print("Target names: {}".format(iris["target_names"]))
print("target:\n{}".format(iris['target'])) # 0 代表setosa,1 代表versicolor,2 代表virginica。
print("shape of data: {}".format(iris['data'].shape))# ===转为df对象===
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
feature_df=df.drop('label',axis=1,inplace=False) # 取出特征
print(df)# ===相关性矩阵===
corr_matrix = feature_df.corr()
plt.figure(figsize=(8, 6))
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm')
plt.title('Correlation Matrix')
plt.show()# ===径向可视化===
ax = pd.plotting.radviz(df, 'label', colormap='brg')
ax.add_artist(plt.Circle((0,0), 1, color='r', fill = False))# ===各特征之间关系矩阵图===
# 设置颜色主题
g = sns.pairplot(data=df, palette="pastel", hue= 'label')

认识数据

属性:花萼长度,花萼宽度,花瓣长度,花瓣宽度
分类:Setosa,Versicolour,Virginica
在这里插入图片描述
在这里插入图片描述

相关性分析

如下图,可以直观看到花瓣宽度(Petal Width)和花瓣长度(Petal Length)存在很高的正相关性,且它们与花萼长度(Speal Length)也具有很高的正相关性,而花萼宽度(Speal Width)与其他三个属性特征的相关性均很弱。
在这里插入图片描述

径向可视化

用于观察每种类别花的四个特征之间的相对关系(线性大小关系)。
如下图,其中0、1、2分别对应Setosa,Versicolour,Virginica类别,可以直观看出:Setosa花的花萼宽度(Speal Width)和花萼长度(Speal Length)这两个特征相比其他两个特征花瓣宽度(Petal Width)和花瓣长度(Petal Length)具有区分性,而Versicolour,Virginica花的四个特征分布很相似,不好区分。
在这里插入图片描述

各个特征之间的关系图

从下图可以看出,Setosa花的花瓣宽度(Petal Width)和花瓣长度(Petal Length)的分布相比其他两类具有很好的区分性。
在这里插入图片描述

🧡🧡支持向量机SVM求解🧡🧡

直觉理解:

对于二维特征,如何区分图中不同的点
第一种思路:如下左图画一条线,但是是一个不太好的分割线
而换一种思路,如下右图,先找两个分类的决策边界(两边的虚线)之间的间隔区域,再取间隔区域的中间为分割线,这样能保证分割效果最佳。因此寻找最佳决策边界线(中间线)的问题可以转化为求解两类数据的最大间隔问题。
在这里插入图片描述在这里插入图片描述
因此将决策边界上下移动c,得到间隔的两个边界线,如下左图,此时这两个边界线称为支持向量,它们决定了间隔距离。如下右图,经过数学变换,可以得到最终要求的超平面表达式,即求解参数w、b即可
在这里插入图片描述在这里插入图片描述
除此之外,只考虑分类点的决策边界之间的距离的间隔,称为硬间隔,同时考虑距离和异常点损失(下图红线上方的黄点)的间隔,称为软间隔。
在这里插入图片描述

数学推导

某点到超平面的距离r:(几何间隔,可以代表分类正确的确信度)
在这里插入图片描述
目标超平面之间的间隔距离γ:
在这里插入图片描述
约束条件:点到超平面距离r >= 超平面间隔距离γ的一半:
在这里插入图片描述
则最终求解的函数表达式为:
在这里插入图片描述

但是以上函数表达式为非凸函数,因此要:

  1. 先转为凸函数
  2. 用拉格朗日乘子法和KKT条件求解对偶问题

1.转为凸函数:
在这里插入图片描述

2.用拉格朗日乘子法和KKT条件求解对偶问题
这个过程就涉及到高阶的数学知识了,我这里也不是很懂,只大概了解:
为什么要用拉格朗日乘子法:将不等式约束转换为等式约束。
整合成如下拉格朗日表达式:
在这里插入图片描述
依据对偶性,求解问题为:
在这里插入图片描述
先求解:在这里插入图片描述
根据KKT条件:对w、b求偏导可得:
在这里插入图片描述
代入L(w,b,a):在这里插入图片描述
再求解:在这里插入图片描述
在这里插入图片描述

3.利用SMO求解α、从而求解w、b
现在优化问题变成了如上的形式,但是它的规模正比于训练样本数m,当m很大时,会有很大开销,因此针对这个问题的特性,有更高效的优化算法,即序列最小优化(SMO)算法。
其大概思想是:先固定α以外的参数,然后对α求极值,在上述约束条件下,α可以由其他变量导出,这样,在参数初始化后,不断迭代,可以最终达到收敛。
通过SMO求得的w、b为:
在这里插入图片描述
则超平面的公式为:
在这里插入图片描述
最后根据超平面的符号,表达成分类决策函数即可:
在这里插入图片描述

代码

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScalerclass SMO:def __init__(self, X, y, C, kernel, tol, max_passes=10):self.X = X  # 样本特征 m*n m个样本 n个特征self.y = y  # 样本标签 m*1self.C = C  # 惩罚因子, 用于控制松弛变量的影响self.kernel = kernel  # 核函数self.tol = tol  # 容忍度self.max_passes = max_passes  # 最大迭代次数self.m, self.n = X.shapeself.alpha = np.zeros(self.m)self.b = 0self.w = np.zeros(self.n)# 计算核函数def K(self, i, j):if self.kernel == 'linear':return np.dot(self.X[i].T, self.X[j])elif self.kernel == 'rbf':gamma = 0.5return np.exp(-gamma * np.linalg.norm(self.X[i] - self.X[j]) ** 2)else:raise ValueError('Invalid kernel specified')def predict(self, X_test):pred = np.zeros_like(X_test[:, 0])pred = np.dot(X_test, self.w) + self.breturn np.sign(pred)def train(self):"""训练模型:return:"""passes = 0while passes < self.max_passes:num_changed_alphas = 0for i in range(self.m):# 计算E_i, E_i = f(x_i) - y_i, f(x_i) = w^T * x_i + b# 计算误差E_iE_i = 0for ii in range(self.m):E_i += self.alpha[ii] * self.y[ii] * self.K(ii, i)E_i += self.b - self.y[i]# 检验样本x_i是否满足KKT条件if (self.y[i] * E_i < -self.tol and self.alpha[i] < self.C) or (self.y[i] * E_i > self.tol and self.alpha[i] > 0):# 随机选择样本x_jj = np.random.choice(list(range(i)) + list(range(i + 1, self.m)), size=1)[0]# 计算E_j, E_j = f(x_j) - y_j, f(x_j) = w^T * x_j + b# E_j用于检验样本x_j是否满足KKT条件E_j = 0for jj in range(self.m):E_j += self.alpha[jj] * self.y[jj] * self.K(jj, j)E_j += self.b - self.y[j]alpha_i_old = self.alpha[i].copy()alpha_j_old = self.alpha[j].copy()# L和H用于将alpha[j]调整到[0, C]之间if self.y[i] != self.y[j]:L = max(0, self.alpha[j] - self.alpha[i])H = min(self.C, self.C + self.alpha[j] - self.alpha[i])else:L = max(0, self.alpha[i] + self.alpha[j] - self.C)H = min(self.C, self.alpha[i] + self.alpha[j])# 如果L == H,则不需要更新alpha[j]if L == H:continue# eta: alpha[j]的最优修改量eta = 2 * self.K(i, j) - self.K(i, i) - self.K(j, j)# 如果eta >= 0, 则不需要更新alpha[j]if eta >= 0:continue# 更新alpha[j]self.alpha[j] -= (self.y[j] * (E_i - E_j)) / eta# 根据取值范围修剪alpha[j]self.alpha[j] = np.clip(self.alpha[j], L, H)# 检查alpha[j]是否只有轻微改变,如果是则退出for循环if abs(self.alpha[j] - alpha_j_old) < 1e-5:continue# 更新alpha[i]self.alpha[i] += self.y[i] * self.y[j] * (alpha_j_old - self.alpha[j])# 更新b1和b2b1 = self.b - E_i - self.y[i] * (self.alpha[i] - alpha_i_old) * self.K(i, i) \- self.y[j] * (self.alpha[j] - alpha_j_old) * self.K(i, j)b2 = self.b - E_j - self.y[i] * (self.alpha[i] - alpha_i_old) * self.K(i, j) \- self.y[j] * (self.alpha[j] - alpha_j_old) * self.K(j, j)# 根据b1和b2更新bif 0 < self.alpha[i] and self.alpha[i] < self.C:self.b = b1elif 0 < self.alpha[j] and self.alpha[j] < self.C:self.b = b2else:self.b = (b1 + b2) / 2num_changed_alphas += 1if num_changed_alphas == 0:passes += 1else:passes = 0# 提取支持向量和对应的参数idx = self.alpha > 0  # 支持向量的索引# SVs = X[idx]selected_idx = np.where(idx)[0]SVs = self.X[selected_idx]SV_labels = self.y[selected_idx]SV_alphas = self.alpha[selected_idx]# 计算权重向量和截距self.w = np.sum(SV_alphas[:, None] * SV_labels[:, None] * SVs, axis=0)self.b = np.mean(SV_labels - np.dot(SVs, self.w))print("w", self.w)print("b", self.b)def score(self, X, y):predict = self.predict(X)print("predict", predict)print("target", y)return np.mean(predict == y)# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data
y = iris.target
y[y != 0] = -1
y[y == 0] = 1 # 分成两类# 为了方便可视化,只取前两个特征
X2 = X[:,:2]
# # 分别画出类别 0 和 1 的点
plt.scatter(X2[y == 1, 0], X2[y == 1, 1], color='red',label="class 1")
plt.scatter(X2[y == -1, 0], X2[y == -1, 1], color='blue',label="class -1")
plt.xlabel("Speal Width")
plt.ylabel("Speal Length")
plt.legend()
plt.show()# 数据预处理,将特征进行标准化,并将数据划分为训练集和测试集
scaler = StandardScaler()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=3706)
X_train_std = scaler.fit_transform(X_train)# 创建SVM对象并训练模型
svm = SMO(X_train_std, y_train, C=0.6, kernel='rbf', tol=0.001)
svm.train()# 预测测试集的结果并计算准确率
X_test_std = scaler.transform(X_test)
accuracy = svm.score(X_test_std, y_test)
print('正确率: {:.2%}'.format(accuracy))from sklearn.metrics import confusion_matrix, roc_curve, auc
y_pred=svm.predict(X_test_std)# 绘制混淆矩阵
def cal_ConfusialMatrix(y_true_labels, y_pred_labels):cm = np.zeros((2, 2))y_true_labels = [0 if x == -1 else x for x in y_true_labels]y_pred_labels = [0 if x == -1 else x for x in y_pred_labels]for i in range(len(y_true_labels)):cm[ y_true_labels[i], y_pred_labels[i] ] += 1plt.figure(figsize=(8, 6))sns.heatmap(cm, annot=True, fmt='g', cmap='Blues', xticklabels=['Predicted Negative', 'Predicted Positive'], yticklabels=['Actual Negative', 'Actual Positive'])plt.xlabel('Predicted label')plt.ylabel('True label')plt.title('Confusion Matrix')plt.show()y_pred=[int(x) for x in y_pred]
y_test=[int(x) for x in y_test]
cal_ConfusialMatrix(y_test, y_pred)

运行结果

由于鸢尾花为三分类,为了简化实验,这里先把setosa定义为1类(+1),versicolor、virginica组合定义为1类(-1)。
做出其对于sepal width和sepal length的分布图,可以看到,训练样本应该是线性可分的。
在这里插入图片描述

按照训练集:测试集=8:2的比例进行训练,之后进行测试集分类结果如下:

线性核:
在这里插入图片描述
在这里插入图片描述

高斯核:
在这里插入图片描述
在这里插入图片描述

🧡🧡总结🧡🧡

实验结果:
当使用的核函数为线性核时,准确率能达到100%,而使用高斯核时,准确率降低到96.67%(其实从混淆矩阵可以看到,只分类错误1个),且运行时间相对长很多。

分析原因:
线性核适用于数据集具有线性可分性的情况,即类别之间可以通过一条直线进行划分。在这种情况下,线性核可以提供较好的分类性能,并且计算效率较高。
高斯核可以更好地处理非线性问题。高斯核可以将输入空间映射到一个更高维度的特征空间,从而使得数据在新的特征空间中更容易被线性分割。但是,高斯核也有其缺点:在使用高斯核时,需要调整的超参数较多,如 gamma 参数和正则化参数 C,不正确的参数选择可能导致过拟合或欠拟合的问题。此外,高斯核计算复杂度较高,需要计算每个样本与其他样本之间的相似度,因此在数据集上的训练和预测时间可能较长。
因此综合分析,本实验中鸢尾花的特征为线性,因此使用线性核效果更佳。同时高斯核对参数比较敏感,实验中对于高斯核的参数选择可能也不够恰当。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/244178.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HuoCMS|免费开源可商用CMS建站系统HuoCMS 2.0下载(thinkphp内核)

HuoCMS是一套基于ThinkPhp6.0Vue 开发的一套HuoCMS建站系统。 HuoCMS是一套内容管理系统同时也是一套企业官网建设系统&#xff0c;能够帮过用户快速搭建自己的网站。可以满足企业站&#xff0c;外贸站&#xff0c;个人博客等一系列的建站需求。HuoCMS的优势: 可以使用统一后台…

从规则到神经网络:机器翻译技术的演化之路

文章目录 从规则到神经网络&#xff1a;机器翻译技术的演化之路一、概述1. 机器翻译的历史与发展2. 神经机器翻译的兴起3. 技术对现代社会的影响 二、机器翻译的核心技术1. 规则基础的机器翻译&#xff08;Rule-Based Machine Translation, RBMT&#xff09;2. 统计机器翻译&am…

一文掌握SpringBoot注解之@Component 知识文集(1)

&#x1f3c6;作者简介&#xff0c;普修罗双战士&#xff0c;一直追求不断学习和成长&#xff0c;在技术的道路上持续探索和实践。 &#x1f3c6;多年互联网行业从业经验&#xff0c;历任核心研发工程师&#xff0c;项目技术负责人。 &#x1f389;欢迎 &#x1f44d;点赞✍评论…

基本语法和 package 与 jar

3.基本语法 1.输入输出 // 导入 java.util 包中的 Scanner 类 import java.util.Scanner;// 定义名为 ScannerExample 的公共类 public class ScannerExample {// 主方法&#xff0c;程序的入口点public static void main(String[] args) {// 创建 Scanner 对象&#xff0c;用…

远程git开发

两种本地与远程仓库同步 """ 1&#xff09;你作为项目仓库初始化人员&#xff1a;线上要创建空仓库 > 本地初始化好仓库 > 建立remote链接(remote add) > 提交本地仓库到远程(push)2&#xff09;你作为项目后期开发人员&#xff1a;远程项目仓库已经创…

OpenHarmony 鸿蒙使用指南——概述

简介 OpenHarmony采用多内核&#xff08;Linux内核或者LiteOS&#xff09;设计&#xff0c;支持系统在不同资源容量的设备部署。当相同的硬件部署不同内核时&#xff0c;如何能够让设备驱动程序在不同内核间平滑迁移&#xff0c;消除驱动代码移植适配和维护的负担&#xff0c;…

深入浅出理解目标检测的NMS非极大抑制

一、参考资料 物体检测中常用的几个概念迁移学习、IOU、NMS理解 目标定位和检测系列&#xff08;3&#xff09;&#xff1a;交并比&#xff08;IOU&#xff09;和非极大值抑制&#xff08;NMS&#xff09;的python实现 Pytorch&#xff1a;目标检测网络-非极大值抑制(NMS) …

【软考中级】3天擦线过软考中级-软件设计师

前提&#xff1a;已有数据结构、操作系统、计算机网络、数据库基础 &#xff08;风险系数较高&#xff0c;请谨慎参考&#xff09; 贴一个成绩单hhhh 弯路&#xff1a;很早之前有看过一遍网上的软考课程&#xff0c;也记录了一些笔记&#xff0c;然而听完还是啥都记不住。 推…

【超简版,代码可用!】【0基础Python爬虫入门——下载歌曲/视频】

安装第三方模块— requests 完成图片操作后输入&#xff1a;pip install requests 科普&#xff1a; get:公开数据 post:加密 &#xff0c;个人信息 进入某音乐网页&#xff0c;打开开发者工具F12 选择网络&#xff0c;再选择—>媒体——>获取URL【先完成刷新页面】 科…

Pycharm详细安装 配置教程

继上次安装完Anaconda之后&#xff0c;现在更新最新版本的pycharm的安装和使用教程~~~ Anaconda&#xff1a;是一个开源的Python发行版本&#xff0c;其中包含了conda、Python等180多个科学包及其依赖项。【Anaconda和Pycharm详细安装 配置教程_anconda安装时clear the packag…

【Emgu CV教程】6.1、图像平滑之添加雪花噪声

文章目录 前言一、什么样的图像需要平滑&#xff1f;二、平滑的办法有哪些三、制作需要平滑的图片1.制作微小斑点的噪声2.制作稍大一点的噪声 总结 前言 首先说三点&#xff1a; 图像平滑&#xff0c;一般就是指对图像进行模糊或去噪&#xff0c;平滑后的图像减少了噪声&…

​ElasticSearch

目录 简介 基本概念 倒排索引 FST 简介 ES是一个基于lucene构建的&#xff0c;分布式的&#xff0c;RESTful的开源全文搜索引擎。支持对各种类型的数据的索引&#xff1b;搜索速度快&#xff0c;可以提供实时的搜索服务&#xff1b;便于水平扩展&#xff0c;每秒可以处理 …

【深度学习:Collaborative filtering 协同过滤】深入了解协同过滤:技术、应用与示例

此图显示了使用协作筛选预测用户评分的示例。起初&#xff0c;人们会对不同的项目&#xff08;如视频、图像、游戏&#xff09;进行评分。之后&#xff0c;系统将对用户对项目进行评分的预测&#xff0c;而用户尚未评分。这些预测基于其他用户的现有评级&#xff0c;这些用户与…

npm install运行报错npm ERR! gyp ERR! not ok问题解决

执行npm install的时候报错&#xff1a; npm ERR! path D:..\node_modules\\**node-sass** npm ERR! command failed ...npm ERR! gyp ERR! node -v v20.11.0 npm ERR! gyp ERR! node-gyp -v v3.8.0 npm ERR! gyp ERR! not ok根据报错信息&#xff0c;看出时node-sass运行出现…

Thinkphp框架,最新ICP备案查询系统源码,附搭建教程

源码介绍 最新ICP备案查询系统源码 附教程 thinkphp框架 本系统支持网址备案&#xff0c;小程序备案&#xff0c;APP备案查询&#xff0c;快应用备案查询 优势&#xff1a; 响应速度快&#xff0c;没有延迟&#xff0c;没有缓存&#xff0c;数据与官方同步

基于SpringBoot Vue美食网站系统

大家好✌&#xff01;我是Dwzun。很高兴你能来阅读我&#xff0c;我会陆续更新Java后端、前端、数据库、项目案例等相关知识点总结&#xff0c;还为大家分享优质的实战项目&#xff0c;本人在Java项目开发领域有多年的经验&#xff0c;陆续会更新更多优质的Java实战项目&#x…

【趣味游戏-08】20240123点兵点将点到谁就是谁(列表倒置reverse)

背景需求&#xff1a; 上个月&#xff0c;看到大4班一个孩子在玩“点兵点将点到谁就是谁”的小游戏&#xff0c;他在桌上摆放两排奥特曼卡片&#xff0c;然后点着数“点兵点将点到谁就是谁”&#xff0c;第10次点击的卡片&#xff0c;拿起来与同伴的卡片进行交换。他是从第一排…

【新书推荐】2.3节 二进制的简写和转换

本节内容&#xff1a;二进制 ■电子计算机为何采用二进制&#xff1a;电子计算机电路只有低电平和高电平两种状态&#xff0c;分别表示二进制数0和1。 ■二进制的简写形式&#xff1a;计算机内的数据都使用二进制数。但是二进制书写不便&#xff0c;通常我们采用十六进制作为二…

网络协议与攻击模拟_06攻击模拟SYN Flood

一、SYN Flood原理 在TCP三次握手过程中&#xff0c; 客户端发送一个SYN包给服务器服务端接收到SYN包后&#xff0c;会回复SYNACK包给客户端&#xff0c;然后等待客户端回复ACK包。但此时客户端并不会回复ACK包&#xff0c;所以服务端就只能一直等待直到超时。服务端超时后会…

React16源码: React中的completeUnitOfWork的源码实现

completeUnitOfWork 1 &#xff09;概述 各种不同类型组件的一个更新过程对应的是在执行 performUnitOfWork 里面的 beginWork 阶段它是去向下遍历一棵 fiber 树的一侧的子节点&#xff0c;然后遍历到叶子节点为止&#xff0c;以及 return 自己 child 的这种方式在 performUni…