【深度学习--RNN 循环神经网络--附LSTM情感文本分类】

deep learning 系列 --RNN 循环神经网络

什么是序列模型

包括了RNN LSTM GRU等网络模型,主要用途是自然语言处理、语音识别等方面,比如生成乐曲,音频转换为文字,文本情感分类,机器翻译等等

标准模型的缺陷

以往的标准模型比如CNN,每次的输入不影响下次的输出,也就是说每次输入的图片都是独立的,没有任何关联,但是很多情况下,我们建立的模型与前项甚至后项的输入是相关的。举个例子,我们要从这两句话中识别出人名:

President Teddy was …
Teddy bear was…

这其中都有一个关键且同样的词Teddy,但是这个Teddy可以是个人名,也可以是泰迪,究竟是哪个呢?通常的识别方法是:

1.从之前学习过的人名识别

我们之前的位置可能提取过Teddy是人名, 但CNN网络,并不能共享从不同位置提取到到的特征,因此不可行

2.从本次输入的上下文出发

比如这里下文有 bear,President,但是CNN也不具备这种序列特性,因此也不可行

另外,CNN的网络对于输入的模型的数据长度都是固定的,但是不同的句子长度并不一致,当然我们可以padding,但是不那么好。
因此才有了循环神经网络架构,它可以克服以上的这些问题

循环神经网络

架构

在这里插入图片描述
图示是循环神经网络架构, 是有循环的,这个图形是最长用于展示的,但实际并不好理解,这个环实际上是多次的输入和输出,下面这种展开的方式更容易理解
在这里插入图片描述
RNN的结构就包括三类输入的权重参数
1.激活值的参数,水平方向的值,每个时间步的激活参数是相同的
2.输入到隐藏层的参数,也就是x到A方向,这个也是每个时间步相同
3.用来预测输出的参数,A到h的方向
在这里插入图片描述
其中y的输出可以用如下表示:
在这里插入图片描述

在这里插入图片描述
预测值y<t>包括了激活值a<t>,而a<t>包括了x<t-1>(注意此处t表示时间步),也就是说每一次的预测输出不仅包括了本次x<t>也包括了上个时间步x<t-1>,依次类推,此次时间步的输出包括了之前所有输入

BPTT 反向传播

实际上在工具中,反向传播都是自动进行的,原理跟普通的反向传播一致
输入序列后,假设预测的值是0.9,但实际是1,产生损失,这个损失可以用交叉熵损失函数来估量
RNN中的反向传播时将之前所有的箭头都反过来,计算出合适的变量,通过导数相关的计算,利用梯度下降算法更新参数,也就是图示:
在这里插入图片描述
这个传播过程中最重要的就是水平方向时间步的反向,因此又叫做穿越时间的反向传播Back Propagate Through time

不同类型的循环神经网络

  • 多对多
    比如机器翻译,输入是多个,输出也是多个,且并不对等,此时经常使用的是encode-decoder,encoder编码器获取输入,并输出,而decoder则使用encode编码的输出,执行decoder,这样输入x和输出y的长度就可以不相同了
    在这里插入图片描述
    当然输入和输出对等的情况也很多
    比如在句子中寻找人名,预测输出y就可以是每个x的对应的每个位置输出y(0表示非人名,1表示是人名)

  • 多对一
    在这里插入图片描述
    比如进行文本的情感分类,这样输入就有可能是一段文本,而输出我们只需要最后一层最后一个时间步的输出即可,一个0和1就足以标识这段文本是positive还是negative

  • 一对多
    在这里插入图片描述
    比如生成类的,输入一个音符或者不输入,就可以产生多个输出

