使用RNN完成IMDB电影评论情感分析

使用RNN完成IMDB电影评论情感分析

    • 任务描述
    • 一、环境设置
    • 二、数据准备
    • 2.1 参数设置
    • 2.2 用padding的方式对齐数据
    • 2.3 用Dataset与DataLoader加载
    • 三、模型配置
    • 四、模型训练
    • 五、模型评估
    • 六、模型预测

任务描述

本示例教程演示如何在IMDB数据集上使用RNN网络完成文本分类的任务。IMDB数据集包含对电影评论进行正向和负向标注的数据,共有25000条文本数据作为训练集,25000条文本数据作为测试集。数据集的官方地址为:IMDB Dataset

在这里插入图片描述

一、环境设置

本示例基于飞桨开源框架2.0版本。

import paddle
import numpy as np
import matplotlib.pyplot as plt
import paddle.nn as nnprint(paddle.__version__)  # 查看当前版本# cpu/gpu环境选择,在 paddle.set_device() 输入对应运行设备。
device = paddle.set_device('gpu')2.0.1

二、数据准备

由于IMDB是NLP领域中常见的数据集,飞桨框架将其内置,路径为paddle.text.datasets.Imdb。通过mode参数可以控制训练集与测试集。

print('loading dataset...')
train_dataset = paddle.text.datasets.Imdb(mode='train')
test_dataset = paddle.text.datasets.Imdb(mode='test')
print('loading finished')

构建了训练集与测试集后,可以通过word_idx获取数据集的词表。

word_dict = train_dataset.word_idx  # 获取数据集的词表# add a pad token to the dict for later padding the sequence
word_dict['<pad>'] = len(word_dict)for k in list(word_dict)[:5]:print("{}:{}".format(k.decode('ASCII'), word_dict[k]))print("...")for k in list(word_dict)[-5:]:print("{}:{}".format(k if isinstance(k, str) else k.decode('ASCII'), word_dict[k]))print("totally {} words".format(len(word_dict)))

2.1 参数设置

在这里设置词表大小、embedding大小、batch_size等参数。

vocab_size = len(word_dict) + 1
print(vocab_size)
emb_size = 256
seq_len = 200
batch_size = 32
epochs = 2
pad_id = word_dict['<pad>']classes = ['negative', 'positive']# 生成句子列表
def ids_to_str(ids):words = []for k in ids:w = list(word_dict)[k]words.append(w if isinstance(w, str) else w.decode('ASCII'))return " ".join(words)

2.2 用padding的方式对齐数据

文本数据中,每一句话的长度都是不一样的,为了方便后续的神经网络的计算,通常使用padding的方式对齐数据。

# 读取数据归一化处理
def create_padded_dataset(dataset):padded_sents = []labels = []for batch_id, data in enumerate(dataset):sent, label = data[0], data[1]padded_sent = np.concatenate([sent[:seq_len], [pad_id] * (seq_len - len(sent))]).astype('int32')padded_sents.append(padded_sent)labels.append(label)return np.array(padded_sents), np.array(labels)# 对train、test数据进行实例化
train_sents, train_labels = create_padded_dataset(train_dataset)
test_sents, test_labels = create_padded_dataset(test_dataset)# 查看数据大小及举例内容
print(train_sents.shape)
print(train_labels.shape)
print(test_sents.shape)
print(test_labels.shape)for sent in train_sents[:3]:print(ids_to_str(sent))

2.3 用Dataset与DataLoader加载

将前面准备好的训练集与测试集用DatasetDataLoader封装后,完成数据的加载。

class IMDBDataset(paddle.io.Dataset):'''继承paddle.io.Dataset类进行封装数据'''def __init__(self, sents, labels):self.sents = sentsself.labels = labelsdef __getitem__(self, index):data = self.sents[index]label = self.labels[index]return data, labeldef __len__(self):return len(self.sents)train_dataset = IMDBDataset(train_sents, train_labels)
test_dataset = IMDBDataset(test_sents, test_labels)train_loader = paddle.io.DataLoader(train_dataset, return_list=True,shuffle=True, batch_size=batch_size, drop_last=True)
test_loader = paddle.io.DataLoader(test_dataset, return_list=True,shuffle=True, batch_size=batch_size, drop_last=True)

