pytorch实现循环神经网络

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

PyTorch 提供三种主要的 RNN 变体:

  • nn.RNN:最基本的循环神经网络,适用于短时依赖任务。
  • nn.LSTM:长短时记忆网络,适用于长序列数据,能有效解决梯度消失问题。
  • nn.GRU:门控循环单元,比 LSTM 计算更高效,适用于大部分任务。
网络类型优势适用场景
RNN计算简单,适用于短时序列语音、文本处理(短序列)
LSTM适用于长序列,能记忆长期信息机器翻译、语音识别、股票预测
GRU比 LSTM 计算更高效,效果相似语音处理、文本生成

例子:

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 1. 生成正弦波数据(仅使用 PyTorch)
def generate_sine_wave(seq_length=10, num_samples=1000):x = torch.linspace(0, 100, num_samples)  # 生成 1000 个等间距数据点y = torch.sin(x)  # 计算正弦值X_data, Y_data = [], []for i in range(len(y) - seq_length):X_data.append(y[i:i + seq_length].unsqueeze(-1))  # 过去 seq_length 作为输入Y_data.append(y[i + seq_length])  # 预测下一个点return torch.stack(X_data), torch.tensor(Y_data).unsqueeze(-1)# 生成数据
seq_length = 10  # 序列长度
X, Y = generate_sine_wave(seq_length)# 划分训练集和测试集
train_size = int(0.8 * len(X))
X_train, X_test = X[:train_size], X[train_size:]
Y_train, Y_test = Y[:train_size], Y[train_size:]# 2. 定义 RNN 模型
class SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)  # 初始化隐藏状态out, _ = self.rnn(x, h0)out = self.fc(out[:, -1, :])  # 取最后一个时间步的输出return out# 3. 训练模型
# 超参数
input_size = 1
hidden_size = 32
output_size = 1
num_layers = 1
num_epochs = 100
learning_rate = 0.001# 初始化模型
model = SimpleRNN(input_size, hidden_size, output_size, num_layers)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练
for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(X_train)loss = criterion(outputs, Y_train)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 4. 评估与绘图
model.eval()
with torch.no_grad():predictions = model(X_test)# 画图
plt.figure(figsize=(10, 5))
plt.plot(Y_test.numpy(), label="Real Data")
plt.plot(predictions.numpy(), label="Predicted Data")
plt.legend()
plt.title("RNN Sine Wave Prediction")
plt.show()

代码解析

数据生成

  • torch.linspace(0, 100, num_samples) 生成 1000 个均匀分布的数据点。
  • torch.sin(x) 计算正弦值,形成时间序列数据。
  • X过去 10 个时间步的数据,Y下一个时间步的预测目标

构建 RNN

  • nn.RNN(input_size, hidden_size, num_layers, batch_first=True) 定义循环神经网络
    • input_size=1:每个时间步只有一个输入值(正弦波)。
    • hidden_size=32:隐藏层神经元数目。
    • num_layers=1:单层 RNN。
  • self.fc = nn.Linear(hidden_size, output_size) 负责最终输出。

训练

  • 使用 MSELoss(均方误差损失) 计算预测值与真实值的误差。
  • 使用 Adam 优化器 更新模型参数。
  • 每 10 个 epoch 输出一次损失 loss

测试 & 绘图

  • 关闭梯度计算 (torch.no_grad()),执行前向传播预测测试数据。
  • Matplotlib 绘制预测曲线与真实曲线。

运行效果

如果训练成功,预测曲线(橙色)应该与真实曲线(蓝色)非常接近

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/10975.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

创建前端项目的方法

目录 一、创建前端项目的方法 1.前提:安装Vue CLI 2.方式一:vue create项目名称 3.方式二:vue ui 二、Vue项目结构 三、修改Vue项目端口号的方法 一、创建前端项目的方法 1.前提:安装Vue CLI npm i vue/cli -g 2.方式一&…

