第二十一周学习周报

目录

    • 摘要
    • Abstract
    • 1. LSTM原理
    • 2. LSTM反向传播的数学推导
    • 3. LSTM模型训练实战
    • 总结

摘要

本周的学习内容是对LSTM相关内容的复习,LSTM被设计用来解决标准RNN在处理长序列数据时遇到的梯度消失和梯度爆炸问题。LSTM通过引入门控机制来控制信息的流动,从而有效地缓解了梯度消失问题,这使得LSTM在各种时序数据处理场景中有更为优秀的表现。通过使用LSTM模型预测股票市场数据的涨幅趋势和对LSTM反向传播的数学推导加深对LSTM的理解。

Abstract

This week’s learning content is a review of LSTM related content. LSTM were designed to address the vanishing and exploding gradient problems that standard RNN encounter when dealing with long sequence data. By introducing gating mechanisms to control the flow of information, LSTM effectively alleviates the vanishing gradient issue, making them perform more excellently in various time series data processing scenarios. The understanding of LSTM is deepened by using LSTM models to predict the trend of stock market data increases and by mathematically deriving the backpropagation of LSTM.

1. LSTM原理

LSTM结构如下图所示:
在这里插入图片描述
LSTM有3个gate:
(1)Input Gate:控制data输入储存单元,外界neural的输出想要写入存储单元时要经过它。
(2)Output Gate:控制data从存储单元中输出,外界neural想要从存储单元中读出值需要经过它。
(3)Forget Gate:决定什么时候要把存储单元中的data清除。
这三个Gate的开关都是神经网络自己学习的,它可以自己学习什么时候开门,什么时候关门。综上,整个LSTM有4个输入,1个输出。
在这里插入图片描述
将上述模型进一步细化,如上图所示。假定
我们要存入单元的输入为z,
控制input Gate的信号为zi,
控制Output Gate的信号为zo,
控制Forget Gate的信号为zf,

把z输入通过激活函数得到g(z),zi输入通过激活函数得到f(zi),这里用到的激活函数通常都为sigmoid函数,因为经过sigmoid函数所得值介于0和1之间,可以以此判断门是打开还是关闭的,把g(z)乘上f(zi)得到g(z)f(zi)。
zf通过激活函数得到f(zf),c’=g(z)f(zi)+cf(zf)。c’为重新存入单元的值,c’经sigmoid函数得到h(c‘)。最后由h(c‘)乘上f(zo)得到最终的输出a。输出门受f(zo) 所操控f(zo)等于 1 的话,就说明 h(c′) 能通过,f(zo) 等于 0 的话,说明记忆元里面存在的值没有办法通过输出门被读取出来。其他gate的处理与此类似。

2. LSTM反向传播的数学推导

LSTM前向传播的示意图如下所示:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在做反向传播的过程中我们发现,当时间跨度越大时,Loss与可调控参数W间的路径就越多越复杂,路径越多连乘的项也就越多,累乘项多就会带来更多的梯度消失的可能。我们可以通过调整可调控参数的大小从而抵消在传递过程种一些连乘项对模型的影响,从而降低了梯度消失的可能。

3. LSTM模型训练实战

本次实训通过Python的Keras库构建LSTM模型,旨在预测时间序列中的步骤和序列在股票市场数据的应用。
实验过程如下:
模型用到的库有Keras 、NumPy 、Matplotlib等
导入代码如下所示:

import numpy as np
import datetime as dt
from numpy import newaxis
from core.utils import Timer
from keras.layers import Dense, Activation, Dropout, LSTM
from keras.models import Sequential, load_model
from keras.callbacks import EarlyStopping, ModelCheckpoint
import matplotlib.pyplot as plt

3.1 数据处理
DataLoader 是一个数据处理工具,能够从 CSV 文件中读取数据,生成训练和测试数据窗口,并在必要时对数据进行归一化。这些功能为 LSTM 模型的训练和评估提供了便利,确保模型可以处理时间序列数据。代码如下所示:

