李宏毅机器学习笔记:RNN循环神经网络

RNN

  • 一、RNN
    • 1、场景引入
    • 2、如何将一个单词表示成一个向量
    • 3种典型的RNN网络结构
  • 二、LSTM
    • LSTM和普通NN、RNN区别
  • 三、 RNN的训练
    • RNN与auto encoder和decoder
  • 四、RNN和结构学习的区别
  • 五、pytorch实现RNN与LSTM
  • 六、训练RNN时如何处理单步预测模型和多步预测问题

一、RNN

1、场景引入

在这里插入图片描述
例如情景补充的情况,根据词汇预测该词汇所属的类别。这个时候的Taipi则属于目的地。但是,在订票系统中,Taipi也可能会属于出发地。到底属于目的地,还是出发地,如果不结合上下文,则很难做出判断。因此,使用传统的深度神经网络解决不了问题,必须引入RNN。

2、如何将一个单词表示成一个向量

在这里插入图片描述
如上图所示,将词汇Taipi表示成[x1,x2]组成的向量。
在这里插入图片描述
一个最简单的方法是1-N encoding。思路是将所有的可能用到的词汇组成一个词典,然后假如我们一共只可能用到5个单词,则如上图所示,每个单词可以用1个五维向量来表示。
在这里插入图片描述
除了1-N econding之外,还有一些其他的方法。
第一种思路是设置1个other选项,将所有没有预先在词典中所设定的单词表示成other
第二种思路是利用26个字母进行hash映射。这种情况下则不需要额外考虑other的情况。
在这里插入图片描述
这样,将词汇向量化之后,我们指导,网络的输入为一个个的词汇向量,网络的输出则为:y1表示词汇属于dest目的地的概率,y2则表示词汇属于出发地的概率。最后其实应该还有一层,做出预测,属于哪个概率最大,则输出哪个。
在这里插入图片描述
这个时候,我们所构建的NN则是需要有记忆的,否则无法解决该问题。
在这里插入图片描述
因此,我们引入了RNN来解决该问题。将每次hidden layer的输出先储存到memory cell中,作为下个词汇向量的输入。不断循环该过程。
在这里插入图片描述
举例来说,我们输入的第一个向量为[1,1],则hidden layer的输出为[2,2],先被储存起来,输出为[4,4]。
在这里插入图片描述
第2个输入仍然为[1,1]。这个时候结合前一个memory的输出[2,2],hdden layer的输出为[6,6],output为[12,12]。
在这里插入图片描述
第3个输入为[2,2],结合前一个memory的输入为[6,6],这个时候hidden layer的输出为[16,16],output为[32,32]。
在这里插入图片描述
RNN的网络结构如上图所示,重复利用了同一种相同的网络结构。
在这里插入图片描述
每次储存在memory中的值并不相同。
在这里插入图片描述
当然,也可以把hidden layer的层数加深。

3种典型的RNN网络结构

在这里插入图片描述
Jordan Network和Elamn Network的区别在于是将每个output的值作为下一个的输入。右侧的网络结构可解释性更强。
在这里插入图片描述
双向RNN则更为全面,同时兼顾到了前后的上下文信息,而不仅仅是前面的信息。

二、LSTM

我们在实际过程中使用更多的则是LSTM。
在这里插入图片描述
LSTM实际上,是将RNN中hidden layer的输出存入memory cell的过程稍微复杂化了一些,使用了3个gate进行代替。input gate的作用是控制输入通过,forget gate的作用是控制对memory cell中的值是否进行清空。output gate的作用是控制是否将该memory cell的值输出。
在这里插入图片描述
每个门的激活函数都是sigmoid函数,因为这样恰好可以将输入值映射到(0,1)之间。0表示不允许通过,1表示可以通过。
这里额外说下,forget gate和直觉似乎有点相反。当 f ( z f ) = 1 f(z_{f})=1 f(zf)=1时,表示forget gate打开,但是 c f ( z f ) = 1 cf(z_{f})=1 cf(zf)=1,c表示前一个memory cell的值, c ′ c' c表示本次计算出来的值。这个时候,前一次计算出来的c的信息完全没有被forget。因此,forget gate打开时,不是表示forget,而是表示unforget。

