基于LSTM及其变体的回归预测

1 所用模型

       代码中用到了以下模型:

      1. LSTM(Long Short-Term Memory):长短时记忆网络,是一种特殊的RNN(循环神经网络),能够解决传统RNN在处理长序列时出现的梯度消失或爆炸的问题。LSTM有门控机制,可以选择性地记住或忘记信息。

       2. FC-LSTM:全连接的LSTM,与传统的LSTM相比,其细胞单元之间采用全连接的方式。

       3. Coupled LSTM:耦合LSTM,是一种特殊的LSTM结构,其中每个LSTM单元被分解为两个交互的子单元。

       4. GRU(Gated Recurrent Unit):门控循环单元,与LSTM类似,但结构更简单,参数更少,通常训练更快,但可能不如LSTM准确。

       5. ConvLSTM:卷积LSTM,将卷积神经网络(CNN)与LSTM结合,可以捕捉时空特征,常用于处理图像和视频数据。

       6. Deep LSTM:深层LSTM,包含多个LSTM层的堆叠,可以捕捉更复杂的模式。

       7. DB-LSTM(Bidirectional LSTM):双向LSTM,有两个方向的LSTM层,一个按时间顺序,一个逆序,可以同时获取过去和未来的信息。

       8. SRU(SimpleRNN):简单循环神经网络,是最基本的RNN形式。

       9. TPA-LSTM:时间感知LSTM,通过改变LSTM的内部计算方式,使其更加关注时间序列的特性。

       10. ConvGRU:卷积GRU,与ConvLSTM类似,但使用GRU代替LSTM。

       这些模型都是用于处理序列数据的深度学习模型,特别适用于时间序列预测、自然语言处理等领域。

2 运行结果

       左边是Epoch=50次的效果,右边是Epoch=15次的效果:

a1e88c48c6f645eea96360f59b239c00.jpg

 图2-1 训练损失

3623cb88b9294ce796d7dbacd244f481.jpg

 图2-2 测试损失

d9ab03d1196542bf9235bafc58288e07.jpg

 图2-3 预测结果

