小白也能读懂的ConvLSTM!(开源pytorch代码)

ConvLSTM

    • 1. 算法简介与应用场景
    • 2. 算法原理
      • 2.1 LSTM基础
      • 2.2 ConvLSTM原理
        • 2.2.1 ConvLSTM的结构
        • 2.2.2 卷积操作的优点
      • 2.3 LSTM与ConvLSTM的对比分析
      • 2.4 ConvLSTM的应用
    • 3. PyTorch代码
    • 参考文献

仅需要网络源码的可以直接跳到末尾即可

1. 算法简介与应用场景

ConvLSTM(卷积长短期记忆网络)是一种结合了卷积神经网络(CNN)和长短期记忆网络(LSTM)优势的深度学习模型。它主要用于处理时空数据,特别适用于需要考虑空间特征和时间依赖关系的任务,如气象预测、视频分析、交通流量预测等。

在气象预测中,ConvLSTM可以根据过去的气象数据(如降水、温度等)预测未来的天气情况。在视频分析中,它可以帮助识别视频中的活动或事件,利用时间序列的连续性和空间信息进行更准确的分析。

2. 算法原理

2.1 LSTM基础

在介绍ConvLSTM之前,先让我们来回归一下什么是长短期记忆网络(LSTM)。LSTM是一种特殊的循环神经网络(RNN),它通过引入门控机制解决了传统RNN在长序列训练中面临的梯度消失和爆炸问题。LSTM单元主要包含三个门:输入门、遗忘门和输出门。这些门控制着信息在单元中的流动,从而有效地记住或遗忘信息。

LSTM的核心公式如下:

  • 遗忘门
    f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf[ht1,xt]+bf)

  • 输入门
    i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi[ht1,xt]+bi)
    C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC[ht1,xt]+bC)

  • 单元状态更新
    C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t = f_t \ast C_{t-1} + i_t \ast \tilde{C}_t Ct=ftCt1+itC~t

  • 输出门
    o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo[ht1,xt]+bo)
    h t = o t ∗ tanh ⁡ ( C t ) h_t = o_t \ast \tanh(C_t) ht=ottanh(Ct)

这里, C t C_t Ct 是当前的单元状态, h t h_t ht 是当前的隐藏状态, x t x_t xt 是当前的输入。

2.2 ConvLSTM原理

ConvLSTM在LSTM的基础上引入了卷积操作。传统的LSTM使用全连接层处理输入数据,而ConvLSTM则采用卷积层来处理空间数据。这样,ConvLSTM能够更好地捕捉输入数据中的空间特征。
在这里插入图片描述

2.2.1 ConvLSTM的结构

ConvLSTM的单元结构与LSTM非常相似,但是在每个门的计算中使用了卷积操作。具体来说,ConvLSTM的每个门的公式可以表示为:

i t = σ ( W x i ∗ X t + W h i ∗ H t − 1 + W c i ∘ C t − 1 + b i ) i_t = \sigma (W_{xi} * X_t + W_{hi} * H_{t-1} + W_{ci} \circ C_{t-1} + b_i) it=σ(WxiXt+WhiHt1+WciCt1+bi)
f t = σ ( W x f ∗ X t + W h f ∗ H t − 1 + W c f ∘ C t − 1 + b f ) f_t = \sigma (W_{xf} * X_t + W_{hf} * H_{t-1} + W_{cf} \circ C_{t-1} + b_f) ft=σ(WxfXt+WhfHt1+WcfCt1+bf)
C t = f t ∘ C t − 1 + i t ∘ t a n h ( W x c ∗ X t + W h c ∗ H t − 1 + b c ) C_t = f_t \circ C_{t-1} + i_t \circ tanh(W_{xc} * X_t + W_{hc} * H_{t-1} + b_c) Ct=ftCt1+ittanh(WxcXt+WhcHt1+bc)
o t = σ ( W x o ∗ X t + W h o ∗ H t − 1 + W c o ∘ C t + b o ) o_t = \sigma (W_{xo} * X_t + W_{ho} * H_{t-1} + W_{co} \circ C_t + b_o) ot=σ(WxoXt+WhoHt1+WcoCt+bo)
H t = o t ∘ t a n h ( C t ) H_t = o_t \circ tanh(C_t) Ht=ottanh(Ct)

这里的 所有 W W W都是是卷积权重, b b b是偏置项, σ \sigma σ 是 sigmoid 函数, tanh ⁡ \tanh tanh 是双曲正切函数。。
在这里插入图片描述

2.2.2 卷积操作的优点
  1. 空间特征提取:卷积操作能够有效提取输入数据中的空间特征。对于图像数据,卷积操作可以捕捉局部特征,例如边缘、纹理等,这在时间序列数据中同样适用。

  2. 参数共享:卷积操作通过使用相同的卷积核在不同位置计算特征,从而减少了模型参数的数量,降低了计算复杂度。

  3. 平移不变性:卷积网络对输入数据的平移具有不变性,即相同的特征在不同位置都会被检测到,这对于时空序列数据来说是非常重要的。

