文本分类-RNN-LSTM

1.前言

        本节介绍RNN和LSTM,并采用它们在电影评论数据集上实现文本分类,会涉及以下几个知识点。

        1. 词表构建:包括数据清洗,词频统计,词频截断,词表构建。

        2. 预训练词向量应用:下载并加载Glove的预训练embedding进行训练,主要是如何把词向量放到nn.embedding层中的权重。

        3. RNN及LSTM构建:涉及nn.RNN和nn.LSTM的使用。

2.任务介绍

        本节采用的数据集是斯坦福大学的大型电影评论数据集(large movie review dataset) https://ai.stanford.edu/~amaas/data/sentiment/

        包含25000个训练样本,25000个测试样本,下载解压后得到aclImdb文件夹,aclImdb下有train和test,neg和pos下分别 有txt文件,txt中为电影评论文本。

         来看看一条具体的样本,train/pos/3_10.txt:

        本节任务就是对这样的一条文本进行处理,输出积极/消极的二分类概率向量。

3.数据模块

        文本任务与图像任务不同,输入不再是像素这样的数值,而是字符串,因此需要将字符串转为矩阵运算可接受的向量形 式。

         为此需要在数据处理模块完成以下步骤:

        a.分词:将一长串文本切分为一个个独立语义的词,英文可用空格来切分。

        b. 词嵌入:词嵌入通常分两步。首先将词字符串转为索引序号,然后索引序号根据词嵌入矩阵(embedding层)取对应的向量。其中词与索引之间的映射关系需要提前构建,这就是词表构建的过程。

        因此,代码开发整体流程:

        1. 编写分词功能函数

        2. 构建词表:对训练数据进行分词,统计词频,并构建词表。例如{'UNK': 0, 'PAD': 1, 'the': 2, '.': 3, 'and': 4, 'a': 5, 'of': 6, 'to': 7, ...}

        3. 编写PyTorch的Dataset,实现分词、词转序号、长度填充/截断序号转词向量的过程由模型的nn.Embedding层实现,因此数据模块只需将词变为索引序号即可,接下来一一解析各环节核心功能代码实现。

        序号转词向量的过程由模型的nn.Embedding层实现,因此数据模块只需将词变为索引序号即可,接下来一一解析各环节核心功能代码实现。

4.词表构建

        参考配套代码a_gen_vocabulary.py,首先编写分词功能函数,分词前做一些简单的数据清洗,例如在标点符号前加入空 格、去除掉不是大小写字母及 .!? 符号的数据。

        接着,写一个词表统计类实现词频统计,和词表字典的创建,代码注释非常详细,这里不赘述。 运行代码,即可完成词频统计,词表的构建,并保存到本地npy文件,在训练及推理过程中使用。

        在词表构建过程中有一个截断数量的超参数需要设置,这里设置为20000,即最多有20000个词的表示,不在字典中的词被归为UNK这个词。

         在这个数据集中,原始词表长度为74952,即通过split切分后,有7万多个不一样的字符串,通常可以通过降序排列,取前面一部分即可。

        代码会输出词频统计图,也可以观察出词频下降的速度以及高频词是哪些。

5.Dataset编写

        参考配套代码aclImdb_dataset.py,getitem中主要做两件事,首先获取label,然后获取文本预处理后的列表,列表中元素是词所对应的索引序号。

        在self.word2index.encode中需要注意设置文本最大长度self.max_len,这是由于需要将所有文本处理到相同长度,长度不足的用词填充,长度超出则截断。

6.模型模块——RNN

        模型的构建相对简单,理论知识在这里不介绍,需要了解和温习的推荐看看《动手学》。这里借助动手学的RNN图片讲解代码的实现。

        在构建的模型RNNTextClassifier中,需要三个子module,分别是:

                1. nn.Embedding:将词序号变为词向量,用于后续矩阵运算

                2. nn.RNN:循环神经网络的实现

                3. nn.Linear:最终分类输出层的实现

        在forward时,流程如下:

                1. 获取词向量

                2. 构建初始化隐藏层,默认为全0

                3. rnn推理获得输出层和隐藏层

                4. fc层输出分类概率:fc层的输入是rnn最后一个隐藏层

        更多关于nn.RNN的参数设置,可以参考官方文档:

        torch.nn.RNN(self, input_size, hidden_size, num_layers=1, nonlinearity='tanh', bias=True, batch_first=False, dropout=0.0, bidirectional=False, device=None, dtype=None)

