循环神经网络-简洁实现

参考:
https://zh-v2.d2l.ai/chapter_recurrent-neural-networks/rnn-concise.html
https://pytorch.org/docs/stable/generated/torch.nn.RNN.html?highlight=rnn#torch.nn.RNN

RNN

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lbatch_size, num_steps = 32, 35  # num_steps: sequence length
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps) #  vocab:Vocab 26# 1 定义模型
# 构造一个具有256个隐藏层的循环神经网络 rnn_layer
# 此处先仅设计一层循环神经网络,以后讨论多层神经网络
num_hiddens = 256
rnn_layer = nn.RNN(len(vocab),num_hiddens) # RNN(28,256)
"""input_size – The number of expected features in the input x
hidden_size – The number of features in the hidden state h
num_layers – Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two RNNs together to form a stacked RNN, with the second RNN taking in outputs of the first RNN and computing the final results. Default: 1
nonlinearity – The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh'
bias – If False, then the layer does not use bias weights b_ih and b_hh. Default: True
batch_first – If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature). Note that this does not apply to hidden or cell states. See the Inputs/Outputs sections below for details. Default: False
dropout – If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout. Default: 0
bidirectional – If True, becomes a bidirectional RNN. Default: False
"""
# 2.我们使用张量来初始化隐状态,它的形状是(隐藏层数,批量大小,隐藏单元数)
state = torch.zeros((1,batch_size,num_hiddens))
print(state.shape)  #(torch.size([1,32,256]))#3. 通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。
# 需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算: 它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。
X=torch.rand(size=(num_steps,batch_size,len(vocab)))  #torch.Size([35, 32, 28])   # (L,N,H(in)) L:sequence length  N batch size Hin: input_size
Y,state_new = rnn_layer(X,state)
print(Y.shape,state_new.shape) #torch.Size([35, 32, 256]) torch.Size([1, 32, 256])class RNNModel(nn.Module):"""循环神经网络"""def __init__(self,rnn_layer,vocab_size,**kwargs):super(RNNModel,self).__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_size# 如果RNN是双向的,num_directions 应该是2,否则应该是1if not self.rnn.bidirectional:self.num_directions = 1self.linear = nn.Linear(self.num_hiddens,self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens*2,self.vocab_size)def forward(self,inputs,state):X = F.one_hot(inputs.T.long(),self.vocab_size)X = X.to(torch.float32)Y,state = self.rnn(X,state)# 全连接首层将Y的形状改为(时间步数*批量大小,隐藏单元数)output = self.linear(Y.reshape((-1,Y.shape[-1])))return output,statedef begin_state(self, device, batch_size=1):if not isinstance(self.rnn, nn.LSTM):# nn.GRU以张量作为隐状态return  torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device)else:# nn.LSTM以元组作为隐状态return (torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device),torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device))# 训练
device = d2l.try_gpu()
net = RNNModel(rnn_layer,vocab_size=len(vocab))
net = net.to(device)
num_epochs ,lr = 500,1
d2l.train_ch8(net,train_iter,vocab,lr,num_epochs,device)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/137775.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

排序算法:归并排序(递归和非递归)

朋友们、伙计们,我们又见面了,本期来给大家解读一下有关排序算法的相关知识点,如果看完之后对你有一定的启发,那么请留下你的三连,祝大家心想事成! C 语 言 专 栏:C语言:从入门到精通…

什么是ELK

什么是ELK ELK 并不是一个技术框架的名称,它其实是一个三位一体的技术名词,ELK 的每个字母都来自一个技术组件,分别是 Elasticsearch(简称 ES)、Logstash 和 Kibana。 三个技术组件是独立的,后两个被elast…

table 写表格

<!-- colspan"3" 合并3列 --> <!-- rowspan"4" 合并4行 --> <!-- 7行4列 --><table><tr><th>企业名称</th><td>2020.11.11</td><th>法定代表人</th><td>2020.11.11</td>&l…

Nginx替代产品-Tengine健康检测

1、官网地址 官网地址&#xff1a;The Tengine Web Server 文档地址&#xff1a;文档 - The Tengine Web Server 健康检测模块&#xff1a;ngx_http_upstream_check_module - The Tengine Web Server 2、安装 下载 wget https://tengine.taobao.org/download/tengine-3.…

如何使用ArcGIS Pro提取河网水系

DEM数据除了可以看三维地图和生成等高线之外&#xff0c;还可以用于水文分析&#xff0c;这里给大家介绍一下如何使用ArcGIS Pro通过水文分析提取河网水系&#xff0c;希望能对你有所帮助。 数据来源 本教程所使用的数据是从水经微图中下载的DEM数据&#xff0c;除了DEM数据&a…

PASCAL VOC2012数据集详细介绍

PASCAL VOC2012数据集详细介绍 0、数据集介绍2、Pascal VOC数据集目标类别3、 数据集下载与目录结构4、目标检测任务5、语义分割任务6、实例分割任务7、类别索引与名称对应关系 0、数据集介绍 2、Pascal VOC数据集目标类别 在Pascal VOC数据集中主要包含20个目标类别&#xff…

初见QT,控件的基本应用,实现简单登录窗口