class DataLoader():"""A class for loading and transforming data for the lstm model"""def __init__(self, filename, split, cols):dataframe = pd.read_csv(filename)i_split = int(len(dataframe) * split)self.data_train = dataframe.get(cols).values[:i_split]self.data_test  = dataframe.get(cols).values[i_split:]self.len_train  = len(self.data_train)self.len_test   = len(self.data_test)self.len_train_windows = Nonedef get_test_data(self, seq_len, normalise):'''Create x, y test data windowsWarning: batch method, not generative, make sure you have enough memory toload data, otherwise reduce size of the training split.'''data_windows = []for i in range(self.len_test - seq_len):data_windows.append(self.data_test[i:i+seq_len])data_windows = np.array(data_windows).astype(float)data_windows = self.normalise_windows(data_windows, single_window=False) if normalise else data_windowsx = data_windows[:, :-1]y = data_windows[:, -1, [0]]return x,ydef get_train_data(self, seq_len, normalise):'''Create x, y train data windowsWarning: batch method, not generative, make sure you have enough memory toload data, otherwise use generate_training_window() method.'''data_x = []data_y = []for i in range(self.len_train - seq_len):x, y = self._next_window(i, seq_len, normalise)data_x.append(x)data_y.append(y)return np.array(data_x), np.array(data_y)def generate_train_batch(self, seq_len, batch_size, normalise):'''Yield a generator of training data from filename on given list of cols split for train/test'''i = 0while i < (self.len_train - seq_len):x_batch = []y_batch = []for b in range(batch_size):if i >= (self.len_train - seq_len):# stop-condition for a smaller final batch if data doesn't divide evenlyyield np.array(x_batch), np.array(y_batch)i = 0x, y = self._next_window(i, seq_len, normalise)x_batch.append(x)y_batch.append(y)i += 1yield np.array(x_batch), np.array(y_batch)def _next_window(self, i, seq_len, normalise):'''Generates the next data window from the given index location i'''window = self.data_train[i:i+seq_len]window = self.normalise_windows(window, single_window=True)[0] if normalise else windowx = window[:-1]y = window[-1, [0]]return x, ydef normalise_windows(self, window_data, single_window=False):'''Normalise window with a base value of zero'''normalised_data = []window_data = [window_data] if single_window else window_datafor window in window_data:normalised_window = []for col_i in range(window.shape[1]):normalised_col = [((float(p) / float(window[0, col_i])) - 1) for p in window[:, col_i]]normalised_window.append(normalised_col)normalised_window = np.array(normalised_window).T # reshape and transpose array back into original multidimensional formatnormalised_data.append(normalised_window)return np.array(normalised_data)

3.2 LSTM模型的搭建
Model 类封装了构建、训练和预测 LSTM 模型的功能,支持多种训练方式和预测方法。它能够处理不同层的添加、模型的编译和训练,并提供了灵活的预测方法,适用于时间序列数据的建模。
代码如下所示:

class Model():"""A class for an building and inferencing an lstm model"""def __init__(self):self.model = Sequential()def load_model(self, filepath):print('[Model] Loading model from file %s' % filepath)self.model = load_model(filepath)def build_model(self, configs):timer = Timer()timer.start()for layer in configs['model']['layers']:neurons = layer['neurons'] if 'neurons' in layer else Nonedropout_rate = layer['rate'] if 'rate' in layer else Noneactivation = layer['activation'] if 'activation' in layer else Nonereturn_seq = layer['return_seq'] if 'return_seq' in layer else Noneinput_timesteps = layer['input_timesteps'] if 'input_timesteps' in layer else Noneinput_dim = layer['input_dim'] if 'input_dim' in layer else Noneif layer['type'] == 'dense':self.model.add(Dense(neurons, activation=activation))if layer['type'] == 'lstm':self.model.add(LSTM(neurons, input_shape=(input_timesteps, input_dim), return_sequences=return_seq))if layer['type'] == 'dropout':self.model.add(Dropout(dropout_rate))self.model.compile(loss=configs['model']['loss'], optimizer=configs['model']['optimizer'])print('[Model] Model Compiled')timer.stop()def train(self, x, y, epochs, batch_size, save_dir):timer = Timer()timer.start()print('[Model] Training Started')print('[Model] %s epochs, %s batch size' % (epochs, batch_size))save_fname = os.path.join(save_dir, '%s-e%s.h5' % (dt.datetime.now().strftime('%d%m%Y-%H%M%S'), str(epochs)))callbacks = [EarlyStopping(monitor='val_loss', patience=2),ModelCheckpoint(filepath=save_fname, monitor='val_loss', save_best_only=True)]self.model.fit(x,y,epochs=epochs,batch_size=batch_size,callbacks=callbacks)self.model.save(save_fname)print('[Model] Training Completed. Model saved as %s' % save_fname)timer.stop()def train_generator(self, data_gen, epochs, batch_size, steps_per_epoch, save_dir):timer = Timer()timer.start()print('[Model] Training Started')print('[Model] %s epochs, %s batch size, %s batches per epoch' % (epochs, batch_size, steps_per_epoch))save_fname = os.path.join(save_dir, '%s-e%s.h5' % (dt.datetime.now().strftime('%d%m%Y-%H%M%S'), str(epochs)))callbacks = [ModelCheckpoint(filepath=save_fname, monitor='loss', save_best_only=True)]self.model.fit_generator(data_gen,steps_per_epoch=steps_per_epoch,epochs=epochs,callbacks=callbacks,workers=1)print('[Model] Training Completed. Model saved as %s' % save_fname)timer.stop()def predict_point_by_point(self, data):#Predict each timestep given the last sequence of true data, in effect only predicting 1 step ahead each timeprint('[Model] Predicting Point-by-Point...')predicted = self.model.predict(data)predicted = np.reshape(predicted, (predicted.size,))return predicteddef predict_sequences_multiple(self, data, window_size, prediction_len):#Predict sequence of 50 steps before shifting prediction run forward by 50 stepsprint('[Model] Predicting Sequences Multiple...')prediction_seqs = []for i in range(int(len(data)/prediction_len)):curr_frame = data[i*prediction_len]predicted = []for j in range(prediction_len):predicted.append(self.model.predict(curr_frame[newaxis,:,:])[0,0])curr_frame = curr_frame[1:]curr_frame = np.insert(curr_frame, [window_size-2], predicted[-1], axis=0)prediction_seqs.append(predicted)return prediction_seqsdef predict_sequence_full(self, data, window_size):#Shift the window by 1 new prediction each time, re-run predictions on new windowprint('[Model] Predicting Sequences Full...')curr_frame = data[0]predicted = []for i in range(len(data)):predicted.append(self.model.predict(curr_frame[newaxis,:,:])[0,0])curr_frame = curr_frame[1:]curr_frame = np.insert(curr_frame, [window_size-2], predicted[-1], axis=0)return predicted