7.模型模块——LSTM

        RNN是神经网络中处理时序任务最为经典的设计,但是其也存在一些缺点,例如梯度消失和梯度爆炸,以及长期依赖问 题。

        当序列很长时,RNN模型很难捕捉到远距离的依赖关系,导致模型预测不准确。

        为此,带门控机制的RNN涌现,包括GRU(Gated Recurrent Unit,门控循环单元)和LSTM(Long Short-Term Memory,长短期记忆网络),其中LSTM应用最广,这里直接跳过GRU。         LSTM模型引入了三个门(input gate、forget gate和output gate),用于控制输入、输出和遗忘的流动,允许模型有选择性地忘记或记住一些信息。

        input gate用于控制输入的流动

        forget gate用于控制遗忘的流动

        output gate用于控制输出的流动

        相较于RNN,除了输出隐藏层向量h,还输出记忆层向量c,不过对于下游使用,不需要关心向量c的存在。 同样地,借助《动手学》中的LSTM示意图来理解代码。

        在这里,借鉴《动手学》的代码,采用的LSTM为双向LSTM,这里简单介绍双向循环神经网络的概念。

         双向循环神经网络(Bidirectional Recurrent Neural Network,Bi-RNN)同时考虑前向和后向的上下文信息,前向层和后向层的输出在每个时间步骤上都被连接起来,形成了一个综合的输出,这样可以更好地捕捉序列中的上下文信息。

        在pytorch代码中,只需要将bidirectional设置为True即可,

        nn.LSTM(embed_size, num_hiddens, num_layers=num_layers, bidirectional=True)。

        当采用双向时,需要注意output矩阵的shape为 [ sequence length , batch size ,2×hidden size]

        更多关于nn.LSTM的参数设置,可以参考官方文档:torch.nn.LSTM(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device=None, dtype=None)

        详细参考:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM

8.embedding预训练加载

        模型构建好之后,词向量的embedding层是随机初始化的,要从头训练具备一定逻辑关系的词向量表示是费时费力的, 通常可以采用在大规模预料上训练好的词向量矩阵。

        这里可以参考斯坦福大学的GloVe(Global Vectors for Word Representation)预训练词向量。

        GloVe是一种无监督学习算法,用于获取单词的向量表示,GloVe预训练词向量可以有效地捕捉单词之间的语义关系,被广泛应用于自然语言处理领域的各种任务,例如文本分类、命名实体识别和机器翻译等。

        Glove有四大类,根据数据量不同进行区分,相同数据下又根据向量长度分

        a.Wikipedia 2014 + Gigaword 5 (6B tokens, 400K vocab, uncased, 50d, 100d, 200d, & 300d vectors, 822 MB download): glove.6B.zip

        b.Common Crawl (42B tokens, 1.9M vocab, uncased, 300d vectors, 1.75 GB download): glove.42B.300d.zip

        c.Common Crawl (840B tokens, 2.2M vocab, cased, 300d vectors, 2.03 GB download): glove.840B.300d.zip

        d.Twitter (2B tweets, 27B tokens, 1.2M vocab, uncased, 25d, 50d, 100d, & 200d vectors, 1.42 GB download): glove.twitter.27B.zip

         在这里,采用Wikipedia 2014 + Gigaword 5 中的100d,即词向量长度为100,向量的token数量有6B。

        下载好的GloVe词向量矩阵是一个txt文件,一行是一个词和词向量,中间用空格隔开,因此加载该预训练词向量矩阵可以这样。

        原始GloVe预训练词向量有40万个词,在这里只关心词表中有的词,因此可以在加载字典时加一行过滤,即在词表中的词,才去获取它的词向量。

        在本案例中,词表大小是2万,根据匹配,只有19720个词在GloVe中找到了词向量,其余的词向量就需要随机初始化。

        获取GloVe预训练词向量字典后,需要把词向量放到embedding层中的矩阵,对弈embedding层来说,一行是一个词的词向量,因此通过词表的序号找到对应的行,然后把预训练词向量放进去即可,代码如下:

9.训练及实验记录

        准备好了数据和模型,接下来按照常规模型训练即可。

        这里将会做一些对比实验,包括模型对比:

         a.RNN vs LSTM

        b.有预训练词向量 vs 无预训练词向量

       c. 冻结预训练词向量 vs 放开预训练词向量

        具体指令如下,推荐放到bash文件中,一次性跑

        实验结果如下所示:

        1. RNN整体不work,经过分析发现设置的文本token长度太长,导致RNN梯度消失,以至于无法训练。调整 text_max_len为50后,train acc=0.8+, val=0.62,整体效果较差。

         2. 有了预训练词向量要比没有预训练词向量高出10多个点。

         3. 放开词向量训练,效果会好一些,但是不明显。

        补充实验:将RNN模型的文本最长token数量设置为50,其余保持不变,得到的三种embedding方式的结果如下:

        结论:

        1. LSTM较RNN在长文本处理上效果更好

        2. 预训练词向量在小样本数据集上很关键,有10多个点的提升

        3. 放开与冻结embedding层训练,效果差不多

10.小结

        本小节通过电影影评数据集实现文本分类任务,通过该任务可以了解:

        1. 文本预处理机制:包括清洗、分词、词频统计、词表构建、词表截断、UNK与PAD特殊词设定等。

        2. 预训练词向量使用:包括GloVe的下载及加载、nn.embedding层的设置 。

        3. RNN系列网络模型使用:大致了解循环神经网络的输入/输出是如何构建,如何配合fc层实现文本分类。

         4. RNN可接收的文本长度有限:文本过长,导致梯度消失,文本过短,导致无法捕获更多文本信息,因此推荐采用 LSTM等门控机制的模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/362838.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue2 - 首页登录实现随机验证码组件的封装与实现详解(详细的注释及常见问题汇总)

在网站首页等登录时,随机验证码在现代网络应用中扮演着重要的安全角色。为了帮助开发者轻松集成和使用随机验证码功能,本文将介绍如何利用 Vue.js 2 封装一个简单而功能强大的随机验证码组件。让你能够快速理解并应用这一组件到你的项目中。 一、解决方案 本文提供了完美便捷…

上海计算机考研避雷,25考研慎报

上大计算机一直很热 408考研er重来没有让我失望过,现在上大的专业课是11408,按理说,这个专业课的难度是很高的,但是408er给卷出了新高度,大家可以去上大官网看看今年最新的数据,我也帮大家统计了24年最新的…

Redis集群(Clustering in Redis)工作机制详解

Redis集群工作机制详解 Redis 集群是用于提高 Redis 可扩展性和高可用性的解决方案。 维基百科:Scalability is the property of a system to handle a growing amount of work by adding resources to the system. 可扩展性是系统的一种允许通过增加系统资源来处…

《Windows API每日一练》6.4 程序测试

前面我们讨论了鼠标的一些基础知识,本节我们将通过一些实例来讲解鼠标消息的不同处理方式。 本节必须掌握的知识点: 第36练:鼠标击中测试1 第37练:鼠标击中测试2—增加键盘接口 第38练:鼠标击中测试3—子窗口 第39练&…

Linux Static calls机制

文章目录 前言一、简介二、Background: indirect calls, Spectre, and retpolines2.1 Indirect calls2.2 Spectre (v2)2.3 RetpolinesConsequences 2.4 Static callsHow it works 三、其他参考资料 前言 Linux内核5.10内核版本引入新特性:Static calls。 Static c…

计算机毕业设计hadoop+spark+hive知识图谱医生推荐系统 医生数据分析可视化大屏 医生爬虫 医疗可视化 医生大数据 机器学习 大数据毕业设计