在这里插入图片描述
举例来说,假如想设计一个LSTM网络,实现上面的功能。
当x2=1时,将x2的值写入到memory中。memory时最上面蓝色框的值。
当x2=-1时,将memory中的值进行reset。
当x3=1时,将memory中的值进行输出。
在这里插入图片描述
我们设计的NN结构如上图所示。输入乘的4个weight为[1,0,0,0]。input gate控制信号为输入与[0,100,0,-10]相乘,依次类推。
在这里插入图片描述
当输入为[3,1,0]时,input的值为3,input gate的值为1,multiply之后得到3.forget gate 的值为1,与前一个memory cell的值0相乘后再加3得到3,outputgate 的值为0,因此输出为0,memory cell的值更新为3,为本次运算的结果。
在这里插入图片描述
当输入为[4,1,0]时,input 的值为4,input gate=1,multiply之后得到4,forgat gate =1,与 C t − 1 = 3 C_{t-1}=3 Ct1=3相乘后+4=7,forget gate的值为0,因此output=0,memory cell更新为7.
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

LSTM和普通NN、RNN区别

在这里插入图片描述
前面已经讲述过,LSTM可以看作是将普通的hidden layer替代成由4个输入控制的cell。
在这里插入图片描述
将输入[x1,x2]分别乘上不同的matrix后输入,用于控制input ,input gate,forget gate,output gate。因此,LSTM网络结构的参数量是普通NN的4倍。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
这里,peephole,指的是,在实际LSTM网络结构设计中,会将前一时刻的memory cell的值ct,输出ht的值一并加入到下一时刻作为输入。
在这里插入图片描述
这里LSTM虽然看起来很复杂,但是在实际中往往这是最标准化的设计。我们可以借助工具来实现它。

三、 RNN的训练

在这里插入图片描述
如果需要train一个RNN,则必须首先定义好cost function。很显然,这里RNN的cost function为每个time step的输出和对应标签vector的cross entropy之和,也是我们需要minimize的函数。
在这里插入图片描述
使用的方法呢,叫做BNPP(Backpropagation through time),和一般的bp有细微的区别。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
为何会出现这种情况呢,我们可以分析原因。
在这里插入图片描述
其实问题的来源,就是在于长序列导致的梯度消失或爆炸。一个非常实用的方法则是使用LSTM。
LSTM可以解决梯度消失的问题,但不能解决梯度爆炸的问题。
在这里插入图片描述
为什么LSTM可以解决梯度消失的问题呢。因为对于LSTM来说,前面每一个timestep中的信息,只要forget gate没有关闭,便会一直累加到最后。而普通的RNN,只会保留上一个timestep的信息。
一般来说,再设计LSTM网络结构时,需要做到使得大多数情况下forget gate是开启的,仅在少部分情况下forget gate会关闭。
另外一种LSTM的变种结构叫做GRU,GRU区别于LSTM,仅有2个gate。核心思想为旧的不去,新的不来。LSTM中的input gate和forget gate相互拮抗,只有forget gate关闭时,input gate才会打开。forget gate打开时,input gate则会关闭。

RNN与auto encoder和decoder

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

四、RNN和结构学习的区别

在这里插入图片描述
(1)从考虑上下文情况来看,单向RNN仅考虑到前文的信息,没有考虑到后文的信息。HMM如果使用viterbi算法的话,则同时考虑了整个sequence的信息。这里来看,结构学习似乎更有优势,但是,双向RNN也可以同时考虑整个sequence信息。
(2)RNN的cost和error是直接相关的,而结构学习并不是。cost往往高于error。
(3)最大的一个区别在于RNN可以deep,而结构学习在deep上则没有优势。

五、pytorch实现RNN与LSTM

在这里插入图片描述
回忆一下我们最简单的Elman RNN网络结构。
看下官方文档如何实现:
在这里插入图片描述
h t h_{t} ht为当前时间t的hidden state。每一个time step,都会运行上面的公式。

