【Pytroch】基于支持向量机算法的数据分类预测(Excel可直接替换数据)

【Pytroch】基于支持向量机算法的数据分类预测(Excel可直接替换数据)

  • 1.模型原理
  • 2.数学公式
  • 3.文件结构
  • 4.Excel数据
  • 5.下载地址
  • 6.完整代码
  • 7.运行结果

1.模型原理

支持向量机(Support Vector Machine,SVM)是一种强大的监督学习算法,用于二分类和多分类问题。它的主要思想是找到一个最优的超平面,可以在特征空间中将不同类别的数据点分隔开。

下面是使用PyTorch实现支持向量机算法的基本步骤和原理:

  1. 数据准备: 首先,你需要准备你的训练数据。每个数据点应该具有特征(Feature)和对应的标签(Label)。特征是用于描述数据点的属性,标签是数据点所属的类别。

  2. 数据预处理: 根据SVM的原理,数据点需要线性可分。因此,你可能需要进行一些数据预处理,如特征缩放或标准化,以确保数据线性可分。

  3. 定义模型: 在PyTorch中,你可以定义一个支持向量机模型作为一个线性模型,例如使用nn.Linear

  4. 定义损失函数: SVM的目标是最大化支持向量到超平面的距离,即最大化间隔(Margin)。这可以通过最小化损失函数来实现,通常使用hinge loss(合页损失)。PyTorch提供了nn.MultiMarginLoss损失函数,它可以用于SVM训练。

  5. 定义优化器: 选择一个优化器,如torch.optim.SGD,来更新模型的参数以最小化损失函数。

  6. 训练模型: 使用训练数据对模型进行训练。在每个训练步骤中,计算损失并通过优化器更新模型参数。

  7. 预测: 训练完成后,你可以使用训练好的模型对新的数据点进行分类预测。对于二分类问题,可以使用模型的输出值来判断数据点所属的类别。

2.数学公式

当使用支持向量机(SVM)进行数据分类预测时,目标是找到一个超平面(或者在高维空间中是一个超曲面),可以将不同类别的数据点有效地分隔开。以下是SVM的数学原理:

  1. 超平面方程: 在二维情况下,超平面可以表示为

    w 1 x 1 + w 2 x 2 + b = 0 w_1 x_1 + w_2 x_2 + b = 0 w1x1+w2x2+b=0

  2. 决策函数: 数据点 (x) 被分为两个类别的决策函数为

    f ( x ) = w T x + b f(x) = w^T x + b f(x)=wTx+b

  3. 间隔(Margin): 对于一个给定的超平面,数据点到超平面的距离被称为间隔。支持向量机的目标是找到能最大化间隔的超平面。间隔可以用下面的公式计算:

    间隔 = 2 ∥ w ∥ \text{间隔} = \frac{2}{\|w\|} 间隔=w2

  4. 支持向量: 支持向量是离超平面最近的那些数据点。这些点对于确定超平面的位置和间隔非常重要。支持向量到超平面的距离等于间隔。

  5. 最大化间隔: SVM 的目标是找到一个超平面,使得所有支持向量到该超平面的距离(即间隔)都最大化。这等价于最小化法向量的范数 (|w|),即:

    最小化 1 2 ∥ w ∥ 2 \text{最小化} \quad \frac{1}{2}\|w\|^2 最小化21w2

  6. 对偶问题和核函数: 对偶问题的解决方法涉及到拉格朗日乘子,可以得到一个关于训练数据点的内积的表达式。这样,如果直接在高维空间中计算内积是非常昂贵的,可以使用核函数来避免高维空间的计算。核函数将数据映射到更高维的空间,并在计算内积时使用高维空间的投影,从而实现了在高维空间中的计算,但在计算上却更加高效。

总之,SVM利用线性超平面来分隔不同类别的数据点,通过最大化支持向量到超平面的距离来实现分类。对偶问题和核函数使SVM能够处理非线性问题,并在高维空间中进行计算。以上是SVM的基本数学原理。

3.文件结构

在这里插入图片描述

iris.xlsx						% 可替换数据集
Main.py							% 主函数

4.Excel数据

在这里插入图片描述

5.下载地址

- 资源下载地址

6.完整代码

