程序员学长 | 快速学会一个算法模型,LSTM

本文来源公众号“程序员学长”,仅用于学术分享,侵权删,干货满满。

原文链接:快速学会一个算法模型,LSTM

今天,给大家分享一个超强的算法模型,LSTM。

LSTM(Long Short-Term Memory)是一种特殊类型的循环神经网络(RNN),专门设计用来解决传统 RNN 在处理序列数据时面临的长期依赖问题

LSTM 的关键特征是其维持细胞状态的能力,细胞状态充当可以存储长序列信息的记忆单元。这使得 LSTM 能够随着时间的推移选择性地记住或忘记信息,使它们非常适合上下文和远程依赖性至关重要的任务。

LSTM 的核心组件

LSTM 单元由以下几个主要部分组成

案例分享

加载数据集
import numpy as np
import pandas as pd
from keras.models import Sequential, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.metrics import RootMeanSquaredError
from tensorflow.keras.optimizers import Adam
from keras.layers import LSTM, Dense, InputLayer
from sklearn.metrics import mean_squared_error as mse
from time import time
import matplotlib.pyplot as plt
import matplotlib
import warningspath_data = r'filter_pt_data.csv'
df = pd.read_csv(path_data)
df.dropna(inplace=True)
df['dt'] = pd.to_datetime(df['dt'])
df.set_index('dt', inplace=True)
df['Seconds'] = df.index.map(pd.Timestamp.timestamp)
year_secs = 60 * 60 * 24 * 365  # Number of seconds in a year
df['year_signal_sin'] = np.sin(df['Seconds'] * (2 * np.pi / year_secs))
df['year_signal_cos'] = np.cos(df['Seconds'] * (2 * np.pi / year_secs))
df.drop(columns=['Seconds'], inplace=True)
准备数据序列

LSTM 模型是专门为处理数据点序列而设计的,因此需要将数据转换为这种格式。

该方法涉及将预测问题转换为监督学习范式。在此设置中,输入 (X) 包含前面的 n 个数据点,而输出 (y) 表示后续时间步的目标值。

为了说明这个概念,假设我们正在使用包含三个特征(“a”、“b”和“c”)的数据集。我们的目标是预测特征 “a”。在这种情况下,我们的输入序列将包含三个时间戳,这意味着我们将检查三个连续时间点的特征值。

def create_sequences_unistep(data, n_steps):data_t = data.to_numpy()X = []y = []for i in range(len(data_t)-n_steps):row = [a for a in data_t[i:i+n_steps]]X.append(row)label = data_t[i+n_steps][0]y.append(label)return np.array(X), np.array(y)
创建模型
def train_model(X, y, X_val, y_val, n_steps, n_preds=1):n_features = X.shape[2]# Create lstm modelmodel = Sequential()model.add(InputLayer((n_steps, n_features)))model.add(LSTM(4, return_sequences=True))model.add(LSTM(5))model.add(Dense(n_preds, activation='linear'))# Compile modelmodel.compile(loss=MeanSquaredError(), optimizer=Adam(learning_rate=0.0001), metrics=[RootMeanSquaredError()])model.summary()# Save model with the least validation losscheckpoint_filepath = 'cps/best_model.h5'model_checkpoint_callback = ModelCheckpoint(filepath=checkpoint_filepath,monitor='val_loss',  # Monitor validation lossmode='min',          # Save the model with the minimum validation losssave_best_only=True)# Stop training if validation loss does not improve in 500 epochsearly_stopping_callback = EarlyStopping(monitor='val_loss',patience=50,  # Stop training if no improvement in validation loss for 100 epochsmode='min',verbose=1,restore_best_weights=True) # when finish train restore best model# Fit modelts = time()history = model.fit(X, y,verbose=2,epochs=500,validation_data=(X_val, y_val),callbacks=[model_checkpoint_callback, early_stopping_callback])tf = time()print('Time to train model: {} s'.format(round(tf - ts, 2)))# Plot loss evolutionplt.figure()plt.plot(history.history['loss'], label='loss')plt.plot(history.history['val_loss'], label='val_loss')plt.title('Training and Validation Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()plt.show()# Load best modeldel modelmodel = load_model(checkpoint_filepath)return model
模型训练

首先,让我们使用之前实现的函数生成序列。我们将分配 500 个值用于训练,50 个值用于验证,并将 “n_steps” 参数设置为 5。

