【深度学习实验】循环神经网络(一):循环神经网络(RNN)模型的实现与梯度裁剪

目录

一、实验介绍

二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. 数据处理

2. rnn

测试

3. grad_clipping

4. 代码整合


        经验是智慧之父,记忆是智慧之母。

——谚语

一、实验介绍

        本实验介绍了一个简单的循环神经网络(RNN)模型,并探讨了梯度裁剪在模型训练中的应用。

        在前馈神经网络中,信息的传递是单向的,这种限制虽然使得网络变得更容易学习,但在一定程度上也减弱了神经网络模型的能力.在生物神经网络中,神经元之间的连接关系要复杂得多.前馈神经网络可以看作一个复杂的函数,每次输入都是独立的,即网络的输出只依赖于当前的输入.但是在很多现实任务中, 网络的输出不仅和当前时刻的输入相关,也和其过去一段时间的输出相关.比如一个有限状态自动机,其下一个时刻的状态(输出)不仅仅和当前输入相关,也和当前状态(上一个时刻的输出)相关.此外,前馈网络难以处理时序数据,比如视频、语音、文本等.时序数据的长度一般是不固定的,而前馈神经网络要求输入和输出的维数都是固定的,不能任意改变.因此,当处理这一类和时序数据相关 的问题时,就需要一种能力更强的模型. 循环神经网络(Recurrent Neural Network,RNN)是一类具有短期记忆能力的神经网络

        在循环神经网络中,神经元不但可以接受其他神经元的信息,也可以接受自身的信息,形成具有环路的网络结构.和前馈神经网络相比,循环神经网络更加符合生物神经网络的结构.循环神经网络已经被广泛应用在语音识别、语言模型以及自然语言生成等任务上.循环神经网络的参数学习可以通过随时间反向传播算法[Werbos, 1990]来学习.随时间反向传播算法即按照时间的逆序将错误信息一步步地往前传递.

二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 导入必要的工具包

import torch

1. 数据处理

        与之前的模型有所不同,循环神经网络引入了隐藏状态时间步两个新概念。当前时间步的隐藏状态由当前时间的输入与上一个时间步的隐藏状态一起计算出。

         根据隐藏状态的计算公式,需要计算两次矩阵乘法和三次加法才能得到当前时刻的隐藏状态。这里通过代码说明: 该计算公式等价于将当前时刻的输入与上一个时间步的隐藏状态做拼接,将两个权重矩阵做拼接,然后对两个拼接后的结果做矩阵乘法。此处展示省略了偏置项。

# X为模拟的输入,H为模拟的隐藏状态,在实际情况时要更复杂一些
X, W_xh = torch.normal(0, 1, (3, 1)), torch.normal(0, 1, (1, 4))
H, W_hh = torch.normal(0, 1, (3, 4)), torch.normal(0, 1, (4, 4))
torch.matmul(X, W_xh) + torch.matmul(H, W_hh)

上面是按照公式计算得到的结果,下面是拼接后计算得到的结果,两个结果完全相同

torch.matmul(torch.cat((X, H), 1), torch.cat((W_xh, W_hh), 0))

  • X是一个形状为(3, 1)的张量,表示输入。
  • W_xh是一个形状为(1, 4)的张量,表示输入到隐藏状态的权重。
  • H是一个形状为(3, 4)的张量,表示隐藏状态。
  • W_hh是一个形状为(4, 4)的张量,表示隐藏状态到隐藏状态的权重。

2. rnn

        定义了一个名为rnn的函数,用于执行循环神经网络的前向传播,在函数内部,通过遍历输入序列的每个时间步,逐步计算隐藏状态和输出。

def rnn(inputs, state, params):# inputs的形状:(时间步数量,批量大小,词表大小)W_xh, W_hh, b_h, W_hq, b_q = paramsH = stateoutputs = []# X的形状:(批量大小,词表大小)for X in inputs:H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)Y = torch.mm(H, W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)
  • 参数:
    • inputs是一个形状为(时间步数量,批量大小,词表大小)的张量,表示输入序列。
    • state是一个形状为(批量大小,隐藏状态大小)的张量,表示初始隐藏状态。
    • params是一个包含了模型的参数的列表,包括W_xhW_hhb_hW_hqb_q
  • 对于每个时间步,
    • 使用tanh激活函数来更新隐藏状态
    • 根据更新后的隐藏状态,计算输出Y
    • 将输出添加到outputs列表中
  • 使用torch.cat函数将输出列表合并成一个张量,返回合并后的张量和最后一个隐藏状态 (H,)

