python pytorch实现RNN,LSTM,GRU,文本情感分类

python pytorch实现RNN,LSTM,GRU,文本情感分类

数据集格式:
在这里插入图片描述
有需要的可以联系我

实现步骤就是:
1.先对句子进行分词并构建词表
2.生成word2id
3.构建模型
4.训练模型
5.测试模型

代码如下:


import pandas as pd
import torch
import matplotlib.pyplot as plt
import jieba
import numpy as np"""
作业:
一、完成优化
优化思路1 jieba
2 取常用的3000字
3 修改model:rnn、lstm、gru二、完成测试代码
"""# 了解数据
dd = pd.read_csv(r'E:\peixun\data\train.csv')
# print(dd.head())# print(dd['label'].value_counts())# 句子长度分析
# 确定输入句子长度为 500
text_len = [len(i) for i in dd['text']]
# plt.hist(text_len)
# plt.show()
# print(max(text_len), min(text_len))# 基本参数 config
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('my device:', DEVICE)MAX_LEN = 500
BATCH_SIZE = 16
EPOCH = 1
LR = 3e-4# 构建词表 word2id
vocab = []
for i in dd['text']:vocab.extend(jieba.lcut(i, cut_all=True))  # 使用 jieba 分词# vocab.extend(list(i))vocab_se = pd.Series(vocab)
print(vocab_se.head())
print(vocab_se.value_counts().head())vocab = vocab_se.value_counts().index.tolist()[:3000]  # 取频率最高的 3000 token
# print(vocab[:10])
# exit()WORD_PAD = "<PAD>"
WORD_UNK = "<UNK>"
WORD_PAD_ID = 0
WORD_UNK_ID = 1vocab = [WORD_PAD, WORD_UNK] + list(set(vocab))print(vocab[:10])
print(len(vocab))vocab_dict = {k: v for v, k in enumerate(vocab)}# 词表大小,vocab_dict: word2id; vocab: id2word
print(len(vocab_dict))import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import pandas as pd# 定义数据集 Dataset
class Dataset(data.Dataset):def __init__(self, split='train'):# ChnSentiCorp 情感分类数据集path =  r'E:/peixun/data/' + str(split) + '.csv'self.data = pd.read_csv(path)def __len__(self):return len(self.data)def __getitem__(self, i):text = self.data.loc[i, 'text']label = self.data.loc[i, 'label']return text, label# 实例化 Dataset
dataset = Dataset('train')# 样本数量
print(len(dataset))
print(dataset[0])# 句子批处理函数
def collate_fn(batch):# [(text1, label1), (text2, label2), (3, 3)...]sents = [i[0][:MAX_LEN] for i in batch]labels = [i[1] for i in batch]inputs = []# masks = []for sent in sents:sent = [vocab_dict.get(i, WORD_UNK_ID) for i in list(sent)]pad_len = MAX_LEN - len(sent)# mask = len(sent) * [1] + pad_len * [0]# masks.append(mask)sent += pad_len * [WORD_PAD_ID]inputs.append(sent)# 只使用 lstm 不需要用 masks# masks = torch.tensor(masks)# print(inputs)inputs = torch.tensor(inputs)labels = torch.LongTensor(labels)return inputs.to(DEVICE), labels.to(DEVICE)# 测试 loader
loader = data.DataLoader(dataset,batch_size=BATCH_SIZE,collate_fn=collate_fn,shuffle=True,drop_last=False)inputs, labels = iter(loader).__next__()
print(inputs.shape, labels)# 定义模型
class Model(nn.Module):def __init__(self, vocab_size=5000):super().__init__()self.embed = nn.Embedding(vocab_size, 100, padding_idx=WORD_PAD_ID)# 多种 rnnself.rnn = nn.RNN(100, 100, 1, batch_first=True, bidirectional=True)self.gru = nn.GRU(100, 100, 1, batch_first=True, bidirectional=True)self.lstm = nn.LSTM(100, 100, 1, batch_first=True, bidirectional=True)self.l1 = nn.Linear(500 * 100 * 2, 100)self.l2 = nn.Linear(100, 2)def forward(self, inputs):out = self.embed(inputs)out, _ = self.lstm(out)out = out.reshape(BATCH_SIZE, -1)  # 16 * 100000out = F.relu(self.l1(out))  # 16 * 100out = F.softmax(self.l2(out))  # 16 * 2return out# 测试 Model
model = Model()
print(model)# 模型训练
dataset = Dataset()
loader = data.DataLoader(dataset,batch_size=BATCH_SIZE,collate_fn=collate_fn,shuffle=True)model = Model().to(DEVICE)# 交叉熵损失
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)model.train()
for e in range(EPOCH):for idx, (inputs, labels) in enumerate(loader):# 前向传播,计算预测值out = model(inputs)# 计算损失loss = loss_fn(out, labels)# 反向传播,计算梯度loss.backward()# 参数更新optimizer.step()# 梯度清零optimizer.zero_grad()if idx % 10 == 0:out = out.argmax(dim=-1)acc = (out == labels).sum().item() / len(labels)print('>>epoch:', e,'\tbatch:', idx,'\tloss:', loss.item(),'\tacc:', acc)# 模型测试
test_dataset = Dataset('test')
test_loader = data.DataLoader(test_dataset,batch_size=BATCH_SIZE,collate_fn=collate_fn,shuffle=False)loss_fn = nn.CrossEntropyLoss()out_total = []
labels_total = []model.eval()
for idx, (inputs, labels) in enumerate(test_loader):out = model(inputs)loss = loss_fn(out, labels)out_total.append(out)labels_total.append(labels)if idx % 50 == 0:print('>>batch:', idx, '\tloss:', loss.item())correct=0
sumz=0
for i in range(len(out_total)):out = out_total[i].argmax(dim=-1)correct = (out == labels_total[i]).sum().item() +correctsumz=sumz+len(labels_total[i])#acc = (out_total == labels_total).sum().item() / len(labels_total)print('>>acc:', correct/sumz)