在这里插入图片描述
可接受的输入参数如图。
(1)input_size,表示特征向量的维度,即输入的特征是多少维的。。
(2)hidden_size,表示hidden state的特征维度。
(3)num_layers,表示RNN的层数
(4)batch_first,这个参数决定入参是(seq,batch,feature)的形式还是(batch,seq,feature)的形式。RNN仅规定了这2种入参形式,为了迎合不同人的喜好,因此设定了这个参数作为模式切换。
注意,这里默认是batch_first=False的状态。
(5)bidirectional,决定了该RNN是否为双向RNN。
在这里插入图片描述
对于不使用batch的input来说,输入为 ( L , H i n ) (L,H_{in}) (L,Hin)。L为序列的长度, H i n H_{in} Hin为输入的特征维度
对于batch_first=True的来说,输入为 ( N , L , H i n ) (N,L,H_{in}) (N,L,Hin),N为batch_size。
对于batch_first=False来说,输入为 ( L , N , H i n ) (L,N,H_{in}) (L,N,Hin),N为batch_size。
h 0 h_{0} h0则表示初始的隐藏层状态。如果是双向RNN,则D=2,否则D=1。形状为 ( D ∗ n u m l a y e r s , H o u t ) (D*num_layers,H_{out}) (Dnumlayers,Hout)。如果使用了batch,则形状为 ( D ∗ n u m l a y e r s , N , H o u t ) (D*num_layers,N,H_{out}) (Dnumlayers,N,Hout)
在这里插入图片描述
输出为所有时间步的输出和最后一个时间步的隐层状态。
对于batch_first=True的情况来说,输出形状为 ( N , L , D ∗ H o u t ) (N,L,D*H_{out}) (N,L,DHout)
h n h_{n} hn为最后一个time step的隐藏层状态,输出形状为 ( D ∗ n u m l a y e r s , N , H o u t ) (D*num layers,N,H_{out}) (Dnumlayers,N,Hout)
在这里插入图片描述
官方给的例子如下:
建立一个RNN,输入的特征维度为10,hidden_state维度为20,num_layers=2。
输入为,seq_len=5,batch_size=3, H i n = 10 。 H_{in}=10。 Hin=10
初始隐藏层状态为num_layers=2,batch_size=3,H_out=20。
在这里插入图片描述
当然,这里估计有人就想问,output的维度 H o u t = 20 H_{out}=20 Hout=20,为什么会是20呢。别忘了,RNN的一个作用是可以进行分类。这里输出为20维的向量,通过torch.max()函数就可以获取预测的具体类别,也可以计算crossentropy
另外提下,如果要将原始的数据转换成具有batch_size的形式,直接使用tensor自带的reshape函数即可。
至于nn.lstm模块,和nn.RNN模块大同小异,几乎可以直接copy过来。

5.1为何 H o u t = h i d d e n s i z e H_{out}=hidden_size Hout=hiddensize?

这里在看官方文档时,注意到了一个问题。
在这里插入图片描述
这里,一般来说,nn.RNN的输出并不是我们最后所得到的分类或回归的1个值。而是隐藏层变量的输出,作为下一个time step的输入。所以2者是相等的。
num_layers表示RNN堆叠的层数。
在这里插入图片描述
例如说如图所示情况下,假如忽略中间的竖直的省略号,则num_layers=3。普通情况下,双向RNN则num_layers也等于2.

5.2 LSTM的实现

前面讲LSTM的实现与RNN大同小异,其实还是有些区别的。来重点看下区别。
在这里插入图片描述
在这里插入图片描述
前面的定义就不细讲,通过前面的内容,肯定是可以理解上述的这些公式的。

在这里插入图片描述
定义的参数和RNN几乎没有任何区别,input_size还是表示输入特征的维度,hidden_size表示隐藏层的维度,num_layers表示网络层数等。最后多了一个proj_size,一般也用不太上。
在这里插入图片描述
网络参数的输入,有些区别,多了初始的隐藏层状态h0和c0,在RNN中只有h0。
在这里插入图片描述
output相比也知识多了一个c_n。
在这里插入图片描述

六、训练RNN时如何处理单步预测模型和多步预测问题

单步预测,指的是每个timestep都有一个对应的输出。
多步预测,指的是只在seq的最后一个再有输出值。

