深度学习基础练习:从pytorch API出发复现LSTM与LSTMP

2024/11/5-2024/11/7:

前置知识:

[译] 理解 LSTM(Long Short-Term Memory, LSTM) 网络 - wangduo - 博客园

【官方双语】LSTM(长短期记忆神经网络)StatQuest_哔哩哔哩_bilibili

大部分思路来自于:

PyTorch LSTM和LSTMP的原理及其手写复现_哔哩哔哩_bilibiliicon-default.png?t=O83Ahttps://www.bilibili.com/video/BV1zq4y1m7aH/?spm_id_from=333.880.my_history.page.click&vd_source=db0d5acc929b82408b1040d67f2b1dde

部分常量设置与官方api使用:

        其实在实现RNN之后可以发现,lstm基本是同样的套路。 在看完上面的前置知识之后,理解三个门的作用即可对lstm有一个具体的认识,这里不再赘述。

         关于输入设置这方面,参考如下:

# 定义常量
bs, T, i_size, h_size = 2, 3, 4, 5
# 输入序列
input = torch.randn(bs, T, i_size)
# 初始值,不需要训练
c0 = torch.randn(bs, h_size)
h0 = torch.randn(bs, h_size)

        将定义的常量输入官方api:

# 调用官方api
lstm_layer = nn.LSTM(i_size, h_size, batch_first=True)
# 单层单项lstm,h0与c0的第0维度为 D(是否双向)*num_layers 故增加0维,维度为1
output, (h_final, c_final) = lstm_layer(input, (h0.unsqueeze(0), c0.unsqueeze(0)))
print(output.shape)
print(output)for k, v in lstm_layer.named_parameters():print(k)print(v.shape)

        输出如下:

torch.Size([2, 3, 5])
tensor([[[ 0.1134, -0.1032,  0.1496,  0.1853, -0.3758],
         [ 0.1831,  0.0223,  0.0377,  0.0867, -0.1090],
         [ 0.1233,  0.1121,  0.0574, -0.0401, -0.1576]],

        [[-0.2761,  0.3259,  0.1687, -0.0632,  0.2046],
         [ 0.1796,  0.3110,  0.0974,  0.0294,  0.0220],
         [ 0.1205,  0.1815,  0.0840, -0.1714, -0.1216]]],
       grad_fn=<TransposeBackward0>)
weight_ih_l0
torch.Size([20, 4])
weight_hh_l0
torch.Size([20, 5])
bias_ih_l0
torch.Size([20])
bias_hh_l0
torch.Size([20])

        可以看到LSTM的内置参数有 weight_ih_l0、weight_hh_l0、bias_ih_l0、bias_hh_l0,将关于三个门的知识结合在一起看差不多就明白接下来应该怎么做了: 

        h和c都是统一经过*weight+bias的操作,加在一起后经过tahn或者sigmoid激活函数,最后或点乘或加在h或者c上进行对参数的更新。只要不把维度的对应关系搞混还是比较好复现的。

        需要注意的是:三个门中的四个weight和bias(遗忘门一个,输入门两个,输出门一个)全部都按照第0维度拼在了一起方便同时进行矩阵运算,所以我们可以看到这些权重和偏置的第0维度的大小为4*h_size。一开始这一点也带给了我比较大的困惑。

代码复现与验证:

        代码较为简单,跟上次实现RNN的思路也差不多,基本是照着官方api那给的公式一步一步来的:

# 代码复现
def lstm_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh):h0, c0 = initial_statesbs, T, i_size = input.shapeh_size = w_ih.shape[0] // 4prev_h = h0     # [bs, h_size]prev_c = c0     # [bs, h_size]"""w_ih    # 4*h_size, i_sizew_hh    # 4*h_size, h_size"""# 输出序列output_size = h_sizeoutput = torch.zeros(bs, T, output_size)for t in range(T):x = input[:, t, :]  # 当前时刻输入向量, [bs, i_size]w_times_x = torch.matmul(w_ih, x.unsqueeze(-1)).squeeze(-1)         # [bs, 4*h_size]w_times_h_prev = torch.matmul(w_hh, prev_h.unsqueeze(-1)).squeeze(-1)    # [bs, 4*h_size]# 分别计算输入门(i),遗忘门(f),输出门(o),cell(g)i_t = torch.sigmoid(w_times_x[:, : h_size] + w_times_h_prev[:, : h_size] + b_ih[: h_size] + b_hh[: h_size])f_t = torch.sigmoid(w_times_x[:, h_size: 2*h_size] + w_times_h_prev[:, h_size: 2*h_size] +b_ih[h_size: 2*h_size] + b_hh[h_size: 2*h_size])g_t = torch.tanh(w_times_x[:, 2*h_size: 3*h_size] + w_times_h_prev[:, 2*h_size: 3*h_size] +b_ih[2*h_size: 3*h_size] + b_hh[2*h_size: 3*h_size])o_t = torch.sigmoid(w_times_x[:, 3*h_size:] + w_times_h_prev[:, 3*h_size:] + b_ih[3*h_size:] + b_hh[3*h_size:])# 更新流prev_c = f_t * prev_c + i_t * g_tprev_h = o_t * torch.tanh(prev_c)output[:, t, :] = prev_hreturn output, (prev_h, prev_c)

        输出结果对比验证:

# 调用官方api
lstm_layer = nn.LSTM(i_size, h_size, batch_first=True)
# 单层单项lstm,h0与c0的第0维度为 D(是否双向)*num_layers 故增加0维,维度为1
output, (h_final, c_final) = lstm_layer(input, (h0.unsqueeze(0), c0.unsqueeze(0)))
print(output.shape)
print(output)for k, v in lstm_layer.named_parameters():print(k, v.shape)
output, (h_final, c_final) = lstm_forward(input, (h0, c0), lstm_layer.weight_ih_l0,lstm_layer.weight_hh_l0, lstm_layer.bias_ih_l0, lstm_layer.bias_hh_l0)print(output)

        结果如下:

torch.Size([2, 3, 5])
tensor([[[-0.6394, -0.1796,  0.0831,  0.0816, -0.0620],
         [-0.5798, -0.2235,  0.0539, -0.0120, -0.0272],
         [-0.4229, -0.0798, -0.0762, -0.0030, -0.0668]],

        [[ 0.0294,  0.3240, -0.4318,  0.5005, -0.0223],
         [-0.1458,  0.0472, -0.1115,  0.3445,  0.3558],
         [-0.2922, -0.1013, -0.1755,  0.3065,  0.1130]]],
       grad_fn=<TransposeBackward0>)
weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 5])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])
tensor([[[-0.6394, -0.1796,  0.0831,  0.0816, -0.0620],
         [-0.5798, -0.2235,  0.0539, -0.0120, -0.0272],
         [-0.4229, -0.0798, -0.0762, -0.0030, -0.0668]],

        [[ 0.0294,  0.3240, -0.4318,  0.5005, -0.0223],
         [-0.1458,  0.0472, -0.1115,  0.3445,  0.3558],
         [-0.2922, -0.1013, -0.1755,  0.3065,  0.1130]]], grad_fn=<CopySlices>)

         复现成功。

appendix:

        这里放下LSTMP的参数设置:

# lstmp对h_size进行压缩
proj_size = 3
# h0的h_size也改为proj_size,而c0不变
h0 = torch.randn(bs, proj_size)# 调用官方api
lstmp_layer = nn.LSTM(i_size, h_size, batch_first=True, proj_size=proj_size)
# 单层单项lstm,h0与c0的第0维度为 D(是否双向)*num_layers 故增加0维,维度为1
output, (h, c) = lstmp_layer(input, (h0.unsqueeze(0), c0.unsqueeze(0)), )print(output)
print(output.shape)
print(h.shape)
print(c.shape)
for k, v in lstmp_layer.named_parameters():print(k, v.shape)

tensor([[[-0.0492,  0.0265,  0.0883],
         [-0.1028, -0.0327, -0.0542],
         [ 0.0250, -0.0231, -0.1199]],

        [[-0.2417, -0.1737, -0.0755],
         [-0.2351, -0.0837, -0.0376],
         [-0.2527, -0.0258, -0.0236]]], grad_fn=<TransposeBackward0>)
torch.Size([2, 3, 3])
torch.Size([1, 2, 3])
torch.Size([1, 2, 5])
weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 3])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])
weight_hr_l0 torch.Size([3, 5])

         其实LSTMP就多出了个weight_hr_l0对h进行压缩,但是不对cell压缩,目的是减少lstm的参数量,在小一点的sequence上基本没啥区别。若要支持lstmp,在前面的代码上改动几行即可:

