PyTorch -- RNN 快速实践

  • RNN Layer torch.nn.RNN(input_size,hidden_size,num_layers,batch_first)

    • input_size: 输入的编码维度
    • hidden_size: 隐含层的维数
    • num_layers: 隐含层的层数
    • batch_first: ·True 指定输入的参数顺序为:
      • x:[batch, seq_len, input_size]
      • h0:[batch, num_layers, hidden_size]
  • RNN 的输入

    • x:[seq_len, batch, input_size]
      • seq_len: 输入的序列长度
      • batch: batch size 批大小
    • h0:[num_layers, batch, hidden_size]
  • RNN 的输出

    • y: [seq_len, batch, hidden_size]

在这里插入图片描述


  • 实战之预测 正弦曲线:以下会以此为例,演示 RNN 预测任务的部署
    在这里插入图片描述
    • 步骤一:确定 RNN Layer 相关参数值并基于此创建 Net

      import numpy as np
      from matplotlib import pyplot as pltimport torch
      import torch.nn as nn
      import torch.optim as optimseq_len     = 50
      batch       = 1
      num_time_steps = seq_leninput_size  = 1
      output_size = input_size
      hidden_size = 10  	
      num_layers = 1  	
      batch_first = True class Net(nn.Module):  ## model 定义def __init__(self):super(Net, self).__init__()self.rnn = nn.RNN(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=batch_first)# for p in self.rnn.parameters():# 	nn.init.normal_(p, mean=0.0, std=0.001)self.linear = nn.Linear(hidden_size, output_size)def forward(self, x, hidden_prev):out, hidden_prev = self.rnn(x, hidden_prev)# out: [batch, seq_len, hidden_size]out = out.view(-1, hidden_size)  # [batch*seq_len, hidden_size]out = self.linear(out) 			 # [batch*seq_len, output_size]out = out.unsqueeze(dim=0)    # [1, batch*seq_len, output_size]return out, hidden_prev
      
    • 步骤二:确定 训练流程

      lr=0.01def tarin_RNN():model = Net()print('model:\n',model)criterion = nn.MSELoss()optimizer = optim.Adam(model.parameters(), lr)hidden_prev = torch.zeros(num_layers, batch, hidden_size)  #初始化hl = []for iter in range(100):  # 训练100次start = np.random.randint(10, size=1)[0]  ## 序列起点time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点output, hidden_prev = model(x, hidden_prev)hidden_prev = hidden_prev.detach()  ## 最后一层隐藏层的状态要 detachloss = criterion(output, y)model.zero_grad()loss.backward()optimizer.step()if iter % 100 == 0:print("Iteration: {} loss {}".format(iter, loss.item()))l.append(loss.item())#############################绘制损失函数#################################plt.plot(l,'r')plt.xlabel('训练次数')plt.ylabel('loss')plt.title('RNN LOSS')plt.savefig('RNN_LOSS.png')return hidden_prev,modelhidden_prev,model = tarin_RNN()
      
    • 步骤三:测试训练结果

      start = np.random.randint(3, size=1)[0]  ## 序列起点
      time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列
      data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据
      x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)
      y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点    predictions = []  ## 预测结果
      input = x[:,0,:]
      for _ in range(x.shape[1]):input = input.view(1, 1, 1)pred, hidden_prev = model(input, hidden_prev)input = pred  ## 循环获得每个input点输入网络predictions.append(pred.detach().numpy()[0])
      x= x.data.numpy()
      y = y.data.numpy( )
      plt.scatter(time_steps[:-1], x.squeeze(), s=90)
      plt.plot(time_steps[:-1], x.squeeze())
      plt.scatter(time_steps[1:],predictions)  ## 黄色为预测
      plt.show()
      

      在这里插入图片描述


【高阶】上述例子比较简单,便于入门以推理到自己的目标任务,实际 RNN (长时间序列) 训练可能更有难度,可以添加

  • 对于梯度爆炸的解决:
    for p in model.parameters()"print(p.grad.norm())  ## 查阅梯度,看看是否爆炸torch.nn.utils.clip_grad_norm_(p, 10)  ## grad 限幅,其中的 norm 后面的_ 表示 in place 操作
    
  • 对于梯度消失的解决:-> LSTM

  • 另一个很好的实例关于飞行轨迹预测- - RNN-博客链接,可供学习参考
  • B站视频参考资料

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/356835.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL 创建数据表

