pytorch实现RNN网络

目录

1.导包

2. 加载本地文本数据

 3.构建循环神经网络层

4.初始化隐藏状态state

5.创建随机的数据,检测一下代码是否能正常运行

6. 构建一个完整的循环神经网络¶ 

7.模型训练 

8.个人知识点理解


 

1.导包

import torch
from torch import nn
from torch.nn import functional as F
import dltools

2. 加载本地文本数据

#声明变量:批次大小(一批所取的数据量)、子序列的长度
batch_size, num_steps =32, 35
#获取训练数据的迭代器, 词汇表
train_iter, vocab = dltools.load_data_time_machine(batch_size=batch_size, num_steps=num_steps)

 3.构建循环神经网络层

#声明变量:隐藏层的神经元数量(每个神经元都会有一个输出)
num_hiddens = 256
#构建一个具有256个隐藏单元的单隐藏层的循环神经网络
#num_layers=1默认值:一层神经网络
rnn_layer = nn.RNN(input_size=len(vocab), hidden_size=num_hiddens, num_layers=1)

4.初始化隐藏状态state

# 括号中的1:因为num_layers=1默认值:一层神经网络
state = torch.zeros((1, batch_size, num_hiddens))
state.shape
torch.Size([1, 32, 256])

5.创建随机的数据,检测一下代码是否能正常运行

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
#传入X和初始化时的state,获取Y和state_new
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape#有输出表示代码正常运行!!!

 (torch.Size([35, 32, 256]), torch.Size([1, 32, 256])) 

6. 构建一个完整的循环神经网络¶ 

.long() 方法‌:这是PyTorch张量的一个方法,用于将张量的数据类型转换为torch.long。torch.long是一种整数数据类型,通常用于索引或存储不需要浮点数精度的整数数据。 

class RNNModel(nn.Module):   #继承nn.Module#初始化(需要用到的)参数,  **kwargs表示继承的其他参数(不一一写明的意思)#vocab_size = len(vocab)def __init__(self, rnn_layer, vocab_size, **kwargs):#继承父类的属性和方法super().__init__(**kwargs)self.rnn_layer = rnn_layer#词汇表的长度self.vocab_size =vocab_sizeself.num_hiddens = self.rnn_layer.hidden_size#判断是否为双向循环if not self.rnn_layer.bidirectional:self.num_directions = 1#nn.Linear用于定义线性层的类,一般用于全连接层self.linear = nn.Linear(in_features=self.num_hiddens, out_features=self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens*2, self.vocab_size)#定义了数据在模型中的前向传播过程。(串联每一件事件的逻辑顺序)def forward(self, inputs, state):#one_hot编码,处理输入的X数据,此时的X.shape=(batch_size, num_steps)#。T转置之后,X.shape=(num_steps,batch_size)#one_hot编码之后, X.shape=(num_steps,batch_size, len(vocab)X = F.one_hot(inputs.T.long(), self.vocab_size)#将数据转化为tensorX = X.to(torch.float32)Y, state = self.rnn_layer(X, state)#此时,Y.shape = torch.Size(num_steps, batch_size, num_hiddens)#输出层:Y.shape必须是一个二维的, -1表示合并Y.shape中的num_steps与batch_size,outputs = self.linear(Y.reshape(-1, Y.shape[-1]))return outputs, state# 初始化隐藏状态def begin_state(self, device, batch_size=1):return torch.zeros((self.num_directions * self.rnn_layer.num_layers, batch_size, self.num_hiddens), device=device)
#在训练之前,基于随机初始化的权重进行预测,测试模型
device = dltools.try_gpu()
rnn_net = RNNModel(rnn_layer, vocab_size=len(vocab))
rnn_net = rnn_net.to(device)
dltools.predict_ch8(prefix='time traveller',num_preds=10, net=rnn_net, vocab=vocab, device=device)
'time travellergghhhhhhhh'

7.模型训练 

#声明变量
#模型训练时,可以先让学习率的值稍大一些,让梯度下降的快一些,然后
#梯度下降到一定程度再改成较小的值
num_epochs, lr = 500, 0.1
dltools.train_ch8(net=rnn_net, train_iter=train_iter, vocab=vocab, lr=lr, num_epochs=num_epochs, device=device)

 

8.个人知识点理解

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/429454.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python画笔案例-057 绘制蜘蛛网

