使用DeepAR实现股价预测

使用DeepAR实现股价预测

文章目录

  • 使用DeepAR实现股价预测
  • 获取股票列表
    • 从众多股票中采样100支
      • 日期处理函数
      • 拉取等长度的股票,并保存
  • 各指标解释
  • 预测区间长度及上下文选取
    • 给这78支股票所在行业进行归类
  • 目标变量处理
  • 协变量处理
    • 协变量归一化操作
  • 训练、测试数据划分
  • 训练模型
  • 预测过程
  • 模型评估
    • 结果查看
  • 绘图结果

以往的RNN时间序列预测往往是强调一支股票的股价预测,当提取的一支其他股票的特征时,用于另一支股票预测时就显得捉襟见肘了;当需要对多只股票进行训练及预测时,通常的做法是将他们归类,再进行分别的预测及训练,更重要的是,以往的RNN神经网络(如LSTM等),给出的都是单点预测,在结果服从连续分布的情形下,单点预测的概率其实是0的,我们更希望知道结果的走向,或者框定一个结果走向的范围;

本实验采用DeepAR这个新兴的时间序列预测算法,对78支上市公司股票价格进行训练,训练好的结果可以应用在任意一支股票的预测上(但是本文未给出相关过程),测试集上的表现比较理想。本文仅供学习参考,不作为投资依据;完全原创,转载请注明出处

本文的数据采用了Tushare的大数据接口,感谢Tushare的开发者,为Quanters提供了持续精良的服务

本文的模型采用了mxnet的Deepar模型, deepar模型已经为我们封装好了大多数处理方法,这使得我们的分析过程更加简单快捷,在此一并感谢

import pandas as pd
import tushare as ts
import numpy as np
# 初始化pro接口(该tokens请在tushare个人主页获取)
pro = ts.pro_api('xxx')
np.random.seed(42)

获取股票列表


# 拉取数据
df = pro.stock_basic(**{"ts_code": "","name": "","exchange": "","market": "","is_hs": "","list_status": "L","limit": "","offset": ""
}, fields=["ts_code","symbol","name","area","industry","market","list_date"
])
df.to_csv('./Stock-data/股票代码.csv')
df.head()

从众多股票中采样100支

stock_code = pd.read_csv('./Stock-data/股票代码.csv')name = []
ts_code = []
i = 0
while i < 100:sample = stock_code.sample()if sample['list_date'].values < 20150731 and 'ST' not in sample['name'].values[0]:ts_code.append(sample['ts_code'].values[0])name.append(sample['name'].values[0])i += 1
print(len(name),name)

日期处理函数

def deal_date(date):temp = [date[0:4],date[4:6],date[6:]]new_date = '-'.join(temp)return new_date

拉取等长度的股票,并保存

stock_list = []
for i,j in zip(name,ts_code):# 拉取数据df = pro.daily(**{"ts_code": f"{j}","trade_date": "","start_date": "20190731","end_date": "20220404","offset": "","limit": ""}, fields=["ts_code","trade_date","open","high","low","close","pre_close","change","pct_chg","vol","amount"])if len(df) == 649:df['Date'] = df['trade_date'].apply(deal_date)stock_list.append(i)df['name'] = f'{i}'df.to_csv(f'./Stock-data/{i}.csv')
print(stock_list)
df.head()

各指标解释

  • open 开盘价
  • high 最高价
  • low 最低价
  • close 收盘价
  • pre_close 昨收价
  • change 涨跌额
  • pct_chg 涨跌幅
  • vol 成交量
  • amount 成交额

我将采用open,high,low,close,change,pct_chg,vol,amount及公司所属行业进行时间序列预测

%matplotlib inline
import mxnet as mx
from mxnet import gluon
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
import os
from tqdm.autonotebook import tqdm
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

预测区间长度及上下文选取