def preprocess_input(X, mean, std):X[:, :, 0] = (X[:, :, 0] - mean) / stdreturn Xdef preprocess_output(y, mean, std):y = (y - mean) / stdreturn ydef postprocess_output(y, mean, std):y = (y * std) + meanreturn ydef plot_predictions_unistep(model, X_test, y_test, mean_ref, std_ref):preds = model.predict(X_test).flatten().tolist()# preprocess preds to actual scalepreds = [postprocess_output(i, mean_ref, std_ref) for i in preds]y_t = [postprocess_output(i, mean_ref, std_ref) for i in y_test.tolist()]er = mse(y_test, preds)plt.figure(figsize=(12, 8))plt.plot(y_t, label='Actual values')plt.plot(preds, label='Predictions', alpha=.7)plt.legend()plt.title('MSE = {}'.format(er))return predsn_steps = 5
X, y = create_sequences_unistep(df, n_steps)# Prepare train and validation data
nr_vals_train = 500
nr_vals_validation = 50X_train = X[:nr_vals_train]
y_train = y[:nr_vals_train]X_val = X[nr_vals_train: nr_vals_train + nr_vals_validation]
y_val = y[nr_vals_train: nr_vals_train + nr_vals_validation]X_test = X[nr_vals_train:]
y_test = y[nr_vals_train:]print('X train shape: {}'.format(X_train.shape))
print('y train shape: {}'.format(y_train.shape))print('X validation shape: {}'.format(X_val.shape))
print('y validation shape: {}'.format(y_val.shape))# Scale temp value with standard scaler -> mean 0 and std 1
mean_ref = np.mean(X_train[:, :, 0])
std_ref = np.std(X_train[:, :, 0])
# Scale X's
X_train = preprocess_input(X_train, mean_ref, std_ref)
X_val = preprocess_input(X_val, mean_ref, std_ref)
X_test = preprocess_input(X_test, mean_ref, std_ref)# Scale y's
y_train = preprocess_output(y_train, mean_ref, std_ref)
y_val = preprocess_output(y_val, mean_ref, std_ref)
y_test = preprocess_output(y_test, mean_ref, std_ref)model = train_model(X_train, y_train, X_val, y_val, n_steps)# Plot train predictions set
plot_predictions_unistep(model, X_train, y_train, mean_ref, std_ref)

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/364036.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【算法专题--栈】栈的压入、弹出序列 -- 高频面试题(图文详解,小白一看就懂!!)

目录 一、前言 二、题目描述 三、解题方法 💧栈模拟法💧-- 双指针 ⭐ 解题思路 ⭐ 案例图解 四、总结与提炼 五、共勉 一、前言 栈的压入、弹出序列 这道题,可以说是--栈专题--,最经典的一道题,也是在…

PD芯片OTG功能的应用 LDR6500