测试

    inputs=torch.rand(10,3,50)params=[torch.rand((50,50)),torch.rand((50,50)),torch.rand((3,50)),torch.rand((50,60)),torch.rand((3,60))]state=torch.rand((3,50))output=rnn(inputs,state,params)print(output)
  • inputs是一个形状为(10, 3, 50)的随机张量,表示模拟的输入序列
  • params是一个包含了随机参数的列表,与rnn函数中的参数对应
  • state是一个形状为(3, 50)的随机张量,表示初始隐藏状态
  • 调用rnn函数
  • 打印输出结果output

3. grad_clipping

        在循环神经网络的训练中,当时间步较大时,可能导致数值不稳定, 例如梯度爆炸或梯度消失,所以一个很重要的步骤是梯度裁剪。通过下面的函数,梯度范数永远不会超过给定的阈值, 并且更新后的梯度完全与的原始方向对齐。

def grad_clipping(net, theta):if isinstance(net, nn.Module):params = [p for p in net.parameters() if p.requires_grad]else:params = net.paramsnorm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))if norm > theta:for param in params:param.grad[:] *= theta / norm

        函数接受两个参数:net和theta。该函数首先根据net的类型获取需要梯度更新的参数,然后计算所有参数梯度的平方和的平方根,并将其与阈值theta进行比较。如果超过阈值,则对参数梯度进行裁剪,使其不超过阈值。

4. 代码整合

# 导入必要的工具包
import torch# # X为模拟的输入,H为模拟的隐藏状态,在实际情况时要更复杂一些
# X, W_xh = torch.normal(0, 1, (3, 1)), torch.normal(0, 1, (1, 4))
# H, W_hh = torch.normal(0, 1, (3, 4)), torch.normal(0, 1, (4, 4))
# # torch.matmul(X, W_xh) + torch.matmul(H, W_hh)
# #
# # torch.matmul(torch.cat((X, H), 1), torch.cat((W_xh, W_hh), 0))def rnn(inputs, state, params):# inputs的形状:(时间步数量,批量大小,词表大小)W_xh, W_hh, b_h, W_hq, b_q = paramsH = stateoutputs = []# X的形状:(批量大小,词表大小)for X in inputs:H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)Y = torch.mm(H, W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)def grad_clipping(net, theta):if isinstance(net, nn.Module):params = [p for p in net.parameters() if p.requires_grad]else:params = net.paramsnorm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))if norm > theta:for param in params:param.grad[:] *= theta / normif __name__ == '__main__':inputs=torch.rand(10,3,50)params=[torch.rand((50,50)),torch.rand((50,50)),torch.rand((3,50)),torch.rand((50,60)),torch.rand((3,60))]state=torch.rand((3,50))output=rnn(inputs,state,params)print(output)

        使用随机生成的输入数据和参数进行模型的测试。测试结果显示,RNN模型能够正确计算隐藏状态和输出结果,并且通过梯度裁剪可以有效控制梯度的大小,提高模型的稳定性和训练效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/156173.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

排序算法——选择排序

一、介绍: 选择排序就是按照一定的顺序从选取第一个元素索引开始,将其储存在一个变量值中,根据排序规则比较后边每一个元素与这个元素的大小,根据排序规则需要,变量值的索引值进行替换,一轮遍历之后&#x…

通用监控视频web播放方案

业务场景 对接监控视频,实现海康大华等监控摄像头的实时画面在web端播放 方案一,使用 RTSP2webnode.jsffmpeg 说明:需要node环境,原理就是RTSP2web实时调用ffmpeg解码。使用单独html页面部署到服务器后,在项目中需要播…

Element组件案例 Vue路由 前端打包部署步骤

目录 Element组件案例案例需求与分析环境搭建整体布局顶部标题左侧导航栏核心-右侧导航栏表格编写表单编写分页工具栏编写 异步数据加载异步加载数据性别展示修复图片展示修复 Vue路由Vue路由简介Vue路由入门 打包部署前端工程打包部署前端工程nginx介绍部署 Element组件案例 …

云原生Kubernetes:K8S集群版本升级(v1.20.6 - v1.20.15)

目录 一、理论 1.K8S集群升级 2.集群概况 3.升级集群 4.验证集群 二、实验 1.升级集群 2.验证集群 三、问题 1.给node1节点打污点报错 一、理论 1.K8S集群升级 (1)概念 搭建K8S集群的方式有很多种,比如二进制,kubeadm…

把短信验证码储存在Redis

校验短信验证码 接着上一篇博客https://blog.csdn.net/qq_42981638/article/details/94656441,成功实现可以发送短信验证码之后,一般可以把验证码存放在redis中,并且设置存放时间,一般短信验证码都是1分钟或者90s过期,…

Redis 新手必读。这篇文章是学 Redis 的捷径。

Redis 简介Redis 优势Redis 数据类型基本命令发布订阅订阅者的客户端显示如下事务持久化复制哨兵分片 Redis 简介 Redis 是完全开源免费的,遵守 BSD 协议,是一个高性能的 key - value 数据库 Redis 与 其他 key - value 缓存产品有以下三个特点&#…

