基于pytorch LSTM 的股票预测

学习记录于《PyTorch深度学习项目实战100例》
https://weibaohang.blog.csdn.net/article/details/127365867?ydreferer=aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L20wXzQ3MjU2MTYyL2NhdGVnb3J5XzEyMDM2MTg5Lmh0bWw%2Fc3BtPTEwMDEuMjAxNC4zMDAxLjU0ODI%3D

1.tushare

Tushare是一个免费、开源的Python财经数据接口包。主要用于提供股票及金融市场相关的数据,是国内常用的金融数据分析工具之一。Tushare对于投资研究、教学、项目背景调研等多种需求提供了极大的方便。

Tushare的主要特点有:

数据丰富:Tushare提供了包括实时行情数据、历史行情数据、基本面数据、宏观经济数据、公司基本信息、大盘指数数据、行业数据、新闻和公告等多种数据。

接口简单:使用Python调用Tushare接口相当简单。即便是初学者,只需数行代码,就能获取所需的金融数据。

社区活跃:由于Tushare的免费和开源特性,其社区相当活跃,有大量的开发者和爱好者进行维护和更新。

扩展性强:除了官方提供的数据接口,Tushare也支持自定义数据扩展,可以通过插件方式实现数据的快速扩展。

要使用Tushare,用户通常需要先进行安装,可以通过pip轻松完成。之后,通过简单的API调用,就可以获取所需的数据。

!pip install tushare

2.导入必要的依赖

# 1.加载股票数据
pro = ts.pro_api('your-key')
df = pro.daily(ts_code='000001.SZ', start_date='20130711', end_date='20230711')df.index = pd.to_datetime(df.trade_date)  # 索引转为日期
df = df.iloc[::-1]  # 由于获取的数据是倒序的,需要将其调整为正序

如何获取your-key

注册网站
https://tushare.pro/register

注册后进入个人主页

在这里插入图片描述

找到接口token

在这里插入图片描述
复制后粘贴 到your-key 里面

在这里插入图片描述

# 创建一个 StandardScaler 实例
scaler = StandardScaler()
scaler_model = StandardScaler()# 使用 scaler 对数据进行拟合和转换
data = scaler_model.fit_transform(np.array(df[['open', 'high', 'low', 'close']]).reshape(-1, 4))# 使用 scaler 对 'close' 列进行拟合和转换
scaler.fit_transform(np.array(df['close']).reshape(-1, 1))

分开训练集和测试集

def split_data(data,timestep):dataX= [] # 用于存储输入序列dataY= [] # 用于存储输出数据#将整个窗口的数据保存到X中,将未来的一天保存到Y中for index in range(len(data)-timestep):# 提取时间窗口内的数据作为输入序列dataX.append(data[index: index + timestep])# 下一个时间步的数据作为输出dataY.append(data[index + timestep][3]) #输出为第四列('close' 列)dataX = np.array(dataX)dataY = np.array(dataY)#获取训练集大小train_size = int(np.round(0.8 * dataX.shape[0]))# 划分训练集,测试集x_train = dataX[: train_size, :].reshape(-1, timestep, 4)y_train = dataY[: train_size]x_test = dataX[train_size:, :].reshape(-1, timestep, 4)y_test = dataY[train_size:]return[x_train,y_train,x_test,y_test]
# 3.获取训练数据, x_train:1750,1,4
x_train,y_train,x_test,y_test = split_data(data,timestep=1)

探究数据集

在这里插入图片描述

# 4.将数据转为tensor
x_train_tensor = torch.from_numpy(x_train).to(torch.float32)
y_train_tensor = torch.from_numpy(y_train).to(torch.float32)
x_test_tensor = torch.from_numpy(x_test).to(torch.float32)
y_test_tensor = torch.from_numpy(y_test).to(torch.float32)

在这里插入图片描述

# 5.形成训练数据集
train_data = TensorDataset(x_train_tensor, y_train_tensor)
test_data = TensorDataset(x_test_tensor, y_test_tensor)# 6.将数据加载成迭代器
batch_size =16
train_loader = torch.utils.data.DataLoader(train_data,batch_size,True)test_loader = torch.utils.data.DataLoader(test_data,batch_size,False)

LSTM

LSTM,即长短时记忆(Long Short-Term Memory)网络,是一种特殊的循环神经网络(Recurrent Neural Network, RNN)结构。LSTM旨在解决传统RNN在处理长序列数据时容易出现的梯度消失或梯度爆炸问题。