3 代码

     

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Dense, LSTM, GRU, SimpleRNN, Bidirectional, TimeDistributed, Conv1D, Attention
from keras.layers import Flatten, Dropout, BatchNormalization
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
from tensorflow.keras.layers import Conv1D
# 读取数据
data = pd.read_excel('A.xlsx')
data=data.dropna()
data = data['A'].values.reshape(-1, 1)
# 数据预处理
scaler = MinMaxScaler()
data = scaler.fit_transform(data)# 划分训练集和测试集
train_size = int(len(data) * 0.8)
train, test = data[:train_size], data[train_size:]# 转换数据格式以适应LSTM输入
def create_dataset(dataset, look_back=1):X, Y = [], []for i in range(len(dataset) - look_back - 1):X.append(dataset[i:(i + look_back), 0])Y.append(dataset[i + look_back, 0])return np.array(X), np.array(Y)look_back = 1
X_train, y_train = create_dataset(train, look_back)
X_test, y_test = create_dataset(test, look_back)# 重塑输入数据的维度以适应LSTM模型
X_train = np.reshape(X_train, (X_train.shape[0], 1, X_train.shape[1]))
X_test = np.reshape(X_test, (X_test.shape[0], 1, X_test.shape[1]))
# 定义模型函数
def create_model(name):model = Sequential()if name == 'LSTM':model.add(LSTM(50, activation='relu', input_shape=(1, 1)))elif name == 'FC-LSTM':model.add(LSTM(50, activation='relu', input_shape=(1, 1), recurrent_activation='sigmoid'))elif name == 'Coupled LSTM':model.add(LSTM(50, activation='relu', input_shape=(1, 1), implementation=2))elif name == 'GRU':model.add(GRU(50, activation='relu', input_shape=(1, 1)))elif name == 'ConvLSTM':model.add(Conv1D(filters=64, kernel_size=1, activation='relu', input_shape=(1, 1)))model.add(LSTM(50, activation='relu'))elif name == 'Deep LSTM':model.add(LSTM(50, return_sequences=True, activation='relu', input_shape=(1, 1)))model.add(LSTM(50, activation='relu'))elif name == 'DB-LSTM':model.add(Bidirectional(LSTM(50, activation='relu'), input_shape=(1, 1)))elif name == 'SRU':model.add(SimpleRNN(50, activation='relu', input_shape=(1, 1)))elif name == 'TPA-LSTM':model.add(LSTM(50, activation='relu', input_shape=(1, 1), unroll=True))elif name == 'ConvGRU':model.add(Conv1D(filters=64, kernel_size=1, activation='relu', input_shape=(1, 1)))model.add(GRU(50, activation='relu'))model.add(Dense(1))model.compile(optimizer=Adam(), loss='mse')return model# 训练模型并绘制损失图
names = ['LSTM', 'FC-LSTM', 'Coupled LSTM', 'GRU', 'ConvLSTM', 'Deep LSTM', 'DB-LSTM','SRU', 'TPA-LSTM', 'ConvGRU']
train_losses = []
test_losses = []
predictions = []for name in names:model = create_model(name)history = model.fit(train, train, epochs=15, batch_size=32, validation_data=(test, test), verbose=0)train_losses.append(history.history['loss'])test_losses.append(history.history['val_loss'])pred = model.predict(test)predictions.append(pred)import matplotlib.pyplot as plt# 设置不同的marker
markers = ['o', '.', '_', '^', '*', '>', '+', '1', 'p', '_', '8']
linestyles = ['-', '--', '--', ':', '-', '-.', '-.', ':', '-', '--']
# 绘制训练损失图
plt.figure(figsize=(16, 20))
for i, loss in enumerate(train_losses):plt.plot(loss, color='black',label=names[i], marker=markers[i], linestyle=linestyles[i])
plt.title('Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(fontsize=8,loc='best')
plt.show()
# 绘制测试损失图
for i, loss in enumerate(test_losses):plt.plot(loss, color='black',label=names[i], marker=markers[i], linestyle=linestyles[i])
plt.title('Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(fontsize=8,loc='best')
plt.show()
# 绘制预测结果折线图
for i, pred in enumerate(predictions):plt.plot(pred, color='black',label=names[i], marker=markers[i], linestyle=linestyles[i])
# 绘制真实值折线图
plt.plot(y_test, color='black', label='True Value')
plt.title('Predictions and True Values')
plt.xlabel('x')
plt.ylabel('value')
plt.legend(fontsize=8, loc='best')
# 显示图像
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/378180.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

动手学深度学习6.3 填充和步幅-笔记练习(PyTorch)

以下内容为结合李沐老师的课程和教材补充的学习笔记,以及对课后练习的一些思考,自留回顾,也供同学之人交流参考。 本节课程地址:填充和步幅_哔哩哔哩_bilibili 代码实现_哔哩哔哩_bilibili 本节教材地址:6.3. 填充和…

笔记 5 :linux 0.11 注释,函数 copy_mem() , copy_process () , 中断函数 int 80H 的代码框架

(38)接着介绍一个创建进程时的重要的函数 copy_mem() 函数: (39) 分析另一个关于 fork() 的重要的函数 copy_process(),与李忠老师的操…

Qcom平台通过Hexagon IDE 测试程序性能指导

Qcom平台通过Hexagon IDE 测试程序性能指导 1 安装Hexagon IDE工具2 测试工程2.1 打开Hexagon IDE2.2 新建工程2.3 添加测试案例2.3.1 方法一:新建2.3.2 方法二:拷贝 2.4 配置测试环境2.4.1 包含头文件2.4.2 添加程序优化功能(需先bulid一下)2.4.3 添加g…

【问题记录】Docker配置mongodb副本集实现数据流实时获取

配置mongodb副本集实现数据流实时获取 前言操作步骤1. docker拉取mongodb镜像2. 连接mongo1镜像的mongosh3. 在mongosh中初始化副本集 注意点 前言 由于想用nodejs实现实时获取Mongodb数据流,但是报错显示需要有副本集的mongodb才能实现实时获取信息流,…

HTTPS请求头缺少HttpOnly和Secure属性解决方案

问题描述: 建立Filter拦截器类 package com.ruoyi.framework.security.filter;import com.ruoyi.common.core.domain.model.LoginUser; import com.ruoyi.common.utils.SecurityUtils; import com.ruoyi.common.utils.StringUtils; import com.ruoyi.framework.…

某客户管理系统Oracle RAC节点异常重启问题详细分析记录

一、故障概述 某日10:58分左右客户管理系统数据库节点1所有实例异常重启,重启后业务恢复正常。经过分析发现,此次实例异常重启的是数据库节点1。 二、故障原因分析 1、数据库日志分析 从节点1的数据库日志来看,10:58:49的时候数据库进程开始…

1.MQ介绍

MQ 消息队列,本质是一个队列,先进先出,只不过队列中存放的内容是message而已。 为啥学习MQ 1.流量消峰 如果一个订单系统最多每秒能处理一万次订单,正常情况下我们下单1秒后就能返回结果。但是在高峰期,如果有两万…

Linux驱动开发笔记(十九)文件系统的构建

文章目录 前言一、文件系统1.1 Linux系统的组成1.2 什么是文件系统 二、根文件系统的制作工具三、busybox3.1 什么是busybox3.2 制作流程 四、buildroot4.1 什么是Buildroot4.2 制作流程 前言 上节我们在mdev实验进行配置时,利用了busybox,这里着重对这部…

活动预告|想更了解流式数据湖?亚马逊云科技数据开源软件-流式数据湖 Tech Talk来啦!

活动介绍 本次活动旨在探索在亚马逊云科技上构建和使用开源数据软件产品的一些最佳实践,特别关注流式数据湖的构建。活动将在线上举行,汇聚来自 AutoMQ Apache paimon和亚马逊云科技的顶尖专家,分享他们在这一领域的最新进展和实际经验。参与…

旗晟巡检机器人的应用场景有哪些?

巡检机器人作为现代科技的杰出成果,已广泛应用于各个关键场景。从危险的工业现场到至关重要的基础设施,它们的身影无处不在。它们以精准、高效、不知疲倦的特性,担当起保障生产、守护安全的重任,为行业发展注入新的活力。那么&…

2024华为数通HCIP-datacom最新题库(变题更新⑥)

请注意,华为HCIP-Datacom考试831已变题 请注意,华为HCIP-Datacom考试831已变题 请注意,华为HCIP-Datacom考试831已变题 近期打算考HCIP的朋友注意了,如果你准备去考试,还是用的之前的题库,切记暂缓。 1、…

【Elasticsearch7】3-基本操作

目录 RESTful 数据格式 HTTP操作 索引操作 倒排索引 创建索引 查看所有索引 查看单个索引 删除索引 文档操作 创建文档 查看文档 ​编辑 全量修改 ​编辑局部修改 删除文档 条件删除文档 高级查询 条件查询 URL带参查询 请求体带参查询 带请求体方式的查…

使用GPT3.5,LangChain,FAISS和python构建一个本地知识库

引言 介绍本地知识库的概念和用途 在现代信息时代,我们面临着海量的数据和信息,如何有效地管理和利用这些信息成为一项重要的任务。本地知识库是一种基于本地存储的知识管理系统,旨在帮助用户收集、组织和检索大量的知识和信息。它允许用户…

java8新特性

目录 一. lambda 1. 为什么要有lambda 2.功能接口 3. 使用lambda的条件 二. Stream流 1. 获取流 1.1 将集合转为流 1.2 将数组转为流 1.3 将相同数据类型的数据转为流 1.4 将文件里的内容转为流 2. 中间操作 3. 终端操作 一. lambda lambda:本质上就是将函数当做参…

Python | Leetcode Python题解之第240题搜索二维矩阵II

题目&#xff1a; 题解&#xff1a; class Solution:def searchMatrix(self, matrix: List[List[int]], target: int) -> bool:m, n len(matrix), len(matrix[0])x, y 0, n - 1while x < m and y > 0:if matrix[x][y] target:return Trueif matrix[x][y] > tar…

印尼语翻译通:AI驱动的智能翻译与语言学习助手

在这个多元文化交织的世界中&#xff0c;语言是连接我们的桥梁。印尼语翻译通&#xff0c;一款专为打破语言障碍而生的智能翻译软件&#xff0c;让您与印尼语的世界轻松接轨。无论是商务出差、学术研究&#xff0c;还是探索印尼丰富的文化遗产&#xff0c;印尼语翻译通都是您的…

基于luckysheet实现在线电子表格和Excel在线预览

概述 本文基于luckysheet实现在线的电子表格&#xff0c;并基于luckyexcel实现excel文件的导入和在线预览。 效果 实现 1. luckysheet介绍 Luckysheet &#xff0c;一款纯前端类似excel的在线表格&#xff0c;功能强大、配置简单、完全开源。 官方文档在线Demo 2. 实现 …

抖音seo短视频矩阵源码系统开发搭建----开源+二次开发

抖音seo短视频矩阵源码系统开发搭建 是一项技术密集型工作&#xff0c;需要对大数据处理、人工智能等领域有深入了解。该系统开发过程中需要用到多种编程语言&#xff0c;如Java、Python等。同时&#xff0c;需要使用一些框架和技术&#xff0c;如Hadoop、Spark、PyTorch等&am…

小程序-设置环境变量

在实际开发中&#xff0c;不同的开发环境&#xff0c;调用的接口地址是不一样的 例如&#xff1a;开发环境需要调用开发版的接口地址&#xff0c;生产环境需要正式版的接口地址 这时候&#xff0c;我们就可以使用小程序提供了 wx.getAccountInfoSync() 接口&#xff0c;用来获取…

iterator(迭代器模式)

引入 在想显示数组当中所有元素时&#xff0c;我们往往会使用下面的for循环语句来遍历数组 #include <iostream> #include <vector>int main() {std::vector<int> v({ 1, 2, 3 });for (int i 0; i < v.size(); i){std::cout << v[i] << &q…