三、模型配置

本示例中使用一个序列特性的RNN网络,在查找到每个词对应的embedding后,取平均作为一个句子的表示。然后用Linear进行线性变换,同时使用Dropout防止过拟合。

class MyRNN(paddle.nn.Layer):def __init__(self):super(MyRNN, self).__init__()self.embedding = nn.Embedding(vocab_size, 256)self.rnn = nn.SimpleRNN(256, 256, num_layers=2, direction='forward',dropout=0.5)self.linear = nn.Linear(in_features=256*2, out_features=2)self.dropout = nn.Dropout(0.5)def forward(self, inputs):emb = self.dropout(self.embedding(inputs))output, hidden = self.rnn(emb)hidden = paddle.concat((hidden[-2,:,:], hidden[-1,:,:]), axis = 1)hidden = self.dropout(hidden)return self.linear(hidden) 

四、模型训练

# 可视化定义
def draw_process(title, color, iters, data, label):plt.title(title, fontsize=24)plt.xlabel("iter", fontsize=20)plt.ylabel(label, fontsize=20)plt.plot(iters, data, color=color, label=label) plt.legend()plt.grid()plt.show()# 对模型进行封装
def train(model):model.train()opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())steps = 0Iters, total_loss, total_acc = [], [], []for epoch in range(epochs):for batch_id, data in enumerate(train_loader):steps +=1sent = data[0]label = data[1]logits = model(sent)loss = paddle.nn.functional.cross_entropy(logits, label)acc = paddle.metric.accuracy(logits, label)if batch_id % 500 == 0:  # 500个epoch输出一次结果Iters.append(steps)total_loss.append(loss.numpy()[0])total_acc.append(acc.numpy()[0])print("epoch: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, loss.numpy()))loss.backward()opt.step()opt.clear_grad()# evaluate model after one epochmodel.eval()accuracies = []losses = []for batch_id, data in enumerate(test_loader):sent = data[0]label = data[1]logits = model(sent)loss = paddle.nn.functional.cross_entropy(logits, label)acc = paddle.metric.accuracy(logits, label)accuracies.append(acc.numpy())losses.append(loss.numpy())avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)print("[validation] accuracy: {}, loss: {}".format(avg_acc, avg_loss))model.train()# 保存模型paddle.save(model.state_dict(), str(epoch) + "_model_final.pdparams")# 可视化查看draw_process("training loss", "red", Iters, total_loss, "training loss")draw_process("training acc", "green", Iters, total_acc, "training acc")model = MyRNN()
train(model)

五、模型评估

model_state_dict = paddle.load('1_model_final.pdparams')  # 导入模型
model = MyRNN()
model.set_state_dict(model_state_dict) 
model.eval()
accuracies = []
losses = []for batch_id, data in enumerate(test_loader):sent = data[0]label = data[1]logits = model(sent)loss = paddle.nn.functional.cross_entropy(logits, label)acc = paddle.metric.accuracy(logits, label)accuracies.append(acc.numpy())losses.append(loss.numpy())avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)
print("[validation] accuracy: {}, loss: {}".format(avg_acc, avg_loss))

六、模型预测

def ids_to_str(ids):words = []for k in ids:w = list(word_dict)[k]words.append(w if isinstance(w, str) else w.decode('UTF-8'))return " ".join(words)label_map = {0: "negative", 1: "positive"}# 导入模型
model_state_dict = paddle.load('1_model_final.pdparams')
model = MyRNN()
model.set_state_dict(model_state_dict) 
model.eval()for batch_id, data in enumerate(test_loader):sent = data[0]results = model(sent)predictions = []for probs in results:# 映射分类labelidx = np.argmax(probs)labels = label_map[idx]predictions.append(labels)for i, pre in enumerate(predictions):print(' 数据: {} \n 情感: {}'.format(ids_to_str(sent[0]), pre))breakbreak

以上是使用RNN完成IMDB电影评论情感分析的示例。通过搭建RNN网络,对文本数据进行预处理、模型训练和评估,最终实现了对电影评论情感的分类。在实际应用中,可以根据需求调整网络结构和超参数,提高模型性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/237037.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

