pytorch使用SVM实现文本分类

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

完整代码:

import torch
import torch.nn as nn
import torch.optim as optim
import jieba
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn import metrics# 1. 数据准备(中文文本)
texts = ["今天的足球比赛非常激烈,球队表现出色,最终赢得了比赛。","NBA比赛今天开打,球员们的表现非常精彩,球迷们热情高涨。","张艺谋的新电影上映了,票房成绩非常好,观众反响热烈。","娱乐圈最近又出了一些新闻,明星们的私生活成了大家讨论的焦点。","昨晚的篮球赛真是太精彩了,球员们的进攻和防守都非常强硬。","李宇春在最新的音乐会上演出了她的新歌,现场观众反应热烈。","今年的世界杯比赛激烈异常,球队之间的竞争越来越激烈。","最近的综艺节目非常火,明星嘉宾的表现让观众们大笑不已。"
]# 标签:0表示体育,1表示娱乐
labels = [0, 0, 1, 1, 0, 1, 0, 1]# 2. 数据预处理:中文分词和 TF-IDF 特征提取
def jieba_cut(text):return " ".join(jieba.cut(text))texts_cut = [jieba_cut(text) for text in texts]vectorizer = TfidfVectorizer(max_features=10000)
X_tfidf = vectorizer.fit_transform(texts_cut).toarray()
y = np.array(labels)# 3. 数据集分割为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_tfidf, y, test_size=0.2, random_state=42)# 4. PyTorch 数据加载
class NewsGroupDataset(torch.utils.data.Dataset):def __init__(self, features, labels):self.features = torch.tensor(features, dtype=torch.float32)self.labels = torch.tensor(labels, dtype=torch.long)def __len__(self):return len(self.features)def __getitem__(self, idx):return self.features[idx], self.labels[idx]train_dataset = NewsGroupDataset(X_train, y_train)
test_dataset = NewsGroupDataset(X_test, y_test)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, shuffle=False)# 5. 定义 SVM 模型(使用线性层)
class SVM(nn.Module):def __init__(self, input_dim, output_dim):super(SVM, self).__init__()self.fc = nn.Linear(input_dim, output_dim)def forward(self, x):return self.fc(x)# 6. 获取特征数并初始化模型
input_dim = X_tfidf.shape[1]  # 自动获取特征数
model = SVM(input_dim=input_dim, output_dim=2)  # 使用特征数量设置输入维度
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 优化器,调整学习率# 7. 训练模型
num_epochs = 50  # 增加训练周期for epoch in range(num_epochs):model.train()running_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader)}, Accuracy: {100 * correct / total}%")# 8. 测试模型
model.eval()
correct = 0
total = 0with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Test Accuracy: {100 * correct / total}%")# 9. 输出性能指标
y_pred = []
y_true = []with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)y_pred.extend(predicted.numpy())y_true.extend(labels.numpy())print(metrics.classification_report(y_true, y_pred))# 10. 测试新样本
def predict(text, model, vectorizer):# 1. 进行分词text_cut = jieba_cut(text)# 2. 将文本转为 TF-IDF 特征向量text_tfidf = vectorizer.transform([text_cut]).toarray()# 3. 转换为 PyTorch 张量text_tensor = torch.tensor(text_tfidf, dtype=torch.float32)# 4. 模型预测model.eval()  # 设置模型为评估模式with torch.no_grad():output = model(text_tensor)_, predicted = torch.max(output.data, 1)# 5. 返回预测结果return predicted.item()# 测试一个新的中文文本
new_text = "今天的篮球比赛真是太精彩了,球员们的表现让大家都为之喝彩。"
predicted_label = predict(new_text, model, vectorizer)# 输出预测结果
if predicted_label == 0:print("预测类别: 体育")
else:print("预测类别: 娱乐")

1. 数据准备

  • 文本数据:我们定义了一个包含中文文本的列表,每条文本表示一个新闻或评论。
  • 标签:为每条文本分配了一个标签,0 代表“体育”,1 代表“娱乐”。

2. 数据预处理

  • 中文分词:使用 jieba 库对每条文本进行分词,并将分词后的结果连接成字符串。这是处理中文文本时的常见做法。
  • TF-IDF 特征提取:使用 TfidfVectorizer 将文本转化为数值特征。TF-IDF 是一种常见的文本表示方式,能够衡量单词在文档中的重要性。

3. 数据集分割

  • 使用 train_test_split 将数据分为训练集和测试集。80% 的数据用于训练,20% 用于测试。

4. PyTorch 数据加载

  • 定义 Dataset 类:创建了一个自定义的 NewsGroupDataset 类,继承自 torch.utils.data.Dataset,用于将文本特征和标签封装为 PyTorch 可用的数据集格式。
  • DataLoader:使用 DataLoader 将训练集和测试集数据进行批处理和加载。

