使用 PyTorch 构建 LSTM 股票价格预测模型

目录

      • 引言
      • 准备工作
      • 1. 训练模型(`train.py`)
      • 2. 模型定义(`model.py`)
      • 3. 测试模型和可视化(`test.py`)
      • 使用说明
      • 模型调整
      • 结论

引言

在金融领域,股票价格预测是一个重要且具有挑战性的任务。随着深度学习的发展,长短期记忆网络(LSTM)因其在处理时间序列数据方面的出色表现而受到关注。本篇博客将指导你如何使用PyTorch构建一个LSTM模型来预测股票价格,我们将逐步介绍数据预处理、模型训练和结果可视化的完整流程。

准备工作

  1. 安装依赖
    确保你已经安装了以下 Python 库:

    pip install pandas numpy torch matplotlib scikit-learn
    
  2. 下载数据
    使用 yfinance 库下载你感兴趣的股票的历史数据,并保存为 CSV 文件。我们这里使用 Apple(AAPL)过去五年的数据,文件命名为 AAPL_5y_data.csv。以下是一个下载数据的代码示例:

    import yfinance as yf# 下载Apple股票过去5年的数据
    data = yf.download('AAPL', start='2019-01-01', end='2024-01-01')
    data.to_csv('AAPL_5y_data.csv')
    

1. 训练模型(train.py

在这个脚本中,我们将读取 CSV 文件,归一化数据,并使用 LSTM 模型进行训练。

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import MinMaxScaler
from model import LSTM  # 导入LSTM类# 设置随机种子
torch.manual_seed(42)# 读取CSV文件
file_path = 'AAPL_5y_data.csv'  # 替换为你的CSV文件路径
data = pd.read_csv(file_path)# 确保日期列是 datetime 类型
data['Date'] = pd.to_datetime(data['Date'])
data.set_index('Date', inplace=True)# 选择多特征:'Close', 'Open', 'High', 'Low', 'Volume'
features = data[['Close', 'Open', 'High', 'Low', 'Volume']].values# 数据归一化
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(features)# 准备训练和测试数据
train_size = int(len(scaled_data) * 0.8)
train_data = scaled_data[:train_size]
test_data = scaled_data[train_size:]def create_dataset(data, time_step=1):X, y = [], []for i in range(len(data) - time_step - 1):a = data[i:(i + time_step)]X.append(a)y.append(data[i + time_step, 0])  # 预测收盘价return np.array(X), np.array(y)# 创建数据集
time_step = 50  # 时间步长
X_train, y_train = create_dataset(train_data, time_step)# 转换为PyTorch张量
X_train = torch.from_numpy(X_train).float()
y_train = torch.from_numpy(y_train).float().view(-1, 1)# 初始化模型、损失函数和优化器
model = LSTM()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)# 训练模型
num_epochs = 300
for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(X_train)loss = criterion(outputs, y_train)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 保存模型
torch.save(model.state_dict(), 'lstm_model.pth')
print("模型已保存为 'lstm_model.pth'")