以其中train方法为例介绍:

def train(self, x, y, epochs, batch_size, save_dir):timer = Timer()timer.start()print('[Model] Training Started')print('[Model] %s epochs, %s batch size' % (epochs, batch_size))

train 方法接收训练数据 x 和 y,训练的轮数 epochs,批量大小 batch_size 和保存模型的目录 save_dir。
3.3 数据的预测和可视化

绘图函数plot_results_multiple:

def plot_results_multiple(predicted_data, true_data, prediction_len):fig = plt.figure(facecolor='white')ax = fig.add_subplot(111)ax.plot(true_data, label='True Data')for i, data in enumerate(predicted_data):padding = [None for p in range(i * prediction_len)]plt.plot(padding + data, label='Prediction')plt.legend()plt.show()

plot_results_multiple的功能是绘制多组预测数据与真实数据的对比图。首先绘制真实数据。对于每组预测数据,先创建填充(padding),以便将每组预测数据在图上正确对齐。通过plt.plot绘制每组预测数据,并使用图例标识。最后显示绘制的图形。
主函数main:

def main():configs = json.load(open('config.json', 'r'))if not os.path.exists(configs['model']['save_dir']): os.makedirs(configs['model']['save_dir'])data = DataLoader(os.path.join('data', configs['data']['filename']),configs['data']['train_test_split'],configs['data']['columns'])model = Model()model.build_model(configs)x, y = data.get_train_data(seq_len=configs['data']['sequence_length'],normalise=configs['data']['normalise'])

主函数,用于执行整个模型的训练与预测流程。从config.json中加载配置文件,获取模型保存目录和数据相关信息。如果保存目录不存在,则创建该目录。使用DataLoader类加载数据,包括文件路径、训练测试划分比例及需要的列。创建模型实例并构建模型。从数据集中获取训练数据(x为特征,y为标签)。
模型训练与预测

steps_per_epoch = math.ceil((data.len_train - configs['data']['sequence_length']) / configs['training']['batch_size'])
model.train_generator(data_gen=data.generate_train_batch(seq_len=configs['data']['sequence_length'],batch_size=configs['training']['batch_size'],normalise=configs['data']['normalise']),epochs=configs['training']['epochs'],batch_size=configs['training']['batch_size'],steps_per_epoch=steps_per_epoch,save_dir=configs['model']['save_dir']
)
x_test, y_test = data.get_test_data(seq_len=configs['data']['sequence_length'],normalise=configs['data']['normalise']
)predictions = model.predict_sequences_multiple(x_test, configs['data']['sequence_length'], configs['data']['sequence_length'])
plot_results_multiple(predictions, y_test, configs['data']['sequence_length'])

该部分的代码代码则是实现了基于生成器的训练方式,适用于数据量较大的情况。计算每个epoch的步骤数(steps_per_epoch),使用生成器训练模型。data.generate_train_batch生成训练数据的批次。从数据集中获取测试数据(x_test和y_test)。使用模型进行多序列预测。调用plot_results_multiple函数绘制预测结果与真实结果的对比图。