以下是LSTM的主要组成部分和功能:

遗忘门 (Forget Gate): 决定从单元状态中删除什么信息。使用sigmoid函数,输出0表示“完全忘记”,输出1表示“完全保留”。

输入门 (Input Gate): 有两部分组成。第一部分是sigmoid层,决定哪些值我们将更新。第二部分是tanh层,创建一个新的候选值向量,它可能被加到状态中。

单元状态 (Cell State): LSTM的核心部分,其在整个链上都有运行,只有一些少量的线性交互。单元状态类似于传送带,信息可以在其上自由流动,除非受到遗忘门或输入门的影响。

输出门 (Output Gate): 决定基于单元状态输出什么值。首先,sigmoid层决定我们将输出哪些部分。然后,将单元状态通过tanh(得到值在-1到1之间)并乘以sigmoid层的输出,这样只输出我们决定的部分。

LSTM的关键优势在于其能够记住长期的依赖关系。在许多序列任务中,当前的输出不仅仅依赖于前几个步骤,还可能依赖于很早之前的步骤。LSTM相比于普通的RNN更能够捕捉这些长期依赖关系。

实际上,LSTM的变种还有很多,例如GRU(Gated Recurrent Units)。不过,无论其结构如何调整,LSTM的核心思想是利用不同的门控结构来有选择地控制信息流,从而更好地学习和记住长期依赖关系
在这里插入图片描述

class LSTM(nn.Module):def __init__(self,input_dim,hidden_dim,num_layers,output_dim):super(LSTM,self).__init__()self.hidden_dim = hidden_dim #隐藏层大小self.num_layers = num_layers #LSTM层数# input_dim 为特征维度,就是每个时间点对应的特征数量,这里为4self.lstm = nn.LSTM(input_dim,hidden_dim,num_layers,batch_first=True)self.fc = nn.Linear(hidden_dim,output_dim)def forward(self,x):# Initialize hidden state and cell stateh0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)# Forward propagate through LSTMoutput, (h_n, c_n)= self.lstm(x, (h0, c0))batch_size, timestep, hidden_dim = output.shape# 将output变成 batch_size * timestep, hidden_dimoutput = output.reshape(-1, hidden_dim)output = self.fc(output)  # 形状为batch_size * timestep, 1output = output.reshape(timestep, batch_size, -1)return output[-1]  # 返回最后一个时间片的输出
model = LSTM(input_dim, hidden_dim, num_layers, output_dim)  # 定义LSTM网络

打印模型
在这里插入图片描述

定义损失函数

loss_function = nn.MSELoss()  # 定义损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # 定义优化器

配置信息

timestep = 1  # 时间步长,就是利用多少时间窗口
batch_size = 16  # 批次大小
input_dim = 4  # 每个步长对应的特征数量,就是使用每天的4个特征,最高、最低、开盘、落盘
hidden_dim = 64  # 隐层大小
output_dim = 1  # 由于是回归任务,最终输出层大小为1
num_layers = 3  # LSTM的层数
epochs = 10
best_loss = 0
model_name = 'LSTM'
save_path = './{}.pth'.format(model_name)

train

# 8. model 训练
save_path = '/content/sample_data/best.pth'
for epoch in range(epochs):model.train()running_loss= 0 #初始loss值train_bar = tqdm(train_loader) #通过使用tqdm来展示进度条for data in train_bar:x_train, y_train = data ## print(x_train.shape,y_train.shape) torch.Size([16, 1, 4]) torch.Size([16])optimizer.zero_grad()y_train_pred = model(x_train)loss = loss_function(y_train_pred,y_train.reshape(-1,1))loss.backward()optimizer.step()running_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss) #每次你计算 loss 并更新 train_bar.desc,进度条的描述就会更新,从而实时显示你的训练 loss 和当前 epoch。torch.save(model.state_dict(),save_path)

加载模型

# 加载权重和偏置
model.load_state_dict(torch.load(save_path))

模型验证

# 模型验证
model.eval()
test_loss = 0
save_path_2="/content/sample_data/best_2.pth"
with torch.no_grad():test_bar = tqdm(test_loader)for data in test_bar:x_test, y_test = datay_test_pred = model(x_test)test_loss = loss_function(y_test_pred, y_test.reshape(-1, 1))if test_loss < best_loss:best_loss = test_losstorch.save(model.state_dict(), save_path_2)

