使用pytorch实现LSTM预测交通流

原始数据:

免费可下载原始参考数据

预测结果图:

根据测试数据test_data的真实值real_flow,与模型根据测试数据得到的输出结果pre_flow

完整源码:

#!/usr/bin/env python
# _*_ coding: utf-8 _*_import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
from torchsummary import summary
import math
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import time
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_error
import math
from matplotlib.font_manager import FontProperties  # 画图时可以使用中文# 加载数据
f = pd.read_csv(r'C:\Users\14600\Desktop\AE86.csv')# 从新设置列标
def set_columns():columns = []for i in f.loc[2]:columns.append(i.strip())return columnsf.columns = set_columns()
f.drop([0, 1, 2], inplace=True)
# 读取数据
data = f['Total Carriageway Flow'].astype(np.float64).values[:, np.newaxis]class LoadData(Dataset):def __init__(self, data, time_step, divide_days, train_mode):self.train_mode = train_modeself.time_step = time_stepself.train_days = divide_days[0]self.test_days = divide_days[1]self.one_day_length = int(24 * 4)# flow_norm (max_data. min_data)self.flow_norm, self.flow_data = LoadData.pre_process_data(data)# 不进行标准化# self.flow_data = datadef __len__(self, ):if self.train_mode == "train":return self.train_days * self.one_day_length - self.time_stepelif self.train_mode == "test":return self.test_days * self.one_day_lengthelse:raise ValueError(" train mode error")def __getitem__(self, index):if self.train_mode == "train":index = indexelif self.train_mode == "test":index += self.train_days * self.one_day_lengthelse:raise ValueError(' train mode error')data_x, data_y = LoadData.slice_data(self.flow_data, self.time_step, index,self.train_mode)data_x = LoadData.to_tensor(data_x)data_y = LoadData.to_tensor(data_y)return {"flow_x": data_x, "flow_y": data_y}# 这一步就是划分数据@staticmethoddef slice_data(data, time_step, index, train_mode):if train_mode == "train":start_index = indexend_index = index + time_stepelif train_mode == "test":start_index = index - time_stepend_index = indexelse:raise ValueError("train mode error")data_x = data[start_index: end_index, :]data_y = data[end_index]return data_x, data_y# 数据与处理@staticmethoddef pre_process_data(data, ):norm_base = LoadData.normalized_base(data)normalized_data = LoadData.normalized_data(data, norm_base[0], norm_base[1])return norm_base, normalized_data# 生成原始数据中最大值与最小值@staticmethoddef normalized_base(data):max_data = np.max(data, keepdims=True)  # keepdims保持维度不变min_data = np.min(data, keepdims=True)# max_data.shape  --->(1, 1)return max_data, min_data# 对数据进行标准化@staticmethoddef normalized_data(data, max_data, min_data):data_base = max_data - min_datanormalized_data = (data - min_data) / data_basereturn normalized_data@staticmethod# 反标准化  在评价指标误差以及画图的使用使用def recoverd_data(data, max_data, min_data):data_base = max_data - min_datarecoverd_data = data * data_base - min_datareturn recoverd_data@staticmethoddef to_tensor(data):return torch.tensor(data, dtype=torch.float)# 划分数据
divide_days = [25, 5]
time_step = 5
batch_size = 48
train_data = LoadData(data, time_step, divide_days, "train")
test_data = LoadData(data, time_step, divide_days, "test")
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, )
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, )# LSTM构建网络
class LSTM(nn.Module):def __init__(self, input_num, hid_num, layers_num, out_num, batch_first=True):super().__init__()self.l1 = nn.LSTM(input_size=input_num,hidden_size=hid_num,num_layers=layers_num,batch_first=batch_first)self.out = nn.Linear(hid_num, out_num)def forward(self, data):flow_x = data['flow_x']  # B * T * Dl_out, (h_n, c_n) = self.l1(flow_x, None)  # None表示第一次 hidden_state是0#         print(l_out[:, -1, :].shape)out = self.out(l_out[:, -1, :])return out# 定义模型参数
input_num = 1  # 输入的特征维度
hid_num = 50  # 隐藏层个数
layers_num = 3  # LSTM层个数
out_num = 1
lstm = LSTM(input_num, hid_num, layers_num, out_num)
loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(lstm.parameters())# 训练模型
lstm.train()
epoch_loss_change = []
for epoch in range(30):epoch_loss = 0.0start_time = time.time()for data_ in train_loader:lstm.zero_grad()predict = lstm(data_)loss = loss_func(predict, data_['flow_y'])epoch_loss += loss.item()loss.backward()optimizer.step()epoch_loss_change.append(1000 * epoch_loss / len(train_data))end_time = time.time()print("Epoch: {:04d}, Loss: {:02.4f}, Time: {:02.2f} mins".format(epoch, 1000 * epoch_loss / len(train_data),(end_time - start_time) / 60))
plt.plot(epoch_loss_change)# 评价模型
lstm.eval()
with torch.no_grad():  # 关闭梯度total_loss = 0.0pre_flow = np.zeros([batch_size, 1])  # [B, D],T=1 # 目标数据的维度,用0填充real_flow = np.zeros_like(pre_flow)for data_ in test_loader:pre_value = lstm(data_)loss = loss_func(pre_value, data_['flow_y'])total_loss += loss.item()# 反归一化pre_value = LoadData.recoverd_data(pre_value.detach().numpy(),test_data.flow_norm[0].squeeze(1),  # max_datatest_data.flow_norm[1].squeeze(1),  # min_data)target_value = LoadData.recoverd_data(data_['flow_y'].detach().numpy(),test_data.flow_norm[0].squeeze(1),test_data.flow_norm[1].squeeze(1),)pre_flow = np.concatenate([pre_flow, pre_value])real_flow = np.concatenate([real_flow, target_value])pre_flow = pre_flow[batch_size:]real_flow = real_flow[batch_size:]
#     # 计算误差
mse = mean_squared_error(pre_flow, real_flow)
rmse = math.sqrt(mean_squared_error(pre_flow, real_flow))
mae = mean_absolute_error(pre_flow, real_flow)
print('均方误差---', mse)
print('均方根误差---', rmse)
print('平均绝对误差--', mae)# 画出预测结果图
font_set = FontProperties(fname=r"C:\Windows\Fonts\simsun.ttc", size=15)  # 中文字体使用宋体,15号
plt.figure(figsize=(15, 10))
plt.plot(real_flow, label='Real_Flow', color='r', )
plt.plot(pre_flow, label='Pre_Flow')
plt.xlabel('测试序列', fontproperties=font_set)
plt.ylabel('交通流量/辆', fontproperties=font_set)
plt.legend()
# 预测储存图片
# plt.savefig('...\Desktop\123.jpg')plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/460239.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Oracle视频基础1.1.3练习