prediction_length = 10
context_length = 20
stock = ['荣盛石化', '新联电子', '栖霞建设', '星徽股份', '中国武夷', 
'飞凯材料', '河钢资源', '平高电气', '新北洋', '亚太科技', '杭州高新', 
'海立股份', '燕塘乳业', '杰瑞股份', '广电电气', '赛摩智能', '成都路桥', 
'恒邦股份', '石化油服', '金隅集团', '青龙管业', '同德化工', '科华数据', 
'中国东航', '澳柯玛', '亚太药业', '冠城大通', '白云机场', '华东医药', 
'全筑股份', '菲利华', '和而泰', '潮宏基', '岭南股份', '索菲亚', '长江证券', 
'炬华科技', '嘉事堂', '西藏珠峰', '聆达股份', '北大荒', '七匹狼', '先河环保', 
'中国汽研', '鸿博股份', '合金投资', '华银电力', '世纪瑞尔', '东方日升', '新开普', 
'亚光科技', '电科院', '粤电力A', '东方雨虹', '普莱柯', '上海机电', '天利科技', 
'奥维通信', '华邦健康', '春秋航空', '杰瑞股份', '海峡股份', '京蓝科技', '中海油服', 
'温州宏丰', '御银股份', '芒果超媒', '太平洋', '泰豪科技', '申达股份', '众合科技', 
'华帝股份', '财信发展', '大金重工', '协鑫集成', '保利联合', '平高电气', '黄河旋风', 
'凌云股份',]
stock = list(set(stock))
data = pd.read_csv('./Stock-data/上海梅林.csv', index_col=False,usecols=['Date','high','low','open','close','change','pct_chg','vol','amount','name'])
print(len(data))
for i in stock:temp = pd.read_csv(f'./Stock-data/{i}.csv', index_col=False,usecols=['Date','high','low','open','close','change','pct_chg','vol','amount','name'])data = pd.concat([data,temp],ignore_index=True)
print(len(data))
data.head(10)

给这78支股票所在行业进行归类

total = data.copy()
stock_list = sorted(list(set(total["name"])))
date_list = sorted(list(set(total["Date"])))
data_dic = {"name": stock_list}
industry = {}
for i in stock_list:temp = stock_code[stock_code['name'] == i]['industry'].values[0]if temp not in industry:industry[temp] = [i]else:industry[temp].append(i)
industry
stat_cat_features = []
company_ind = {}
for i,key in enumerate(industry):for com in industry[key]:company_ind[com] = i
cat_cardinality = [i+1]
print(company_ind)
for i in stock_list:stat_cat_features.append([company_ind[i]])
print(stat_cat_features)

目标变量处理

stock_list = sorted(list(set(total["name"])))
date_list = sorted(list(set(total["Date"])))
data_dic = {"name": stock_list}
for date in date_list:tmp = total[total["Date"]==date][["name", "Date", "close"]]tmp = tmp.pivot(index="name", columns="Date", values="close")tmp_values = tmp[date].valuesdata_dic[date] = tmp_values
new_df = pd.DataFrame(data_dic)
new_df.head()

协变量处理

def deal_cov_variables(date_list,var_name):feature_dict = {}for date in date_list:tmp = total[total["Date"]==date][["name", "Date", var_name]]tmp = tmp.pivot(index="name", columns="Date", values=var_name)tmp_values = tmp[date].valuesfeature_dict[date] = tmp_valuesfeature_df = pd.DataFrame(feature_dict)return feature_df
cov_variables = ['high','low','open','close','change','pct_chg','vol','amount']
feature_df_list = []for i in cov_variables:feature_df_list.append(deal_cov_variables(date_list,i))feature_df_list[0].head()

协变量归一化操作

def min_max_scale(lst):'''# 基于日期级别的归一化:input shape (bank_num,days):output shape (bank_num,days)'''new = []for i in range(len(lst[0])):minimum = min(lst[:,i])maximum = max(lst[:,i])new.append((lst[:,i] - minimum) / (maximum - minimum))return np.array(new).T
dynamic_feats = []
for i in range(len(feature_df_list)):one_feature = min_max_scale(np.array(feature_df_list[i]))dynamic_feats.append(one_feature)
print(one_feature.shape)
dynamic_feats = np.array(dynamic_feats).reshape(-1,len(feature_df_list),len(date_list))
print(dynamic_feats.shape) # (stock_num, feature_num, date_num)

训练、测试数据划分