RNN的缺陷 梯度消失和梯度爆炸

  • 现象:

    实际上深度比较大的网络都可能梯度消失或者爆炸,这种现象在在RNN中更加明显
    当我们输入的序列为1000的时候,拿最简单的模型举例 y = wx 经过1000次的传播,y1000 的变化
    在这里插入图片描述
    w仅仅变化一点,经过1000次的传播,变化非常的大

  • 原因:

    发生这样的根本原因是RNN中每一次的输出都将被前面的数据彻底的清洗,而经过长时间的传输,很前面步的影响都后面的影响已经很微弱了,损失的反向调整同样也是,经过长距离的调整,差错已经很难反馈很多个时间步之后了

  • 解决办法:

    梯度爆炸:进行梯度裁剪即可,比如我们发现输出有很多超大的值的时候,进行裁剪
    梯度消失:它很难察觉,也在标准的RNN结构中,可以用GRU或者LSTM解决,而解决梯度消失的实质是通过保留一些前期输入的记忆

LSTM

标准的RNN结构是这样的:
在这里插入图片描述
而标准的LSTM加入了四个门控单元
在这里插入图片描述
这四个门单元分别是:
it, ft,gt, ot 分别是输入门、遗忘门、cell(记忆)门,以及输出门
他们控制 是否输入,对输入的遗忘和记忆,也控制是否输出,从而控制重要信息的传递,不重要信息遗忘

GRU

它比LSTM诞生更晚,是LSTM的变形版本,由于门单元更少,计算简单些,因此训练时间更短一些。
在这里插入图片描述

LSTM实例

本实例以IMDB数据集为例,代码篇幅过长,本文仅列示其中LSTM使用相关的重点,后续会有专门的博客详细解析代码。

  1. 预处理数据
    读取IMDB数据集
def read_imdb(datafolder ='train', dataroot=imdb_zip_path):data=[]for label in ['pos', 'neg']:filepath = os.path.join(imdb_zip_path, datafolder, label)for file in tqdm(os.listdir(filepath)):with open(os.path.join(filepath,file), 'rb') as f:content = f.read().decode('utf-8').replace('\n', ' ').lower()data.append([content, 1 if label == 'pos' else 0])random.shuffle(data)return data

IMDB中的数据分词

def get_tokenized(data):def tokenizer(text):filters = ['!', '"', '#', '$', '%', '&', '\(', '\)', '\*', '\+', ',', '-', '\.', '/', ':', ';', '<', '=', '>','\?', '@', '\[', '\\', '\]', '^', '_', '`', '\{', '\|', '\}', '~', '\t', '\n', '\x97', '\x96', '”', '“', ]text = re.sub("<.*?>", " ", text, flags=re.S)text = re.sub("|".join(filters), " ", text, flags=re.S)return [i.strip().lower() for i in text.split()]return [tokenizer(context) for context, _ in data]

创建分词后的词典

def get_vocab(data):counter = collections.Counter(_flatten(data))return vocab.vocab(counter)

封装dataloader

class ImdbLoader(object):def __init__(self, set_name='train', batch_size='64'):super(ImdbLoader, self).__init__()self.data_set = set_nameself.batch_size = batch_sizedef get_data_loader(self):# train_data = [['"dick tracy" is one of our"', 1],#               ['arguably this is a  the )', 1],#               ["i don't  just to warn anyone ", 0]]train_data = read_imdb(self.data_set)data = preprocess(train_data)#print(data)data_set = Data.TensorDataset(*data)data_loader = Data.DataLoader(data_set, self.batch_size, shuffle=True)return data_loader
  1. 创建模型
    此处创建了一个模型,包括双向的LSTM层和一个全连接层
 class BiRNN(nn.Module):def __init__(self, vocabulary, embed_len, hidden_len, num_layer):super(BiRNN, self).__init__()self.embedding = nn.Embedding(len(vocabulary), embed_len)self.encoder = nn.LSTM(input_size=embed_len,hidden_size=hidden_len,num_layers=num_layer,bidirectional=True,dropout = 0.3)# 本次使用起始和最终时间步的隐藏状态座位全连接层的输入self.decoder = nn.Linear(2*2*hidden_len, 2)def forward(self, inputs):#print('rnn model py: input_shape: ', inputs.shape)embeddings = self.embedding(inputs)glove_vab = getGlove()net.embedding.weight.data.copy_(load_pretrained_embedding(vo.get_itos(), glove_vab))net.embedding.weight.requires_grad = False#print('after embed input shape:', embeddings.shape)embeddings = embeddings.permute(1, 0, 2)output_sequence, _ = self.encoder(embeddings)concat_out = torch.cat((output_sequence[0], output_sequence[-1]), -1)outputs = self.decoder(concat_out)return outputs