创建MySQL数据表需要以下信息: 表名表字段名定义每个表字段 语法 以下为创建MySQL数据表的SQL通用语法: CREATE TABLE table_name (column_name column_type); 以下例子中我们将在 W3CSCHOOL 数据库中创建数据表w3cschool_tbl: CREAT…

three.js 第八节 - gltf加载器、解码器

// ts-nocheck // 引入three.js import * as THREE from three // 导入轨道控制器 import { OrbitControls } from three/examples/jsm/controls/OrbitControls // 导入hdr加载器(专门加载hdr的) import { RGBELoader } from three/examples/jsm/loaders…

Unity3d自定义TCP消息替代UNet实现网络连接

以前使用UNet实现网络连接,Unity2018以后被弃用了。要将以前的老程序升到高版本,最开始打算使用Mirro,结果发现并不好用。那就只能自己写连接了。 1.TCP消息结构 (1). TCP消息是按流传输的,会发生粘包。那么在发射和接收消息时就需要对消息进行打包和解包。如果接收的消息…

建造者模式(大话设计模式)C/C++版本

建造者模式 C 参考&#xff1a;https://www.cnblogs.com/Galesaur-wcy/p/15907863.html #include <iostream> #include <vector> #include <algorithm> #include <string> using namespace std;// Product Class&#xff0c;产品类&#xff0c;由多个…

荒野大镖客2启动找不到emp.dll的7个修复方法,轻松解决dll丢失的办法

一、emp.dll文件丢失的常见原因 安装或更新问题&#xff1a;在软件或游戏的安装过程中&#xff0c;可能由于安装程序未能正确复制文件到目标目录&#xff0c;或在更新过程中文件被意外覆盖或删除&#xff0c;导致emp.dll文件丢失。 安全软件误删&#xff1a;某些安全软件可能…

【ajax基础01】ajax简介

目录 一&#xff1a;ajax简介 1 什么是ajax 二&#xff1a;ajax使用 1 如何使用ajax 2 axios使用&#xff08;重点&#xff09; 三&#xff1a;案例 四&#xff1a;如何赚钱 一&#xff1a;ajax简介 1 什么是ajax AJAX&#xff08;Asynchronous JavaScript And XML &am…

莱辅络Rebro BIM机电专业软件

莱辅洛&#xff08;Rebro&#xff09;是一款专业机电 BIM 软件。它具备专业人士所期待的各种专业功能&#xff0c;应用于建筑机电工程的三维设计&#xff0c;并且适用于建筑、结构、给排水、暖通、电气五大专业。 该软件具有以下特点&#xff1a; • 3D 模型&#xff1a;可以…

渗透测试-若依框架的杀猪交易所系统管理后台

前言 这次是带着摸鱼的情况下简单的写一篇文章&#xff0c;由于我喜欢探究黑灰产业&#xff0c;所以偶尔机遇下找到了一个加密H币的交易所S猪盘&#xff0c;我记得印象是上年的时候就打过这一个同样的站&#xff0c;然后我是通过指纹查找其它的一些站&#xff0c;那个站已经关…

计网重点面试题-TCP三次握手四次挥手

三次握手 第一次握手(syn1) 客户端会随机初始化序号&#xff08;client_isn&#xff09;&#xff0c;将此序号置于 TCP 首部的「序列号」字段中&#xff0c;同时把 SYN 标志位置为 1&#xff0c;表示 SYN 报文。接着把第一个 SYN 报文发送给服务端&#xff0c;表示向服务端发…

一个电商创业者眼中的618:平台大变局

战役结束了&#xff0c;战斗还在继续。 一位朋友去年5月创业&#xff0c;网上卖咖啡&#xff0c;这个赛道很拥挤&#xff0c;时机也不好&#xff0c;今年是他参加第一个618。朋友说&#xff0c;今年的目标是锤炼团队&#xff0c;总结方法&#xff0c;以及最重要的——活下去。…

