NLP项目实战01--电影评论分类

介绍:

欢迎来到本篇文章!在这里,我们将探讨一个常见而重要的自然语言处理任务——文本分类。具体而言,我们将关注情感分析任务,即通过分析电影评论的情感来判断评论是正面的、负面的。

展示:
训练展示如下:

在这里插入图片描述
在这里插入图片描述

实际使用如下:

请添加图片描述

实现方式:

选择PyTorch作为深度学习框架,使用电影评论IMDB数据集,并结合torchtext对数据进行预处理。

环境:

Windows+Anaconda
重要库版本信息
torch==1.8.2+cu102
torchaudio==0.8.2
torchdata==0.7.1
torchtext==0.9.2
torchvision==0.9.2+cu102

实现思路:

1、数据集
本次使用的是IMDB数据集,IMDB是一个含有50000条关于电影评论的数据集
数据如下:
请添加图片描述
请添加图片描述

2、数据加载与预处理
使用torchtext加载IMDB数据集,并对数据集进行划分
具体划分如下:

TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
LABEL = data.LabelField(dtype=torch.float)
# Load the IMDB dataset
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

创建一个 Field 对象,用于处理文本数据。同时使用spacy分词器对文本进行分词,由于IMDB是英文的,所以使用en_core_web_sm语言模型。
创建一个 LabelField 对象,用于处理标签数据。设置dtype 参数为 torch.float,表示标签的数据类型为浮点型。

使用 datasets.IMDB.splits 方法加载 IMDB 数据集,并将文本字段 TEXT 和标签字段 LABEL 传递给该方法。返回的 train_data 和 test_data 包含了 IMDB 数据集的训练和测试部分。
下面是train_data的输出
请添加图片描述

3、构建词汇表与加载预训练词向量

TEXT.build_vocab(train_data,max_size=25000,vectors="glove.6B.100d",unk_init=torch.Tensor.normal_)
LABEL.build_vocab(train_data)

train_data:表示使用train_data中数据构建词汇表
max_size:限制词汇表的大小为 25000
vectors=“glove.6B.100d”:表示使用预训练的 GloVe 词向量,其中 “glove.6B.100d” 指的是包含 100 维向量的 6B 版 GloVe。
unk_init=torch.Tensor.normal_ :表示指定未知单词(UNK)的初始化方式,这里使用正态分布进行初始化。
LABEL.build_vocab(train_data):表示对标签进行类似的操作,构建标签的词汇表

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits( (train_data, valid_data, test_data), batch_size=BATCH_SIZE, device=device)

使用data.BucketIterator.splits 来创建数据加载器,包括训练、验证和测试集的迭代器。这将确保你能够方便地以批量的形式获取数据进行训练和评估。

4、定义神经网络
这里的网络定义比较简单,主要采用在词嵌入层(embedding)后接一个全连接层的方式完成对文本数据的分类。
具体如下:

class NetWork(nn.Module):def __init__(self,vocab_size,embedding_dim,output_dim,pad_idx):super(NetWork,self).__init__()self.embedding = nn.Embedding(vocab_size,embedding_dim,padding_idx=pad_idx)self.fc = nn.Linear(embedding_dim,output_dim)self.dropout = nn.Dropout(0.5)self.relu = nn.ReLU()def forward(self,x):embedded = self.embedding(x)embedded = embedded.permute(1,0,2) pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1)pooled = self.relu(pooled)pooled = self.dropout(pooled)output = self.fc(pooled)return output

5、模型初始化

vocab_size = len(TEXT.vocab)
embedding_dim  = 100
output = 1
pad_idx = TEXT.vocab.stoi[TEXT.pad_token]
model = NetWork(vocab_size,embedding_dim,output,pad_idx)
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)

定义模型的超参数,包括词汇表大小(vocab_size)、词向量维度(embedding_dim)、输出维度(output,在这个任务中是1,因为是二元分类,所以使用1),以及 PAD 标记的索引(pad_idx)

之后需要将预训练的词向量加载到嵌入层的权重中。TEXT.vocab.vectors 包含了词汇表中每个单词的预训练词向量,然后通过 copy_ 方法将这些词向量复制到模型的嵌入层权重中对网络进行初始化。这样做确保了模型的初始化状态良好。

6、训练模型

 total_loss = 0train_acc = 0 
model.train()
for batch in train_iterator:optimizer.zero_grad()preds = model(batch.text).squeeze(1)loss = criterion(preds,batch.label)total_loss += loss.item()batch_acc = (torch.round(torch.sigmoid(preds)) == batch.label).sum().item()train_acc += batch_accloss.backward()optimizer.step()average_loss = total_loss / len(train_iterator)train_acc /= len(train_iterator.dataset)

