LSTM变种模型

一、GRU

1.概念

GRU(门控循环单元,Gated Recurrent Unit)是一种循环神经网络(RNN)的变体,旨在解决标准 RNN 在处理长期依赖关系时遇到的梯度消失问题。GRU 通过引入门控机制简化了 LSTM(长短期记忆网络)的设计,使得模型更轻便,同时保留了 LSTM 的优点。

2.原理

2.1.两个重大改进

1.将输入门、遗忘门、输出门三个门变为更新门(Updata Gate)和重置门(Reset Gate)两个门。

2.将 (候选) 单元状态 与 隐藏状态 (输出) 合并,即只有 当前时刻候选隐藏状态 \tilde{h_t}当前时刻隐藏状态 h_t

2.2模型结构

简化图:

内部结构:

GRU通过其门控机制能够有效地捕捉到序列数据中的时间动态,同时相较于LSTM来说,由于其结构更加简洁,通常参数更少,计算效率更高。

2.2.1 重置门

重置门决定在计算当前候选隐藏状态时,忽略多少过去的信息。

2.2.2 更新门

更新门决定了多少过去的信息将被保留。它使用前一时间步的隐藏状态 ( h_{t-1} ) 和当前输入 ( x_t ) 来计算得出。

2.2.3 候选隐藏状态

候选隐藏状态是当前时间步的建议更新,它包含了当前输入和过去的隐藏状态的信息。重置门的作用体现在它可以允许模型抛弃或保留之前的隐藏状态。

2.2.4 最终隐藏状态

最终隐藏状态是通过融合过去的隐藏状态和当前候选隐藏状态来计算得出的。更新门 ​控制了融合过去信息和当前信息的比例。

h_{t}忘记传递下来的 h_{t-1}中的某些信息,并加入当前节点输入的某些信息。这就是最终的记忆。

3. 代码实现

3.1 原生代码
import numpy as np
​
class GRU:def __init__(self, input_size, hidden_size):self.input_size = input_sizeself.hidden_size = hidden_size# 初始化w和b 更新门self.W_z = np.random.rand(hidden_size, input_size + hidden_size)self.b_z = np.zeros(hidden_size)#重置门self.W_r = np.random.rand(hidden_size, input_size + hidden_size)self.b_r = np.zeros(hidden_size)#候选隐藏状态self.W_h = np.random.rand(hidden_size, input_size + hidden_size)self.b_h = np.zeros(hidden_size)def tanh(self, x):return np.tanh(x)def sigmoid(self, x):return 1 / (1 + np.exp(-x))def forward(self, x):#初始化隐藏状态h_prev=np.zeros((self.hidden_size,))concat_input=np.concatenate([x, h_prev],axis=0)
​z_t=self.sigmoid(np.dot(self.W_z,concat_input)+self.b_z)r_t=self.sigmoid(np.dot(self.W_r,concat_input)+self.b_r)
​concat_reset_input=np.concatenate([x,r_t*h_prev],axis=0)h_hat_t=self.tanh(np.dot(self.W_h,concat_reset_input)+self.b_h)
​h_t=(1-z_t)*h_prev+z_t*h_hat_t
​return h_t
​
# 测试数据
input_size=3
hidden_size=2
seq_len=4
​
x=np.random.randn(seq_len,input_size)
​
gru=GRU(input_size,hidden_size)
​
all_h=[]
for t in range(seq_len):h_t=gru.forward(x[t,:])all_h.append(h_t)print(h_t.shape)
print(np.array(all_h).shape)
3.2 PyTorch
nn.GRUCell
import torch
import torch.nn as nn
​
class GRUCell(nn.Module): def __init__(self,input_size,hidden_size):super(GRUCell,self).__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.gru_cell=nn.GRUCell(input_size,hidden_size)def forward(self,x):h_t=self.gru_cell(x)return h_t# 测试数据
input_size=3
hidden_size=2
seq_len=4
​
gru_model=GRUCell(input_size,hidden_size)
​
x=torch.randn(seq_len,input_size)
​
for t in range(seq_len):h_t=gru_model(x[t])print(h_t)​
​
nn.GRU
import torch
import torch.nn as nn
​
class GRU(nn.Module):def __init__(self,input_size,hidden_size):super(GRU,self).__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.gru=nn.GRU(input_size,hidden_size)def forward(self,x):out,_=self.gru(x)return out# 测试数据
input_size=3
hidden_size=2
seq_len=4
batch_size=5
x=torch.randn(seq_len,batch_size,input_size)
gru_mosel=GRU(input_size,hidden_size)
​
out=gru_mosel(x)
print(out)
print(out.shape)
​

二、BiLSTM

1.概述