模型的训练过程如下所示:
在这里插入图片描述
模型的预测结果如下所示:
在这里插入图片描述
可以看到,预训模型预测的趋势与实际股票趋势比较吻合!

总结

通过本周的复习我对LSTM有了进一步的理解,不同于传统的神经网络,LSTM由4个input和一个Output组成,它通过引入复杂的门控机制和内部状态来增强模型的记忆能力和对序列数据的理解。通过LSTM的反向传播数学推导明白了LSTM为什么可以缓解RNN中的梯度消失问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/473834.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《Spring 基础之 IoC 与 DI 入门指南》

一、IoC 与 DI 概念引入 Spring 的 IoC&#xff08;控制反转&#xff09;和 DI&#xff08;依赖注入&#xff09;在 Java 开发中扮演着至关重要的角色&#xff0c;是提升代码质量和可维护性的关键技术。 &#xff08;一&#xff09;IoC 的含义及作用 IoC 全称为 Inversion of…

Vulnhub靶场案例渗透[9]- HackableIII

文章目录 一、靶场搭建1. 靶场描述2. 下载靶机环境3. 靶场搭建 二、渗透靶场1. 确定靶机IP2. 探测靶场开放端口及对应服务3. 扫描网络目录结构4. 敏感数据获取5. 获取shell6. 提权6.1 敏感信息获取6.2 lxd提权 一、靶场搭建 1. 靶场描述 Focus on general concepts about CTF…

抖音热门素材去哪找?优质抖音视频素材网站推荐!

是不是和我一样&#xff0c;刷抖音刷到停不下来&#xff1f;越来越多的朋友希望在抖音上创作出爆款视频&#xff0c;但苦于没有好素材。今天就来推荐几个超级实用的抖音视频素材网站&#xff0c;让你的视频内容立刻变得高大上&#xff01;这篇满是干货&#xff0c;直接上重点&a…

如何轻松导出所有 WordPress URL 为纯文本格式

作为一名多年的 WordPress 使用者&#xff0c;我深知管理一个网站的复杂性。从迁移网站、设置重定向到整理内容结构&#xff0c;每一步都需要精细处理。而拥有所有 URL 的清单&#xff0c;不仅能让这些工作变得更加简单&#xff0c;还能为后续的管理提供极大的便利。其实&#…

vue项目使用eslint+prettier管理项目格式化

代码格式化、规范化说明 使用eslintprettier进行格式化&#xff0c;vscode中需要安装插件ESLint、Prettier - Code formatter&#xff0c;且格式化程序选择为后者&#xff08;vue文件、js文件要分别设置&#xff09; 对于eslint规则&#xff0c;在格式化时不会全部自动调整&…

Ubuntu 18.04 配置sources.list源文件(无法安全地用该源进行更新,所以默认禁用该源)

如果你 sudo apt update 时出现诸如 无法安全地用该源进行更新&#xff0c;所以默认禁用该源 的错误&#xff0c;那就换换源吧&#xff0c;链接&#xff1a; https://mirror.tuna.tsinghua.edu.cn/help/ubuntu/ 注意版本&#xff1a; 修改源文件&#xff1a; sudo nano /etc…

5. langgraph中的react agent使用 (从零构建一个react agent)

1. 定义 Agent 状态 首先&#xff0c;我们需要定义 Agent 的状态&#xff0c;这包括 Agent 所持有的消息。 from typing import (Annotated,Sequence,TypedDict, ) from langchain_core.messages import BaseMessage from langgraph.graph.message import add_messagesclass …

【网络】什么是交换机?switch

交换机&#xff08;Switch&#xff09;意为“开关”&#xff0c;是一种用于电&#xff08;光&#xff09;信号转发的网络设备。以下是关于交换机的详细解释&#xff1a; 一、交换机的基本定义 功能&#xff1a;交换机能为接入交换机的任意两个网络节点提供独享的电信号通路&am…

【AlphaFold3】开源本地的安装及使用

文章目录 安装安装DockerInstalling Docker on Host启用Rootless Docker 安装 GPU 支持安装 NVIDIA 驱动程序安装 NVIDIA 对 Docker 的支持 获取 AlphaFold 3 源代码获取基因数据库获取模型参数构建将运行 AlphaFold 3 的 Docker 容器 参考 AlphaFold3: https://github.com/goo…

【免越狱】iOS砸壳 可下载AppStore任意版本 旧版本IPA下载