def lstm_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh, w_hr=None):h0, c0 = initial_statesbs, T, i_size = input.shapeh_size = w_ih.shape[0] // 4prev_h = h0     # [bs, h_size]prev_c = c0     # [bs, h_size]"""w_ih    # 4*h_size, i_sizew_hh    # 4*h_size, h_size"""if w_hr is not None:# 输出压缩至p_sizep_size = w_hr.shape[0]output_size = p_sizeelse:output_size = h_sizeoutput = torch.zeros(bs, T, output_size)for t in range(T):x = input[:, t, :]  # 当前时刻输入向量, [bs, i_size]w_times_x = torch.matmul(w_ih, x.unsqueeze(-1)).squeeze(-1)         # [bs, 4*h_size]w_times_h_prev = torch.matmul(w_hh, prev_h.unsqueeze(-1)).squeeze(-1)    # [bs, 4*h_size]# 分别计算输入门(i),遗忘门(f),输出门(o),cell(g)i_t = torch.sigmoid(w_times_x[:, : h_size] + w_times_h_prev[:, : h_size] + b_ih[: h_size] + b_hh[: h_size])f_t = torch.sigmoid(w_times_x[:, h_size: 2*h_size] + w_times_h_prev[:, h_size: 2*h_size] +b_ih[h_size: 2*h_size] + b_hh[h_size: 2*h_size])g_t = torch.tanh(w_times_x[:, 2*h_size: 3*h_size] + w_times_h_prev[:, 2*h_size: 3*h_size] +b_ih[2*h_size: 3*h_size] + b_hh[2*h_size: 3*h_size])o_t = torch.sigmoid(w_times_x[:, 3*h_size:] + w_times_h_prev[:, 3*h_size:] + b_ih[3*h_size:] + b_hh[3*h_size:])# 更新流prev_c = f_t * prev_c + i_t * g_tprev_h = o_t * torch.tanh(prev_c)   # [bs, h_size]if w_hr is not None:prev_h = torch.matmul(w_hr, prev_h.unsqueeze(-1)).squeeze(-1)   # [bs, p_size]output[:, t, :] = prev_hreturn output, (prev_h, prev_c)

        经过验证,复现成功。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/468847.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

scala学习记录,Set,Map

set&#xff1a;集合&#xff0c;表示没有重复元素的集合&#xff0c;特点&#xff1a;唯一 语法格式&#xff1a;val 变量名 Set [类型]&#xff08;元素1&#xff0c;元素2...&#xff09; 可变不可变 可变&#xff08;mutable&#xff09;可对元素进行添加&#xff0c;删…

Docker了解

Docker是一种容器化技术&#xff0c;它可以将应用程序和其依赖项打包到一个独立的、可移植的容器中&#xff0c;以便在不同的环境中运行。Docker基于Linux操作系统的容器化技术&#xff0c;可以提供更轻量、更快速、更灵活、更一致的应用部署和管理方式。 Docker的基本概念包括…

‘conda‘ 不是内部或外部命令,也不是可运行的程序或批处理文件,Miniconda

下载了conda&#xff0c;但是在cmd里执行conda --version会显示’conda’ 不是内部或外部命令&#xff0c;也不是可运行的程序或批处理文件。 原因是环境变量里没有添加conda&#xff0c;无法识别路径。 需要在系统环境变量里添加如下路径&#xff1a; 保存之后重新打开cmd&am…

《TCP/IP网络编程》学习笔记 | Chapter 4:基于TCP的服务器端/客户端(1)

《TCP/IP网络编程》学习笔记 | Chapter 4&#xff1a;基于TCP的服务器端/客户端&#xff08;1&#xff09; 《TCP/IP网络编程》学习笔记 | Chapter 4&#xff1a;基于TCP的服务器端/客户端&#xff08;1&#xff09;理解TCP和UDPTCP/IP协议栈TCP/IP协议的诞生背景链路层网络层T…

【C语言】。末尼

用js写一个冒泡排序 当然可以。在JavaScript中&#xff0c;你可以使用递归的方式来实现冒泡排序算法。这是一个简单的示例&#xff1a; function bubbleSort(arr) { let len arr.length; if(len < 1){ return arr; // 如果数组只有一个元素或为空&#xf…

深度学习笔记12

1.神经网络的代价函数 神经网络可同时用于解决分类问题和回归问题&#xff0c;对于不同的问题会在输出层后&#xff0c;加上不同的变换函数。一般来说&#xff0c;回归问题使用恒等函数,分类问题使用sigmoid或softmax函数。而不同的变换函数&#xff0c;也对应不同的代价函数。…

RabbitMQ队列详细属性(重要)

RabbitMQ队列详细属性 1、队列的属性介绍1.1、Type&#xff1a;队列类型1.2、Name&#xff1a;队列名称1.3、Durability&#xff1a;声明队列是否持久化1.4、Auto delete&#xff1a; 是否自动删除1.5、Exclusive&#xff1a;1.6、Arguments&#xff1a;队列的其他属性&#xf…

json即json5新特性,idea使用json5,fastjson、gson、jackson对json5支持

文章目录 1.新特性1.1.JSON&#xff06;JSON5官网2.示例2.1. IntelliJ IDEA2.1.1.支持.json5文件2.1.2.md支持json5代码块 2.9. 示例源码 1.新特性 【通用】 注释尾随逗号key无需引号&#xff08;或单引号&#xff09; 【字符串】 字符串可以用单引号引起来。字符串可以通过转…

【NOIP普及组】摆花

