代码 RNN原理及手写复现

29、PyTorch RNN的原理及其手写复现_哔哩哔哩_bilibili

笔记连接: https://pan.baidu.com/s/1_Sm7ptEiJtTTq3vQWgOTNg?pwd=2rei 提取码: 2rei

import torch
import torch.nn as nn
bs,T=2,3  # 批大小,输入序列长度
input_size,hidden_size = 2,3 # 输入特征大小,隐含层特征大小
input = torch.randn(bs,T,input_size)  # 随机初始化一个输入特征序列
h_prev = torch.zeros(bs,hidden_size) # 初始隐含状态
# step1 调用pytorch RNN API
rnn = nn.RNN(input_size,hidden_size,batch_first=True)
rnn_output,state_finall = rnn(input,h_prev.unsqueeze(0))print(rnn_output)
print(state_finall)
# step2 手写 rnn_forward函数,实现RNN的计算原理
def rnn_forward(input,weight_ih,weight_hh,bias_ih,bias_hh,h_prev):bs,T,input_size = input.shapeh_dim = weight_ih.shape[0]h_out = torch.zeros(bs,T,h_dim) # 初始化一个输出(状态)矩阵for t in range(T):x = input[:,t,:].unsqueeze(2)  # 获取当前时刻的输入特征,bs*input_size*1w_ih_batch = weight_ih.unsqueeze(0).tile(bs,1,1) # bs * h_dim * input_sizew_hh_batch = weight_hh.unsqueeze(0).tile(bs,1,1)# bs * h_dim * h_dimw_times_x = torch.bmm(w_ih_batch,x).squeeze(-1) # bs*h_dimw_times_h = torch.bmm(w_hh_batch,h_prev.unsqueeze(2)).squeeze(-1) # bs*h_himh_prev = torch.tanh(w_times_x + bias_ih + w_times_h + bias_hh)h_out[:,t,:] = h_prevreturn h_out,h_prev.unsqueeze(0)
# 验证结果
custom_rnn_output,custom_state_finall = rnn_forward(input,rnn.weight_ih_l0,rnn.weight_hh_l0,rnn.bias_ih_l0,rnn.bias_hh_l0,h_prev)
print(custom_rnn_output)
print(custom_state_finall)
print(torch.allclose(rnn_output,custom_rnn_output))
print(torch.allclose(state_finall,custom_state_finall))
# step3 手写一个 bidirectional_rnn_forward函数,实现双向RNN的计算原理
def bidirectional_rnn_forward(input,weight_ih,weight_hh,bias_ih,bias_hh,h_prev,weight_ih_reverse,weight_hh_reverse,bias_ih_reverse,bias_hh_reverse,h_prev_reverse):bs,T,input_size = input.shapeh_dim = weight_ih.shape[0]h_out = torch.zeros(bs,T,h_dim*2) # 初始化一个输出(状态)矩阵,注意双向是两倍的特征大小forward_output = rnn_forward(input,weight_ih,weight_hh,bias_ih,bias_hh,h_prev)[0]  # forward layerbackward_output = rnn_forward(torch.flip(input,[1]),weight_ih_reverse,weight_hh_reverse,bias_ih_reverse, bias_hh_reverse,h_prev_reverse)[0] # backward layer# 将input按照时间的顺序翻转h_out[:,:,:h_dim] = forward_outputh_out[:,:,h_dim:] = torch.flip(backward_output,[1]) #需要再翻转一下 才能和forward output拼接h_n = torch.zeros(bs,2,h_dim)  # 要最后的状态连接h_n[:,0,:] = forward_output[:,-1,:]h_n[:,1,:] = backward_output[:,-1,:]h_n = h_n.transpose(0,1)return h_out,h_n# return h_out,h_out[:,-1,:].reshape((bs,2,h_dim)).transpose(0,1)# 验证一下 bidirectional_rnn_forward的正确性
bi_rnn = nn.RNN(input_size,hidden_size,batch_first=True,bidirectional=True)
h_prev = torch.zeros((2,bs,hidden_size))
bi_rnn_output,bi_state_finall = bi_rnn(input,h_prev)for k,v in bi_rnn.named_parameters():print(k,v)
custom_bi_rnn_output,custom_bi_state_finall = bidirectional_rnn_forward(input,bi_rnn.weight_ih_l0,bi_rnn.weight_hh_l0,bi_rnn.bias_ih_l0,bi_rnn.bias_hh_l0,h_prev[0],bi_rnn.weight_ih_l0_reverse,bi_rnn.weight_hh_l0_reverse,bi_rnn.bias_ih_l0_reverse,bi_rnn.bias_hh_l0_reverse,h_prev[1])
print("Pytorch API output")
print(bi_rnn_output)
print(bi_state_finall)print("\n custom bidirectional_rnn_forward function output:")
print(custom_bi_rnn_output)
print(custom_bi_state_finall)
print(torch.allclose(bi_rnn_output,custom_bi_rnn_output))
print(torch.allclose(bi_state_finall,custom_bi_state_finall))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/470761.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JVM】关于JVM的内部原理你到底了解多少(八股文面经知识点)