class LSTMModel(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(LSTMModel, self).__init__()self.hidden_dim = hidden_dimself.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)  # 分别表示输入x的特征维度和中间隐藏层的特征维度self.linear = nn.Linear(hidden_dim, output_dim)def forward(self, input):  # batch_first=True。x严格按照batch_size,seq_len,feature的形式进行排列batch_size = input.size(0)seq_len = input.size(0)hidden_state = torch.zeros(1, batch_size, self.hidden_dim)cell_state = torch.zeros(1, batch_size, self.hidden_dim)output_seq, _ = self.lstm(input, (hidden_state, cell_state))  # 第2个为最后一个时间步的隐藏状态,output格式为[N,L,N*Hout]# last_hidden_state = output_seq[:, -1, :] #返回一个[N,1]的数据output = self.linear(output_seq)  # output size =[N,L,1]return output

区别就在于是否要注释中间那一行。在多步预测过程中,需要取1个batch_size中的每个seq_len的最后一个数值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/129100.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一个集成的BurpSuite漏洞探测插件1.1

免责声明 本文发布的工具和脚本,仅用作测试和学习研究,禁止用于商业用途,不能保证其合法性,准确性,完整性和有效性,请根据情况自行判断。如果任何单位或个人认为该项目的脚本可能涉嫌侵犯其权利&#xff0c…

Mysql数据库基础总结:

什么是数据库: 数据库(DataBase):存储和管理数据的一个仓库。 数据库类型分为:关系型数据库和非关系型数据库。 关系型数据库(SQL):存储的数据以行和列为格式,类似于e…

手写Mybatis

Mybatis核心配置文件就是为了配置Configration 因此要首先会解析Mybatis核心配置文件 首先使用dom4J解析Mybatis核心配置文件 新建模块演示dom4j解析.xml 目录放错了 无所谓 引入依赖 从原来项目可以拷贝过来 就些简单配置就好 解析核心配置文件和解析xxxMapper.xml映射文件…

vue学习之属性绑定

内容渲染 采用 &#xff1a;进行属性渲染创建 demo3.html,内容如下 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"&…

CHS零壹视频恢复程序OCR使用方法

目前CHS零壹视频恢复程序监控版、专业版、高级版已经支持了OCR&#xff0c;OCR是一种光学识别系统&#xff0c;通俗说就和扫描仪带的OCR软件一样的原理&#xff1a; 分析照片->OCR获取字符串->整理字符串->输出 使用方法如下&#xff08;以CHS零壹视频恢复程序监控版…

使用LlamaIndex构建自己的PandasAI

推荐&#xff1a;使用 NSDT场景编辑器 快速搭建3D应用场景 Pandas AI 是一个 Python 库&#xff0c;它利用生成 AI 的强大功能来增强流行的数据分析库 Pandas。只需一个简单的提示&#xff0c;Pandas AI 就可以让你执行复杂的数据清理、分析和可视化&#xff0c;而这以前需要很…

STL线程各种容器对比、数组和vector如何互相转换

STL vector如何扩展内存和释放内存STL中各种容器对比STL中的swap函数STL中哈希表扩容STL迭代器失效的情况和原因vector删除元素后如何避免当前迭代器会失效vector的iterator和const_iterator和const iterator vector如何扩展内存和释放内存 内存增长 1.5还是2倍扩容 gcc 二倍扩…

微信小程序ibeacon搜索功能制作

以下是一个完整的微信小程序代码示例&#xff0c;演示如何实现iBeacon搜索功能&#xff1a; // 在小程序页面中的js文件中编写代码Page({data: {beacons: [] // 存储搜索到的iBeacon设备信息},onReady() {// 初始化iBeaconwx.startBeaconDiscovery({uuids: [你的UUID], // 替换…

数据结构和算法(1):开始

算法概述 所谓算法&#xff0c;即特定计算模型下&#xff0c;旨在解决特定问题的指令序列 输入 待处理的信息&#xff08;问题&#xff09; 输出 经处理的信息&#xff08;答案&#xff09; 正确性 的确可以解决指定的问题 确定性 任一算法都可以描述为一个由基本操作组成的序…

用户促活留存新方式——在APP中嵌入小游戏