【NOIP普及组】摆花 C语言代码C 代码Java代码Python代码 &#x1f490;The Begin&#x1f490;点点关注&#xff0c;收藏不迷路&#x1f490; 小明的花店新开张&#xff0c;为了吸引顾客&#xff0c;他想在花店的门口摆上一排花&#xff0c;共 m 盆。通过调 查顾客的喜好&am…

pdf转excel;pdf中表格提取

一、问题描述 在工作中或多或少会遇到&#xff1a;需要将某份pdf中的表格数据提取出来&#xff0c;以便能够“修改使用”数据 可将pdf中的表格提取出来&#xff0c;解决办法还有点复杂 尤其涉及“pdf中表格不是标准的单元格”的时候&#xff0c;提取数据到excel不太容易 比…

Qt中 QWidget 和 QMainWindow 区别

QWidget 用来构建简单窗口 QMainWindow 用来构建更复杂的窗口&#xff0c;QMainWindow 继承自QWidget&#xff0c;在QWidget 的基础上提供了菜单栏、工具栏、状态栏等功能 菜单栏&#xff08;QMenuBar&#xff09;工具栏&#xff08;QToolBar&#xff09;状态栏&#xff08;Q…

《深入浅出Apache Spark》系列③:Spark SQL解析层优化策略与案例解析

导读&#xff1a;本系列是Spark系列分享的第三期。第一期分享了Spark Core的一些基本原理和一些基本概念&#xff0c;包括一些核心组件。Spark的所有组件都围绕Spark Core来运转&#xff0c;其中最活跃的一个上层组件是Spark SQL。第二期分享则专门介绍了Spark SQL的基本架构和…

安全的时钟启动

Note&#xff1a;文章内容以 Xilinx 系列 FPGA 进行讲解 1、什么是安全启动时钟 通常情况下&#xff0c;在MMCM/PLL的LOCKED信号抬高之后&#xff08;由0变为1&#xff09;&#xff0c;MMCM/PLL就处于锁定状态&#xff0c;输出时钟已保持稳定。但在此之前&#xff0c;输出时钟会…

【mongodb】数据库的安装及连接初始化简明手册

NoSQL(NoSQL Not Only SQL )&#xff0c;意即"不仅仅是SQL"。 在现代的计算系统上每天网络上都会产生庞大的数据量。这些数据有很大一部分是由关系数据库管理系统&#xff08;RDBMS&#xff09;来处理。 通过应用实践证明&#xff0c;关系模型是非常适合于客户服务器…

丹韵红墙成红毯至美背景!冠珠华脉「雍华京韵」于M essential大秀绽放京韵时尚

东方美学代表品牌M essential近日于上海科学会堂举办十周年大秀&#xff0c;并发布品牌全新2024/25冬春系列。冠珠瓷砖作为国风新韵合作品牌&#xff0c;以高定岩板华脉「雍华京韵」系列的宫墙丹韵打造红毯背景墙&#xff0c;中国高定岩板与中国高级时装作品碰撞着“中国美”的…

工程认证与Spring Boot:计算机课程管理的新探索

摘要 随着信息技术在管理上越来越深入而广泛的应用&#xff0c;管理信息系统的实施在技术上已逐步成熟。本文介绍了基于工程教育认证的计算机课程管理平台的开发全过程。通过分析基于工程教育认证的计算机课程管理平台管理的不足&#xff0c;创建了一个计算机管理基于工程教育认…

excel功能

统计excel中每个名字出现的次数 在Excel中统计每个名字出现的次数&#xff0c;您可以使用COUNTIF函数或数据透视表。以下是两种方法的详细步骤&#xff1a; 方法一&#xff1a;使用COUNTIF函数 准备数据&#xff1a;确保您的姓名列表位于一个连续的单元格区域&#xff0c;例如…

【flask开启进程,前端内容图片化并转pdf-会议签到补充】

flask开启进程,前端内容图片化并转pdf-会议签到补充 flask及flask-socketio开启threading页面内容转图片转pdf流程前端主js代码内容转图片-browser端browser端的同步编程flask的主要功能route,def 总结 用到了pdf,来回数据转发和合成,担心flask卡顿,响应差,于是刚好看到threadi…

聊一聊Spring中的自定义监听器

前言 通过一个简单的自定义的监听器&#xff0c;从源码的角度分一下Spring中监听的整个过程&#xff0c;分析监听的作用。 一、自定义监听案例 1.1定义事件 package com.lazy.snail;import lombok.Getter; import org.springframework.context.ApplicationEvent;/*** Class…

VMWareTools安装及文件无法拖拽解决方案

文章目录 1 安装VMWare Tools2 安装vmware tools之后还是无法拖拽文件解决方案2.1 确认vmware tools安装2.2 客户机隔离2.3 修改自定义配置文件2.4 安装open-vm-tools-desktop软件 1 安装VMWare Tools 打开虚拟机VMware Workstation&#xff0c;启动Ubuntu系统&#xff0c;菜单…