1.1.3 需求: 完整格式查看所有用户进程里的oracle后台进程 查看物理网卡,虚拟网卡的ip地址 ps -ef | grep oracle /sbin/ifconfig要以完整格式查看所有用户进程中的 Oracle 后台进程,并查看物理和虚拟网卡的 IP 地址,可以使用以下…

【数据集】MODIS地表温度数据(MOD11)

【数据集】MODIS地表温度数据(MOD11) 数据概述MYD11A2数据下载MYD11A2 v006MYD11A2 v061参考MODIS(Moderate Resolution Imaging Spectroradiometer)是美国国家航空航天局(NASA)和美国国家海洋和大气管理局(NOAA)联合开发的一种遥感仪器,搭载于Terra和Aqua卫星上。MOD…

SpringBoot最佳实践之 - 项目中统一记录正常和异常日志

1. 前言 此篇博客是本人在实际项目开发工作中的一些总结和感悟。是在特定需求背景下,针对项目中统一记录日志(包括正常和错误日志)需求的实现方式之一,并不是普适的记录日志的解决方案。所以阅读本篇博客的朋友,可以参考此篇博客中记录日志的…

2024年优秀的天气预测API

准确、可操作的天气预报对于许多组织的成功至关重要。 事实上,在整个行业中,天气条件会直接影响日常运营,包括航运、按需、能源和供应链(仅举几例)。 以公用事业为例。根据麦肯锡的数据,在 1.4 年的时间里…

Tenda路由器 敏感信息泄露

0x01 产品描述: ‌ Tenda路由器‌是由深圳市吉祥腾达科技有限公司(Tenda)生产的一系列网络通信产品。Tenda路由器以其高性能、高性价比和广泛的应用场景而闻名,适合家庭、办公室和各种网络环境。0x02 漏洞描述&#xff1a…

net mvc中使用vue自定义组件遇到的坑

自定义一个ButtonCounter.js组件 export default {data() {return {count: 0}},template: <van-button type"primary" click"count">You clicked me {{ count }} times.</van-button> }按照官网文档的意思&#xff0c;组件命名需要大写驼峰命…

Python第六次作业

01.求第n项的斐波那契数列值 #求第n项的斐波那契数列值 #1、1、2、3、5、8、13、21、34…… #F(0)0&#xff0c;F(1)1, F(n)F(n - 1)F(n - 2)&#xff08;n ≥ 2&#xff0c;n ∈ N*&#xff09;def shulie ():print("求第n项的斐波那契数列值:",end"")xev…

Vue3 学习笔记(十三)Vue组件详解

1、组件&#xff08;Component&#xff09; 介绍 组件&#xff08;Component&#xff09;是 Vue.js 最强大的功能之一。 组件可以扩展 HTML 元素&#xff0c;封装可重用的代码&#xff0c;可以帮助你将用户界面拆分成独立和可复用的部分。 每个 Vue 组件都是一个独立的 Vue 实…

MySQL基础(二)