2.3 LSTM与ConvLSTM的对比分析

特性LSTMConvLSTM
输入类型一维序列三维数据(时序的图像数据)
处理方式全连接层卷积操作
空间特征捕捉较弱较强
应用场景自然语言处理、时间序列预测图像序列预测、视频分析

2.4 ConvLSTM的应用

ConvLSTM在多个领域中表现出色,特别适合处理具有时空特征的数据。以下是一些主要的应用场景:

  • 气象预测:利用历史气象数据(如温度、湿度、降水等)来预测未来的天气情况。
  • 视频分析:对视频中的动态场景进行建模,识别和预测视频中的活动。
  • 交通流量预测:基于历史交通数据预测未来的交通流量,帮助城市交通管理。
  • 医学影像分析:分析医学影像序列(如CT、MRI)中的变化,辅助疾病诊断。

3. PyTorch代码

以下是ConvLSTM的完整代码,可以直接拿来用:

import torch.nn as nn
import torchclass ConvLSTMCell(nn.Module):def __init__(self, input_dim, hidden_dim, kernel_size, bias):"""初始化卷积 LSTM 单元。参数:----------input_dim: int输入张量的通道数。hidden_dim: int隐藏状态的通道数。kernel_size: (int, int)卷积核的大小。bias: bool是否添加偏置项。"""super(ConvLSTMCell, self).__init__()self.input_dim = input_dimself.hidden_dim = hidden_dimself.kernel_size = kernel_size# 计算填充大小以保持输入和输出尺寸一致self.padding = kernel_size[0] // 2, kernel_size[1] // 2self.bias = bias# 定义卷积层,输入是输入维度加上隐藏维度,输出是4倍的隐藏维度(对应i, f, o, g)self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,out_channels=4 * self.hidden_dim,kernel_size=self.kernel_size,padding=self.padding,bias=self.bias)def forward(self, input_tensor, cur_state):h_cur, c_cur = cur_state# 沿着通道轴进行拼接combined = torch.cat([input_tensor, h_cur], dim=1)combined_conv = self.conv(combined)# 将输出分割成四个部分,分别对应输入门、遗忘门、输出门和候选单元状态cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)i = torch.sigmoid(cc_i)f = torch.sigmoid(cc_f)o = torch.sigmoid(cc_o)g = torch.tanh(cc_g)# 更新单元状态c_next = f * c_cur + i * g# 更新隐藏状态h_next = o * torch.tanh(c_next)return h_next, c_nextdef init_hidden(self, batch_size, image_size):height, width = image_size# 初始化隐藏状态和单元状态为零return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))class ConvLSTM(nn.Module):"""卷积 LSTM 层。参数:----------input_dim: 输入通道数hidden_dim: 隐藏通道数kernel_size: 卷积核大小num_layers: LSTM 层的数量batch_first: 批次是否在第一维bias: 卷积中是否有偏置项return_all_layers: 是否返回所有层的计算结果输入:------一个形状为 B, T, C, H, W 或者 T, B, C, H, W 的张量输出:------元组包含两个列表(长度为 num_layers 或者长度为 1 如果 return_all_layers 为 False):0 - layer_output_list 是长度为 T 的每个输出的列表1 - last_state_list 是最后的状态列表,其中每个元素是一个 (h, c) 对应隐藏状态和记忆状态示例:>>> x = torch.rand((32, 10, 64, 128, 128))>>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)>>> _, last_states = convlstm(x)>>> h = last_states[0][0]  # 0 表示层索引,0 表示 h 索引"""def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,batch_first=False, bias=True, return_all_layers=False):super(ConvLSTM, self).__init__()# 检查 kernel_size 的一致性self._check_kernel_size_consistency(kernel_size)# 确保 kernel_size 和 hidden_dim 的长度与层数一致kernel_size = self._extend_for_multilayer(kernel_size, num_layers)hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)if not len(kernel_size) == len(hidden_dim) == num_layers:raise ValueError('不一致的列表长度。')self.input_dim = input_dimself.hidden_dim = hidden_dimself.kernel_size = kernel_sizeself.num_layers = num_layersself.batch_first = batch_firstself.bias = biasself.return_all_layers = return_all_layers# 创建 ConvLSTMCell 列表cell_list = []for i in range(0, self.num_layers):cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,hidden_dim=self.hidden_dim[i],kernel_size=self.kernel_size[i],bias=self.bias))self.cell_list = nn.ModuleList(cell_list)def forward(self, input_tensor, hidden_state=None):"""前向传播函数。参数:----------input_tensor: 输入张量,形状为 (t, b, c, h, w) 或者 (b, t, c, h, w)hidden_state: 初始隐藏状态,默认为 None返回:-------last_state_list, layer_output"""if not self.batch_first:# 改变输入张量的顺序,如果 batch_first 为 Falseinput_tensor = input_tensor.permute(1, 0, 2, 3, 4)b, _, _, h, w = input_tensor.size()# 实现状态化的 ConvLSTMif hidden_state is not None:raise NotImplementedError()else:# 初始化隐藏状态hidden_state = self._init_hidden(batch_size=b,image_size=(h, w))layer_output_list = []last_state_list = []seq_len = input_tensor.size(1)cur_layer_input = input_tensorfor layer_idx in range(self.num_layers):h, c = hidden_state[layer_idx]output_inner = []for t in range(seq_len):# 在每个时间步上更新状态h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],cur_state=[h, c])output_inner.append(h)# 将输出堆叠起来layer_output = torch.stack(output_inner, dim=1)cur_layer_input = layer_outputlayer_output_list.append(layer_output)last_state_list.append([h, c])if not self.return_all_layers:# 如果不需要返回所有层,则只返回最后一层的输出和状态layer_output_list = layer_output_list[-1:]last_state_list = last_state_list[-1:]return layer_output_list, last_state_listdef _init_hidden(self, batch_size, image_size):init_states = []for i in range(self.num_layers):# 初始化每一层的隐藏状态init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))return init_states@staticmethoddef _check_kernel_size_consistency(kernel_size):if not (isinstance(kernel_size, tuple) or(isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):raise ValueError('`kernel_size` 必须是 tuple 或者 list of tuples')@staticmethoddef _extend_for_multilayer(param, num_layers):if not isinstance(param, list):param = [param] * num_layersreturn param