绘制模型

# 9.绘制结果 train 的结果
pred_value =model(x_train_tensor).detach().numpy().reshape(-1, 1)
true_value =y_train_tensor.detach().numpy().reshape(-1, 1)
plt.figure(figsize=(12, 8))
plt.plot(scaler.inverse_transform(pred_value, "b"),label="Preds_value")
plt.plot(scaler.inverse_transform(true_value, "r"),label="Data_value")
plt.legend()
plt.show()

在这里插入图片描述

绘制test的结果

y_test_pred = model(x_test_tensor)
plt.figure(figsize=(12, 8))
plt.plot(scaler.inverse_transform(y_test_pred.detach().numpy()), "b",label="pred_value")
plt.plot(scaler.inverse_transform(y_test_tensor.detach().numpy().reshape(-1, 1)), "r",label='true_value')
plt.legend()
plt.show()

在这里插入图片描述
代码ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/124815.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Mariadb高可用(四十)

目录 一、概述 &#xff08;一&#xff09;概念 &#xff08;二&#xff09;组成 &#xff08;三&#xff09;特点 &#xff08;四&#xff09;工作原理 二、实验要求 三、构建MHA &#xff08;一&#xff09;ssh免密登录 &#xff08;二&#xff09;安装mariadb数据库…

【工作技术栈】【源码解读】一次springboot注入bean失败问题的排查过程

目录 前言现象分析原因解决方法思考感悟 前言 对这次的过程排查如果要形容的话&#xff0c;我觉得更像是悬疑剧&#xff0c;bean not found 这种错误&#xff0c;已经看腻了&#xff0c;甚至有时候都看不起这种错误&#xff0c;但是似乎这个想法被springboot听见了&#xff0c…

云服务器下如何部署Flask项目详细操作步骤

参考网上各种方案&#xff0c;再结合之前学过的Django部署方案&#xff0c;最后确定Flask总体部署是基于&#xff1a;centos7nginxuwsgipython3Flask之上做的。 本地windows开发测试好了我的OCR项目&#xff0c;现在要部署我的OCR项目到云服务器上验证下。 第一步&#xff1a…

java 整合 swagger-ui 步骤

1.在xml 中添加Swagger 相关依赖 <!-- springfox-swagger2 --><dependency><groupId>io.springfox</groupId><artifactId>springfox-swagger2</artifactId><version>2.9.2</version></dependency><!-- springfox-swa…

Git 命令行查看仓库信息

目录 查看系统config ​编辑查看当前用户&#xff08;global&#xff09;配置 查看当前仓库配置信息 查看系统config git config --system --list 1 查看当前用户&#xff08;global&#xff09;配置 git config --global --list 1 查到的是email , name 等ssl签名信息&a…

测试岗位的不足和缺点-思考

软件测试岗位在实际工作中可能会面临一些不足和缺点&#xff0c;以下是一些常见的问题&#xff1a; 高压力、高强度的工作&#xff1a;软件测试工作往往需要在项目截止日期前完成测试&#xff0c;这可能会带来巨大的压力。同时&#xff0c;如果开发团队在项目中进行了大量的更改…

CRM软件系统能否监控手机的使用

CRM可以监控手机吗&#xff1f;答案是不可以。CRM是一款帮助企业优化业务流程&#xff0c;提高销售效率的工具。例如Zoho CRM&#xff0c;最多也就是听一下销售的通话录音&#xff0c;却不可以监控手机&#xff0c;毕竟CRM不是一款监控软件。 CRM的主要作用有以下几点&#xf…

《Go 语言第一课》课程学习笔记(十四)

接口 认识接口类型 接口类型是由 type 和 interface 关键字定义的一组方法集合&#xff0c;其中&#xff0c;方法集合唯一确定了这个接口类型所表示的接口。type MyInterface interface {M1(int) errorM2(io.Writer, ...string) }我们在接口类型的方法集合中声明的方法&#…

深入探索KVM虚拟化技术:全面掌握虚拟机的创建与管理

文章目录 安装KVM开启cpu虚拟化安装KVM检查环境是否正常 KVM图形化创建虚拟机上传ISO创建虚拟机加载镜像配置内存添加磁盘能否手工指定存储路径呢&#xff1f;创建成功安装完成查看虚拟机 KVM命令行创建虚拟机创建磁盘通过命令行创建虚拟机手动安装虚拟机 KVM命令行创建虚拟机-…