目录 一. 数据库命令行基本操作指令 1. 查看当前有哪些数据库——show databases; 2. 创建数据库——create database 数据库名 charset utf8 3. 选中数据库——use 数据库名; 4. 删除数据库——drop database 数据库名; 二. 常用数据类型 2.1 数值类型 2.2. 字符串类型 …

详细解读 CVPR2024:VideoBooth: Diffusion-based Video Generation with Image Prompts

Diffusion Models专栏文章汇总:入门与实战 前言:今天是程序员节,先祝大家节日快乐!文本驱动的视频生成正在迅速取得进展。然而,仅仅使用文本提示并不足以准确反映用户意图,特别是对于定制内容的创建。个性化图片领域已经非常成功了,但是在视频个性化领域才刚刚起步,这篇…

深度学习案例:带有一个隐藏层的平面数据分类

该案例来自吴恩达深度学习系列课程一《神经网络和深度学习》第三周编程作业&#xff0c;作业内容是设计带有一个隐藏层的平面数据分类。作业提供的资料包括测试实例&#xff08;testCases.py&#xff09;和任务功能包&#xff08;planar_utils.py&#xff09;&#xff0c;下载请…

SD教程 重绘 ControlNet-Inpain

SD教程 重绘 ControlNet-Inpain ———————————————— 版权声明&#xff1a;本文为博主原创文章&#xff0c;遵循 CC 4.0 BY-SA 版权协议&#xff0c;转载请附上原文出处链接和本声明。原文链接&#xff1a;https://blog.csdn.net/A1353192296/article/details/13…

【界面改版】JimuReport 积木报表 v1.9.0 版本发布,填报能力和大屏能力

项目介绍 积木报表JimuReport&#xff0c;是一款免费的数据可视化报表&#xff0c;含报表、仪表盘和大屏设计&#xff0c;像搭建积木一样完全在线设计&#xff01;功能涵盖&#xff1a;数据报表、打印设计、图表报表、门户设计、大屏设计等&#xff01; Web版报表设计器&#x…

【网络】1.UDP通信

UDP通信 1 server1.1 server建立的步骤1.2 运行server 2 client2.1 client的建立步骤2.2 运行client 3 总结3.1 server3.2 client 1 server server的启动方式是&#xff1a;./udpserver 8080 --> 格式就是./proc port端口 port端口自己指定 1.1 server建立的步骤 获取文件描…

告别冰冷机器声:GLM-4-Voice开启情感语音交互新时代!

目录 引言一、GLM-4-Voice概述二、GLM-4-Voice的架构三、GLM-4-Voice的主要功能四、GLM-4-Voice的技术原理五、GLM-4-Voice的应用场景六、GLM-4-Voice体验快速开始结语 引言 在人工智能的不断进步中&#xff0c;语音交互技术正逐渐成为人机沟通的重要桥梁。它不仅极大地提升了…

MySQL定时异机备份

场景&#xff1a;将A机器MySQL数据库部分表每日定时备份到B机器上 &#xff08;只适用于Linux&#xff09; 实现方式算是比简单了&#xff0c;就是用mysqldump生成文件&#xff0c;使用scp命令传输到另一台机器上。 1. 编写备份shell脚本 在A机器新建脚本 (当然没有vim的话vi…

使用VS2019将C#代码生成DLL文件在Unity3D里面使用(一)

系列文章目录 untiy知识点 文章目录 系列文章目录&#x1f449;前言&#x1f449;一、首先你要先有VS&#x1f449;二、引用UnityAPI使用步骤&#x1f449;2-1.引用unitydll文件到项目里面&#x1f449;2-2.导入Dll文件 &#x1f449;三、编辑dll代码&#x1f449;四、导出dll…

平台化运营公司如何在创业市场招商

在当今商业环境中&#xff0c;平台化运营的公司正成为推动经济发展的重要力量。对于这类公司而言&#xff0c;在创业市场招商意义重大。 平台化运营公司具有独特特点&#xff1a;通过搭建开放共享平台连接供需双方&#xff0c;实现资源优化配置与价值创造。比如电子商务平台、社…

聚类分析算法——K-means聚类 详解

K-means 聚类是一种常用的基于距离的聚类算法&#xff0c;旨在将数据集划分为 个簇。算法的目标是最小化簇内的点到簇中心的距离总和。下面&#xff0c;我们将从 K-means 的底层原理、算法步骤、数学基础、距离度量方法、参数选择、优缺点 和 源代码实现 等角度进行详细解析。…

SpringMVC执行流程(视图阶段JSP、前后端分离阶段)、面试题

目录 1.SpringMVC执行流程分为以下两种 2.非前后端分离的SpringMVC的执行流程 3.前后端分离的项目SpringMVC执行流程 4. 面试题 1.SpringMVC执行流程分为以下两种 2.非前后端分离的SpringMVC的执行流程 流程图&#xff1a; 更加生动的描述&#xff1a; DisPatcherServlet…