第R4周:LSTM-火灾温度预测

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

    文章目录

    • 一、代码流程
      • 1、导入包,设置GPU
      • 2、导入数据
      • 3、数据集可视化
      • 4、数据集预处理
      • 5、设置X,y
      • 6、划分数据集
      • 7、构建模型
      • 8、定义训练函数
      • 9、定义测试函数
      • 10、正式训练
      • 11、模型评估- LOSS图
      • 12、调用模型进行训练

电脑环境:
语言环境:Python 3.8.0

一、代码流程

1、导入包,设置GPU

import torch.nn.functional as F
import torch.nn as nn
import torch
import numpy as np
import pandas as pd

2、导入数据

data = pd.read_csv('woodpine2.csv')
data
	Time	Tem1	CO 1	Soot 1
0	0.000	25.0	0.000000	0.000000
1	0.228	25.0	0.000000	0.000000
2	0.456	25.0	0.000000	0.000000
3	0.685	25.0	0.000000	0.000000
4	0.913	25.0	0.000000	0.000000
...	...	...	...	...
5943	366.000	295.0	0.000077	0.000496
5944	366.000	294.0	0.000077	0.000494
5945	367.000	292.0	0.000077	0.000491
5946	367.000	291.0	0.000076	0.000489
5947	367.000	290.0	0.000076	0.000487
5948 rows × 4 columns

3、数据集可视化

from os import confstr_names
import matplotlib.pyplot as plt
import seaborn as snsplt.rcParams['figure.dpi'] = 500
plt.rcParams['savefig.dpi'] = 500fig, ax = plt.subplots(1, 3, constrained_layout=True, figsize=(14, 3))
sns.lineplot(data=data['Tem1'], ax=ax[0])
sns.lineplot(data=data['CO 1'], ax=ax[1])
sns.lineplot(data=data['Soot 1'], ax=ax[2])
plt.show()

在这里插入图片描述

dataFrame = data.iloc[:, 1:]
dataFrame

4、数据集预处理

from sklearn.preprocessing import MinMaxScalerdataFrame = data.iloc[:, 1:].copy()scaler = MinMaxScaler(feature_range=(0, 1))for i in ['CO 1', 'Soot 1', 'Tem1']:dataFrame[i] = scaler.fit_transform(dataFrame[i].values.reshape(-1, 1))  dataFrame.shape

(5948, 3)

5、设置X,y

width_X = 8
width_Y = 1X = []
y = []in_start = 0for _, _ in data.iterrows():in_end = in_start + width_Xout_end = in_end + width_Yif out_end < len(dataFrame):X_ = np.array(dataFrame.iloc[in_start:in_end, :])y_ = np.array(dataFrame.iloc[in_end:out_end, :])X.append(X_)y.append(y_)in_start += 1 
X = np.array(X)
y = np.array(y)X.shape, y.shape

((5939, 8, 3), (5939, 1, 1))

检查数据集中是否有空值

print(np.any(np.isnan(X)))
print(np.any(np.isnan(y)))

6、划分数据集

X_train = torch.tensor(np.array(X[:5000]), dtype=torch.float32)
y_train = torch.tensor(np.array(y[:5000]), dtype=torch. float32)X_test = torch.tensor(np.array(X[5000:]), dtype=torch.float32)
y_test = torch.tensor(np.array(y[5000:]), dtype=torch. float32)X_train.shape, y_train.shape

(torch.Size([5000, 8, 3]), torch.Size([5000, 1, 3]))

from torch.utils.data import TensorDataset, DataLoader
train_dl = DataLoader(TensorDataset(X_train, y_train),batch_size=64,shuffle=False)
test_dl = DataLoader(TensorDataset(X_test, y_test),batch_size=64,shuffle=False)

7、构建模型

class model_lstm(nn.Module):def __init__(self):super(model_lstm, self).__init__()self.lstm0 = nn.LSTM(input_size=3, hidden_size=320,num_layers=1, batch_first=True)self.lstm1 = nn.LSTM(input_size=320, hidden_size=320,num_layers=1, batch_first=True)self.fc0 = nn.Linear(320, 1)def forward(self, x):out, hidden1 = self.lstm0(x)out, _ = self. lstm1(out, hidden1)out = self.fc0(out)return out[:, -1:, :]#取2个预测值,否则经过1stm会得到8*2个预
model = model_lstm()
model