参考文献

[1]Shi, X., Chen, Z., Wang, H., Yeung, D. Y., Wong, W. K., & Woo, W. (2015). Convolutional LSTM Network: A Machine Learning [2]Approach for Precipitation Nowcasting. Advances in Neural Information Processing Systems, 28.
[3]Hochreiter, S., & Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation, 9(8), 1735-1780.
Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. MIT Press.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/386120.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【漏洞复现】phpStudy 小皮 Windows面板 存在RCE漏洞

靶场资料后台自行领取【靶场】 image-20240726092307252 PhpStudy小皮面板曝RCE漏洞,本质是存储型XSS引发。攻击者通过登录用户名输入XSS代码,结合后台计划任务功能,实现远程代码执行,严重威胁服务器安全。建议立即更新至安全版…

OpenSSL SSL_connect: Connection was reset in connection to github.com:443

OpenSSL SSL_connect: Connection was reset in connection to github.com:443 目录 OpenSSL SSL_connect: Connection was reset in connection to github.com:443 【常见模块错误】 【解决方案】 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页&…

机器视觉12-相机

相机 作用: 工业相机 是 机器视觉系统 的重要组成部分 最本质的功能就是通过CCD或CMOS成 像传感器将镜头产生的光信号转变为 有序的电信号,并将这些信息通过相 应接口传送到计算机主机 工业相机分类 目前业内没有对相机进行明确的分类定义, 以下分类是…

正点原子 通用外设配置模型 GPIO配置步骤 NVIC配置

1. 这个是通用外设驱动模式配置 除了初始化是必须的 其他不是必须的 2. gpio配置步骤 1.使能时钟是相当于开电 2.设置工作模式是配置是输出还是输入 是上拉输入还是下拉输入还是浮空 是高速度还是低速度这些 3 和 4小点就是读写io口的状态了 3. 这个图是正点原子 将GPIO 的时…

鸿蒙开发—黑马云音乐之Music页面