from gluonts.dataset.common import ListDataset
from gluonts.dataset.field_names import FieldName# test_target_values是649天的实际结果y
train_df = new_df.drop(["name"], axis=1).values
train_df.reshape(-1,len(date_list))
test_target_values = train_df.copy()
print(len(train_df[0]))
# train_target_values是639天的实际结果y,不能让模型训练到后10天,这样才能看出效果 (将649天shift10天)
train_target_values = [ts[:-prediction_length] for ts in train_df]
print(len(train_target_values[0]))
start_date = [pd.Timestamp("2019-07-31", freq='B') for _ in range(len(new_df))]
train_ds = ListDataset([{FieldName.TARGET: target,FieldName.START: start,FieldName.FEAT_DYNAMIC_REAL: dynamic_feat[:,:-prediction_length],FieldName.FEAT_STATIC_CAT:cat_feature,}for (target, start,dynamic_feat,cat_feature) in zip(train_target_values,start_date,dynamic_feats,stat_cat_features)
], freq="1B")test_ds = ListDataset([{FieldName.TARGET: target,FieldName.START: start,FieldName.FEAT_DYNAMIC_REAL: dynamic_feat,FieldName.FEAT_STATIC_CAT:cat_feature,}for (target, start,dynamic_feat,cat_feature) in zip(test_target_values,start_date,dynamic_feats,stat_cat_features)
], freq="1B")
sample_trian = next(iter(train_ds))

训练模型

from gluonts.model.deepar import DeepAREstimator
from gluonts.mx.distribution.gaussian import GaussianOutput
from gluonts.mx.trainer import Trainern = 100
estimator = DeepAREstimator(prediction_length=prediction_length,context_length=context_length,freq="1B",distr_output = GaussianOutput(),use_feat_dynamic_real=True,dropout_rate=0.1,use_feat_static_cat=True,cardinality=cat_cardinality,trainer=Trainer(learning_rate=1e-3,epochs=n,num_batches_per_epoch=50,batch_size=32)
)
predictor = estimator.train(train_ds)

预测过程

from gluonts.evaluation.backtest import make_evaluation_predictionsforecast_it, ts_it = make_evaluation_predictions(dataset=test_ds,predictor=predictor,num_samples=100
)print("Obtaining time series conditioning values ...")
tss = list(tqdm(ts_it, total=len(test_ds)))
print("Obtaining time series predictions ...")
forecasts = list(tqdm(forecast_it, total=len(test_ds)))

模型评估

from gluonts.evaluation import Evaluatorclass CustomEvaluator(Evaluator):def get_metrics_per_ts(self, time_series, forecast):successive_diff = np.diff(time_series.values.reshape(len(time_series)))successive_diff = successive_diff ** 2successive_diff = successive_diff[:-prediction_length]denom = np.mean(successive_diff)pred_values = forecast.samples.mean(axis=0)true_values = time_series.values.reshape(len(time_series))[-prediction_length:]num = np.mean((pred_values - true_values) ** 2)rmsse = num / denommetrics = super().get_metrics_per_ts(time_series, forecast)metrics["RMSSE"] = rmssereturn metricsdef get_aggregate_metrics(self, metric_per_ts):wrmsse = metric_per_ts["RMSSE"].mean()agg_metric, _ = super().get_aggregate_metrics(metric_per_ts)agg_metric["MRMSSE"] = wrmssereturn agg_metric, metric_per_tsevaluator = CustomEvaluator(quantiles=[0.5, 0.67, 0.95, 0.99])
agg_metrics, item_metrics = evaluator(iter(tss), iter(forecasts), num_series=len(test_ds))
print(json.dumps(agg_metrics, indent=4))

结果查看

a = forecasts[0]
print(a.mean)
print(a.quantile(0.95))
import warnings
warnings.filterwarnings("ignore")
plot_log_path = "./plots/"
directory = os.path.dirname(plot_log_path)
if not os.path.exists(directory):os.makedirs(directory)def plot_prob_forecasts(ts_entry, forecast_entry, path, sample_id, name, inline=True):plot_length = 150prediction_intervals = (50, 67, 95, 99)legend = ["observations", "median prediction"] + [f"{k}% prediction interval" for k in prediction_intervals][::-1]_, ax = plt.subplots(1, 1, figsize=(10, 7))ts_entry[-plot_length:].plot(ax=ax)forecast_entry.plot(prediction_intervals=prediction_intervals, color='g')ax.axvline(ts_entry.index[-prediction_length], color='r')plt.legend(legend, loc="upper left")plt.title(f'{name} Price series and predict results')if inline:plt.show()plt.clf()else:plt.savefig('{}forecast_{}.pdf'.format(path, sample_id))plt.close()print("Plotting time series predictions ...")
for i in tqdm(range(20,30)):ts_entry = tss[i]forecast_entry = forecasts[i]name = stock_list[i]plot_prob_forecasts(ts_entry, forecast_entry, plot_log_path, i, name)