双向长短期记忆网络(BiLSTM)是长短期记忆网络(LSTM)的扩展,旨在同时考虑序列数据中的过去和未来信息。BiLSTM 通过引入两个独立的 LSTM 层,一个正向处理输入序列,另一个逆向处理,使得每个时间步的输出包含了该时间步前后的信息。这种双向结构能够更有效地捕捉序列中的上下文关系,从而提高模型对语义的理解能力。

  • 正向传递: 输入序列按照时间顺序被输入到第一个LSTM层。每个时间步的输出都会被计算并保留下来。

  • 反向传递: 输入序列按照时间的逆序(即先输入最后一个元素)被输入到第二个LSTM层。与正向传递类似,每个时间步的输出都会被计算并保留下来。

  • 合并输出: 在每个时间步,将两个LSTM层的输出通过某种方式合并(如拼接或加和)以得到最终的输出。

2. BILSTM模型应用背景

命名体识别

标注集

BMES标注集

分词的标注集并非只有一种,举例中文分词的情况,汉子作为词语开始Begin,结束End,中间Middle,单字Single,这四种情况就可以囊括所有的分词情况。于是就有了BMES标注集,这样的标注集在命名实体识别任务中也非常常见。

词性标注

在序列标注问题中单词序列就是x,词性序列就是y,当前词词性的判定需要综合考虑前后单词的词性。而标注集最著名的就是863标注集和北大标注集。

3. 代码实现

原生代码

import numpy as np
import torch
​
class BiLSTM():def __init__(self, input_size, hidden_size,output_size):self.input_size = input_sizeself.hidden_size = hidden_sizeself.output_size = output_size#正向self.lstm_forward = LSTM(input_size, hidden_size,output_size)#反向self.lstm_backward = LSTM(input_size, hidden_size,output_size)def forward(self,x):# 正向LSTMoutput,_,_=self.lstm_forward.forward(x)# 反向LSTM,np.flip()是将数组进行翻转output_backward,_,_=self.lstm_backward.forward(np.flip(x,1))#合并两层的隐藏状态combine_output=[np.concatenate((x,y),axis=0) for x,y in zip(output,output_backward)]return combine_outputclass LSTM:def __init__(self, input_size, hidden_size,output_size):""":param input_size: 词向量大小:param hidden_size: 隐藏层大小:param output_size: 输出类别"""self.input_size = input_sizeself.hidden_size = hidden_sizeself.output_size = output_size
​# 初始化权重和偏置 我们把结构图上的W U 拼接在了一起 所以参数是 input_size+hidden_sizeself.w_f = np.random.rand(hidden_size, input_size+hidden_size)self.b_f = np.random.rand(hidden_size)
​self.w_i = np.random.rand(hidden_size, input_size+hidden_size)self.b_i = np.random.rand(hidden_size)
​self.w_c = np.random.rand(hidden_size, input_size+hidden_size)self.b_c = np.random.rand(hidden_size)
​self.w_o = np.random.rand(hidden_size, input_size+hidden_size)self.b_o = np.random.rand(hidden_size)
​# 输出层self.w_y = np.random.rand(output_size, hidden_size)self.b_y = np.random.rand(output_size)
​def tanh(self,x):return np.tanh(x)
​def sigmoid(self,x):return 1/(1+np.exp(-x))
​def forward(self,x):h_t = np.zeros((self.hidden_size,)) # 初始隐藏状态c_t = np.zeros((self.hidden_size,)) # 初始细胞状态
​h_states = [] # 存储每个时间步的隐藏状态c_states = [] # 存储每个时间步的细胞状态
​for t in range(x.shape[0]):x_t = x[t] # 当前时间步的输入# concatenate 将x_t和h_t拼接 垂直方向x_t = np.concatenate([x_t,h_t])
​# 遗忘门f_t = self.sigmoid(np.dot(self.w_f,x_t)+self.b_f)
​# 输入门i_t = self.sigmoid(np.dot(self.w_i,x_t)+self.b_i)# 候选细胞状态c_hat_t = self.tanh(np.dot(self.w_c,x_t)+self.b_c)
​# 更新细胞状态c_t = f_t*c_t + i_t*c_hat_t
​# 输出门o_t = self.sigmoid(np.dot(self.w_o,x_t)+self.b_o)# 更新隐藏状态h_t = o_t*self.tanh(c_t)
​# 保存每个时间步的隐藏状态和细胞状态h_states.append(h_t)c_states.append(c_t)
​# 输出层 对最后一个时间步的隐藏状态进行预测,分类类别y_t = np.dot(self.w_y,h_t)+self.b_y# 转成张量形式 dim 0 表示行的维度output = torch.softmax(torch.tensor(y_t),dim=0)
​return np.array(h_states), np.array(c_states), output​
# 测试数据
input_size=3
hidden_size=8
output_size=5
seq_len=4
​
x=np.random.randn(seq_len,input_size)
​
bilstm=BiLSTM(input_size,hidden_size,output_size)
outputs=bilstm.forward(x)        
print(outputs)
print(np.array(outputs).shape)
# ---------------------------------------------------------------------------
import numpy as np
​
# 创建一个包含两个二维数组的列表
inputs = [np.array([[0.1], [0.2], [0.3]]), np.array([[0.4], [0.5], [0.6]])]
​
# 使用 numpy 库中的 np.stack 函数。这会将输入的二维数组堆叠在一起,从而形成一个新的三维数组
inputs_3d = np.stack(inputs)
​
# 将三维数组转换为列表
list_from_3d_array = inputs_3d.tolist()
​
print(list_from_3d_array)