2. 模型定义(model.py

在这个文件中定义 LSTM 模型结构。

import torch
import torch.nn as nnclass LSTM(nn.Module):def __init__(self):super(LSTM, self).__init__()self.lstm = nn.LSTM(input_size=5, hidden_size=100, num_layers=2, batch_first=True)self.fc = nn.Linear(100, 1)def forward(self, x):out, _ = self.lstm(x)out = self.fc(out[:, -1, :])  # 取最后时间步的输出return out

3. 测试模型和可视化(test.py

在这个脚本中,我们将加载训练好的模型,并使用测试数据进行预测和可视化。

import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from model import LSTM  # 导入LSTM类# 设置字体为SimHei,用于显示中文
plt.rcParams['font.family'] = 'SimHei'# 读取CSV文件
file_path = 'AAPL_5y_data.csv'  # 替换为你的CSV文件路径
data = pd.read_csv(file_path)# 确保日期列是 datetime 类型
data['Date'] = pd.to_datetime(data['Date'])
data.set_index('Date', inplace=True)# 选择多特征:'Close', 'Open', 'High', 'Low', 'Volume'
features = data[['Close', 'Open', 'High', 'Low', 'Volume']].values# 数据归一化
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(features)# 准备训练和测试数据
train_size = int(len(scaled_data) * 0.8)
train_data = scaled_data[:train_size]
test_data = scaled_data[train_size:]def create_dataset(data, time_step=1):X, y = [], []for i in range(len(data) - time_step - 1):a = data[i:(i + time_step)]X.append(a)y.append(data[i + time_step, 0])  # 预测收盘价return np.array(X), np.array(y)# 创建测试数据集
time_step = 50  # 时间步长
X_test, y_test = create_dataset(test_data, time_step)# 转换为PyTorch张量
X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test).float().view(-1, 1)# 加载模型
model = LSTM()
model.load_state_dict(torch.load('lstm_model.pth'))
model.eval()# 测试模型
with torch.no_grad():test_outputs = model(X_test)# test_outputs 是预测的收盘价,将其重新归一化为原始价格test_outputs = scaler.inverse_transform(np.concatenate((test_outputs.numpy(), np.zeros((test_outputs.shape[0], 4))), axis=1))[:, 0]  # 反归一化收盘价y_test_inverse = scaler.inverse_transform(np.concatenate((y_test.numpy(), np.zeros((y_test.shape[0], 4))), axis=1))[:, 0]# 可视化结果
plt.figure(figsize=(14, 7))
plt.plot(data.index[-len(y_test):], y_test_inverse, label='真实价格', color='blue')
plt.plot(data.index[-len(test_outputs):], test_outputs, label='预测价格', color='red')
plt.title('股票价格预测')
plt.xlabel('日期')
plt.ylabel('价格')
plt.legend()
plt.show()

使用说明

  1. 保存脚本

    • 将训练脚本代码保存为 train.py
    • 将模型定义代码保存为 model.py
    • 将测试脚本代码保存为 test.py
  2. 运行训练

    • 在命令行中运行训练脚本:
      python train.py
      
    • 训练完成后,模型将保存为 lstm_model.pth
  3. 运行测试和可视化

    • 在命令行中运行测试脚本:

      python test.py
      
    • 这将加载已训练的模型,并可视化预测结果。
      在这里插入图片描述
      这只是一个演示,模型的预测效果还有待进一步优化。

模型调整

如果预测的价格和真实价格差距较大,可能是由于以下几个原因:

  1. 数据规模不足

    • 如果训练数据不足,模型可能无法学到市场的长期趋势。
    • 改进:使用更多的历史数据,尽量包括多年的数据。可以尝试增加数据的时间跨度。
  2. 数据预处理问题

    • 数据没有正确归一化,或归一化范围过窄。
    • 改进:检查 MinMaxScaler 的应用。你可以尝试不同的归一化范围,例如 (0, 1)(-1, 1),也可以使用其他标准化方法(例如 StandardScaler)。
  3. 模型复杂度不足

    • 模型的层数或隐藏单元数量可能不足以捕捉数据的复杂性。
    • 改进:增加 LSTM 的隐藏层数量或隐藏单元数量。你还可以考虑添加其他类型的层,例如卷积层(CNN)或全连接层,以提高模型的表达能力。
  4. 超参数调整

    • 学习率、批大小和时间步长等超参数可能需要调整以优化模型性能。
    • 改进:尝试不同的学习率(例如,0.001、0.0001 等)、不同的批大小(如 16、32、64)和时间步长(如 30、60)。
  5. 更改损失函数

    • 在某些情况下,使用不同的损失函数可能有助于模型的收敛。
    • 改进:可以尝试使用其他损失函数,例如 Huber 损失函数(nn.SmoothL1Loss)或自定义损失函数,以更好地适应数据。

结论

通过使用 PyTorch 构建 LSTM 模型,我们成功地实现了股票价格的预测。在这个过程中,我们学习了如何处理时间序列数据,构建和训练深度学习模型,以及如何评估和可视化预测结果。尽管模型的性能可能需要进一步的优化和调整,但这个示例为未来的工作奠定了基础。

希望这篇博客能够帮助你在股票价格预测方面取得更好的成果。欢迎分享你的成果和经验,或者提出你的问题!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/451694.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux文件操作基础

目录 Linux文件操作基础 引入 回顾C语言文件操作 系统调用接口 open函数 read函数和write函数 close函数 模拟C语言接口 文件描述符 如何理解Linux下一切皆文件 文本读写与二进制读写 Linux文件操作基础 引入 在Linux第一章提到过,在Linux中&#xff0…

快速创建一个vue项目并运行

前期准备工作: 1.安装node 2.安装npm 3.设置淘宝镜像 4.全局安装webpack 5.webpack 4.X 开始,需要安装 webpack-cli 依赖 6.全局安装vue-cli 正文开始: 1.创建项目 ,回车 vue init webpack vue-svg > Project name vue-demo 项目名称 回车 > Pro…

电脑桌面自己变成了英文Desktop,怎么改回中文

目录 前言找到Desktop查看位置查找目标修改文件名为桌面重启电脑 或 重启 Windows 资源管理器CtrlShiftEsc 打开任务管理器找到 Windows 资源管理器重启 Windows 资源管理器 查看修改结果 前言 许多人在使用电脑的时候发现,我们经常使用的桌面,不知道因为…

安卓流式布局实现记录

效果图&#xff1a; 1、导入第三方控件 implementation com.google.android:flexbox:1.1.0 2、布局中使用 <com.google.android.flexbox.FlexboxLayoutandroid:id"id/baggageFl"android:layout_width"match_parent"android:layout_height"wrap_co…

震惊!原来贡献开源代码这么简单,分分钟上手!

文章目录 前言一、什么是 Fork 和 PR&#xff1f;1. Fork&#xff08;分叉&#xff09;2. PR&#xff08;Pull Request&#xff0c;拉取请求&#xff09; 二、两种常见的贡献代码方式1. Fork 后通过 PR 提交代码2. 直接在项目分支中修改 三、如何 Fork 和发起 Pull Request&…

高效车辆管理:SpringBoot实现指南

1系统概述 1.1 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及&#xff0c;互联网成为人们查找信息的重要场所&#xff0c;二十一世纪是信息的时代&#xff0c;所以信息的管理显得特别重要。因此&#xff0c;使用计算机来管理车辆管理系统的相关信息成为必然。开发合适…

蜗牛兼职网的设计与实现(论文+源码)_kaic

摘 要 随着科学技术的飞速发展&#xff0c;社会的方方面面、各行各业都在努力与现代的先进技术接轨&#xff0c;通过科技手段来提高自身的优势&#xff0c;蜗牛兼职网当然也不能排除在外。蜗牛兼职网是以实际运用为开发背景&#xff0c;运用软件工程原理和开发方法&#xff0c…

Unity开发Hololens项目

Unity打包Hololens设备 目录Visual Studio2019 / Visual Studio2022 远端部署设置Visual Studio2019 / Visual Studio2022 USB部署设置Hololens设备如何查找自身IPHololens设备门户Unity工程内的打包设置 目录 记录下自己做MR相关&#xff1a;Unity和HoloLens设备的历程。 Vi…

智能家居的“眼睛”:计算机视觉如何让家更智能

引言 在不远的未来&#xff0c;当我们走进家门&#xff0c;灯光自动亮起&#xff0c;空调已经调至最舒适的温度&#xff0c;甚至音乐也播放着我们最喜欢的歌曲。 这一切&#xff0c;都得益于智能家居系统的发展。而在这个系统中&#xff0c;计算机视觉技术扮演着至关重要的角色…

opencv 图像BGR三通道分离 split 与 合并 merge -python 实现

图像BGR三通道分离 split 与 合并 merge 会在图像预处理和图像增强中使用。 具体代码如下&#xff1a; #-*-coding:utf-8-*- # date:2021-03-21 # Author: DataBall - XIAN 1、将彩色图片 BGR 三通道分离&#xff08;注意观察 B、G、R 单通道图像素的明暗&#xff09;2、将3个…

Java知识巩固(六)

什么是可变长参数&#xff1f; 从 Java5 开始&#xff0c;Java 支持定义可变长参数&#xff0c;所谓可变长参数就是允许在调用方法时传入不定长度的参数。就比如下面这个方法就可以接受 0 个或者多个参数。 public static void method1(String... args) {//...... } 另外&am…

python 作业1

任务1: python为主的工作是很少的 学习的python的优势在于制作工具&#xff0c;制作合适的工具可以提高我们在工作中的工作效率的工具 提高我们的竞争优势。 任务2: 不换行 换行 任务3: 安装pycharm 进入相应网站Download PyCharm: The Python IDE for data science and we…

分享一套SpringBoot+Vue民宿(预约)系统

大家好&#xff0c;我是java1234_小锋老师&#xff0c;看到一个不错的SpringBootVue民宿(预约)系统&#xff0c;分享下嘿嘿。 项目介绍 传统办法管理信息首先需要花费的时间比较多&#xff0c;其次数据出错率比较高&#xff0c;而且对错误的数据进行更改也比较困难&#xff0c…

qt QGraphicsEffect详解

一、QGraphicsEffect概述 QGraphicsEffect通过挂接到渲染管道并在源&#xff08;例如QGraphicsPixmapItem、QWidget&#xff09;和目标设备&#xff08;例如QGraphicsView的视口&#xff09;之间进行操作来更改元素的外观。它允许开发者为图形项添加各种视觉效果&#xff0c;如…

Redis——事务

文章目录 Redis 事务Redis 的事务和 MySQL 事务的区别:事务操作MULTIEXECDISCARDWATCHUNWATCHwatch的实现原理 总结 Redis 事务 什么是事务 Redis 的事务和 MySQL 的事务 概念上是类似的. 都是把⼀系列操作绑定成⼀组. 让这⼀组能够批量执行 Redis 的事务和 MySQL 事务的区别:…

无人机之融合集群技术篇

无人机的融合集群技术是一个涉及多个领域的复杂技术体系&#xff0c;它结合了无人机技术、自组网技术、集群控制技术以及反制设备等多个方面&#xff0c;旨在实现多架无人机之间的协同、编队、信息共享、任务分配和高效作业。 一、无人机自组网技术 无人机自组网技术是一种利用…

UDP/TCP协议

网络层只负责将数据包送达至目标主机&#xff0c;并不负责将数据包上交给上层的哪一个应用程序&#xff0c;这是传输层需要干的事&#xff0c;传输层通过端口来区分不同的应用程序。传输层协议主要分为UDP&#xff08;用户数据报协议&#xff09;和TCP&#xff08;传输控制协议…

1. 安装框架

一、安装 Laravel 11 框架 按照官方文档直接下一步安装即可 1. 安装步骤 2. 执行数据库迁移 在.env文件中提前配置好数据库连接信息 php artisan migrate二、安装 Filament3.2 参考 中文文档 进行安装 1. 安装 拓展包 composer require filament/filament:"^3.2" -W…

cisco网络安全技术第3章测试及考试

测试 使用本地数据库保护设备访问&#xff08;通过使用 AAA 中央服务器来解决&#xff09;有什么缺点&#xff1f; 试题 1选择一项&#xff1a; 必须在每个设备上本地配置用户帐户&#xff0c;是一种不可扩展的身份验证解决方案。 请参见图示。AAA 状态消息的哪一部分可帮助…

java基于SpringBoot+Vue+uniapp微信小程序的自助点餐系统的详细设计和实现(源码+lw+部署文档+讲解等)

项目运行截图 技术框架 后端采用SpringBoot框架 Spring Boot 是一个用于快速开发基于 Spring 框架的应用程序的开源框架。它采用约定大于配置的理念&#xff0c;提供了一套默认的配置&#xff0c;让开发者可以更专注于业务逻辑而不是配置文件。Spring Boot 通过自动化配置和约…