二极管限幅电路理论分析,工作原理+作用

一、限幅是什么意思&#xff1f; 限幅也就是&#xff0c;将电压限制在某个范围内&#xff0c;去除交流信号的一部分但不会对波形的剩余部分造成影响。通常来说&#xff0c;限幅电路主要是由二极管构成&#xff0c;波形的形状取决于电路的配置和设计。二、限幅电路工作原…

x-cmd pkg | busybox - 嵌入式 Linux 的瑞士军刀

目录 简介首次用户功能特点竞品和相关作品 进一步阅读 简介 busybox 是一个开源的轻量级工具集合&#xff0c;集成了一批最常用 Unix 工具命令&#xff0c;只需要几 MB 大小就能覆盖绝大多数用户在 Linux 的使用&#xff0c;能在多款 POSIX 环境的操作系统&#xff08;如 Linu…

Vulnhub靶机:driftingblues 3

一、介绍 运行环境&#xff1a;Virtualbox 攻击机&#xff1a;kali&#xff08;10.0.2.15&#xff09; 靶机&#xff1a;driftingblues3&#xff08;10.0.2.19&#xff09; 目标&#xff1a;获取靶机root权限和flag 靶机下载地址&#xff1a;https://www.vulnhub.com/entr…

WebGL在实验室方向的应用

WebGL在实验室方向的应用涉及到实验过程的可视化、数据分析、模拟等方面。以下是一些WebGL在实验室领域的应用示例&#xff0c;希望对大家有所帮助。北京木奇移动技术有限公司&#xff0c;专业的软件外包开发公司&#xff0c;欢迎交流合作。 1.分子模型和化学反应模拟&#xff…

D25XB80-ASEMI开关电源桥堆D25XB80

编辑&#xff1a;ll D25XB80-ASEMI开关电源桥堆D25XB80 型号&#xff1a;D25XB80 品牌&#xff1a;ASEMI 封装&#xff1a;GBJ-5&#xff08;带康铜丝&#xff09; 特性&#xff1a;插件、整流桥 平均正向整流电流&#xff08;Id&#xff09;&#xff1a;25A 最大反向击…

C++笔记

1.输入输出流 在C中要想输入和输出 我们会经常用到 #include <stdio.h>在C中头文件的命名风格不用.h #include <iostream>using namespace std;为什么要用上面俩句话的解释&#xff08;自己写的博客&#xff09; c中 为什么要写&#xff1c;iostream&#xff1e;…

数模学习day13-典型相关分析

典型相关分析&#xff08;Canonical Correlation analysis&#xff09; 研究两组变量&#xff08;每组变量中都可能有多个指标&#xff09;之间相关关系的一种多元统计方法。它能够揭示出两组变量之间的内在联系。 注&#xff1a;本文源于数学建模学习交流相关公众号观…

使用JGit拉取代码提示未授权not authorized

原因&#xff1a;2021年8月13日后不支持密码登录&#xff0c;需要使用token验证 调用时候需要先去git仓库创建个人令牌 需要在安全中心创建个人token&#xff0c;使用token名称作为账号&#xff0c;使用token作为密码。 另&#xff1a; Github克隆仓库的三种方式对比&#xff…

哪里能找到好用的PPT模板?12个免费模板网站让你畅快办公!

你是否有过这样的经历&#xff0c;在准备重要会议或者演讲的时候&#xff0c;为找不到合适的PPT模板而困扰&#xff1f;或是在网上漫无目的地搜寻&#xff0c;结果收获的是设计平淡无奇的PPT模板&#xff1f; 如果你有同样的疑问&#xff0c;那么你来对地方了&#xff01;在这…

基于嵌入式的智能台灯系统

基于嵌入式的智能台灯系统 功能说明 通过微信小程序控制台灯的亮灭及亮度。采集温湿度传到微信小程序上&#xff0c;台灯可以显示实时北京时间。 功能展示 01智能台灯演示 Mqtt服务器 http://www.yoyolife.fun/iot&#xff1a;Mqtt服务器&#xff0c;我是在这里注册的&#x…

WPF实现右键选定TreeViewItem

