机器学习入门--循环神经网络原理与实践

循环神经网络

循环神经网络(RNN)是一种在序列数据上表现出色的人工神经网络。相比于传统前馈神经网络,RNN更加适合处理时间序列数据,如音频信号、自然语言和股票价格等。本文将介绍RNN的基本数学原理、使用PyTorch和Scikit-Learn数据集实现的代码。

数学原理

RNN是一种带有循环结构的神经网络,其在处理序列数据时将前一次的输出作为当前输入的一部分。这使得RNN能够记住先前的状态和信息,并且在处理长期依赖关系时表现出色。

RNN的基本公式可以表示为:

h t = f ( W h h h t − 1 + W x h x t ) h_t = f(W_{hh}h_{t-1} + W_{xh}x_t) ht=f(Whhht1+Wxhxt)

其中 h t h_t ht是RNN在时间步 t t t的隐藏状态, f f f是激活函数, W h h W_{hh} Whh是隐藏状态的权重矩阵, h t − 1 h_{t-1} ht1是上一次的隐藏状态, W x h W_{xh} Wxh是输入 x t x_t xt和隐藏状态 h t h_t ht之间的权重矩阵, x t x_t xt是时间步 t t t的输入。

在RNN的训练过程中,我们需要使用反向传播算法计算梯度并更新权重。由于RNN具有时间上的依赖关系,每一步的梯度都取决于前一步的梯度,这意味着我们需要使用反向传播算法的变体——反向传播通过时间(BPTT)算法来计算梯度。

代码实现