8、定义训练函数

import copy
def train(train_dl, model, loss_fn, opt, lr_scheduler=None) :size = len(train_dl.dataset)num_batches = len(train_dl)train_loss = 0 #初始化训练损失和正确率for x, y in train_dl:x, y = x.to(device), y.to(device)#计算预测误差pred = model(x) #网络输出loss = loss_fn(pred, y) #计算网络输出和真实值之间的差距# 反向传播opt.zero_grad()#grad属性归零loss.backward()# 反向传播opt.step()# 每一步自动更新#记录Losstrain_loss += loss. item()if lr_scheduler is not None:lr_scheduler.step()print ("learning rate = {:.5f}". format(opt.param_groups[0]['lr']), end='  ')train_loss /= num_batchesreturn train_loss

9、定义测试函数

def test (dataloader, model, loss_fn) :size = len(dataloader.dataset) #测试集的大小num_batches = len(dataloader)# 批次数目test_loss = 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for x, y in dataloader:x, y = x.to(device), y.to(device)# 计算lossy_pred = model(x)loss = loss_fn(y_pred, y)test_loss += loss.item()test_loss /= num_batchesreturn test_loss

10、正式训练

# 设置GPU训练
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")#训练模型
model = model_lstm()
model = model.to(device)
loss_fn = nn.MSELoss() #创建损失函数
learn_rate = 1e-1 #学习率
opt = torch.optim.SGD(model.parameters(),lr=learn_rate, weight_decay=1e-4)
epochs = 50
train_loss = []
test_loss = []
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, epochs, last_epoch=-1)for epoch in range(epochs):model.train()epoch_train_loss = train(train_dl, model, loss_fn, opt, lr_scheduler)model.eval()epoch_test_loss = test(test_dl, model, loss_fn)train_loss.append(epoch_train_loss)test_loss.append(epoch_test_loss)template = ('Epoch: {:2d}, Train loss: {:.5f}, Test loss: {:.5f}')print(template.format(epoch+1, epoch_train_loss,epoch_test_loss))
print("="*20, 'Done', "="*70)

11、模型评估- LOSS图