import torch
import torch.nn as nn
import pandas as pd
import numpy as np  # Don't forget to import numpy for the functions using it
import matplotlib.pyplot as plt  # Import matplotlib for plotting
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrixclass SVM(nn.Module):def __init__(self, input_size, num_classes):super(SVM, self).__init__()self.linear = nn.Linear(input_size, num_classes)def forward(self, x):return self.linear(x)def train(model, X, y, num_epochs, learning_rate):criterion = nn.CrossEntropyLoss()  # Use CrossEntropyLoss for multi-class classificationoptimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)for epoch in range(num_epochs):inputs = torch.tensor(X, dtype=torch.float32)labels = torch.tensor(y, dtype=torch.long)  # Use long for class indicesoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')def test(model, X, y):inputs = torch.tensor(X, dtype=torch.float32)labels = torch.tensor(y, dtype=torch.long)  # Use long for class indiceswith torch.no_grad():outputs = model(inputs)_, predicted = torch.max(outputs, 1)accuracy = (predicted == labels).float().mean()print("真实值:", labels)print("预测值:", predicted)print(f'Accuracy on test set: {accuracy:.2f}')# Define the plot functions
def plot_confusion_matrix(conf_matrix, classes):plt.figure(figsize=(8, 6))plt.imshow(conf_matrix, cmap=plt.cm.Blues, interpolation='nearest')plt.title("Confusion Matrix")plt.colorbar()tick_marks = np.arange(len(classes))plt.xticks(tick_marks, classes)plt.yticks(tick_marks, classes)plt.xlabel("Predicted Label")plt.ylabel("True Label")plt.tight_layout()plt.show()def plot_predictions_vs_true(y_true, y_pred):plt.figure(figsize=(10, 6))plt.plot(y_true, 'go', label='True Labels')plt.plot(y_pred, 'rx', label='Predicted Labels')plt.title("True Labels vs Predicted Labels")plt.xlabel("Sample Index")plt.ylabel("Class Label")plt.legend()plt.show()def main():data = pd.read_excel('iris.xlsx')X = data.iloc[:, :-1].valuesy = data.iloc[:, -1].valueslabel_encoder = LabelEncoder()y = label_encoder.fit_transform(y)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)num_classes = len(label_encoder.classes_)model = SVM(X_train.shape[1], num_classes)num_epochs = 1000learning_rate = 0.001train(model, X_train, y_train, num_epochs, learning_rate)# Call the test function to get predictionsinputs = torch.tensor(X_test, dtype=torch.float32)labels = torch.tensor(y_test, dtype=torch.long)with torch.no_grad():outputs = model(inputs)_, predicted = torch.max(outputs, 1)# Convert torch tensors back to numpy arraysy_true = labels.numpy()y_pred = predicted.numpy()test(model, X_test, y_test)# Call the plot functionsconf_matrix = confusion_matrix(y_true, y_pred)plot_confusion_matrix(conf_matrix, label_encoder.classes_)plot_predictions_vs_true(y_true, y_pred)if __name__ == '__main__':main()

7.运行结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/89329.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

metaRTC7 demo mac/ios编译指南

概要 metaRTC7.0开始全面支持mac/ios操作系统,新版本7.0.023 mac os demo 包含有srs/zlm的推拉流演示。发布版自带了x64版第三方类库,arm版第三方类库还需开发者自己编译。 源码下载 下载文件metartc7.023.7z https://github.com/metartc/metaRTC/re…

php从静态资源到动态内容

1、从HTML到PHP demo.php:后缀由html直接改为php,实际上当前页面已经变成了动态的php应用程序脚本 demo.php: 允许通过<?php ... ?>标签,添加php代码到当前脚本中 php标签内部代码由php.exe解释, php标签之外的代码原样输出,仍由web服务器解析 <!DOCTYPE html>…

【云原生】kubernetes中容器的资源限制

目录 1 metrics-server 2 指定内存请求和限制 3 指定 CPU 请求和限制 资源限制 在k8s中对于容器资源限制主要分为以下两类: 内存资源限制: 内存请求&#xff08;request&#xff09;和内存限制&#xff08;limit&#xff09;分配给一个容器。 我们保障容器拥有它请求数量的…

【uniapp】使用Vs Code开发uniapp:

文章目录 一、使用命令行创建uniapp项目&#xff1a;二、安装插件与配置&#xff1a;三、编译和运行:四、修改pinia&#xff1a; 一、使用命令行创建uniapp项目&#xff1a; 二、安装插件与配置&#xff1a; 三、编译和运行: 该项目下的dist》dev》mp-weixin文件导入微信开发者…

时序预测 | MATLAB实现EEMD-LSTM、LSTM集合经验模态分解结合长短期记忆神经网络时间序列预测对比

时序预测 | MATLAB实现EEMD-LSTM、LSTM集合经验模态分解结合长短期记忆神经网络时间序列预测对比 目录 时序预测 | MATLAB实现EEMD-LSTM、LSTM集合经验模态分解结合长短期记忆神经网络时间序列预测对比效果一览基本介绍模型搭建程序设计参考资料 效果一览 基本介绍 时序预测 | …

【 运维这些事儿 】- Gerrit代码审查详

文章目录 背景作用代码审查工具Gerrit镜像构建Dockerfile 部署配置 Gitlab代码同步ssh-agent 相关概念常用命令Git 配置使用 Git Review针对已有项目添加commit-msg&#xff0c;用于自动添加changeId添加源配置 .gitreview备注指定审核人自定义git命令 开发使用代码审查 背景 …

