TensorFlow项目练手(三)——基于GRU股票走势预测任务

项目介绍

项目基于GRU算法通过20天的股票序列来预测第21天的数据,有些项目也可以用LSTM算法,两者主要差别如下:

  • LSTM算法:目前使用最多的时间序列算法,是一种特殊的RNN(循环神经网络),能够学习长期的依赖关系。主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。
  • GRU算法:是一种特殊的RNN。和LSTM一样,也是为了解决长期记忆和反向传播中的梯度等问题而提出来的。相比LSTM,使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU。

一、准备数据

1、获取数据

  1. 通过命令行安装yfinance
  2. 通过api获取股票数据
  3. 保存到csv中方便使用
import pandas_datareader.data as web
import datetime
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
plt.rcParams['font.sans-serif']='SimHei' #图表显示中文import yfinance as yf
yf.pdr_override() #需要调用这个函数# 1、获取股票数据
#上海的股票代码+.SS;深圳的股票代码+.SZ :
stock = web.get_data_yahoo("601318.SS", start="2022-01-01", end="2023-07-17")
# 保存到csv中
pd.DataFrame(data=stock).to_csv('./stock.csv')# 2、获取csv中的数据
features = pd.read_csv('stock.csv')
features = features.drop('Adj Close',axis=1)
features.head()

在这里插入图片描述

2、数据可视化

通过绘图的方式查看当前的数据情况

# 3、绘图看看收盘价数据情况
close=features["Close"]
# 计算20天和100天移动平均线:
short_rolling_close = close.rolling(window=20).mean()
long_rolling_close = close.rolling(window=100).mean()
# 绘制
fig, ax = plt.subplots(figsize=(16,9))   #画面大小,可以修改
ax.plot(close.index, close, label='中国平安')   #以收盘价为索引值绘图
ax.plot(short_rolling_close.index, short_rolling_close, label='20天均线')
ax.plot(long_rolling_close.index, long_rolling_close, label='100天均线')
#x轴、y轴及图例:
ax.set_xlabel('日期')
ax.set_ylabel('收盘价 (人民币)')
ax.legend()      #图例
plt.show()      #绘图

在这里插入图片描述

3、数据预处理

取出当前的收盘价,删除无用的日期元素

# 4、取出label值
labels = features['Close']
time = features['Date']
features = features.drop('Date',axis=1)
features.head()

在这里插入图片描述

进行数据的归一化

# 5、数据预处理
from sklearn import preprocessing
input_features = preprocessing.StandardScaler().fit_transform(features)
input_features

在这里插入图片描述

4、构建数据序列

由于RNN的算法要求我们要有一定的序列,来预测出下一个值,所以我们按照20天的数据作为一个序列

# 6、定义序列,[下标1-20天预测第21天的收盘价]
from collections import dequex = []
y = []seq_len = 20
deq = deque(maxlen=seq_len)
for i in input_features:deq.append(list(i))if len(deq) == seq_len:x.append(list(deq))x = x[:-1] # 取少一个序列,因为最后个序列没有答案
y = features['Close'].values[seq_len: ] #从第二十一天开始(下标为20)
time = time.values[seq_len: ] #从第二十一天开始(下标为20)x, y, time = np.array(x), np.array(y), np.array(time)
print(x.shape)
print(y.shape)
print(time.shape)

在这里插入图片描述

二、构建模型

1、搭建GRU模型