随着科技的飞速发展,智能手机、平板电脑等电子设备已经成为我们日常生活和工作中不可或缺的工具。这些设备的功能日益强大,应用场景也愈发广泛,但随之而来的是对充电和数据传输效率的高要求。在这一背景下,PD(Power De…

使用shell脚本编写监控系统资源(CPU,内存,磁盘)使用情况

🏡作者主页:点击! 🛠️Shell编程专栏:点击! ⏰️创作时间:2024年6月20日16点30分 🀄️文章质量:95分 目录 ————前言———— 1.本章目标 2.编写脚本 1.获取内…

新能源汽车 LabCar 测试系统方案(二)

什么是LabCar测试 LabCar测试目标是进行整车黄板台架功能测试,用于整车开发和测试阶段,满足设计人员和测试人员的试验需求,以验证整车性能,减少开发工作量。系统主要用于测试静态及动态工况下的纯电动汽车的各项功能实现情况。 …

CPN Tools学习——案例实操

基于CPN的空竭服务单重休假 M/G/1 型排队系统建模与分析 目录 1创建CPN模型 1.1函数和变量声明 1.2监控声明 2 创建分层CPN 3 添加monitor 3.1 Predicate谓词函数 3.2 Observer观察函数 3.3 Init function &Stop初始化和停止函数 4CPN Tools工具补充 1创建CPN模…

【ARM】MDK工程切换高版本的编译器后出现error A1137E报错

【更多软件使用问题请点击亿道电子官方网站】 1、 文档目标 解决工程从Compiler 5切换到Compiler 6进行编译时出现一些非语法问题上的报错。 2、 问题场景 对于一些使用Compiler 5进行编译的工程,要切换到Compiler 6进行编译的时候,原本无任何报错警告…

微信商家转账到零钱

1.发起商家转账 发起商家转账接口。商户可以通过该接口同时向多个用户微信零钱进行转账操作。请求消息中应包含商家批次单号、转账名称、appid、转账总金额、转账总笔数、转账openid、收款用户姓名等信息。注意受理成功将返回批次单号,此时并不代表转账成功&#x…

【涵子来信】——社交宝典:克服你心中的内向,世界总有缺陷

内向,你是内向的吗?想必每个人不同,面对的情形也是不同的。 暑假是一个很好的机会,我是可以去多社交社交。但是,面对着CSDN上这么多技术人er,那么,我的宝典,对于大家,有…

麒麟桌面系统CVE-2024-1086漏洞修复

原文链接:麒麟桌面操作系统上CVE-2024-1086漏洞修复 Hello,大家好啊!今天给大家带来一篇在麒麟桌面操作系统上修复CVE-2024-1086漏洞的文章。漏洞CVE-2024-1086是一个新的安全漏洞,如果不及时修复,可能会对系统造成安全…

MySQL之如何分析慢查询

1、一个SQL语句执行很慢,如何分析? 可使用“explain”或者“desc”命令获取MySQL如何执行select语句的信息。 语法:直接在select语句前加关键字 explain或desc explain select job_desc from xxl_job_info where id 1; 2、执行计划中五个重…

电商平台数据爬取经验分享

一、引言 在电商领域,数据的重要性不言而喻。无论是市场趋势分析、竞争对手研究,还是用户行为洞察,都离不开数据的支持。而数据爬虫作为获取这些数据的重要工具,其技术的掌握和运用对于电商平台来说至关重要。本文将结合个人实际…

什么是指令微调(LLM)

经过大规模数据预训练后的语言模型已经具备较强的模型能力,能够编码丰富的世界知识,但是由于预训练任务形式所限,这些模型更擅长于文本补全,并不适合直接解决具体的任务。 指令微调是相对“预训练”来讲的,预训练的时…

electron-builder 打包过慢解决

报错内容如下 > 6-241.0.0 build > electron-builder • electron-builder version24.13.3 os10.0.22631 • loaded configuration filepackage.json ("build" field) • writing effective config filedist\builder-effective-config.yaml • pack…

FinalShell:功能强大的 SSH 工具软件,Mac 和 Win 系统的得力助手

在当今数字化的时代,SSH 工具软件成为了许多开发者、运维人员以及技术爱好者不可或缺的工具。而 FinalShell 作为一款出色的中文 SSH 工具软件,无论是在 Mac 系统还是 Windows 系统上,都展现出了卓越的性能和便捷的使用体验。 FinalShell 拥…

详解ApplicationRunner和CommandLineRunner

一、前言 springBoot框架项目,有时候有预加载数据需求——提前加载到缓存中或类的属性中,并且希望执行操作的时间是在容器启动末尾时间执行操作。比如笔者工作中遇到了一个预加载redis中的缓存数据,加载为java对象。针对这种场景&#xff0c…

Linux /proc目录总结

1、概念 在Linux系统中,/proc目录是一个特殊的文件系统,通常被称为"proc文件系统"或"procfs"。这个文件系统以文件系统的方式为内核与进程之间的通信提供了一个接口。/proc目录中的文件大多数都提供了关于系统状态的信息&#xff0…

51-52Windows密码安全性测试与Windows提权

目录 Windows密码安全性测试 一、本地管理员密码如何直接提取 1、直接通过mimikatz读取管理员密码 2、使用laZagne工具读取管理员密码 二、利用Hash远程登录系统 window提权 三、远程webshell执行命令解决 不能执行原因: 解决方法:单独上传cmd.e…

利用python爬取上证指数股吧评论并保存到mongodb数据库

大家好,我是带我去滑雪! 东方财富网是中国领先的金融服务网站之一,以提供全面的金融市场数据、资讯和交易工具而闻名。其受欢迎的“股吧”论坛特别适合爬取股票评论,东方财富网的股吧聚集了大量投资者和金融分析师,他们…

夏令营1期-对话分角色要素提取挑战赛-第①次打卡

零基础入门大模型技术竞赛 简介: 本次学习是 Datawhale 2024 年 AI 夏令营第一期,学习活动基于讯飞开放平台“基于星火大模型的群聊对话分角色要素提取挑战赛”开展实践学习。 适合想 入门并实践大模型 API 开发、了解如何微调大模型的学习者参与 快来…

【C++】哈希表

目录 一、unordered系列关联式容器 二、哈希 2.1 概念 2.2 哈希冲突 2.3 哈希函数 (1)直接定址法 (2)除留余数法 (3)平方取中法 (4)折叠法 (5)随机…