5. 模型定义

  • 定义了一个简单的线性 SVM 模型。实际上,使用了一个线性层 (nn.Linear) 来进行分类,输入是文本的 TF-IDF 特征,输出是两个类别(体育或娱乐)。
  • 使用了 CrossEntropyLoss 作为损失函数,因为这是分类任务中常用的损失函数。
  • 优化器使用了随机梯度下降(SGD),并设置了学习率为 0.01。

6. 训练过程

  • 训练模型的过程包括:前向传播(计算输出),计算损失,反向传播(更新参数),并在每个 epoch 后输出损失和准确率。
  • 每个 batch 的训练过程中,模型会通过计算损失并进行优化,逐步提升准确率。

7. 测试和评估

  • 在测试过程中,将模型设置为评估模式 (model.eval()),并计算测试集上的准确率。通过比较预测标签与真实标签,计算正确的预测数量并输出准确率。
  • 使用 classification_report 来输出精确度、召回率和 F1 分数等更多的评估指标。

8. 预测新样本

  • 定义了一个 predict() 函数,用于预测新的文本样本的分类。
  • 预测过程包括:对文本进行分词,转化为 TF-IDF 特征,传入模型进行前向传播,最后返回模型预测的标签。

9. 输出

  • 代码的最后,会输出模型对新文本的预测结果,标明是属于体育类别还是娱乐类别。

关键技术点:

  • 中文分词:使用 jieba 对中文文本进行分词处理,这对于中文文本的处理至关重要。
  • TF-IDF:将文本转换为数值特征,便于模型处理。TF-IDF 是基于单词在文档中的出现频率及其在整个语料中的稀有度进行加权的。
  • 模型训练与评估:通过多轮训练提升模型准确度,使用测试集来评估模型的泛化能力。
  • PyTorch DataLoader:通过 DataLoader 高效地处理训练集和测试集,进行批处理和自动化管理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/12643.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

014-STM32单片机实现矩阵薄膜键盘设计

1.功能说明 本设计主要是利用STM32驱动矩阵薄膜键盘,当按下按键后OLED显示屏上会对应显示当前的按键键值,可以将此设计扩展做成电子秤、超市收银机、计算器等需要多个按键操作的单片机应用。 2.硬件接线 模块管脚STM32单片机管脚矩阵键盘行1PA0矩阵键盘…

鸿蒙Harmony-双向数据绑定MVVM以及$$语法糖介绍

鸿蒙Harmony-双向数据绑定MVVM以及$$语法糖介绍 1.1 双向数据绑定概念 在鸿蒙(HarmonyOS)应用开发中,双向数据改变(或双向数据绑定)是一种让数据模型和UI组件之间保持同步的机制,当数据发生变化时&#x…

Chromium132 编译指南 - Android 篇(一):编译前准备

1. 引言 欢迎来到《Chromium 132 编译指南 - Android 篇》系列的第一部分。本系列指南将引导您逐步完成在 Android 平台上编译 Chromium 132 版本的全过程。Chromium 作为一款由 Google 主导开发的开源浏览器引擎,为众多现代浏览器提供了核心驱动力。而 Android 作…

通向AGI之路:人工通用智能的技术演进与人类未来

文章目录 引言:当机器开始思考一、AGI的本质定义与技术演进1.1 从专用到通用:智能形态的范式转移1.2 AGI发展路线图二、突破AGI的五大技术路径2.1 神经符号整合(Neuro-Symbolic AI)2.2 世界模型架构(World Models)2.3 具身认知理论(Embodied Cognition)三、AGI安全:价…

使用 DeepSeek-R1 与 AnythingLLM 搭建本地知识库

一、下载地址Download Ollama on macOS 官方网站:Ollama 官方模型库:library 二、模型库搜索 deepseek r1 deepseek-r1:1.5b 私有化部署deepseek,模型库搜索 deepseek r1 运行cmd复制命令:ollama run deepseek-r1:1.5b 私有化…

C++游戏开发实战:从引擎架构到物理碰撞

📝个人主页🌹:一ge科研小菜鸡-CSDN博客 🌹🌹期待您的关注 🌹🌹 1. 引言 C 是游戏开发中最受欢迎的编程语言之一,因其高性能、低延迟和强大的底层控制能力,被广泛用于游戏…

计算机网络——三种交换技术

目录 电路交换——用于电话网络 电路交换的优点: 电路交换的缺点: 报文交换——用于电报网络 报文交换的优点: 报文交换的缺点: 分组交换——用于现代计算机网络 分组交换的优点: 分组交换的缺点 电路交换——…