(leetcode 213 打家劫舍ii)

代码随想录&#xff1a; 将一个线性数组换成两个线性数组&#xff08;去掉头&#xff0c;去掉尾&#xff09; 分别求两个线性数组的最大值 最后求这两个数组的最大值 代码随想录视频 #include<iostream> #include<vector> #include<algorithm> //nums:2,…

【Python】第七弹---Python基础进阶:深入字典操作与文件处理技巧

✨个人主页&#xff1a; 熬夜学编程的小林 &#x1f497;系列专栏&#xff1a; 【C语言详解】 【数据结构详解】【C详解】【Linux系统编程】【MySQL】【Python】 目录 1、字典 1.1、字典是什么 1.2、创建字典 1.3、查找 key 1.4、新增/修改元素 1.5、删除元素 1.6、遍历…

本地部署DeepSeek开源多模态大模型Janus-Pro-7B实操

本地部署DeepSeek开源多模态大模型Janus-Pro-7B实操 Janus-Pro-7B介绍 Janus-Pro-7B 是由 DeepSeek 开发的多模态 AI 模型&#xff0c;它在理解和生成方面取得了显著的进步。这意味着它不仅可以处理文本&#xff0c;还可以处理图像等其他模态的信息。 模型主要特点:Permalink…

BW AO/工作簿权限配置

场景&#xff1a; 按事业部配置工作簿权限&#xff1b; 1、创建用户 事务码&#xff1a;SU01&#xff0c;用户主数据的维护&#xff0c;可以创建、修改、删除、锁定、解锁、修改密码等 用户设置详情页 2、创建权限角色 用户的权限菜单是通过权限角色分配来实现的 2.1、自定…

jstat命令详解

jstat 用于监视虚拟机运行时状态信息的命令&#xff0c;它可以显示出虚拟机进程中的类装载、内存、垃圾收集、JIT 编译等运行数据。 命令的使用格式如下。 jstat [option] LVMID [interval] [count]各个参数详解&#xff1a; option&#xff1a;操作参数LVMID&#xff1a;本…

3.Spring-事务

一、隔离级别&#xff1a; 脏读&#xff1a; 一个事务访问到另外一个事务未提交的数据。 不可重复读&#xff1a; 事务内多次查询相同条件返回的结果不同。 幻读&#xff1a; 一个事务在前后两次查询同一个范围的时候&#xff0c;后一次查询看到了前一次查询没有看到的行。 二…

MYSQL--一条SQL执行的流程,分析MYSQL的架构

文章目录 第一步建立连接第二部解析 SQL第三步执行 sql预处理优化阶段执行阶段索引下推 执行一条select 语句中间会发生什么&#xff1f; 这个是对 mysql 架构的深入理解。 select * from product where id 1;对于mysql的架构分层: mysql 架构分成了 Server 层和存储引擎层&a…

ReentrantReadWriteLock源码分析

文章目录 概述一、状态位设计二、读锁三、锁降级机制四、写锁总结 概述 ReentrantReadWriteLock&#xff08;读写锁&#xff09;是对于ReentranLock&#xff08;可重入锁&#xff09;的一种改进&#xff0c;在可重入锁的基础上&#xff0c;进行了读写分离。适用于读多写少的场景…

51单片机开发:温度传感器

温度传感器DS18B20&#xff1a; 初始化时序图如下图所示&#xff1a; u8 ds18b20_init(void){ds18b20_reset();return ds18b20_check(); }void ds18b20_reset(void){DS18B20_PORT 0;delay_10us(75);DS18B20_PORT 1;delay_10us(2); }u8 ds18b20_check(void){u8 time_temp0;wh…

vue2项目(一)

项目介绍 电商前台项目 技术架构&#xff1a;vuewebpackvuexvue-routeraxiosless.. 封装通用组件登录注册token购物车支付项目性能优化 一、项目初始化 使用vue create projrct_vue2在命令行窗口创建项目 1.1、脚手架目录介绍 ├── node_modules:放置项目的依赖 ├──…