3.模型训练

   def train(epoch, imdb_model, lr, train_batch_size):imdb_model_device = imdb_model.to(device)# 过滤掉不需要计算梯度的embedding的参数optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, imdb_model_device.parameters()), lr=lr)loader = ImdbLoader('train', train_batch_size)data_loader = loader.get_data_loader()for i in range(epoch):for idx, (inputs, target) in enumerate(data_loader):target = target.to(device)inputs = inputs.to(device)#print('train.py input shape:', inputs.shape)optimizer.zero_grad()output = imdb_model(inputs)#print('ouput.shape', output.shape)criterion = nn.CrossEntropyLoss()loss = criterion(output, target)loss.backward()optimizer.step()if idx % 10 == 0:predict = torch.max(output, dim=-1, keepdim=False)[-1]acc = predict.eq(target.data).cpu().numpy().mean() * 100print('train Epoch:{} processed:[{} / {} ({:.0f}%) Loss: {:.6f}, ACC: {:.6f}]'.format(i,idx * len(inputs),len(data_loader.dataset),100. * idx / len(data_loader),loss.item(),acc))torch.save(imdb_model.state_dict(), '../../resources/model_save/imdb_net.pkl')torch.save(optimizer.state_dict(), '../../resources/model_save/imdb_optimizer.pkl')

4.模型评估

 def test(imdb_model, test_batch_size):imdb_model.eval()imdb_model = imdb_model.to(device)loader = ImdbLoader('test', test_batch_size)data_loader = loader.get_data_loader()with torch.no_grad():for idx, (inputs, target) in enumerate(data_loader):target = target.to(device)inputs = inputs.to(device)#print(inputs)output = imdb_model(inputs)criterion = nn.CrossEntropyLoss()loss = criterion(output, target)predict = torch.max(output, dim=-1, keepdim=False)[-1]correct = predict.eq(target.data).sum()acc = 100. * predict.eq(target.data).cpu().numpy().mean()print('idx: {} loss : {}, accurate: {}/{} {:.2f}'.format(idx,  loss, correct, target.size(0), acc))

最终效果,当我们执行4个epoch后,准确率基本稳定在80%以上

train Epoch:3 processed:[23680 / 25000 (95%) Loss: 0.325240, ACC: 84.375000]
train Epoch:3 processed:[24320 / 25000 (97%) Loss: 0.449456, ACC: 75.000000]
train Epoch:3 processed:[15600 / 25000 (100%) Loss: 0.438567, ACC: 80.000000]
train Epoch:4 processed:[0 / 25000 (0%) Loss: 0.353131, ACC: 85.937500]
train Epoch:4 processed:[640 / 25000 (3%) Loss: 0.345814, ACC: 89.062500]
train Epoch:4 processed:[1280 / 25000 (5%) Loss: 0.195520, ACC: 93.750000]
train Epoch:4 processed:[1920 / 25000 (8%) Loss: 0.269773, ACC: 87.500000]
train Epoch:4 processed:[2560 / 25000 (10%) Loss: 0.287010, ACC: 85.937500]
train Epoch:4 processed:[3200 / 25000 (13%) Loss: 0.291449, ACC: 90.625000]

参考

colah https://colah.github.io/posts/2015-08-Understanding-LSTMs/
吴恩达 deep learning