我们将使用PyTorch和Scikit-Learn数据集实现一个简单的RNN模型,用于预测时间序列数据。以下是代码:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt# 加载数据集
data = load_boston()
X = data.data
y = data.target# 数据标准化
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)# 转换为PyTorch张量,并增加时间步维度
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)# 定义RNN模型
class SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.rnn = nn.RNN(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.rnn(x)out = self.fc(out[:, -1, :])return out# 创建模型实例
input_size = X.shape[2]  # 更新input_size的值
hidden_size = 32
output_size = 1
model = SimpleRNN(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 启用异常检测
torch.autograd.set_detect_anomaly(True)# 训练模型
num_epochs = 10000
# 记录损失
loss_list = []for epoch in range(num_epochs):optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 100 == 0:loss_list.append(loss.item())print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')# 关闭异常检测
torch.autograd.set_detect_anomaly(False)# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of RNN Training')
plt.show()
plt.savefig('Loss_of_RNN_Training.png')# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)  # 假设使用第一个数据点进行预测
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码实现了一个简单的循环神经网络(RNN)模型来预测波士顿房价,并可视化训练过程中损失的变化。代码首先加载并标准化了波士顿房价数据集,然后定义了一个包含RNN层和全连接层的SimpleRNN模型,并使用均方误差作为损失函数和Adam优化器进行训练。训练完成后,使用matplotlib库绘制训练过程中损失的变化曲线(如下图所示)。最后,使用训练好的模型对新的数据点进行预测,并输出预测值。这段代码可以为初学者提供一个实现RNN模型的参考,并通过可视化训练过程中的损失曲线来帮助理解模型的性能。
RNN 损失曲线

总结

本文介绍了RNN的基本数学原理、使用PyTorch和Scikit-Learn数据集实现的代码,以及如何解读代码并总结。RNN是一种在序列数据上表现出色的神经网络,常用于处理时间序列数据,如音频信号、自然语言和股票价格等。我们可以使用PyTorch和Scikit-Learn数据集来实现一个简单的RNN模型,并用它来预测未知的时间序列数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/258127.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PLC_博图系列☞FBD

PLC_博图系列☞FBD 文章目录 PLC_博图系列☞FBD背景介绍FBD优势局限性 FBD 元素 关键字: PLC、 西门子、 博图、 Siemens 、 FBD 背景介绍 这是一篇关于PLC编程的文章,特别是关于西门子的博图软件。我并不是专业的PLC编程人员,也不懂电路…

深度学习之梯度下降算法

梯度下降算法 梯度下降算法数学公式结果 梯度下降算法存在的问题随机梯度下降算法 梯度下降算法 数学公式 这里案例是用梯度下降算法,来计算 y w * x 先计算出梯度,再进行梯度的更新 import numpy as np import matplotlib.pyplot as pltx_data [1.0,…

心理辅导|高校心理教育辅导系统|基于Springboot的高校心理教育辅导系统设计与实现(源码+数据库+文档)

高校心理教育辅导系统目录 目录 基于Springboot的高校心理教育辅导系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、学生功能模块的实现 (1)学生登录界面 (2)留言反馈界面 (3)试卷列表界…

2.7日学习打卡----初学RabbitMQ(二)

2.7日学习打卡 目录: 2.7日学习打卡一. RabbitMQ 简单模式![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/42009c68e078440797c3183ffda6955d.png)生产者代码实现消费者代码实现 二. RabbitMQ 工作队列模式生产者代码实现消费者代码实现 三. RabbitMQ 发…

【天衍系列 04】深入理解Flink的ElasticsearchSink组件:实时数据流如何无缝地流向Elasticsearch

文章目录 01 Elasticsearch Sink 基础概念02 Elasticsearch Sink 工作原理03 Elasticsearch Sink 核心组件04 Elasticsearch Sink 配置参数05 Elasticsearch Sink 依赖管理06 Elasticsearch Sink 初阶实战07 Elasticsearch Sink 进阶实战7.1 包结构 & 项目配置项目配置appl…

《杨绛传:生活不易,保持优雅》读书摘录

目录 书简介 作者成就 书中内容摘录 良好的家世背景,书香门第为求学打基础 求学相关 念大学 清华研究生 自费英国留学 法国留学自学文学 战乱时期回国 当校长 当小学老师 创造话剧 支持钱锺书写《围城》 出任震旦女子文理学院的教授 接受清华大学的…

【AIGC】Stable Diffusion的ControlNet参数入门

Stable Diffusion 中的 ControlNet 是一种用于控制图像生成过程的技术,它可以指导模型生成特定风格、内容或属性的图像。下面是关于 ControlNet 的界面参数的详细解释: 低显存模式 是一种在深度学习任务中用于处理显存受限设备的技术。在这种模式下&am…

【AIGC】Stable Diffusion的模型入门

下载好相关模型文件后,直接放入Stable Diffusion相关目录即可使用,Stable Diffusion 模型就是我们日常所说的大模型,下载后放入**\webui\models\Stable-diffusion**目录,界面上就会展示相应的模型选项,如下图所示。作者…

【C++】 为什么多继承子类重写的父类的虚函数地址不同?『 多态调用汇编剖析』

👀樊梓慕:个人主页 🎥个人专栏:《C语言》《数据结构》《蓝桥杯试题》《LeetCode刷题笔记》《实训项目》《C》《Linux》《算法》 🌝每一个不曾起舞的日子,都是对生命的辜负 前言 本篇文章主要是为了解答有…

Pytest测试技巧之Fixture:模块化管理测试数据

在 Pytest 测试中,有效管理测试数据是提高测试质量和可维护性的关键。本文将深入探讨 Pytest 中的 Fixture,特别是如何利用 Fixture 实现测试数据的模块化管理,以提高测试用例的清晰度和可复用性。 什么是Fixture? 在 Pytest 中&a…

华为问界M9:领跑未来智能交通的自动驾驶黑科技

华为问界M9是一款高端电动汽车,其自动驾驶技术是该车型的重要卖点之一。华为在问界M9上采用了多种传感器和高级算法,实现了在不同场景下的自动驾驶功能,包括自动泊车、自适应巡航、车道保持、自动变道等。 华为问界M9的自动驾驶技术惊艳之处…

Linux之多线程

目录 一、进程与线程 1.1 进程的概念 1.2 线程的概念 1.3 线程的优点 1.4 线程的缺点 1.5 线程异常 1.6 线程用途 二、线程控制 2.1 POSIX线程库 2.2 创建一个新的线程 2.3 线程ID及进程地址空间布局 2.4 线程终止 2.5 线程等待 2.6 线程分离 一、进程与线程 在…

构造题记录

思路&#xff1a;本题要求构造一个a和b数组相加为不递减序列&#xff0c;并且b数组的极差为最小的b数组。 可以通过遍历a数组并且每次更新最大值&#xff0c;并使得b数组为这个最大值和当前a值的差。 #include <bits/stdc.h> using namespace std; #define int long lon…

优化策略模式,提高账薄显示的灵活性和扩展性

接着上一篇文章&#xff0c;账薄显示出来之后&#xff0c;为了提高软件的可扩展性和灵活性&#xff0c;我们应用策略设计模式。这不仅仅是为了提高代码的维护性&#xff0c;而是因为明细分类账账薄显示的后面有金额分析这个功能&#xff0c;从数据库后台分析及结合Java语言特性…

小程序或者浏览器chrome访问的时候出现307 interval redicrect内部http自动跳转到https产生的原理分析及解决方案

#小李子9479# 出现的情况如下&#xff0c;即我们访问http的时候&#xff0c;它会自动307重定向到https,产生的原因是&#xff0c; 当你通过https访问过一个没有配置证书的http的网站之后&#xff0c;你再访问http的时候&#xff0c;它就会自动跳转到https&#xff0c;导致访问…

奔跑吧小恐龙(Java)

前言 Google浏览器内含了一个小彩蛋当没有网络连接时&#xff0c;浏览器会弹出一个小恐龙&#xff0c;当我们点击它时游戏就会开始进行&#xff0c;大家也可以玩一下试试&#xff0c;网址&#xff1a;恐龙快跑 - 霸王龙游戏. (ur1.fun) 今天我们也可以用Java来简单的实现一下这…

黄金交易策略(Nerve Nnife.mql4):移动止盈的设计

完整EA&#xff1a;Nerve Knife.ex4黄金交易策略_黄金趋势ea-CSDN博客 相较mt4的止盈止损&#xff0c;在ea上实现移动止盈&#xff0c;可以尽最大可能去获得更高收益。移动止盈的大体逻辑是&#xff1a;到达止盈点就开始追踪止盈&#xff0c;直到在最高盈利点回撤指定点数即平…

【Python】通过conda安装Python的IDE

背景 系统&#xff1a;win11 软件&#xff1a;anaconda Navigator 问题现象&#xff1a;①使用Navigator安装jupyter notebook以及Spyder IDE 一直转圈。②然后进入anaconda prompt执行conda install jupyter notebook一直卡在Solving environment/-\。 类似问题&#xff1a; …

Stable Diffusion 模型下载:DreamShaper XL(梦想塑造者 XL)

本文收录于《AI绘画从入门到精通》专栏&#xff0c;专栏总目录&#xff1a;点这里。 文章目录 模型介绍生成案例案例一案例二案例三案例四案例五案例六案例七案例八案例九案例十 下载地址 模型介绍 DreamShaper 是一个分格多样的大模型&#xff0c;可以生成写实、原画、2.5D 等…

第80讲订单管理功能实现

后端 <?xml version"1.0" encoding"UTF-8" ?> <!DOCTYPE mapperPUBLIC "-//mybatis.org//DTD Mapper 3.0//EN""http://mybatis.org/dtd/mybatis-3-mapper.dtd"> <mapper namespace"com.java1234.mapper.OrderM…