optimizer.zero_grad():表示将模型参数的梯度清零,以准备接收新的梯度。
preds = model(batch.text).squeeze(1):表示一次前向传播的过程,由于model输出的是torch.tensor(batch_size,1)所以使用squeeze(1)给其中的1维度数据去除,以匹配标签张量的形状
criterion(preds,batch.label):定义的损失函数 criterion 计算预测值 preds 与真实标签 batch.label 之间的损失

(torch.round(torch.sigmoid(preds)) == batch.label).sum().item():
通过比较模型的预测值与真实标签,计算当前批次的准确率,并将其累加到 train_acc 中
后面的就是进行反向传播更新参数,还有就是计算loss和train_acc的值了
7、模型评估:

model.eval()valid_loss = 0valid_acc = 0best_valid_acc = 0with torch.no_grad():for batch in valid_iterator:preds = model(batch.text).squeeze(1)loss = criterion(preds,batch.label)valid_loss += loss.item()batch_acc = ((torch.round(torch.sigmoid(preds)) == batch.label).sum().item())valid_acc += batch_acc

和训练模型的类似,这里就不解释了

8、保存模型
这里一共使用了两种保存模型的方式:

torch.save(model, "model.pth")
torch.save(model.state_dict(),"model.pth")

第一种方式叫做模型的全量保存
第二种方式叫做模型的参数保存

全量保存是保存了整个模型,包括模型的结构、参数、优化器状态等信息
参数量保存是保存了模型的参数(state_dict),不包括模型的结构
9、测试模型
测试模型的基本思路:
加载训练保存的模型、对待推理的文本进行预处理、将文本数据加载给模型进行推理

加载模型:

saved_model_path = "model.pth"
saved_model = torch.load(saved_model_path)

输入文本:
input_text = “Great service! The staff was very friendly and helpful.”

文本进行处理:

tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
tokenized_text = tokenizer(input_text)
indexed_text = [TEXT.vocab.stoi[token] for token in tokenized_text]
tensor_text = torch.LongTensor(indexed_text).unsqueeze(1).to(device)

模型推理:

saved_model.eval()
with torch.no_grad():output = saved_model(tensor_text).squeeze(1)prediction = torch.round(torch.sigmoid(output)).item()probability = torch.sigmoid(output).item()

由于笔者能力有限,所以在描述的过程中难免会有不准确的地方,还请多多包含!

更多NLP和CV文章以及完整代码请到"陶陶name"获取。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/216291.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

WPF仿网易云搭建笔记(0):项目搭建

文章目录 前言项目地址项目Nuget包搭建项目初始化项目架构App.xaml引入MateralDesign资源包 项目初步分析将标题栏去掉DockPanel初步布局 资源字典举例 结尾 前言 最近在找工作,发现没有任何的WPF可以拿的出手的工作经验,打算仿照网易云搭建一个WPF版本…

leaflet使用热力图报L找不到的问题ReferenceError: L is not defined at leaflet-heat.js:11:3

1.在main.js中直接引入会显示找不到L 2.解决办法 直接在组件中单独引入使用 可以直接显示出来。 至于为什么main中不能引入为全局,我是没找到,我的另外一个项目可以,新项目不行,不知哪里设置的问题

【Docker】swarm stack部署多service应用

前面我们已经学习过了Docker Compose,它可以用来进行一个完整的应用程序相互依赖的多个容器的编排的,但是缺点是只能在单机模式使用,不能在分布式多机器上使用;前面我们也学习了Docker swarm,它可以将单个服务部署为多…

每日一题,杨辉三角

给定一个非负整数 numRows,生成「杨辉三角」的前 numRows 行。 示例 1: 输入: numRows 5 输出: [[1],[1,1],[1,2,1],[1,3,3,1],[1,4,6,4,1]] 示例 2: 输入: numRows 1 输出: [[1]]

EarCMS 前台任意文件上传漏洞复现

0x01 产品简介 EarCMS是一个APP内测分发系统的平台。 0x02 漏洞概述 EarCMS前台put_upload.php中,存在pw参数硬编码问题,同时sql语句pdo使用错误,没有有效过滤sql语句,可以控制文件名和后缀,导致可以任意文件上传。 0x03 复现环境 FOFA:app="EearCMS" 0x0…

关系型数据库和非关系型数据库有什么区别?

一、什么是数据库? 数据库是一个结构化的数据集合,用于存储、管理和组织数据。它是一个电子化的文件柜,可以存储大量的数据,并提供了一种高效地检索、更新和管理数据的方法。数据库可以用于存储各种类型的数据,例如文…

CNN发展史脉络 概述图整理

CNN发展史脉络概述图整理,学习心得,供参考,错误请批评指正。 相关论文: LeNet:Handwritten Digit Recognition with a Back-Propagation Network; Gradient-Based Learning Applied to Document Recogniti…