运行结果如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/212016.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

各大期刊网址

1.NeurIPS&#xff0c;全称Annual Conference on Neural Information Processing Systems&#xff0c; 是机器学习领域的顶级会议&#xff0c;与ICML&#xff0c;ICLR并称为机器学习领域难度最大&#xff0c;水平最高&#xff0c;影响力最强的会议&#xff01; NeurIPS是CCF 推…

List系列集合

List系列集合特点&#xff1a;有序&#xff0c;可重复&#xff0c;有索引 ArrayList&#xff1a;有序&#xff0c;可重复&#xff0c;有索引 LinkedList&#xff1a;有序&#xff0c;可重复&#xff0c;有索引 &#xff08;底层实现不同&#xff01;适合的场景不同&#xff01;…

keepalive路由缓存实现前进刷新后退缓存

1.在app.vue中配置全局的keepalive并用includes指定要缓存的组件路由name名字数组 <keep-alive :include"keepCachedViews"><router-view /></keep-alive>computed: {keepCachedViews() {console.log(this.$store.getters.keepCachedViews, this.…

常见动物经济手术3d模拟交互演示教学实现了教育资源的共享

动物常见病防治是兽医必备的技能&#xff0c;为了让实习兽医在上岗作业前拥有丰富的常见病防治经验。借助动物常见病防治VR虚拟仿真技术开展动物常见病防治VR模拟实操培训&#xff0c;能极大方便院校实训。 提高教学质量 传统的动物医学教学往往依赖于理论知识和实验室实践&…

C语言实现植物大战僵尸(完整版)

实现这个游戏需要Easy_X 这个在我前面一篇C之番外篇爱心代码有程序教你怎么下载&#xff0c;大家可自行查看 然后就是需要植物大战僵尸的素材和音乐&#xff0c;需要的可以在评论区 首先是main.cpp //开发日志 //1导入素材 //2实现最开始的游戏场景 //3实现游戏顶部的工具栏…

《Java 并发编程艺术》笔记(上)

如何减少上下文切换 减少上下文切换的方法有无锁并发编程、CAS算法、使用最少线程和使用协程。 无锁并发编程&#xff1a;多线程竞争锁时&#xff0c;会引起上下文切换&#xff0c;所以多线程处理数据时&#xff0c;可以用一些办法来避免使用锁。如将数据的 ID 按照 Hash 算法…

Dockerfile与Docker网络

一、Dockerfile 1、概念&#xff1a; Dockerfile是用来构建docker镜像的文本文件&#xff0c;是由构建镜像所需要的指令和参数构建的脚本。 2、构建步骤&#xff1a; ① 编写Dockerfile文件 ② docker build命令构建镜像 ③ docker run依据镜像运行容器实例 Dockerfile …

探索低代码之路——JNPF

目录 一、低代码行业现状 二、产品分析 1.可视化应用开发 2.流程管理 3.整个平台源码合作 三、架构和技术 技术栈 四、规划和展望 低代码平台&#xff08;Low-code Development Platform&#xff09;是一种让开发者通过拖拽和配置&#xff0c;而非传统的手动编写大量代…

外包干了8个月,技术退步明显.......

先说一下自己的情况&#xff0c;大专生&#xff0c;18年通过校招进入武汉某软件公司&#xff0c;干了接近4年的功能测试&#xff0c;今年年初&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落! 而我已经在一个企业干了四年的功能测…

java单人聊天

服务端 package 单人聊天;import java.awt.BorderLayout; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; import…

Ubuntu22.04 使用Docker部署Neo4j出错 Exited(70)

项目场景&#xff1a; 最近需要使用Neo4j图数据库&#xff0c;因此打算使用docker部署 环境使用WSL Ubuntu22.04 问题描述 拉下最新Neo4j镜像&#xff0c;执行命令部署 启动容器脚本 docker run -d -p 7474:7474 -p 7687:7687 \ --name neo4j \ --env "NEO4J_AUTHneo…

洗鞋机行业分析:2023年市场发展前景及消费现状

随着消费主力的转移&#xff0c;年轻群体在消费中的话语权和影响力越来越大&#xff0c;“精致懒”正在成为潮流。洗鞋机作为消费升级时代的产物&#xff0c;自诞生以来&#xff0c;经过十几年的发展&#xff0c;逐渐被年轻消费者熟知&#xff0c;洗鞋机品牌阵营和产品种类也变…

任课老师和班主任的区别

任课老师和班主任都是学校中非常重要的角色&#xff0c;他们的工作性质和职责略有不同。作为一位老师&#xff0c;我来说说任课老师和班主任的区别。 任课老师的主要职责是教授学科知识&#xff0c;并负责解答学生在学习过程中遇到的问题。他们的工作涉及到备课、讲课、布置作业…

Qt之基于QMediaPlayer的音视频播放器(支持常见音视频格式)

Qt自带了一个Media Player的例子,如下图所示: 但是运行这个例子机会发现,连最基本的MP4格式视频都播放不了。因为QMediaPlayer是个壳(也可以叫框架),依赖本地解码器,视频这块默认基本上就播放个MP4,甚至连MP4都不能播放,如果要支持其他格式需要下载k-lite或者LAVFilte…

Java 并发编程面试题——Java 线程间通信方式

目录 1.✨Java 线程间有哪些通信方式&#xff1f;1.1.volatile 和 synchronized 关键字1.2.等待/通知机制1.2.1.概述1.2.2.经典范式 1.3.管道输入/输出流1.4.信号量 2.Thread.join() 有什么作用&#xff1f;它的使用场景是什么&#xff1f;3.Java 中需要主线程等待子线程执行完…

mac M系列芯片安装chatGLM3-6b模型

1 环境安装 1.1 mac安装conda. 下载miniconda&#xff0c;并安装 curl -O https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh sh Miniconda3-latest-MacOSX-arm64.sh1.2 创建虚拟环境并激活 创建名为chatglm3的虚拟环境&#xff0c;python版本为3.10…

js vue 输入正确手机号/邮箱后,激活“发送验证码”按钮

按钮禁止点击状态&#xff1a; 按钮能够点击状态&#xff1a; 我采用的方式是监听手机号/邮箱输入框的输入事件&#xff0c;即实判断用户输入的数据是否满足规则&#xff0c;如果满足手机号/邮箱规则&#xff0c;则激活“获取验证码”按钮。 话不多说&#xff0c;上代码 样式…

IMR TBR TBDR

IMR Immediate Mode Rendering(即时渲染)&#xff0c;是 PC 和主机 GPU 使用的渲染方式 IMR下的渲染示意图 每次渲染&#xff0c;都要读写Frame Buffer和Depth Buffer IMR优化 IMR需要大量的带宽和功耗&#xff0c;优化方式是L1、L2 Cache大缓存&#xff0c;不适用于移动G…

贪心算法及相关题目

贪心算法概念 贪心算法是指&#xff0c;在对问题求解时&#xff0c;总是做出在当前看来是最好的选择。也就是说&#xff0c;不从整体最优上加以考虑&#xff0c;算法得到的是在某种意义上的局部最优解 。 贪心算法性质&#xff08;判断是否可以使用贪心算法&#xff09; 1、贪…

微信小程序中生命周期钩子函数

微信小程序 App 的生命周期钩子函数有以下 7 个&#xff1a; onLaunch(options)&#xff1a;当小程序初始化完成时&#xff0c;会触发 onLaunch&#xff08;全局只触发一次&#xff09;。onShow(options)&#xff1a;当小程序启动或从后台进入前台显示时&#xff0c;会触发 on…