labelme_json_to_dataset ValueError: path is on mount ‘D:‘,start on C

这是你的labelme运行时label照片的盘和保存目的地址的盘不同都值得报错 labelme_json_to_dataset ValueError: path is on mount D:,start on C 只需要放一个盘但可以不放一个目录

物联网 STM32【源代码形式-使用以太网】连接OneNet IOT从云产品开发到底层MQTT实现,APP控制 【保姆级零基础搭建】

物联网&#xff08;IoT&#xff09;‌是指通过各种信息传感器、射频识别技术、全球定位系统、红外感应器等装置与技术&#xff0c;实时采集并连接任何需要监控、连接、互动的物体或过程&#xff0c;实现对物品和过程的智能化感知、识别和管理。物联网的核心功能包括数据采集与监…

无心剑七绝《深度求索》

七绝深度求索 深研妙理定乾坤 度世玄机启智门 求路千难兼万险 索萦华夏自为尊 2025年2月1日 平水韵十三元平韵 无心剑七绝《深度求索》以平水韵十三元平韵写成&#xff0c;意境深远&#xff0c;气势磅礴。诗中“深研妙理定乾坤”开篇点题&#xff0c;展现出对深奥道理的钻研与探…

Hot100之普通数组

53最大子数组和 题目 思路解析 我们用一个dp数组来收集我们从左往右&#xff0c;加起来的最大的和 也就是我们的节点不是负数&#xff0c;那我们直接收集就好了 如果是负数&#xff0c;我们就用Max&#xff08;&#xff09;比较是这个节点大还是当前节点大&#xff08;这个情…

如何利用天赋实现最大化的价值输出-补

原文&#xff1a; https://blog.csdn.net/ZhangRelay/article/details/145408621 ​​​​​​如何利用天赋实现最大化的价值输出-CSDN博客 如何利用天赋实现最大化的价值输出-CSDN博客 引用视频差异 第一段视频目标明确&#xff0c;建议也非常明确。 录制视频的人是主动性…

新能源算力战争:为什么AI大模型需要绿色数据中心?

新能源算力战争:为什么AI大模型需要绿色数据中心? 近年来,人工智能(AI)大模型的爆发式增长正在重塑全球科技产业的格局。以GPT-4、Gemini、Llama等为代表的千亿参数级模型,不仅需要海量数据训练,更依赖庞大的算力支撑。然而,这种算力的背后隐藏着一个日益严峻的挑战——…

Spring Boot 中的事件发布与监听:深入理解 ApplicationEventPublisher(附Demo)

目录 前言1. 基本知识2. Demo3. 实战代码 前言 &#x1f91f; 找工作&#xff0c;来万码优才&#xff1a;&#x1f449; #小程序://万码优才/r6rqmzDaXpYkJZF 基本的Java知识推荐阅读&#xff1a; java框架 零基础从入门到精通的学习路线 附开源项目面经等&#xff08;超全&am…

unity学习24:场景scene相关生成,加载,卸载,加载进度,异步加载场景等

目录 1 场景数量 SceneManager.sceneCount 2 直接代码生成新场景 SceneManager.CreateScene 3 场景的加载 3.1 用代码加载场景&#xff0c;仍然build setting里先加入配置 3.2 卸载场景 SceneManager.UnloadSceneAsync(); 3.3 同步加载场景 SceneManager.LoadScene 3.3.…

【Android】布局文件layout.xml文件使用控件属性android:layout_weight使布局较为美观,以RadioButton为例

目录 说明举例 说明 简单来说&#xff0c;android:layout_weight为当前控件按比例分配剩余空间。且单个控件该属性的具体数值不重要&#xff0c;而是多个控件的属性值之比发挥作用&#xff0c;例如有2个控件&#xff0c;各自的android:layout_weight的值设为0.5和0.5&#xff0…