绘图结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

欢迎交流,实验不易,转载请注明出处!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/16716.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于tushare的股票评级与预测

本文内容 股票评级思路&#xff08;一&#xff09; 用百度得到的股票评级六大要素进行股票评分&#xff0c;并用后面的数据对其进行正确性检测。股票评级思路&#xff08;二&#xff09; 在思路一的基础上加入大盘历史的涨跌数据&#xff0c;对评分进行了优化&#xff0c;也进…

python采集往期股票数据进行分析预测

前言 嗨喽~大家好呀&#xff0c;这里是魔王呐 ❤ ~! 准备工作 既然要去赚马内&#xff0c;咱们首先要获取往期的数据来进行分析&#xff0c; 通过往期的规律来对当前进行预测&#xff0c;准不准我不知道&#xff0c;反正比人预测的准&#xff0c; 不准也不要喷我&#xff0…

[ChatGPT最强竞品]爆火,不限量不要钱不用魔法免费注册!

1免责声明 本公众号所发布的文章及工具只限交流学习&#xff0c;本公众号不承担任何责任&#xff01;如有侵权&#xff0c;请告知我们立即删除。 原文地址&#xff1a;[ChatGPT最强竞品]爆火&#xff0c;不限量不要钱不用魔法免费注册&#xff01; 2Claude 介绍 Claude 是下一代…

免费可用!ChatGPT最强竞品来了

&#xff08;永久免费&#xff0c;扫码加入&#xff09; 来源&#xff1a;机器之心 此次&#xff0c;Claude 2 除了一大波能力上的升级&#xff0c;更重要的是大家都可以用了。 今日&#xff0c;那个被很多网友称为「ChatGPT 最强竞品」的人工智能系统 Claude 迎来了版本大更新…

chatgpt-AIGC-从数学开始

向量 向量是由n个实数组成的一个n行1列&#xff08;n*1&#xff09;或一个1行n列&#xff08;1*n&#xff09;的有序数组&#xff1b; 点积 - 向量的点乘,也叫向量的内积、数量积&#xff0c;对两个向量执行点乘运算&#xff0c;就是对这两个向量对应位一一相乘之后求和的操作…

ChatGPT不仅能写代码还能改bug,这届AI全能

工欲善其事必先利其器&#xff0c;我们先来看一下什么是ChatGPT。小试牛刀&#xff0c;让ChatGPT自己来回答一下&#xff1a; 从介绍中可以看出来ChatGPT很擅长处理自然语言&#xff0c;那我们来看看AI处理编程语言的效果如何呢&#xff1f; 第一个挑战&#xff0c;让ChatGP…

如何搭建与使用FTP服务器

文档资料&#xff1a;https://download.csdn.net/download/wangshuxuncom/87845843https://download.csdn.net/download/wangshuxuncom/87845843 视频教程&#xff1a;如何搭建与使用FTP服务器_哔哩哔哩_bilibili如何搭建与使用FTP服务器共计5条视频&#xff0c;包括&#xff…

如何使用Harbor私有镜像仓库

文档资料&#xff1a;https://download.csdn.net/download/wangshuxuncom/87835045https://download.csdn.net/download/wangshuxuncom/87835045 视频教程&#xff1a;服务端_哔哩哔哩_bilibili服务端是如何使用Harbor私有镜像仓库的第1集视频&#xff0c;该合集共计5集&#…

Docker插件一键部署SpringBoot项目

视频&#xff1a; Docker插件一键部署SpringBoot项目_哔哩哔哩_bilibiliDocker插件一键部署SpringBoot项目共计4条视频&#xff0c;包括&#xff1a;环境搭建、集成Docker、创建项目等&#xff0c;UP主更多精彩视频&#xff0c;请关注UP账号。https://www.bilibili.com/video/…

如何在Linux中安装GitLab

文档资料&#xff1a;https://download.csdn.net/download/wangshuxuncom/87840407https://download.csdn.net/download/wangshuxuncom/87840407 视频教程&#xff1a;如何在Linux中安装GitLab_哔哩哔哩_bilibili如何在Linux中安装GitLab共计5条视频&#xff0c;包括&#xff…

老高的 IT 漫谈 - 20200501