目录 1.外层容器效果 2.信息区-发光效果 3.信息区-内容布局 4.播放列表布局 5.播放列表动态化 6.模拟器运行并配置权限 效果: 1.外层容器效果 Entry Component export struct MuiscPage {build() {Column() {// 信息区域Column() {}.width(100%)// .backgroun…

带有扰动观测器的MPC电机控制

模型预测控制(Model Predictive Contro1, MPC)是一种先进的控制策略,虽然具有鲁棒性、建模简单、处理多变量系统、显示约束、预测未来行为和优化性能的能力等优势。它的不足在于预测控制行为的计算需要繁琐的计算量,以及抗干扰能力较弱。这里提出基于扰动…

GPIO子系统

1. GPIO子系统视频概述 1.1 GPIO子系统的作用 芯片内部有很多引脚,这些引脚可以接到GPIO模块,也可以接到I2C等模块。 通过Pinctrl子系统来选择引脚的功能(mux function)、配置引脚: 当一个引脚被复用为GPIO功能时,我们可以去设…

数组模拟单调栈--C++

本题的计算量较大,用暴力算法会超时,的用别的方法,我们假设在左边找第一个比它小的数,那么在左边出现一次的数如果比右边大了,那么就不会在出现了,我们将它删除掉就可以了,用这个方法我们可以的…

28.Labview界面设计(上篇) --- 软件登陆界面设计与控件美化

摘要: 作为GUI界面设计的老大哥般的存在,Labview本身的G语言属性就展现了其优越的外观设计能力,其中不乏许多编程爱好者、架构师等的喜欢使用Labview进行界面相关的设计,而使用Matlab、Python等软件写底层数据处理模块、自动化脚本…

Javaweb项目|springboot医院管理系统

收藏点赞不迷路 关注作者有好处 文末获取源码 一、系统展示 二、万字文档展示 基于springboot医院管理系统 开发语言:Java 数据库:MySQL 技术:SpringSpringMVCMyBatisVue 工具:IDEA/Ecilpse、Navicat、Maven 编号:…

通信原理-思科实验五:家庭终端以太网接入Internet实验

实验五 家庭终端以太网接入Internet实验 一实验内容 二实验目的 三实验原理 四实验步骤 1.按照上图选择对应的设备,并连接起来 为路由器R0两个端口配置IP 为路由器R1端口配置IP 为路由器设备增加RIP,配置接入互联网的IP的动态路由项 5.为路由器R1配置静…

如何使用Firefox浏览器连接IPXProxy设置海外代理IP教程

​Firefox浏览器是大家上网时经常会使用的一款工具。不过,有时候我们会遇到一些网站无法直接访问的情况。这时候,通过海外代理IP,比如像IPXProxy代理这样的服务,可能就能帮助我们进入那些受限制的网站,获取我们所需的资…

手写spring简易版本,让你更好理解spring源码

首先我们要模拟spring,先搞配置文件,并配置bean 创建我们需要的类,beandefito,这个类是用来装解析后的bean,主要三个字段,id,class,scop,对应xml配置的属性 package org…

【北京迅为】《i.MX8MM嵌入式Linux开发指南》-第三篇 嵌入式Linux驱动开发篇-第六十三章 输入子系统实验

i.MX8MM处理器采用了先进的14LPCFinFET工艺,提供更快的速度和更高的电源效率;四核Cortex-A53,单核Cortex-M4,多达五个内核 ,主频高达1.8GHz,2G DDR4内存、8G EMMC存储。千兆工业级以太网、MIPI-DSI、USB HOST、WIFI/BT…

IP 泄露: 原因与避免方法

始终关注您的IP信息! 您的IP地址不仅显示您的位置,它包含几乎所有的互联网活动信息! 如果出现IP泄漏,几乎所有的信息都会被捕获甚至非法利用! 那么,网站究竟如何追踪您的IP地址?您又如何有效…

Catalyst优化器:让你的Spark SQL查询提速10倍

目录 1 逻辑优化阶段 2.1 逻辑计划解析 2.2 逻辑计划优化 2.2.1 Catalys的优化过程 2.2.2 Cache Manager优化 2 物理优化阶段 2.1 优化 Spark Plan 2.1.1 Catalyst 的 Join 策略 2.1.2 如何决定选择哪一种 Join 策略 2.2 Physical Plan 2.2.1 EnsureRequirements 规则 3 相关文…

【Unity2D 2022:Data】读取csv格式文件的数据

一、创建csv文件 1. 打开Excel,创建xlsx格式文件 2. 编辑卡牌数据:这里共写了两类卡牌,第一类是灵物卡,具有编号、卡名、生命、攻击四个属性;第二类是法术卡,具有编号、卡名、效果三个属性。每类卡的第一…

qt 如何制作动态库插件

首先 首先第一点要确定我们的接口是固定的,也就是要确定 #ifndef RTSPPLUGIN_H #define RTSPPLUGIN_H #include "rtspplugin_global.h" typedef void (*func_callback)(uint8_t* data,int len,uint32_t ssrc,uint32_t ts,const char* ipfrom,uint16_t f…

【Maven学习】-3.进阶

文章目录 3. 进阶3.1 maven依赖传递特性 3.2 依赖冲突3.2.1 自动选择原则3.2.2 手动排除 3.3 聚合工程3.3.1 继承介绍继承作用继承语法父工程依赖统一管理-dependencyManagement 3.3.2 工程聚合关系简介聚合作用聚合作用 3.4 私服3.4.1 简介3.4.2 Nexus下载安装Nexus3Nexus2 3.…

带你学会Git必会操作

文章目录 带你学会Git必会操作1Git的安装2.Git基本操作2.1本地仓库的创建2.2配置本地仓库 3.认识一些Git的基本概念3.1操作流程: 4.一些使用场景4.1添加文件场景一4.2查看git文件4.3修改文件4.4Git版本回退4.5git撤销修改 5.分支管理5.1查看分支5.2创建本地分支5.3切…