Pytorch

import torch
import torch.nn as nn
​
class BiLSTM(nn.Module):def __init__(self, input_size, hidden_size,output_size):super(BiLSTM, self).__init__()#定义双向LSTMself.lstm=nn.LSTM(input_size,hidden_size,bidirectional=True)#输出层 因为双向LSTM的输出是双向的,所以第一个参数是隐藏层*2self.linear=nn.Linear(hidden_size*2,output_size)
​def forward(self,x):out,_=self.lstm(x)linear_out=self.linear(out)return linear_out# 测试数据
input_size=3
hidden_size=8
output_size=5
seq_len=4
batch_size=6
​
x=torch.randn(seq_len,batch_size,input_size)   
model=BiLSTM(input_size,hidden_size,output_size)
outputs=model(x)
print(outputs)
print(outputs.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/442329.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

APP测试中ios和androis的区别,有哪些注意点

一、运行机制不同 IOS采用的是沙盒运行机制,安卓采用的是虚拟机运行机制。 1、沙盒机制: 概念:沙盒是一种安全机制,用于防止不同应用之间互相访问 作用:就是存储数据,每个沙盒就相当于每个每个应用的系…

1688商品详情关键词数据-API

要利用 Python 爬虫采集 1688 商品详情数据,需要先了解 1688 网站的页面结构和数据请求方式。一般使用 requests 库请求网站的数据,使用 BeautifulSoup 库解析网页中的数据。 以下是一个简单的 Python 爬虫采集 1688 商品详情数据的示例代码&#xff1a…

红队老子养成记2 - 不想渗透pc?我们来远控安卓!(全网最详细)

大家好,我是Dest1ny。 今天我们是红队专题中的远控安卓。 这个实验会非常有趣,大家多多点赞! 环境: 一台有公网ip的vps / kali / 带msf工具即可 一台安卓手机(最好老一点,因为我们这里不涉及免杀&#…

typora笔记导出word格式:

Pandoc:各系统下载github链接 https://github.com/jgm/pandoc/releases/latest windows安装包 链接:https://pan.baidu.com/s/17AZNIMImbzFtWJAcAfAB0g?pwd55l2 提取码:55l2 先解压压缩包 点击 设置Pandoc路径,然后选择pa…

如何搭建自己的域名邮箱服务器?Poste.io邮箱服务器搭建教程,Linux+Docker搭建邮件服务器的教程

Linux系统Docker搭建Poste.io电子邮件服务器,搭建属于自己的域名邮箱服务器,可以无限收发电子邮件(Email)! 视频教程:https://www.bilibili.com/video/BV11p1mYaEpM/ 前言 什么是域名邮箱? …

vscode中安装python的包

首先需要调出命令行。然后运行代码,找到你所需要的环境。 PS C:\Users\Administrator\AppData\Local\ESRI\conda\envs\arcgispro-env> conda env list # conda environments: #C:\ProgramData\Anaconda3 base * C:\Users\Administrator\.con…

[C++ 核心编程]笔记 2 栈区和堆区

栈区: 由编译器自动分配释放&#xff0c;存放函数的参数值,同部变量等 注意事项&#xff1a;不要返回局部变量的地址&#xff0c;栈区开辟的数据由编译器自动释放 #define _CRT_SECURE_NO_WARNINGS 1 #include<iostream> using namespace std;//栈区数据注意事项 不要…

让机器来洞察他的内心!

本文所涉及所有资源均在传知代码平台可获取。 目录 洞察你的内心&#xff1a;你真的这么认为吗&#xff1f; 一、研究背景 二、模型结构和代码 D. 不一致性学习网络 E. 多模态讽刺分类 三、数据集介绍 四、性能展示 五、实现过程 1. 下载预训练的 GloVe 词向量&#xff08;Comm…

Hydra 新手友好使用教程

1. Hydra 简介 Hydra是一款强大的网络登录暴力破解工具&#xff0c;支持多种协议。本教程将帮助新手快速上手&#xff0c;掌握常用指令和操作。 2. 基本语法 hydra [参数] 目标 3. 核心参数详解 3.1 用户名和密码设置 单个用户名: -l LOGIN 例&#xff1a;-l admin 用户名…

Redis:事务

Redis&#xff1a;事务 事务事务操作MULTI & EXECDISCARDWATCH 事务 在MySQL中&#xff0c;事务遵循CIRD特性&#xff1a; 原子性&#xff1a;事务是一个整体&#xff0c;要么没有发生&#xff0c;要么已经执行完毕一致性&#xff1a;事务执行前后&#xff0c;数据都要符…

基于Arduino的遥控自平衡小车

基于Arduino的遥控自平衡小车 一、项目简介二、所需材料三、理论支持四、外壳设计五、线路连接六、检查MPU6050连接七、烧录库八、PID控制设置九、设置传感器参数十、无线移动控制十一、超声波模块 一、项目简介 一个使用Arduino Nano、MPU-6050以及便宜的6伏直流齿轮电机的自…

活久见!2024年诺贝尔物理学奖颁给了AI大佬Hinton 和 Hopfield

家人们&#xff01;让我们暂停手中工作&#xff0c;庆祝AI届的科学家首次获得诺贝尔物理学奖&#xff01;&#xff01;&#xff01;&#xff01;&#xff01; 刚刚出炉的热乎消息&#xff1a;今年的诺贝尔物理学奖颁发给了约翰霍普菲尔德 (John Joseph Hopfield)与杰弗里辛顿&a…

充电桩用能计量有序充电服务的探索应用

关键词&#xff1a;云平台&#xff1b;自动检测&#xff1b;能源管理&#xff1b;有序充电 今年&#xff0c;电动汽车行业抓住了疫情影响洼地&#xff0c;迅速找到了发展突破口&#xff0c;从电动汽车发行政策到锂电池开发技术均出台了多层面利好消息&#xff0c;未来一段时间…

【JavaEE初阶】深入理解不同锁的意义,synchronized的加锁过程理解以及CAS的原子性实现(面试经典题);

前言 &#x1f31f;&#x1f31f;本期讲解关于锁的相关知识了解&#xff0c;这里涉及到高频面试题哦~~~ &#x1f308;上期博客在这里&#xff1a;【JavaEE初阶】深入理解线程池的概念以及Java标准库提供的方法参数分析-CSDN博客 &#x1f308;感兴趣的小伙伴看一看小编主页&am…

SpringBoot日常:redission的接入使用和源码解析

文章目录 一、简介二、集成redissionpom文件redission 配置文件application.yml文件启动类 三、JAVA 操作案例字符串操作哈希操作列表操作集合操作有序集合操作布隆过滤器操作分布式锁操作 四、源码解析 一、简介 Redisson 是一个在 Redis 的基础上实现的 Java 驻内存数据网格…

Windows Ubuntu下搭建深度学习Pytorch训练框架与转换环境TensorRT

Windows Ubuntu下搭建深度学习Pytorch训练框架与转换环境TensorRT JetBrains2024&#xff08;IntelliJ IDEA、PhpStorm、RubyMine、Rider……&#xff09;安装包Anaconda Miniconda安装.condarc 文件配置镜像源查看conda的配置和源(channel)自定义conda虚拟环境路径conda常用命…

破解反编译:使用 ClassFinal 保护你的SpringBoot代码

在当今数字化时代&#xff0c;保护源代码的安全性变得愈发重要。无论是企业的核心算法还是独特的业务逻辑&#xff0c;代码一旦暴露&#xff0c;便可能导致竞争优势的丧失和商业机密的泄露。因此&#xff0c;在使用 Java 和 Spring Boot 开发项目时&#xff0c;理解从源代码到可…

websocket连接异常报错1006

目录&#xff1a; 1、问题现象2、问题原因3、解决方案 1、问题现象 WebSocket状态码的作用&#xff1a; 在WebSocket协议中&#xff0c;状态码用于表示连接状态和错误信息。通过状态码&#xff0c;我们可以快速判断连接是否成功&#xff0c;以及出现错误时的原因。常见的WebSo…

教培机构如何向知识付费转型

在数字化时代&#xff0c;知识付费已成为一股不可忽视的潮流。面对这一趋势&#xff0c;教育培训机构必须积极应对&#xff0c;实现向知识付费的转型&#xff0c;以在新的市场环境中立足。 一、教培机构应明确自身的知识定位。 在知识付费领域&#xff0c;专业性和独特性是关键…

VUE前后端分离毕业设计题目项目有哪些,VUE程序开发常见毕业论文设计推荐

目录 0 为什么选择Vue.js 1 Vue.js 的主要特点 2 前后端分离毕业设计项目推荐 3 后端推荐 4 总结 0 为什么选择Vue.js 使用Vue.js开发计算机毕业设计是一个很好的选择&#xff0c;因为它不仅具有现代前端框架的所有优点&#xff0c;还能让你专注于构建高性能、高可用性的W…