import tensorflow as tf
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras import layersfrom keras.models import load_model
from keras.models import Sequential
from keras.layers import Dropout
from keras.layers.core import Dense
from keras.optimizers import Adam# 7、搭建模型
model = tf.keras.Sequential()
model.add(layers.GRU(8,input_shape=(20,5), activation='relu', return_sequences=True,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.GRU(16, activation='relu', return_sequences=True,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.GRU(32, activation='relu', return_sequences=False,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.Dense(16,kernel_initializer='random_normal',kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.Dense(1))
model.summary()

在这里插入图片描述

2、优化器和损失函数

# 优化器和损失函数
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss=tf.keras.losses.MeanAbsoluteError(), # 标签和预测之间绝对差异的平均metrics = tf.keras.losses.MeanSquaredLogarithmicError()) # 计算标签和预测

3、开始训练

25%的比例作为验证集,75%的比例作为训练集

# 开始训练
model.fit(x,y,validation_split=0.25,epochs=200,batch_size=128)

在这里插入图片描述

4、模型预测

# 预测
y_pred = model.predict(x)
fig = plt.figure(figsize=(10,5))
axes = fig.add_subplot(111)
axes.plot(time,y,'b-',label='actual')
# 预测值,红色散点
axes.plot(time,y_pred,'r--',label='predict')
axes.set_xticks(time[::50])
axes.set_xticklabels(time[::50],rotation=45)plt.legend()
plt.show()

在这里插入图片描述

5、回归指标评估

from sklearn.metrics import mean_squared_error,mean_absolute_error,r2_score
from math import sqrt#回归评价指标
# calculate MSE 均方误差
mse=mean_squared_error(y,y_pred)
# calculate RMSE 均方根误差
rmse = sqrt(mean_squared_error(y, y_pred))
#calculate MAE 平均绝对误差
mae=mean_absolute_error(y,y_pred)
print('均方误差: %.6f' % mse)
print('均方根误差: %.6f' % rmse)
print('平均绝对误差: %.6f' % mae)

在这里插入图片描述

源代码

  • 源码查看

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/75035.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Safari 查看 http 请求

文章目录 1、开启 Safari 开发菜单2、显示 JavaScript 控制台 1、开启 Safari 开发菜单 Safari 设置中,打开开发菜单选项 *** 选择完成后,Safari 的目录栏就会出现一个 开发 功能。 2、显示 JavaScript 控制台 开启页面后,在开发中选中 显…

规划路线(微信小程序、H5)

//地图getLocationDian(e1, e2) {console.log(e1, e2);let self this;self.xx1 [];self.xx2 [];self.points [];// self.markers[]console.log(self.markers, >marks);// self.$jsonp(url, data).then(re > {// var coors re.result.routes[0].polyline;// for (v…

Electron逆向调试

复杂程序处理方式: 复杂方式通过 调用窗口 添加命令行参数 启动允许调用,就可以实现调试发布环境的electron程序。 断点调试分析程序的走向,程序基本上会有混淆代码处理, 需要调整代码格式,处理程序。

ARM 常见汇编指令学习 9 - 缓存管理指令 DC 与 IC

文章目录 ARM64 DC 与 IC 指令 上篇文章:ARM 常见汇编指令学习 8 - dsb sy 指令及 dsb 参数介绍 ARM64 DC 与 IC 指令 AArch64指令集中有两条关于缓存维护(cache maintenance)的指令,分别是IC和DC。 IC 是用于指令缓存操作&…

数智保险 创新未来 | GBASE南大通用亮相中国保险科技应用高峰论坛

本届峰会以“数智保险 创新未来”为主题,GBASE南大通用携新一代创新数据库产品及金融信创解决方案精彩亮相,与国内八百多位保险公司高管和众多保险科技公司技术专家,就保险领域数字化的创新应用及生态建设、新一代技术突破及发展机遇、前沿科…

Git下:Git命令使用-详细解读

今天给大家讲一讲 Git常用命令的使用说明,希望本篇文章对大家有所帮助。 一、Git 安装 Git 的详细安装教程:见上一篇文章《Git上:Git安装教程》: Git上:全网最全最详细的Git安装教程,建议收藏保存 二、…

2023,谁在引领实时互动进入高清时代?

实践是检验真理的唯一标准,技术是行业进步的核心动能。在实时互动的新时代里,不断进化的声网已然完成自证。 作者|斗斗 出品|产业家 “一个医疗行业的客户,曾向我们提出一个需求,希望在120急救场景下,可以远程看清…

.net 6 efcore一个model映射到多张表(非使用IEntityTypeConfiguration)

现在有两张表,结构一模一样,我又不想创建两个一模一样的model,就想一个model映射到两张表 废话不多说直接上代码 安装依赖包 创建model namespace oneModelMultiTable.Model {public class Test{public int id { get; set; }public string…

两数相加 II

给你两个 非空 链表来代表两个非负整数。数字最高位位于链表开始位置。它们的每个节点只存储一位数字。将这两数相加会返回一个新的链表。 你可以假设除了数字 0 之外,这两个数字都不会以零开头。 示例1: 输入:l1 [7,2,4,3], l2 [5,6,4] 输…

React Native元素旋转一定的角度

mMeArrowIcon: {fontSize: 30, color: #999, transform: [{rotate: 180deg}]},<Icon name"arrow" style{styles.mMeArrowIcon}></Icon>参考链接&#xff1a; https://reactnative.cn/docs/transforms https://chat.xutongbao.top/

子组件未抛出事件 父组件如何通过$refs监听子组件中数据的变化

我们平时开发项目会使用一些比较成熟的组件库, 但是在极小的情况下,可能会出现我们需要监听某个属性的变化,使我们的页面根据这个属性发生一些改变,但是偏偏组件库没有把这个属性抛出来,当我们使用watch通过refs监听时,由于生命周期的原因还不能拿到,这时候我们可以这样做,以下…

《TCP IP 网络编程》第十五章

第 15 章 套接字和标准I/O 15.1 标准 I/O 的优点 标准 I/O 函数的两个优点&#xff1a; 除了使用 read 和 write 函数收发数据外&#xff0c;还能使用标准 I/O 函数收发数据。下面是标准 I/O 函数的两个优点&#xff1a; 标准 I/O 函数具有良好的移植性标准 I/O 函数可以利用…

安卓:BottomNavigationBar——底部导航栏控件

目录 一、BottomNavigationBar介绍 二、BottomNavigationBar的常用方法及其常用类 &#xff08;一&#xff09;、常用方法 1. 添加菜单项 2. 移除菜单项 3. 设置选中监听器 4. 设置当前选中项 5. 设置徽章 6. 样式和颜色定制 7. 动画效果 8. 隐藏底部导航栏。 9、设…

过程:从虚拟机上添加 git 并成功提交到 GitLab 的全过程

Ⅰ、准备工作&#xff1a; 1、Git 查看&#xff1a; 其一、命令&#xff1a;git --version // 此时就能在虚拟机环境下看到 git 的版本为: git version 2.41.0 其二、如何在虚拟机上安装 git &#xff1a; A、命令 &#xff1a; sudo apt-get install git B、然后再输入虚…

小研究 - 主动式微服务细粒度弹性缩放算法研究(三)

微服务架构已成为云数据中心的基本服务架构。但目前关于微服务系统弹性缩放的研究大多是基于服务或实例级别的水平缩放&#xff0c;忽略了能够充分利用单台服务器资源的细粒度垂直缩放&#xff0c;从而导致资源浪费。为此&#xff0c;本文设计了主动式微服务细粒度弹性缩放算法…

React 之 Redux - 状态管理

一、前言 1. 纯函数 函数式编程中有一个非常重要的概念叫纯函数&#xff0c;JavaScript符合函数式编程的范式&#xff0c;所以也有纯函数的概念 确定的输入&#xff0c;一定会产生确定的输出 函数在执行过程中&#xff0c;不能产生副作用 2. 副作用 表示在执行一个函数时&a…

Docker实战-关于Docker镜像的相关操作(一)

导语   镜像&#xff0c;Docker中三大核心概念之一&#xff0c;并且在运行Docker容器之前需要本地存储对应的镜像。那么下面我们就来介绍一下在Docker中如何使用镜像。 如何获取镜像&#xff1f; 镜像作为容器运行的前提条件&#xff0c;在Docker Hub上提供了各种各样的开放的…

Edge浏览器安装vue devtools

1. 下载地址 GitHub - vuejs/devtools: ⚙️ Browser devtools extension for debugging Vue.js applications. 2. 下载后的压缩包解压并打开文件夹&#xff0c;右键选择&#xff1a;git bush here 3. 安装依赖 npm install 4. 成功安装依赖后打包 npm run build

万界星空科技/免费开源MES系统/免费仓库管理

仓库管理&#xff08;仓储管理&#xff09;&#xff0c;指对仓库及仓库内部的物资进行收发、结存等有效控制和管理&#xff0c;确保仓储货物的完好无损&#xff0c;保证生产经营活动的正常进行&#xff0c;在此基础上对货物进行分类记录&#xff0c;通过报表分析展示仓库状态、…

基于opencv的几种图像滤波

一、介绍 盒式滤波、均值滤波、高斯滤波、中值滤波、双边滤波、导向滤波。 boxFilter() blur() GaussianBlur() medianBlur() bilateralFilter() 二、代码 #include <opencv2/core/core.hpp> #include <opencv2/highgui/highgui.hpp> …