窗口实现代码 #include "widget.h"Widget::Widget(QWidget *parent): QWidget(parent) {//窗口设置this->setFixedSize(538, 373); //固定窗口大小this->setWindowIcon(QIcon("G:\\QT_Icon\\windos_icon2.png"))…

线性代数基础-行列式

一、行列式之前的概念 1.全排列&#xff1a; 把n个不同的元素排成一列&#xff0c;称为n个元素的全排列&#xff0c;简称排列 &#xff08;实际上就是我们所说的排列组合&#xff0c;符号是A&#xff0c;arrange&#xff09; 2.标准序列&#xff1a; 前一项均小于后一项的序列…

微博情绪分类

引自&#xff1a;https://blog.csdn.net/no1xiaoqianqian/article/details/130593783 友好借鉴&#xff0c;总体抄袭。 所需要的文件如下&#xff1a;https://download.csdn.net/download/m0_37567738/88340795 import os import torch import torch.nn as nn import numpy a…

32:TX Text Control ActiveX/ASP.NET/WinForms/WPF Crack

TX Text Control ActiveX 32.0 添加操作“普通”样式表的能力。 2023 年 9 月 14 日 - 15:38新版本 特征 脚注- 在文档中插入与 Microsoft Word 兼容的脚注。脚注是一种文字处理功能&#xff0c;允许用户在页面底部插入附加信息。 可编辑的[普通]样式表- 添加了操作[普通]样式的…

9.20号作业实现钟表

1.widget.h #include <QPainter> //画家 #include <QTimerEvent> #include <QTime> #include<QTimer> //定时器类QT_BEGIN_NAMESPACE namespace Ui { class Widget; } QT_END_NAMESPACEclass Widget : public QWidget {Q_OBJECTpublic:Wid…

物联网网络安全:保护物理世界和数字世界的融合

我们正在见证数字技术如何成为我们日常生活和经济系统的一部分&#xff0c;从而提高福利并增强竞争力。尽管如此&#xff0c;新的尖端互联技术的迅速出现和采用也对政府、企业和整个社会构成了重大威胁。 长期以来&#xff0c;网络安全威胁一直是电影行业的一个现成的灵感来源&…

el-table表格中加入输入框

<template><div class"box"><div class"btn"><el-button type"primary">发送评委</el-button><el-button type"primary" click"flag true" v-if"!flag">编辑</el-button…

RFID技术在仓储物流供应链管理中的应用

仓储物流供应链管理的透明度和库存周转率成为管控的重点&#xff0c;为了提高仓储物流的效率和减少库存损失&#xff0c;RFID技术被广泛应用于仓储、分发、零售管理等各个环节&#xff0c;为供应链管理带来了巨大的改变和提升。 首先&#xff0c;采用RFID技术进行仓库物流智能化…

Jenkins “Trigger/call builds on other project“用法及携带参数

1.功能 “Trigger/call builds on other project” 功能是 Jenkins 中的一个特性&#xff0c;允许您在某个项目的构建过程中触发或调用另一个项目的构建。 当您在 Jenkins 中启用了 “Trigger/call builds on other project” 功能并配置了相应的触发条件后&#xff0c;当主项…

(三十二)大数据实战——Maxwell安装部署及其应用案例实战

前言 Maxwell是一个开源的MySQL数据库binlog解析工具&#xff0c;用于将MySQL数据库的binlog转换成易于消费的JSON格式&#xff0c;并通过Kafka、RabbitMQ、Kinesis 等消息队列或直接写入文件等方式将其输出。本节内容主要介绍如何安装部署Maxwell以及如何使用Maxwell完成数据…

从淘宝数据分析产品需求(商品销量总销量精准月销)

淘宝数据分析总体来说可以分为商品分析、客户分析、地区分析、时间分析四大维度(参考数据雷达的分析思路)。在这里我重点说商品分析。 在淘宝上开店的竞争还是非常激烈的&#xff0c;随便拿出一个单品就有很多竞品存在&#xff0c;所以做起来还是很难的&#xff0c;而想要在众…

嵌入式学习 - 用电控制电

目录 前言&#xff1a; 1、继电器 2、二极管 3、三极管 3.1 特殊的三极管-mos管 3.2 npn类型三极管 3.3 pnp类型三极管 3.4 三极管的放大特性 3.5 mos管和三极管的区别 前言&#xff1a; 计算机的工作的核心原理&#xff1a;用电去控制电。 所有的电子元件都有数据手册…

MySQL的高级SQL语句

目录 一、高级SQL语句 1、select 查询表中一个或多个字段的数据 2、distinct 不显示重复的数据记录 3、where 有条件查询 4、and与or 且与或 5、in 显示在某个范围值内 的字段的信息 6、between 显示两个值范围内的数据记录 7、order by 对字…

ChatGLM 实现一个BERT

前言 本文包含大量源码和讲解,通过段落和横线分割了各个模块,同时网站配备了侧边栏,帮助大家在各个小节中快速跳转,希望大家阅读完能对BERT有深刻的了解。同时建议通过pycharm、vscode等工具对bert源码进行单步调试,调试到对应的模块再对比看本章节的讲解。 涉及到的jupyt…