测试过程及结果 本次对于医生推荐系统测试通过手动测试的方式共进行了两轮测试。 (1)第一轮测试中执行了个20个测试用例,通过16个,失败4个,其中属于严重缺陷的1个,属于一般缺陷的3个。 (2&am…

Spark SQL 的总体工作流程

Spark SQL 是 Apache Spark 的一个模块,它提供了处理结构化和半结构化数据的能力。通过 Spark SQL,用户可以使用 SQL 语言或 DataFrame API 来执行数据查询和分析。这个模块允许开发者将 SQL 查询与 Spark 的数据处理能力结合起来,实现高效、优化的数据处理。下面是 Spark S…

Spring Boot中实现定时任务最常用的方法 @Scheduled 注解和 TaskScheduler 接口【包含详情代码】

Spring Boot中实现定时任务最常用的方法 Scheduled 注解和 TaskScheduler 接口【包含详情代码】 学习总结 1、掌握 JAVA入门到进阶知识(持续写作中……) 2、学会Oracle数据库入门到入土用法(创作中……) 3、手把手教你开发炫酷的vbs脚本制作(完善中………

CogMG:用大模型解决知识图谱覆盖不足的问题

CogMG:用大模型解决知识图谱覆盖不足的问题 提出背景知识图谱的作用知识覆盖不完整知识更新不对齐 显式分解知识三元组和补全检索增强生成(RAG)和知识更新 框架设计1. 查询知识图谱2. 处理结果3. 知识图谱演化 CogMG 实现3.1 模型和组件问题分…

.NET 漏洞分析 | 某ERP系统存在SQL注入

01阅读须知 此文所提供的信息只为网络安全人员对自己所负责的网站、服务器等(包括但不限于)进行检测或维护参考,未经授权请勿利用文章中的技术资料对任何计算机系统进行入侵操作。利用此文所提供的信息而造成的直接或间接后果和损失&#xf…

c++智能指针shared_ptr

文章目录 概念1.shared_ptr1.基本使用2.如何获取原始指针3. 指定删除器 2 使用shared_ptr要注意的问题2.1不要用一个原始指针初始化多个shared_ptr2.2. 避免循环引用 小结 概念 C程序设计中使用堆内存是非常频繁的操作,堆内存的申请和释放都由程序员自己管理。内存…

安装 Docker 环境(通过云平台创建一个实例实现)

目录 1. 删除原有 yum 2. 手动配置 yum 源 3. 删除防火墙规则 4. 保存防火墙配置 5. 修改系统内核。打开内核转发功能。 6. 安装 Docker 7. 设置本地镜像仓库 8.重启服务 1. 删除原有 yum rm -rfv /etc/yum.repos.d/* 2. 手动配置 yum 源 使用 centos7-1511.iso 和 Xi…

Python 语法基础二

7.常用内置函数 执行这个命令可以查看所有内置函数和内置对象(两个下划线) >>>dir(__builtins__) [__class__, __contains__, __delattr__, __delitem__, __dir__, __doc__, __eq__, __format__, __ge__, __getattribute__, __getitem__, __gt…

深入剖析 Android 网络开源库 Retrofit 的源码详解

文章目录 概述一、Retrofit 简介Android主流网络请求库 二、Retrofit 源码剖析1. Retrofit 网络请求过程2. Retrofit 实例构建2.1 Retrofit.java2.2 Retrofit.Builder()2.2.1 Platform.get()2.2.2 Android 平台 2.3 Retrofit.Builder().baseUrl()2.4 Retrofit.Builder.client()…

OpenAI穿着「皇帝的新衣」;扒了数万条帖子汇总100种AIGC玩法;北美出海的财务避坑指南;我创业「如」有CTO | ShowMeAI日报

👀日报&周刊合集 | 🎡生产力工具与行业应用大全 | 🧡 点赞关注评论拜托啦! 1. 我扒了 Reddit 论坛数万条帖子,汇总了 GenAI 的 100 种玩法 ChatGPT 已经问世一年半了。这期间诞生了很多大语言模型和生成式人工智能…

备份和还原

stai和dnta snat:源地址转换 内网---外网 内网ip转换成可以访问外网的ip 内网的多个主机可以使用一个有效的公网ip地址访问外部网络 DNAT:目的地址转发 外部用户,可以通过一个公网地址访问服务内部的私网服务。 私网的ip和公网ip做一个…

【JavaEE进阶】Spring AOP使用篇

目录 1.AOP概述 2.SpringAOP快速入门 2.1 引入AOP依赖 2.2 编写AOP程序 3. Spring AOP详解 3.1 Spring AOP 核心概念 3.1.1切点(Pointcut) 3.1.2 连接点 (Join Point) 3.1.3 通知(Advice) 3.1.4 切面(Aspect) 3.2 通知类型 3.3PointCut 3.4 切面优先级 3.5 切点表…

「51媒体」政企活动媒体宣发如何做?

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 媒体宣传加速季,100万补贴享不停,一手媒体资源,全国100城线下落地执行。详情请联系胡老师。 政企活动媒体宣发是一个系统性的过程,需要明确…

使用Scala爬取安居客房产信息并存入CSV文件

使用Scala爬取安居客房产信息并存入CSV文件 本篇博客中,我们将介绍如何使用Scala语言编写一个简单的程序,来爬取安居客(Anjuke)网站上的房产信息,并将这些信息存储到CSV文件中。这个示例将涵盖HTTP请求、HTML解析、数…

麒麟系统安装MySQL

搞了一整天,终于搞定了,记录一下。 一、背景 项目的原因,基于JeecgBoot开发的系统需要国产化支持,这就需要在电脑上安装MySQL等支撑软件。 国产化项目的操作系统多是麒麟系统,我的系统如下: arm64架构。…