https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/93120.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于ArcGis提取道路中心线

基于ArcGis提取道路中心线 文章目录 基于ArcGis提取道路中心线前言一、生成缓冲区二、导出栅格数据三、导入栅格数据四、新建中心线要素五、生成中心线总结 前言 最近遇到一个问题&#xff0c;根据道路SHP数据生成模型的时候由于下载的道路数据杂项数据很多&#xff0c;所以导…

windows server 2016 搭建使用 svn 服务器教程

参考教程&#xff1a; https://zhuanlan.zhihu.com/p/428552058 https://blog.csdn.net/weixin_33897722/article/details/85602029 配置环境 windows server 2016 远程服务器公网 ip 安装 SVN 服务端 下载 svn 服务端安装包&#xff1a;https://www.visualsvn.com/download…

Python Web框架:Django、Flask和FastAPI巅峰对决

今天&#xff0c;我们将深入探讨Python Web框架的三巨头&#xff1a;Django、Flask和FastAPI。无论你是Python小白还是老司机&#xff0c;本文都会为你解惑&#xff0c;带你领略这三者的魅力。废话不多说&#xff0c;让我们开始这场终极对比&#xff01; Django&#xff1a;百…

Vue 项目运行 npm install 时,卡在 sill idealTree buildDeps 没有反应

解决方法&#xff1a;切换到淘宝镜像。 以下是之前安装的 xmzs 包&#xff0c;用于控制切换淘宝镜像。 该截图是之前其他项目切换淘宝镜像的截图。 切换镜像后&#xff0c;顺利执行 npm install 。

基于令牌级 BERT 嵌入的趋势生成句子级嵌入

一、说明 句子&#xff08;短语或段落&#xff09;级别嵌入通常用作许多 NLP 分类问题的输入&#xff0c;例如&#xff0c;在垃圾邮件检测和问答 &#xff08;QA&#xff09; 系统中。在我上一篇文章发现不同级别的BERT嵌入的趋势中&#xff0c;我讨论了如何生成一个向量表示&a…

docker安装国产开源数据库tidb 单机版

docker pull pingcap/tidb 创建目录&#xff0c;然后安装tidb mkdir -p /data/tidb/data 创建容器 docker run --name tidb -d -v /data/tidb/data:/tmp/tidb --privilegedtrue -p 4000:4000 -p 10080:10080 pingcap/tidb:latest TiDB 简介 | PingCAP 文档中心

LabVIEW使用图像处理进行交通控制性能分析

LabVIEW使用图像处理进行交通控制性能分析 采用普雷维特、拉普拉斯、索贝尔和任意的空间域方法对存储的图像进行边缘检测&#xff0c;并获取实时图像。然而&#xff0c;对四种不同空间域边缘检测方法的核的性能分析。 以前&#xff0c;空路图像存储在数据库中&#xff0c;道路…

Python 使用Hadoop 3 之HDFS 总结

Hadoop 概述 Hadoop 是一个由Apache 软件基金会开发的分布式基础架构。用户可以在不了解分布式底层细节的情况下&#xff0c;开发分布式程序&#xff0c;充分利用集群的威力进行高速运算和存储。 Hadoop 实现一个分布式文件系统&#xff08;Hadoop Distributed File Sy…

电脑键盘打不了字按哪个键恢复?最新分享!

“有没有朋友知道电脑键盘为什么会莫名其妙就打不了字&#xff1f;明明用得好好的&#xff0c;突然就打不了字了&#xff0c;真的让人很迷惑&#xff01;有什么方法可以解决吗&#xff1f;” 电脑键盘为我们的办公提供了很大的方便&#xff0c;我们可以利用键盘输入我们需要的文…

Redis对象和五种常用数据类型

Redisobject 对象 对象分为键对象和值对象 键对象一般是string类型 值对象可以是string&#xff0c;list&#xff0c;set,zset,hash q&#xff1a;redisobj的结构 typedef struct redisObject { //类型 unsigned type:4; //编码 unsigned encoding:4; //指向底层实现…