RAMday9

设置按键中断,按键1按下,LED亮,再按一次,灭;按键2按下,蜂鸣器响,再按一次,不响;按键3按下,风扇转,再按一次,风扇停 代码 do_irq.c #include "key.h" extern…

网工内推 | 技术支持工程师,厂商公司,HCIA即可,有带薪年假

01 华为终端有限公司 招聘岗位:初级技术支持 职责描述: 1、通过远程方式处理华为用户在产品使用过程中各种售后问题; 2、收集并整理消费者声音,提供服务持续优化建议; 3、对服务中发现的热点、难点问题及其他有可能造…

PTA 7-5 令人抓狂的四则运算

题目 曾记否,我们小学时,遇到这种四则运算,心情是抓狂的: 那么当我们学会使用计算机,自然是要程序去完成这个工作啦~ 现在请对输入的四则运算求值。注意: 四则运算表达式必定包含运算数,还可能…

记一次惊险的CDH6.3.2集群断电后重启的过程

重启服务 systemctl restart cloudera-scm-server.service systemctl restart cloudera-scm-agent.service查看服务是否启动,显然结果是failed systemctl status cloudera-scm-server.service查看异常 journalctl -xe去看服务日志 发现是这个位置错误 SqlExcep…

使用 Eziriz .NET Reactor 对c#程序加密

我目前测试过好几个c#加密软件。效果很多时候是加密后程序执行错误,或者字段找不到的现象 遇到这个加密软件用了一段时间都很正常,分享一下使用流程 破解版本自行百度。有钱的支持正版,我用的是 Eziriz .NET Reactor 6.8.0 第一步 安装 Ezi…

35道Rust面试题

这套Rust面试题包括了填空题、判断题、连线题和编码题等题型。 选择题 1 ,下面哪个是打印变量language的正确方法? A,println("{}", language); B,println(language); C,println!("{}", langu…

Vue-2.3v-model原理

原理:v-model本质上是一个语法糖,例如应用在输入框上,就是value属性和input事件的合写。 作用:提供数据的双向绑定 1)数据变,视图跟着变:value 2)视图变,数据跟着变input 注意&a…

RabbitMQ详细使用

工作队列 注意事项:一个消息只能被处理一次,不可以处理多次 轮询分发信息 消息应答 消费者在接收到消息并且处理该消息之后,告诉rabbitmq它已经处理了,rabbitmq可以把该消息删除了。倘若mq没有收到应答,mq会将消息转…

微信小程序 movable-view 控制长按才触发拖动 轻轻滑动页面正常滚动效果

今天写 movable-areamovable-view遇到了个头疼的问题 那就是 movable-view 监听了用户拖拽自己 但 我们小程序 上下滚动页面靠的也是拖拽 也就是说 如果放在这里 用户拖动 movable-view部分 就会永远触发不了滚动 那么 我们先可以 加一个 bindlongpress"longpressHandler…

FHRP首跳冗余的解析

首跳冗余的解析 个人简介 HSRP hot standby router protocol 热备份路由协议 思科设备上 HSRP VRRP 华为设备上 VRRP HSRP v1 version 1 HSRP v2 version 2 虚拟一个HSRP虚拟IP地址 192.168.1.1 开启HSRP的抢占功能 通过其他参数 人为调整谁是主 谁是从 &a…

Maven 构建配置文件

目录 构建配置文件的类型 配置文件激活 配置文件激活实例 1、配置文件激活 2、通过Maven设置激活配置文件 3、通过环境变量激活配置文件 4、通过操作系统激活配置文件 5、通过文件的存在或者缺失激活配置文件 构建配置文件是一系列的配置项的值,可以用来设置…

Linux知识点 -- 高级IO(一)

Linux知识点 – 高级IO(一) 文章目录 Linux知识点 -- 高级IO(一)一、5种IO模型1.IO再理解2.阻塞IO3.非阻塞轮询式IO4.信号驱动IO5.IO多路转接6.异步IO7.同步通信vs异步通信8.阻塞vs非阻塞 二、非阻塞IO1.设置非阻塞的方法2.非阻塞…

电子科大软件系统架构设计——系统架构设计

文章目录 系统架构设计系统设计概述系统设计定义系统设计过程系统设计活动系统设计基本方法系统设计原则系统设计方法分类面向对象系统分析与设计建模过程 系统架构基础系统架构定义系统架构设计定义系统架构作用系统架构类型系统总体架构系统拓扑架构系统拓扑架构类型系统拓扑…

JavaWeb---Servlet

1.Srvlet概述 Servlet是运行在java服务器端的程序,用于接收和响应来着客户端基于HTTP协议的请求 如果想实现Servlet的功能,可以通过实现javax。servlet。Servlet接口或者继承它的实现类 核心方法:service()&#xf…