import matplotlib.pyplot as plt
plt. figure(figsize=(5, 3), dpi=120)plt.plot(train_loss, label='LSTM Training Loss')
plt.plot(test_loss, label='LSTM Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

在这里插入图片描述

12、调用模型进行训练

predicted_y_lstm = sc.inverse_transform(model(X_test).detach().numpy().reshape(-1,1))
y_test_1 = sc.inverse_transform(y_test.reshape(-1,1))
y_test_one = [i[0] for i in y_test_1]
predicted_y_lstm_one = [i[0] for i in predicted_y_lstm]plt.figure(figsize=(5, 3) , dpi=120)
# 画出真实数据和预测数据的对比曲线
plt.plot(y_test_one[:2000], color='red', label='real_temp')
plt.plot(predicted_y_lstm_one[:2000], color='blue', label='prediction')
plt.title('Title')
plt.xlabel('X')
plt.ylabel('y')
plt.legend ()
plt.show( )

在这里插入图片描述

from sklearn import metrics
'''
RMSE:均方根误差--->对均方误差开方
R2:决定系数,可以简单理解为反映模型拟合优度的重要的统计量
'''
RMSE_lstm = metrics.mean_squared_error(predicted_y_lstm_one, y_test_1)**0.5
R2_lstm = metrics.r2_score(predicted_y_lstm_one, y_test_1)
print('均方根误差:%.5f' % RMSE_lstm)
print('R2: %.5f' % R2_lstm)

均方根误差:7.01314
R2: 0.82595

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/1804.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机组存储系统

局部性 理论 程序执行&#xff0c;会不均匀访问主存&#xff0c;有些被频繁访问&#xff0c;有些很少被访问 时间局部性 被用到指令&#xff0c;不久可能又被用到 产生原因是大量循环操作 空间局部性 某个数据和指令被使用&#xff0c;附近数据也可能使用 主要原因是顺序存…

Windows图形界面(GUI)-QT-C/C++ - Qt图形绘制详解

公开视频 -> 链接点击跳转公开课程博客首页 -> ​​​链接点击跳转博客主页 目录 Qt绘图基础 QPainter概述 基本工作流程 绘图事件系统 paintEvent事件 重绘机制 文字绘制技术 基本文字绘制 ​编辑 高级文字效果 基本图形绘制 线条绘制 ​编辑 形状绘制 …

mapbox进阶,添加绘图控件

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:mapbox 从入门到精通 文章目录 一、🍀前言1.1 ☘️mapboxgl.Map 地图对象1.2 ☘️mapboxgl.Map style属性1.3 ☘️MapboxDraw 绘图控件二、🍀添加绘图控…

client-go 的 QPS 和 Burst 限速

1. 什么是 QPS 和 Burst &#xff1f; 在 kubernetes client-go 中&#xff0c;QPS 和 Burst 是用于控制客户端与 Kubernetes API 交互速率的两个关键参数&#xff1a; QPS (Queries Per Second) 定义&#xff1a;表示每秒允许发送的请求数量&#xff0c;即限速器的平滑速率…

WeakAuras NES Script(lua)

WeakAuras NES Script 修星脚本字符串 脚本1&#xff1a;NES !WA:2!TMZFWXX1zDxVAs4siiRKiBN4eV(sTRKZ5Z6opYbhQQSoPtsxr(K8ENSJtS50(J3D7wV3UBF7E6hgmKOXdjKsgAvZFaPTtte0mD60XdCmmecDMKruyykDcplAZiGPfWtSsag6myGuOuq89EVDV9wPvKeGBM7U99EFVVVV33VFFB8Z2TJ8azYMlZj7Ur3QDR(…

【make】makefile 函数全解

目录 makefile简介函数全解介绍相关链接字符串处理函数subst 函数—字符串替换patsubst 函数 — 模式字符串替换strip 函数 — 去空格findstring 函数 — 查找字符串filter 函数 — 过滤器filter-out 函数 — 过滤器sort 函数 — 排序word 函数 — 取单词wordlist函数 — 取一串…

Android 15应用适配指南:所有应用的行为变更

Android系统版本适配&#xff0c;一直是影响App上架Google Play非常重要的因素。 当前Google Play政策规定 新应用和应用更新 必须以 Android 14&#xff08;API 级别 34&#xff09;为目标平台&#xff0c;才能提交到Google Play。现有应用 必须以 Android 13&#xff08;AP…

Java Agent(三)、ASM 操作字节码入门

目录 1、前言 2、什么是ASM&#xff1f; 2.1、工作流程 2.2、ASM集合核心API 2.1.1、ClassReader 2.1.2、ClassWriter 2.1.3、 ClassVisitor 2.1.4、MethodVisitor 2.1.5、 FieldVisitor 2.1.6、Opcodes 3、简单示例 3.1、maven依赖 3.2、hello world 3.3、执行结…

MySQL数据库(SQL分类)

SQL分类 分类全称解释DDLData Definition Language数据定义语言&#xff0c;用来定义数据库对象&#xff08;数据库&#xff0c;表&#xff0c;字段&#xff09;DMLData Manipulation Language数据操作语言&#xff0c;用来对数据库表中的数据进行增删改DQLData Query Languag…

[微服务]redis数据结构

介绍 我们常用的Redis数据类型有5种&#xff0c;分别是&#xff1a; StringListSetSortedSetHash 还有一些高级数据类型&#xff0c;比如Bitmap、HyperLogLog、GEO等&#xff0c;其底层都是基于上述5种基本数据类型。因此在Redis的源码中&#xff0c;其实只有5种数据类型。 …

PyQt5

PyQt5 环境搭建安装 pycharm安装 PyQt5 打包成exe安装 pyinstaller打包 报错进程已结束&#xff0c;退出代码-1073740791&#xff08;0xC0000409&#xff09; 环境搭建 安装 pycharm 安装 PyQt5 pip install pyqt5 -i https://pypi.tuna.tsinghua.edu.cn/simplepip install …

高级运维:shell练习2

1、需求&#xff1a;判断192.168.1.0/24网络中&#xff0c;当前在线的ip有哪些&#xff0c;并编写脚本打印出来。 vim check.sh #!/bin/bash# 定义网络前缀 network_prefix"192.168.1"# 循环遍历1-254的IP for i in {1..254}; do# 构造完整的IP地址ip"$network_…

Grails应用http.server.requests指标数据采集问题排查及解决

问题 遇到的问题&#xff1a;同一个应用&#xff0c;Spring Boot(Java)和Grails(Groovy)混合编程&#xff0c;常规的Spring Controller&#xff0c;可通过Micromete Pushgateway&#xff0c; 采集到http.server.requests指标数据&#xff0c;注意下面的指标名称是点号&#…

pycharm+pyside6+desinger实现查询汉字笔顺GIF动图

一、引言 这学期儿子语文期末考试有一道这样的题目&#xff1a; 这道题答案是B&#xff0c;儿子做错了选了C。我告诉他“车字旁”和“车”的笔顺是不一样的&#xff0c;因为二者有一个笔画是不一样的&#xff0c;“车字旁”下边那笔是“提”&#xff0c;而“车”字是“横”&am…

【2025 Rust学习 --- 17 文本和格式化 】

字符串与文本 Rust 的主要文本类型 String、str 和 char 内容概括&#xff1a; Unicode 背景知识&#xff1f;单个 Unicode 码点的 char&#xff1f;String 类型和 str 类型都是表示拥有和借用的 Unicode 字符序列。Rust 的字符串格式化工具&#xff0c;比如 println! 宏和 …

EasyCVR视频汇聚平台如何配置webrtc播放地址?

EasyCVR安防监控视频系统采用先进的网络传输技术&#xff0c;支持高清视频的接入和传输&#xff0c;能够满足大规模、高并发的远程监控需求。平台支持多协议接入&#xff0c;能将接入到视频流转码为多格式进行分发&#xff0c;包括RTMP、RTSP、HTTP-FLV、WebSocket-FLV、HLS、W…

rknn环境搭建之docker篇

目录 1. rknn简介2. 环境搭建2.1 下载 RKNN-Toolkit2 仓库2.2 下载 RKNN Model Zoo 仓库2.3 下载交叉编译器2.4 下载Docker镜像2.5 下载ndk2.5 加载docker镜像2.6 docker run 命令创建并运行 RKNN Toolkit2 容器2.7 安装cmake 3. 模型转换3.1 下载模型3.2 模型转换 4. 编译cdem…

【MySQL实战】mysql_exporter+Prometheus+Grafana

要在Prometheus和Grafana中监控MySQL数据库&#xff0c;如下图&#xff1a; 可以使用mysql_exporter。 以下是一些步骤来设置和配置这个监控环境&#xff1a; 1. 安装和配置Prometheus&#xff1a; - 下载和安装Prometheus。 - 在prometheus.yml中配置MySQL通过添加以下内…

W25Q64-FLASH

前言&#xff1a; 1.理解flash的组织结构&#xff0c;block块, sector扇区&#xff0c;page页&#xff0c;之间的结构怎么组织安排划分的。 2.理解flash的特性&#xff0c;只能从1写为0&#xff0c;不能从0写为1&#xff0c;这就是为什么写之前要先擦除操作。(这个特性一直困扰…

FPGA EDA软件的位流验证

位流验证&#xff0c;对于芯片研发是一个非常重要的测试手段&#xff0c;对于纯软件开发人员&#xff0c;最难理解的就是位流验证。在FPGA芯片研发中&#xff0c;位流验证是在做什么&#xff0c;在哪些阶段需要做位流验证&#xff0c;如何做&#xff1f;都是问题。 我们先整体的…