软件介绍 下载iOS旧版应用&#xff0c;简化繁琐的抓包流程。 一键生成去更新IPA&#xff08;手机安装后&#xff0c;去除App Store的更新检测&#xff09;。 软件界面 支持系统 Windows 10/Windows 8/Windows 7&#xff08;由于使用了Fiddler库&#xff0c;因此需要.Net环境…

shell 100例

1、每天写一个文件 (题目要求&#xff09; 请按照这样的日期格式(xxxx-xx-xx每日生成一个文件 例如生成的文件为2017-12-20.log&#xff0c;并且把磁盘的使用情况写到到这个文件中不用考虑cron&#xff0c;仅仅写脚本即可 [核心要点] date命令用法 df命令 知识补充&#xff1…

Acrobat Pro DC 2023(pdf免费转化word)

所在位置 通过网盘分享的文件&#xff1a;Acrobat Pro DC 2023(64bit).tar 链接: https://pan.baidu.com/s/1_m8TT1rHTtp5YnU8F0QGXQ 提取码: 1234 --来自百度网盘超级会员v4的分享 安装流程 打开安装所在位置 进入安装程序 找到安装程序 进入后点击自定义安装&#xff0c;这里…

linux之调度管理(5)-实时调度器

一、概述 在Linux内核中&#xff0c;实时进程总是比普通进程的优先级要高&#xff0c;实时进程的调度是由Real Time Scheduler(RT调度器)来管理&#xff0c;而普通进程由CFS调度器来管理。 实时进程支持的调度策略为&#xff1a;SCHED_FIFO和SCHED_RR。 SCHED_FIFO&#xff…

在arm64架构下, Ubuntu 18.04.5 LTS 用命令安装和卸载qt4、qt5

问题&#xff1a;需要在 arm64下安装Qt&#xff0c;QT源码编译失败以后&#xff0c;选择在线安装&#xff01; 最后安装的版本是Qt5.9.5 和QtCreator 4.5.2 。 一、ubuntu安装qt4的命令(亲测有效)&#xff1a; sudo add-apt-repository ppa:rock-core/qt4 sudo apt updat…

Qt 之 qwt和QCustomplot对比

QWT&#xff08;Qt Widgets for Technical Applications&#xff09;和 QCustomPlot 都是用于在 Qt 应用程序中绘制图形和图表的第三方库。它们各有优缺点&#xff0c;适用于不同的场景。 以下是 QWT 和 QCustomPlot 的对比分析&#xff1a; 1. 功能丰富度 QWT 功能丰富&a…

实用教程:如何无损修改MP4视频时长

如何在UltraEdit中搜索MP4文件中的“mvhd”关键字 引言 在视频编辑和分析领域&#xff0c;有时我们需要深入到视频文件的底层结构中去。UltraEdit&#xff08;UE&#xff09;和UEStudio作为强大的文本编辑器&#xff0c;允许我们以十六进制模式打开和搜索MP4文件。本文将指导…

使用nossl模式连接MySQL数据库详解

使用nossl模式连接MySQL数据库详解 摘要一、引言二、nossl模式概述2.1 SSL与nossl模式的区别2.2 选择nossl模式的场景三、在nossl模式下连接MySQL数据库3.1 准备工作3.2 C++代码示例3.3 代码详解3.3.1 初始化MySQL连接对象3.3.2 连接到MySQL数据库3.3.3 执行查询操作3.3.4 处理…

Linux下编译MFEM

本文记录在Linux下编译MFEM的过程。 零、环境 操作系统Ubuntu 22.04.4 LTSVS Code1.92.1Git2.34.1GCC11.4.0CMake3.22.1Boost1.74.0oneAPI2024.2.1 一、安装依赖 二、编译代码 附录I: CMakeUserPresets.json {"version": 4,"configurePresets": [{&quo…

号卡分销系统,号卡系统,物联网卡系统源码安装教程

号卡分销系统&#xff0c;号卡系统&#xff0c;物联网卡系统&#xff0c;&#xff0c;实现的高性能(PHP协程、PHP微服务)、高灵活性、前后端分离(后台)&#xff0c;PHP 持久化框架&#xff0c;助力管理系统敏捷开发&#xff0c;长期持续更新中。 主要特性 基于Auth验证的权限…

Java基础-集合

(创作不易&#xff0c;感谢有你&#xff0c;你的支持&#xff0c;就是我前行的最大动力&#xff0c;如果看完对你有帮助&#xff0c;请留下您的足迹&#xff09; 目录 前言 一、Java集合框架概述 二、Collection接口及其实现 2.1 Collection接口 2.2 List接口及其实现 …