在WPF中&#xff0c;TreeView默认情况是不支持右键选定的&#xff0c;也就是说&#xff0c;当右键点击某节点时&#xff0c;是无法选中该节点的。当我们想在TreeViewItem中实现右键菜单时&#xff0c;往往希望在弹出菜单的同时选中该节点&#xff0c;以使得菜单针对选中的节点生…

程序员如何弯道超车?周末有奇效

作为一名程序员&#xff0c;不断提升自己的技能和知识是至关重要的。然而&#xff0c;在繁忙的工作日常中&#xff0c;很难有足够的时间和精力来学习新技术或深入研究。因此&#xff0c;周末成为了一个理想的时机&#xff0c;可以专注于个人发展和技能提升。所以程序员如何利用…

LNMP架构及应用部署

众所周知&#xff0c;LAMP平台是目前应用最为广泛的网站服务器架构&#xff0c;其中的“A”对应着Web服务软件Apache HTTP Server。随着Nginx在企业中的使用越来越多&#xff0c;LNMP架构也受到越来越多Linux系统工程师的青睐 1.1 构建LNMP网站平台 就像构建LAMP平台一样&…

Java实现在线编辑预览office文档

文章目录 1 在线编辑1.1 PageOffice简介1.2 前端项目1.2.1 配置1.2.2 页面部分 1.3 后端项目1.3.1 pom.xml1.3.2 添加配置1.3.3 controller 2 在线预览2.1 引言2.2 市面上现有的文件预览服务2.2.1 微软2.2.2 Google Drive查看器2.2.3 阿里云 IMM2.2.4 XDOC 文档预览2.2.5 Offic…

Vue3-完成任意组件之间的传值

一、props&#xff08;只限于父子之间&#xff0c;不嫌麻烦可以不断传&#xff09; 父给子传值&#xff0c;子接收defineProps 父给子传事件&#xff0c;子接收defineProps&#xff0c;并触发事件的时候传值&#xff0c;然后父通过事件的回调函数拿到子传的值 二、mitt&#…

【UE Niagara学习笔记】07 - 火焰的热变形效果

目录 效果 步骤 一、创建热变形材质 二、添加新的发射器 2.1 设置粒子材质 2.2 设置粒子初始大小 2.3 设置粒子持续生成 三、修改材质 四、设置粒子效果 在上一篇博客&#xff08;【UE Niagara学习笔记】06 - 制作火焰喷射过程中飞舞的火星&#xff09;的基础上继续…

Python自动化我选DrissionPage,弃用Selenium

DrissionPage 是一个基于 python 的网页自动化工具。 它既能控制浏览器&#xff0c;也能收发数据包&#xff0c;还能把两者合而为一。 可兼顾浏览器自动化的便利性和 requests 的高效率。 它功能强大&#xff0c;内置无数人性化设计和便捷功能。 它的语法简洁而优雅&#x…

VMware workstation安装debian-12.1.0虚拟机并配置网络

VMware workstation安装debian-12.1.0虚拟机并配置网络 Debian 是一个完全自由的操作系统&#xff01;Debian 有一个由普罗大众组成的社区&#xff01;该文档适用于在VMware workstation平台安装debian-12.1.0虚拟机。 1.安装准备 1.1安装平台 Windows 11 1.2软件信息 软…

【野火i.MX6NULL开发板】挂载 NFS 网络文件系统

0、前言 参考资料&#xff1a; &#xff08;误人子弟&#xff09;《野火 Linux 基础与应用开发实战指南基于 i.MX6ULL 系列》PDF 第22章 参考视频&#xff1a;&#xff08;成功&#xff09; https://www.bilibili.com/video/BV1JK4y1t7io?p26&vd_sourcefb8dcae0aee3f1aab…

HarmonyOS4.0——ArkUI应用说明

一、ArkUI框架简介 ArkUI开发框架是方舟开发框架的简称&#xff0c;它是一套构建 HarmonyOS / OpenHarmony 应用界面的声明式UI开发框架&#xff0c;它使用极简的UI信息语法、丰富的UI组件以及实时界面语言工具&#xff0c;帮助开发者提升应用界面开发效率 30%&#xff0c;开发…