基于CentOS Stream 9平台 安装/卸载 Redis7.0.15

已更正systemctl管理Redis服务问题 1. 官方下载地址 https://redis.io/downloads/#redis-downloads 1.1 下载或上传到/opt/coisini目录下&#xff1a; mkdir /opt/coisini cd /opt/coisini wget https://download.redis.io/releases/redis-7.0.15.tar.gz2. 解压 tar -zxvf re…

基于IDEA的Maven简单工程创建及结构分析

目录 一、用 mvn 命令创建项目 二、用 IDEA 的方式来创建 Maven 项目。 &#xff08;1&#xff09;首先在 IDEA 下的 Maven 配置要已经确保完成。 &#xff08;2&#xff09;第二步去 new 一个 project &#xff08;创建一个新工程&#xff09; &#xff08;3&#xff09;…

1027. 方格取数

Powered by:NEFU AB-IN Link 文章目录 1027. 方格取数题意思路代码 1027. 方格取数 题意 某人从图中的左上角 A 出发&#xff0c;可以向下行走&#xff0c;也可以向右行走&#xff0c;直到到达右下角的 B 点。 在走过的路上&#xff0c;他可以取走方格中的数&#xff08;取…

Docker网络介绍

网络是虚拟化技术中最复杂的部分&#xff0c;也是Docker应用中的一个重要环节。 Docker中的网络主要解决容器与容器、容器与外部网络、外部网络与容器之间的互相通信的问题。 这些复杂情况的存在要求Docker有一个强大的网络功能去保障其网络的稳健性。因此&#xff0c;Docker…

SCI一区TOP|双曲正弦余弦优化算法(SCHO)原理及实现【免费获取Matlab代码】

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献5.代码获取 1.背景 2023年&#xff0c;J Bai受到双曲正弦余弦函数启发&#xff0c;提出了双曲正弦余弦优化算法&#xff08;Sinh Cosh optimizer, SCHO&#xff09;。 2.算法原理 2.1算法思想 SCHO灵感来源…

分布式锁(Redission)

分布式锁&#xff1a; 使用场景&#xff1a; 通常对于一些使用率高的服务&#xff0c;我们会进行多次部署&#xff0c;可能会部署在不同的服务器上&#xff0c;但是他们获取和操作的数据仍然是同一份。为了保证服务的强一致性&#xff0c;我们需要对线程进行加锁&#xff0c;…

ROS中的TF是什么

在ROS (Robot Operating System) 中&#xff0c;tf::TransformBroadcaster 是一个用于发布坐标变换信息的重要类&#xff0c;尤其在处理机器人定位和导航数据时非常常见。tf::TransformBroadcaster 对象允许你广播从一个坐标系到另一个坐标系的变换关系&#xff0c;这对于多传感…

Vue38 安装脚手架 vue-cli ,并使用脚手架创建项目

安装脚手架 vue-cli &#xff0c;并使用脚手架创建项目 第一步 安装脚手架 npm config set registry https:\\[registry.npmmirror.com // 切换淘宝镜像 npm install -g vue/cli第二步 切换到创建项目的目录&#xff0c;创建项目 cd XXX vue create XXX第三步 启动项目 npm…

【Linux】基础IO_3

文章目录 六、基础I/O3. 软硬链接4. 动静态库 未完待续 六、基础I/O 3. 软硬链接 使用 ln 就可以创建链接&#xff0c;使用 ln -s 可以创建软链接&#xff0c;直接使用 ln 则是硬链接。 我们对硬链接进行测试一下&#xff1a; 根据测试&#xff0c;我们知道了 硬链接就像一…

WPF文本绑定显示格式StringFormat设置-数值类型处理

绑定显示格式设置 在Textblock等文本控件中&#xff0c;我们经常要绑定一些数据类型&#xff0c;但是我们希望显示的时候能够按照我们想要的格式去显示&#xff0c;比如增加文本前缀&#xff0c;后面加单位&#xff0c;显示百分号等等&#xff0c;这种就需要对绑定格式进行处理…