1、绘制蜘蛛网 通过 python 的turtle 库绘制 蜘蛛网,如下图: 2、实现代码 绘制蜘蛛网,以下为实现代码: """蜘蛛网.py """ import turtledef draw_circle(pos,r):"""p

实时数据的处理一致性

实时数据一致性的定义以及面临的挑战‍‍‍‍‍ 数据一致性通常指的是数据在整个系统或多个系统中保持准确、可靠和同步的状态。在实时数据处理中,一致性包括但不限于数据的准确性、完整性、时效性和顺序性。 下图是典型的实时/流式数据处理的流程: 1、…

Infineon——TC397 Multicore简介

文章目录 前言一、TC397简介二、命名规则三、多核开发建议 前言 AURIX™ TC3xx微控制器架构具有多达6个独立的处理器内核CPU0…CPU5, 可在一个统一平台上无缝托管多个应用程序和操作系统. 由于实现了具有独立读取接口的多个程序Flash模块, 该架构支持进一步的实时处理. AURIX™…

毕业设计选题:基于ssm+vue+uniapp的驾校预约管理系统小程序

开发语言:Java框架:ssmuniappJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:M…

iPhone16,超先进摄像头系统?丝滑的相机控制

iPhone 16将于9月20号正式开售,这篇文章我们来看下iPhone 16 在影像方面,有哪些升级和新feature。 芯片:采用第二代 3纳米芯片,A18。 摄像头配置: iPhone 16 前置:索尼 IMX714 ,1200 万像素&am…

Red Hat 和 Debian Linux 对比

原图的作者(https://bbs.deepin.org/post/209759) Red Hat Enterprise Linux https://www.redhat.com/ CentOS Linux https://www.centos.org/ Fedora Linux https://fedoraproject.org/ Debian https://www.debian.org/ Ubuntu https://cn.ubuntu.com/ https://ubuntu.c…

harbor私有镜像仓库,搭建及管理

私有镜像仓库 docker-distribution docker的镜像仓库,默认端口号5000 做个仓库,把镜像放里头,用什么服务,起什么容器 vmware公司在docker私有仓库的基础上做了一个web页面,是harbor docker可以把仓库的镜像下载到本地&…

Matlab 的.m 文件批量转成py文件

在工作中碰到了一个问题,需要将原来用matlab gui做出来的程序改为python程序,因为涉及到很多文件,所以在网上搜了搜有没有直接能转化的库。参考了【Matlab】一键Matlab代码转python代码详细教程_matlab2python-CSDN博客 这位博主提到的matla…

Lanterns (dp 紫 线段树 二分 维护dp)

Lanterns - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 让所有点被覆盖,那么状态可以设计成覆盖一段前缀,并且中间不允许出现断点 由于CF崩了,所以暂时没提交代码。 记f(i) 为前 i 个灯笼点亮的最长前缀。 由于答案具有保留性&#xff…

自闭症孩子送寄宿学校,给他们成长的机会

在自闭症儿童的教育与康复之路上,选择一种合适的寄宿方式对于孩子的成长至关重要。这不仅关乎到孩子能否获得专业的训练与关怀,还直接影响到他们未来的社交能力、独立生活能力以及心理健康。今天,我们将以广州的星贝育园自闭症儿童寄宿制学校…

SOI 刻蚀气体

Liu, Yingjie, et al. "Very sharp adiabatic bends based on an inverse design." Optics letters 43.11 (2018): 2482-2485.

yolov8模型在手部关键点检测识别中的应用【代码+数据集+python环境+GUI系统】

yolov8模型在手部关键点检测识别中的应用【代码数据集python环境GUI系统】 背景意义 在手势识别、虚拟现实(VR)、增强现实(AR)等领域,手部关键点检测为用户提供了更加自然、直观的交互方式。通过检测手部关键点&#…

FreeSWITCH 简单图形化界面29 - 使用mod_xml_curl 动态获取配置、用户、网关数据

FreeSWITCH 简单图形化界面29 - 使用mod_xml_curl 动态获取配置、用户、网关数据 FreeSWITCH GUI界面预览安装FreeSWITCH GUI先看使用手册1、简介2、安装mod_xml_curl模块3、配置mod_xml_curl模块3、编写API接口4、测试一下5、其他注意的地方 FreeSWITCH GUI界面预览 http://m…

鸿蒙开发(NEXT/API 12)【跨设备互通特性简介】协同服务

跨设备互通提供跨设备的相机、扫描、图库访问能力,平板或2in1设备可以调用手机的相机、扫描、图库等功能。 说明 本章节以拍照为例展开介绍,扫描、图库功能的使用与拍照类似。 用户在平板或2in1设备上使用富文本类编辑应用(如:…

【yolo破损纸板-包装盒-快递袋缺陷检测】

yolo破损纸板-包装盒-快递袋缺陷检测 破损纸质包装盒检测方盒型快递包裹检测 破损纸质包装盒检测 数据集合模型 可视化 方盒型快递包裹检测 数据集和模型 train: ../train/images val: ../valid/images test: ../test/images nc: 1 names: - box_packet可视化

股指期权交易详细基础介绍

股指期权是期权市场中的一种特定类型,其标的资产为股票指数。简而言之,它允许投资者在未来某个特定时间,以预先约定的价格,买入或卖出股票指数的权利。在中国,已上市的股指期权包括上证50、沪深300和中证1000股指期权&…

鸿萌数据恢复服务: 修复 Windows, Mac, 手机中 “SD 卡无法读取”错误

天津鸿萌科贸发展有限公司从事数据安全服务二十余年,致力于为各领域客户提供专业的数据恢复、数据备份解决方案与服务,并针对企业面临的数据安全风险,提供专业的相关数据安全培训。 公司是多款国际主流数据恢复软件的授权代理商,为…

C语言深入理解指针(四)

目录 字符指针变量数组指针变量数组指针变量是什么数组指针变量怎么初始化 二维数组传参的本质函数指针变量函数指针变量的创建函数指针变量的使用代码typedef关键字 函数指针数组转移表 字符指针变量 字符指针在之前我们有提到过,(字符)&am…

Python--TCP/UDP通信

文章目录 前言一、pandas是什么?二、使用步骤 1.引入库2.读入数据总结 一.客户端与服务端通信原理 1. 服务器端 服务器端的主要任务是监听来自客户端的连接请求,并与之建立连接,然后接收和发送数据。 创建套接字:首先&#xff0…

《使用 LangChain 进行大模型应用开发》学习笔记(四)

前言 本文是 Harrison Chase (LangChain 创建者)和吴恩达(Andrew Ng)的视频课程《LangChain for LLM Application Development》(使用 LangChain 进行大模型应用开发)的学习笔记。由于原课程为全英文视频课…