攻防世界-WEB-ics-05

打开靶机 只有设备维护中心可以点开 点标签得到新的url pageindex 想到文件包含漏洞&#xff08;URL中出现path、dir、file、pag、page、archive、p、eng、语言文件等相关关键字眼 利用php伪协议查看源码 出现一段base64源码&#xff0c;进行转码得出源码 ?pagephp://filter…

驱动开发--day2

实现三盏灯的控制&#xff0c;编写应用程序测试 head.h #ifndef __HEAD_H__ #define __HEAD_H__#define LED1_MODER 0X50006000 #define LED1_ODR 0X50006014 #define LED1_RCC 0X50000A28#define LED2_MODER 0X50007000 #define LED2_ODR 0X50007014#endif mychrdev.c #inc…

LeetCode刷题笔记【30】:动态规划专题-2(不同路径、不同路径 II)

文章目录 前置知识62.不同路径题目描述解题思路代码 63. 不同路径 II题目描述障碍信息传递法(比较复杂)被障碍物阻挡后直接清空计数法(更简洁) 总结 前置知识 参考前文 参考文章&#xff1a; LeetCode刷题笔记【29】&#xff1a;动态规划专题-1&#xff08;斐波那契数、爬楼梯…

【深入理解Linux内核锁】七、互斥体

我的圈子: 高级工程师聚集地 我是董哥,高级嵌入式软件开发工程师,从事嵌入式Linux驱动开发和系统开发,曾就职于世界500强企业! 创作理念:专注分享高质量嵌入式文章,让大家读有所得! 文章目录 1、互斥体API2、API实现2.1 mutex2.2 mutex_init2.3 mutex_lock2.4 mutex_un…

2023外贸SEO推广怎么做?

答案是&#xff1a;2023外贸SEO推广可以选择谷歌SEO谷歌Ads双向运营。 外贸SEO的核心要素 外贸SEO不仅仅是关于关键词排名&#xff0c;它更多的是关于品牌建设和目标受众的吸引。 要想成功&#xff0c;必须认识到几个关键要素。 了解目标市场 首先&#xff0c;要深入了解目…

vue3+ts+vite项目引入echarts,vue3项目echarts组件封装

概述 技术栈&#xff1a;Vue3 Ts Vite Echarts 简介&#xff1a; 图文详解&#xff0c;教你如何在Vue3项目中引入Echarts&#xff0c;封装Echarts组件&#xff0c;并实现常用Echarts图例 文章目录 概述一、先看效果1.1 静态效果1.2 动态效果 二、话不多数&#xff0c;引入 …

redis持久化

Redis高可用 在web服务器中&#xff0c;高可用是指服务器可以正常访问的时间&#xff0c;衡量的标准是在多长时间内可以提供正常服务&#xff08;99.9%、99.99%、99.999%等等&#xff09;。 但是在Redis语境中&#xff0c;高可用的含义似乎要宽泛一些&#xff0c;除了保证提供…

ARM的异常处理

概念 处理器在正常执行程序的过程中可能会遇到一些不正常的事件发生 这时处理器就要将当前的程序暂停下来转而去处理这个异常的事件 异常事件处理完成之后再返回到被异常打断的点继续执行程序 异常处理机制 不同的处理器对异常的处理的流程大体相似&#xff0c;但是不同的处理器…

面试如何回答弹性盒子布局这个问题呢?

在我们面试中如果被问道css方面的面试题 那么极有可能被问到的一道面试题就是弹性盒子&#xff0c;本篇文章通过一张图带你拿捏这道面试题。 1、首先需要说一说弹性盒子的基本概念&#xff1a;弹性盒子是一种用于网页布局中创建灵活和响应式设计的CSS布局模型。 2、其次需要说…

IDEA找不到Maven窗口

有时候导入项目或者创建项目时候Maven窗口找不到了 然后指定项目的pom.xml文件

力扣(LeetCode)算法_C++——有效的数独

请你判断一个 9 x 9 的数独是否有效。只需要 根据以下规则 &#xff0c;验证已经填入的数字是否有效即可。 数字 1-9 在每一行只能出现一次。 数字 1-9 在每一列只能出现一次。 数字 1-9 在每一个以粗实线分隔的 3x3 宫内只能出现一次。&#xff08;请参考示例图&#xff09; …