前言 🌟🌟本期讲解关于HTTPS的重要的加密原理~~~ 🌈感兴趣的小伙伴看一看小编主页:GGBondlctrl-CSDN博客 🔥 你的点赞就是小编不断更新的最大动力 🎆那么废话不…

【Pikachu】目录遍历实战

既然已经决定做一件事,那么除了当初决定做这件事的我之外,没人可以叫我傻瓜。 1.目录遍历漏洞概述 目录遍历漏洞概述 在Web功能的设计过程中,开发者经常会将需要访问的文件作为变量进行定义,以实现前端功能的灵活性。当用户发起…

如何用C#和Aspose.PDF实现PDF转Word工具

在本篇博文中,我将详细讲解如何用C#实现一个PDF转Word工具。这款工具基于Aspose.PDF库,实现PDF文件转为Word(DOC/DOCX)格式的功能,并通过用户友好的界面和状态提示提升用户体验。希望通过这篇文章帮助大家理解软件的实…

【图像压缩感知】论文阅读:Self-supervised Scalable Deep Compressed Sensing

tips:本文为个人阅读论文的笔记,仅作为学习记录所用。 Title:Self-supervised Scalable Deep Compressed Sensing Journal:IJCV 2024 代码链接:GitHub - Guaishou74851/SCNet: Self-Supervised Scalable Deep Comp…

使用elementUI实现表格行拖拽改变顺序,无需引入外部库

前言: 使用vue2element UI,且完全使用原生的拖拽事件,无需引入外部库。 如果表格数据量较大,或需要更多复杂功能,可以考虑使用 vuedraggable库,提供更多配置选项和拖拽功能。 思路: 1. 通过el-table的ro…

深入理解接口测试:实用指南与最佳实践5.0(三)

✨博客主页: https://blog.csdn.net/m0_63815035?typeblog 💗《博客内容》:.NET、Java.测试开发、Python、Android、Go、Node、Android前端小程序等相关领域知识 📢博客专栏: https://blog.csdn.net/m0_63815035/cat…

32位、64位、x86与x64:深入解析计算机架构

目录 一、32位架构(x86) 1.1 定义与历史 1.2 技术特点 1.3 优缺点 二、64位架构(x64) 2.1 定义与历史 2.2 技术特点 2.3 优缺点 三、x86与x64的关系 四、应用场景 4.1 32位架构的应用场景 4.2 64位架构的应用场景 五、总结 在计算机领域中,处理器架构的选择对…

【stable diffusion部署】超强AI绘画Stable Diffusion,本地部署使用教程,完全免费使用

前言 01 软件介绍 Stable Diffusion和Midjourney类似,都是当下AI绘画最流行的AI工具之一,都支持用文字生成AI图片或者图片生成图片的软件。 二者的区别是:Midjourney只能在网上使用,国内需要魔法才能使用,而且存在使…

【计算机网络】【网络层】【习题】