随着APP同类产品的不断出现&#xff0c;APP开发者们面临着激烈的竞争&#xff0c;很多APP下载后被新的APP取代&#xff0c;获客成本越来越高。同时开发者还会面临用户粘性差、忠诚度低、用完即走、留存困难&#xff0c;商业化价值被大大缩减。 在APP中植入小游戏来提高用户活跃…

Vue——vue3+element plus实现多选表格使用ajax发送id数组

代码来源: Vue 3结合element plus&#xff08;问题总结二&#xff09;之 table组件实现多选和清除选中&#xff08;在vue3中获取ref 的Dom&#xff09;_multipletableref.value.togglerowselection()打印出来的是u_子时不睡的博客-CSDN博客 前言 为了实现批量删除功能的功能…

【Python爬虫实战】爬虫封你ip就不会了?ip代理池安排上

前言 在进行网络爬取时&#xff0c;使用代理是经常遇到的问题。由于某些网站的限制&#xff0c;我们可能会被封禁或者频繁访问时会遇到访问速度变慢等问题。因此&#xff0c;我们需要使用代理池来避免这些问题。本文将为大家介绍如何使用IP代理池进行爬虫&#xff0c;并带有代…

C语言练习:输入日期输出该日期为当年第几天

用scanf()输入某年某月某日&#xff0c;判断这一天是这一年的第几天。以3月5日为例&#xff0c;应该先把前两个月的加起来&#xff0c;然后再加上5天即本年的第几天&#xff0c;特殊情况&#xff0c;闰年且输入月份≥3时需考虑多加一天。注&#xff1a;判断年份是否为闰年的方法…

【C刷题】day1

一、选择题 1.正确的输出结果是 int x5,y7; void swap() { int z; zx; xy; yz; } int main() { int x3,y8; swap(); printf("%d,%d\n"&#xff0c;x, y); return 0; } 【答案】&#xff1a; 3&#xff0c;8 【解析】&#xff1a; 考点&#xff1a; &#xff…

Matlab如何导入Excel数据并进行FFT变换

如果你发现某段信号里面有干扰&#xff0c;想要分析这段信号里面的频率成分&#xff0c;就可以使用matlab导入Excel数据后进行快速傅里叶变换&#xff08;fft&#xff09;。 先直接上使用方法&#xff0c;后面再补充理论知识。 可以通过串口将需要分析的数据发送到串口助手&a…

postgresql-窗口函数

postgresql-窗口函数 窗口函数简介窗口函数的定义分区排序选项窗口选项 窗口函数简介 包括 AVG、COUNT、MAX、MIN、SUM 以及 STRING_AGG。聚合函数的作用是针对一组数据行进行运算&#xff0c;并且返回一条汇总结果 分析的窗口函数&#xff08;Window Function&#xff09;。 …

投稿指南【NO.12_8】【极易投中】核心期刊投稿(组合机床与自动化加工技术)

近期有不少同学咨询投稿期刊的问题&#xff0c;大部分院校的研究生都有发学术论文的要求&#xff0c;少部分要求高的甚至需要SCI或者多篇核心期刊论文才可以毕业&#xff0c;但是核心期刊要求论文质量高且审稿周期长&#xff0c;所以本博客梳理一些计算机特别是人工智能相关的期…

单相并联下垂控原理

Part1 上述有个核心的piont是等效阻抗上的电压一般时很小的&#xff0c;这就导致逆变器输出电压矢量E和负载电压矢量UL之间的夹角很小 》基于上述的结论有助于我们去简化下垂控制的公式&#xff01;&#xff01;&#xff01; Part2 上述得到负载电流&#xff0c;接着乘以负载…

mac 查看端口占用

sudo lsof -i tcp:port # 示例 sudo lsof -i tcp:8080 杀死进程 sudo kill -9 PID # 示例 sudo kill -9 8080

基于奇偶模的跨线桥(crossover)分析

文章目录 1、ADS建模2、奇偶模分析2.1 Port1→Port2传输特性2.1.1奇模分析2.1.2偶模分析 2.2 Port1→Port4传输特性 附&#xff1a;正交混合网络的奇偶模分析1、 Port1→Port21.1奇模分析1.2Port1→Port2偶模分析1.3 奇模传输与偶模传输相位关系![在这里插入图片描述](https://…