前端面试——CSS面经(持续更新)

1. CSS选择器及其优先级 !important > 行内样式 > id选择器 > 类/伪类/属性选择器 > 标签/伪元素选择器 > 子/后台选择器 > *通配符 2. 重排和重绘是什么?浏览器的渲染机制是什么? 重排(回流):当增加或删除dom节点&…

【STM32】STM32学习笔记-LED闪烁 LED流水灯 蜂鸣器(06-2)

00. 目录 文章目录 00. 目录01. GPIO之LED电路图02. GPIO之LED接线图03. LED闪烁程序示例04. LED闪烁程序下载05. LED流水灯接线图06. LED流水灯程序示例07. 蜂鸣器接线图08. 蜂鸣器程序示例09. 下载10. 附录 01. GPIO之LED电路图 电路图示例1 电路图示例2 02. GPIO之LED接线图…

【unity】【WebRTC】从0开始创建一个Unity远程媒体流app-设置输入设备

【项目源码】 包括本篇需要的脚本都打包在项目源码中,可以通过下面链接下载: 【背景】 目前我们能投射到远端浏览器(或者任何其它Peer)的媒体流只有默认的MainCamera画面,其实我们还可以通过配置输入来传输操作输入信息,比如键鼠等。 【追加input processing组件】 …

C++刷题 -- 哈希表

C刷题 – 哈希表 文章目录 C刷题 -- 哈希表1.两数之和2.四数相加II3.三数之和(重点) 当我们需要查询一个元素是否出现过,或者一个元素是否在集合里的时候,就要第一时间想到哈希法; 1.两数之和 https://leetcode.cn/problems/two…

Node.js黑马时钟案例

先上没有使用node.js之前的html部分代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title></title><style>* {margin: 0;padding: 0;}html,body {height: 100%;overflow: hidden;backgrou…

Linux 常用命令汇总

1 linux定时任务 查看定时任务&#xff1a;crontab -l 每晚一点半执行定时任务&#xff1a; 30 1 * * * sh /var/lib/pgsql/pg_db_backup.sh >> /var/lib/pgsql/pg_db_backup.log 2>&1 配置定时任务&#xff1a;crontab -e 2 linux 内核版本查询 cat /etc/r…

Kafka-快速实战

Kafka介绍 ChatGPT对于Apache Kafka的介绍&#xff1a; Apache Kafka是一个分布式流处理平台&#xff0c;最初由LinkedIn开发并于2011年开源。它主要用于解决大规模数据的实时流式处理和数据管道问题。 Kafka是一个分布式的发布-订阅消息系统&#xff0c;可以快速地处理高吞吐…

JavaScript——基本语法

1.定义变量&#xff1a; 变量类型 变量名 变量值 var关键字声明变量 es6版本以上 var 可写可不写 <script>// 定义变量&#xff1a;变量类型 变量名 变量值 var关键字声明变量 es6版本以上 var 可写可不写var num 2;</script>2.条件控制 <script>var …

Java连接数据库的各种细节错误(细节篇)

目录 前后端联调&#xff08;传输文件&#xff09; ClassNotFoundException: SQLException: SQL语法错误: 数据库连接问题: 驱动问题: 资源泄露: 并发问题: 超时问题: 其他库冲突: 配置问题: 网络问题: SSL/TLS问题: 数据库权限问题: 驱动不兼容: 其他未知错误…

SIEM 解决方案的不同部署方式,如何选择SIEM 解决方案

安全信息和事件管理&#xff08;SIEM&#xff09;作为一种网络安全解决方案&#xff0c;是多种技术的融合&#xff0c;这些技术结合了包括安全信息管理和安全事件管理在内的流程。简单来说&#xff0c;SIEM 解决方案是一种重要的安全工具&#xff0c;它收集、存储和分析来自整个…

Impala4.x源码阅读笔记(二)——Impala如何高效读取Iceberg表

前言 本文为笔者个人阅读Apache Impala源码时的笔记&#xff0c;仅代表我个人对代码的理解&#xff0c;个人水平有限&#xff0c;文章可能存在理解错误、遗漏或者过时之处。如果有任何错误或者有更好的见解&#xff0c;欢迎指正。 Iceberg表是一种用于存储大规模结构化数据的…

Redis 环境搭建2

文章目录 第2关&#xff1a;使用 Redis 第2关&#xff1a;使用 Redis 本文是接着上篇文章写的第二关代码&#xff0c;部分人再进入第二关时不会保留第一关的配置的环境&#xff0c;可以通过下面一句代码进行检验。 redis-cli -p 7001 -c如果进入到了redis界面就是有环境&…