计算机网络-传输层-习题 文章目录 13. 图 4-69 给出了距离-向量协议工作过程,表(a)是路由表 R1 初始的路由表,表(b)是相邻路由器 R2 传送来的路由表。请写出 R1 更新后的路由表(c)。…

【嵌入式开发】单片机CAN配置详解

0 前言 CAN外设作为一种传输速率较高,且连线较为简洁的通信协议,如今很多单片机内部都集成了CAN控制模块,这样只需要再外接一个CAN收发芯片,将TTL/CMOS电平转换成CAN协议的差分电平,就是一个完整的CAN收发节点。   最…

虚拟机安装Ubuntu 24.04服务器版(命令行版)

这个是专门用于服务器使用的,没有GUI,常用软件安装,见 虚拟机安装Ubuntu 24.04及其常用软件(2024.7)_ubuntu24.04-CSDN博客https://blog.csdn.net/weixin_42173947/article/details/140335522这里只记录独特的安装步骤 1 下载Ubuntu 24.04安…

ctfshow-web入门-SSTI(web361-web368)上

目录 1、web361 2、web362 3、web363 4、web364 5、web365 6、web366 7、web367 8、web368 1、web361 测试一下存在 SSTI 注入 方法很多 (1)使用子类可以直接调用的函数来打 payload1: ?name{{.__class__.__base__.__subclasses__…

Axure网络短剧APP端原型图,竖屏微剧视频模版40页

作品概况 页面数量:共 40 页 使用软件:Axure RP 9 及以上,非软件无源码 适用领域:短剧、微短剧、竖屏视频 作品特色 本作品为网络短剧APP的Axure原型设计图,定位属于免费短剧软件,类似红果短剧、河马剧场…

如何从头开始构建神经网络?(附教程)

随着流行的深度学习框架的出现,如 TensorFlow、Keras、PyTorch 以及其他类似库,学习神经网络对于新手来说变得更加便捷。虽然这些框架可以让你在几分钟内解决最复杂的计算任务,但它们并不要求你理解背后所有需求的核心概念和直觉。如果你知道…

JS 实现SSE通讯和了解SSE通讯

SSE 介绍: Server-Sent Events(SSE)是一种用于实现服务器向客户端实时推送数据的Web技术。与传统的轮询和长轮询相比,SSE提供了更高效和实时的数据推送机制。 SSE基于HTTP协议,允许服务器将数据以事件流(…

HTML之表单学习记录

如果一个页面仅仅供用户浏览,那就是静态页面。如果这个页面还能实现与服务器进行数据交互(像注册登录、话费充值、评论交流)​,那就是动态页面。表单是我们接触动态页面的第一步。其中表单最重要的作用就是:在浏览器端…

WPF学习之路,控件的只读、是否可以、是否可见属性控制

C#的控件学习之控件属性操作 控件的只读、是否可以、是否可见,是三个重要的参数,在很多表单、列表中都有用到,正常表单控制可以在父层主键控制参数是否可以编辑和可见,但是遇到个别字段需要单独控制时,可以在初始化wi…

three.js 杂记

clip: 1: 着色器 #ifdef USE_CLIP_DISTANCE vec4 worldPosition modelMatrix * vec4( position, 1.0 ); gl_ClipDistance[ 0 ] worldPosition.x - sin( time ) * ( 0.5 ); #endif gl_Position projectionMatrix * modelViewMatrix * vec4( positio…

基于混合配准策略的多模态医学图像配准方法研究

摘要: 提出了一种由“粗”到“细”的混合配准策略,该配准策略吸取了以往配准方法的优点,且在细配阶段将基于特征的配准方法和基于灰度的配准方法结合在一起,提出了基于轮廓特征点集最大互信息的配准方法,从而在速度和精…

贪心算法入门(二)

相关文章 贪心算法入门(一)-CSDN博客 1.什么是贪心算法? 贪心算法是一种解决问题的策略,它将复杂的问题分解为若干个步骤,并在每一步都选择当前最优的解决方案,最终希望能得到全局最优解。这种策略的核心…