35.Word:公积金管理中心文员小谢【37】

目录 Word1.docx ​ Word2.docx Word2.docx ​ 注意本套题还是与上一套存在不同之处 Word1.docx 布局样式的应用设计页眉页脚位置在水平/垂直方向上均相对于外边距居中排列:格式→大小对话框→位置→水平/垂直 按下表所列要求将原文中的手动纯文本编号分别替换…

小程序越来越智能化,作为设计师要如何进行创新设计

一、用户体验至上 (一)简洁高效的界面设计 小程序的特点之一是轻便快捷,用户期望能够在最短的时间内找到所需功能并完成操作。因此,设计师应致力于打造简洁高效的界面。避免过多的装饰元素和复杂的布局,采用清晰的导航…

全栈开发:使用.NET Core WebAPI构建前后端分离的核心技巧(二)

目录 配置系统集成 分层项目使用 筛选器的使用 中间件的使用 配置系统集成 在.net core WebAPI前后端分离开发中,配置系统的设计和集成是至关重要的一部分,尤其是在管理不同环境下的配置数据时,配置系统需要能够灵活、可扩展&#xff0c…

Linux——进程概念

目录 一、系统调用和库函数概念二、基本概念三、描述进程-PCB3.1 task_struct-PCB的一种3.2 task_ struct内容分类 四、组织进程五、查看进程六、通过系统调用获取进程标示符七、通过系统调用创建进程- fork初始7.1 fork函数创建子进程7.2 fork 之后通常要用 if 进行分流 八、进…

基于Springboot框架的学术期刊遴选服务-项目演示

项目介绍 本课程演示的是一款 基于Javaweb的水果超市管理系统,主要针对计算机相关专业的正在做毕设的学生与需要项目实战练习的 Java 学习者。 1.包含:项目源码、项目文档、数据库脚本、软件工具等所有资料 2.带你从零开始部署运行本套系统 3.该项目附…

2. K8S集群架构及主机准备

本次集群部署主机分布K8S集群主机配置主机静态IP设置主机名解析ipvs管理工具安装及模块加载主机系统升级主机间免密登录配置主机基础配置完后最好做个快照备份 2台负载均衡器 Haproxy高可用keepalived3台k8s master节点5台工作节点(至少2及以上)本次集群部署主机分布 K8S集群主…

Linux第105步_基于SiI9022A芯片的RGB转HDMI实验

SiI9022A是一款HDMI传输芯片,可以将“音视频接口”转换为HDMI或者DVI格式,是一个视频转换芯片。本实验基于linux的驱动程序设计。 SiI9022A支持输入视频格式有:xvYCC、BTA-T1004、ITU-R.656,内置DE发生器,支持SYNC格式…

物联网领域的MQTT协议,优势和应用场景

MQTT(Message Queuing Telemetry Transport)作为轻量级发布/订阅协议,凭借其低带宽消耗、低功耗与高扩展性,已成为物联网通信的事实标准。其核心优势包括:基于TCP/IP的异步通信机制、支持QoS(服务质量&…

Baklib推动数字化内容管理解决方案助力企业数字化转型

内容概要 在当今信息爆炸的时代,数字化内容管理成为企业提升效率和竞争力的关键。企业在面对大量数据时,如何高效地存储、分类与检索信息,直接关系到其经营的成败。数字化内容管理不仅限于简单的文档存储,更是整合了文档、图像、…

PAT甲级1052、Linked LIst Sorting

题目 A linked list consists of a series of structures, which are not necessarily adjacent in memory. We assume that each structure contains an integer key and a Next pointer to the next structure. Now given a linked list, you are supposed to sort the stru…

18 大量数据的异步查询方案

在分布式的应用中分库分表大家都已经熟知了。如果我们的程序中需要做一个模糊查询,那就涉及到跨库搜索的情况,这个时候需要看中间件能不能支持跨库求交集的功能。比如mycat就不支持跨库查询,当然现在mycat也渐渐被摒弃了(没有处理笛卡尔交集的…

UE求职Demo开发日志#21 背包-仓库-装备栏移动物品

1 创建一个枚举记录来源位置 UENUM(BlueprintType) enum class EMyItemLocation : uint8 {None0,Bag UMETA(DisplayName "Bag"),Armed UMETA(DisplayName "Armed"),WareHouse UMETA(DisplayName "WareHouse"), }; 2 创建一个BagPad和WarePa…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.22 形状操控者:转置与轴交换的奥秘

1.22 形状操控者:转置与轴交换的奥秘 目录 #mermaid-svg-Qb3eoIWrPbPGRVAf {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-Qb3eoIWrPbPGRVAf .error-icon{fill:#552222;}#mermaid-svg-Qb3eoIWrPbPGRVAf…