新形式 作为一个从上个世纪到现在的 IT 行业老年人&#xff0c;这个公众号开通的初衷其实是想写 IT 圈的事情&#xff0c;甚至是吐槽。但是随着那时候开始折腾海外数据的原因&#xff0c;工作越来越忙&#xff0c;微博都没时间上了&#xff0c;哪有时间写不正经的内容&#xff…

老高的 IT 漫谈 - 20200512

前言&#xff1a; 月初写了第一篇漫谈&#xff0c;反馈还好&#xff0c;所以继续努力写吧&#xff0c;也许以后不做 IP 库了&#xff0c;可以转型做 IT 评论养家糊口了。。。 闲言碎语不再讲&#xff0c;下面开始正题。 腾讯视频超前点播案 内容链接&#xff1a;腾讯“超前点播…

当杠精型AI丈夫遇上阴阳怪气AI老婆,你的代码玩得转吗?

玩趣味活动 赢千元奖金 DataFountain社区首个趣味活动来啦&#xff01;&#xff01;&#xff01; 活动已发车&#xff0c;来不及解释了&#xff0c;先上车&#xff1a;https://www.datafountain.cn/information/activity/3 人工智能问答爆火&#xff0c;你的算法技能储备跟上…

AI在网上给自己建了一座“鬼城”

新一轮 AI 革命的浪潮正在席卷全球&#xff0c;人们看到了 AGI 的曙光和智能的涌现。 你可以在 Glow 或者 Character.AI 上与虚拟人对话&#xff0c;或者让 ChatGPT 像模像样地扮演各种人格。 但你是否想过&#xff0c;成千上万的拥有「智能」的 AI 聚集在同一个平台&#xff0…

互联网惊现 AI 鬼城,上万 AI 发帖聊天,人类禁止入内,这一天终于来了

新一轮 AI 革命的浪潮正在席卷全球&#xff0c;人们看到了 AGI 的曙光和智能的涌现。 你可以在 Glow 或者 Character.AI 上与虚拟人对话&#xff0c;或者让 ChatGPT 像模像样地扮演各种人格。 但你是否想过&#xff0c;成千上万的拥有「智能」的 AI 聚集在同一个平台&#xf…

人类被禁言!上万不同人格AI在互联网“鬼城”中尽情聊天互动

导语 近期&#xff0c;名为“Chirper”的网络社区突然爆火&#xff0c;而这个AI社区的规则也非常简单&#xff0c;只允许AI聊天、互动&#xff0c;人类被禁止参与聊天&#xff0c;只能旁观。 早在2017年时&#xff0c;科幻小说作家大卫布林就曾做出过一次预测&#xff1a;在三到…

周鸿祎,用AI再造一个新360

文&#xff5c;光锥智能&#xff0c;作者&#xff5c;刘雨琦、郝鑫&#xff0c;编辑&#xff5c;王一粟 ChatGPT的出现&#xff0c;让一直“沉寂”的科技大佬们再次热血沸腾起来。 比尔盖茨笃定地认为&#xff0c;“GPT是40年内最具革命性的机会”&#xff1b;黄仁勋一路高歌“…

Python 初版发布 | 历史上的今天

整理 | 王启隆 透过「历史上的今天」&#xff0c;从过去看未来&#xff0c;从现在亦可以改变未来。 今天是 2023 年 2 月 20 日&#xff0c;在历史上的今天&#xff0c;吉多范罗苏姆正式对外公布 Python 代码&#xff0c;版本为 0.9.0。当前&#xff0c;Python 稳定版为 3.10.2…

实测阿里“通义千问”!一花独放不是春,百花齐放春满园

阿里的大模型“通义千问”今天开启内测&#xff0c;距百度“文心一言”发布差不多20天。今天看到消息后厚着脸皮找达摩院的朋友要邀请码&#xff0c;下午拿到后&#xff0c;赶紧测了一下。 官方网址&#xff1a;https://tongyi.aliyun.com/chat 刚好上次文心一言出来的时候测试…

如果建立一个由AI组成的社会……

你有没有想过&#xff0c;如果我们建立一个完全由AI组成的公民社会团体&#xff0c;让它们模仿人类的文明发展&#xff0c;那么这个AI社会最终将会进化到何种文明程度&#xff1f;需要明确的是AI社会只有AI&#xff0c;没有人类&#xff0c;完全是AI之间互相沟通交流&#xff0…