diffusion model (七) diffusion model是一个zero-shot 分类器

Paper: Your Diffusion Model is Secretly a Zero-Shot Classifier Website: diffusion-classifier.github.io/ 文章目录 相关阅读背景方法大意diffusion model的背景知识如何将diffusion model应用到zero-shot classification如何求解 实验参考文献 相关阅读 diffusion mode…

数学建模之“层次分析法”原理和代码详解

一、层次分析法简介 层次分析法&#xff08;Analytic Hierarchy Process&#xff0c;AHP&#xff09;是一种用于多准则决策分析和评估问题的定量方法&#xff0c;常用于数学建模中。它是由数学家托马斯赛蒂&#xff08;Thomas Saaty&#xff09;开发的。 层次分析法将复杂的决…

【k8s、云原生】基于metrics-server弹性伸缩

第四阶段 时 间&#xff1a;2023年8月17日 参加人&#xff1a;全班人员 内 容&#xff1a; 基于metrics-server弹性伸缩 目录 一、Kubernetes部署方式 &#xff08;一&#xff09;minikube &#xff08;二&#xff09;二进制包 &#xff08;三&#xff09;Kubeadm 二…

【数据结构】二叉搜索树

&#x1f680; 作者简介&#xff1a;一名在后端领域学习&#xff0c;并渴望能够学有所成的追梦人。 &#x1f40c; 个人主页&#xff1a;蜗牛牛啊 &#x1f525; 系列专栏&#xff1a;&#x1f6f9;数据结构、&#x1f6f4;C &#x1f4d5; 学习格言&#xff1a;博观而约取&…

近 2000 台 Citrix NetScaler 服务器遭到破坏

Bleeping Computer 网站披露在某次大规模网络攻击活动中&#xff0c;一名攻击者利用被追踪为 CVE-2023-3519 的高危远程代码执行漏洞&#xff0c;入侵了近 2000 台 Citrix NetScaler 服务器。 研究人员表示在管理员安装漏洞补丁之前已经有 1200 多台服务器被设置了后门&#x…

shell之正则表达式及三剑客grep命令

一、正则表达式概述 什么是正则表达式&#xff1f; 正则表达式是一种描述字符串匹配规则的重要工具 1、正则表达式定义: 正则表达式&#xff0c;又称正规表达式、常规表达式 使用字符串描述、匹配一系列符合某个规则的字符串 正则表达式 普通字符&#xff1a; 大小写字母…

【网络编程系列】网络编程实战

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kuan 的首页,持续学…

什么是CSS预处理器?请列举几个常见的CSS预处理器。

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ CSS预处理器是什么&#xff1f;⭐ 常见的CSS预处理器⭐ 写在最后 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端之旅 欢迎来到前端入门之旅&#xff01;这个专栏是…

《安富莱嵌入式周报》第320期:键盘敲击声解码, 军工级boot设计,开源CNC运动控制器,C语言设计笔记,开源GPS车辆跟踪器,一键生成RTOS任务链表

周报汇总地址&#xff1a;嵌入式周报 - uCOS & uCGUI & emWin & embOS & TouchGFX & ThreadX - 硬汉嵌入式论坛 - Powered by Discuz! 视频版&#xff1a; https://www.bilibili.com/video/BV1Cr4y1d7Mp/ 《安富莱嵌入式周报》第320期&#xff1a;键盘敲击…

【STM32】FreeRTOS互斥量学习

互斥量&#xff08;Mutex&#xff09; 互斥量又称互斥信号量&#xff08;本质也是一种信号量&#xff0c;不具备传递数据功能&#xff09;&#xff0c;是一种特殊的二值信号量&#xff0c;它和信号量不同的是&#xff0c;它支持互斥量所有权、递归访问以及防止优先级翻转的特性…