nginx基于主机和用户访问控制以及缓存简单例子

一.基于主机访问控制 1.修改nginx.conf文件 2.到其他主机上测试 &#xff08;1&#xff09;191主机 &#xff08;2&#xff09;180主机 二.基于用户访问控制 1.修改nginx.conf文件 2.使用hpasswd为用户创建密码文件&#xff0c;并指定到刚才指定的密码文件webck 3.测试…

item_get_desc获得TB/TM商品描述

一、接口参数说明&#xff1a; item_get_desc-获得TB/TM商品描述&#xff0c;点击更多API调试&#xff0c;请移步注册API账号点击获取测试key和secret 公共参数 请求地址: https://api-gw.onebound.cn/taobao/item_get_desc 名称类型必须描述keyString是调用key&#xff08;…

Pytorch基于VGG cosine similarity实现简单的以图搜图(图像检索)

代码如下&#xff1a; from PIL import Image from torchvision import transforms import os import torch import torchvision import torch.nn.functional as Fclass VGGSim(torch.nn.Module):def __init__(self):super(VGGSim, self).__init__()blocks []blocks.append(t…

Python Opencv实践 - 图像缩放

import cv2 as cv import numpy as np import matplotlib.pyplot as pltimg_cat cv.imread("../SampleImages/cat.jpg", cv.IMREAD_COLOR) plt.imshow(img_cat[:,:,::-1])#图像绝对尺寸缩放 #cv.resize(src, dsize[, dst[, fx[, fy[, interpolation]]]]) #指定Size大…

企业服务器数据库遭到malox勒索病毒攻击后如何解决,勒索病毒解密

网络技术的发展不仅为企业带来了更高的效率&#xff0c;还为企业带来信息安全威胁&#xff0c;其中较为常见的就是勒索病毒攻击。近期&#xff0c;我们公司收到很多企业的求助&#xff0c;企业的服务器数据库遭到了malox勒索病毒攻击&#xff0c;导致系统内部的许多重要数据被加…

大数据课程I3——Kafka的消息流与索引机制

文章作者邮箱:yugongshiye@sina.cn 地址:广东惠州 ▲ 本章节目的 ⚪ 掌握Kafka的消息流处理; ⚪ 掌握Kafka的索引机制; ⚪ 掌握Kafka的消息系统语义; 一、Kafka消息流处理 1. Producer 写入消息 流程说明: 1. producer 要向Kafka生产消息,需要先通过…

使用node搭建服务器,前端自己写接口,将vue或react打包后生成的dist目录在本地运行

使用node.jsexpress或者使用node.jspm2搭建服务器&#xff0c;将vue或react打包后生成的dist目录在本地运行 vue项目打包后生成的dist目录如果直接在本地打开index.html,在浏览器中会报错&#xff0c;无法运行起来。 通常我是放到后端搭建的服务上面去运行&#xff0c;当时前端…

命令模式(C++)

定义 将一个请求(行为)封装为一个对象&#xff0c;从而使你可用不同的请求对客户进行参数化;对请求排队或记录请求日志&#xff0c;以及支持可撤销的操作。 应用场景 在软件构建过程中&#xff0c;“行为请求者”与“行为实现者”通常呈现一种“紧耦合”。但在某些场合——比…

jmeter提取token方式以及设置成全局变量(跨线程组传token值)方式

前言 今天Darren洋教大家如何使用jmeter中的插件来进行token值的提取与调用&#xff0c;今天Darren洋介绍两种jmeter提取token值的方式&#xff0c;一种是在当前线程组中直接提取token值&#xff0c;一种是跨线程组的方式进行token值的提取并调用给不同线程组里的HTTP接口使用。…

软件测试工程师的职业发展方向

一、软件测试工程师大致有4个发展方向: 1资深软件测试工程师 一般情况&#xff0c;软件测试工程师可分为测试工程师、高级测试工程师和资深测试工程师三个等级。 达到这个水平比较困难&#xff0c;这需要了解很多知识&#xff0c;例如C语言&#xff0c;JAVA语言&#xff0c;…

每天一道leetcode:1466. 重新规划路线(图论中等广度优先遍历)

今日份题目&#xff1a; n 座城市&#xff0c;从 0 到 n-1 编号&#xff0c;其间共有 n-1 条路线。因此&#xff0c;要想在两座不同城市之间旅行只有唯一一条路线可供选择&#xff08;路线网形成一颗树&#xff09;。去年&#xff0c;交通运输部决定重新规划路线&#xff0c;以…

el-tree-select那些事

下拉菜单树形选择器 用于记录工作及日常学习涉及到的一些需求和问题 vue3 el-tree-select那些事 1、获取el-tree-select选中的任意层级的节点对象 1、获取el-tree-select选中的任意层级的节点对象 1-1数据集 1-2画面 1-3代码 1-3-1画面代码 <el-tree-selectv-model"s…