【深度学习中的注意力机制6】11种主流注意力机制112个创新研究paper+代码——加性注意力(Additive Attention)

【深度学习中的注意力机制6】11种主流注意力机制112个创新研究paper+代码——加性注意力(Additive Attention)

【深度学习中的注意力机制6】11种主流注意力机制112个创新研究paper+代码——加性注意力(Additive Attention)


文章目录

  • 【深度学习中的注意力机制6】11种主流注意力机制112个创新研究paper+代码——加性注意力(Additive Attention)
  • 1. 加性注意力的起源与提出
  • 2. 加性注意力的原理
  • 3. 发展
  • 4. 代码实现
  • 5. 代码逐句解释


欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

1. 加性注意力的起源与提出

加性注意力(Additive Attention)是由Bahdanau et al. 在其2015年关于机器翻译的论文中提出的。这一注意力机制被应用于神经机器翻译(NMT)模型中,旨在提高翻译任务中序列对序列(Seq2Seq)模型的性能,尤其是解决长距离依赖问题。传统的Seq2Seq模型仅依赖于编码器的最终隐藏状态来生成翻译,这在处理长文本时容易丢失输入的细节信息。加性注意力通过在解码过程中对编码器隐藏状态进行加权求和,显著提升了模型性能

加性注意力是一种较早提出的注意力机制,与随后流行的点积注意力不同,加性注意力通过一个可学习的网络计算注意力分数,而不是直接计算向量之间的点积。加性注意力的提出标志着注意力机制在深度学习领域中的广泛应用,尤其是在处理长序列数据时的应用。

2. 加性注意力的原理

加性注意力的核心思想是通过学习一个函数来计算查询(Query)和键(Key)之间的相似性,然后根据相似性对值(Value)进行加权。

具体步骤如下:

1) 输入:

  • Query:解码器中的当前隐藏状态。
  • Key 和 Value:编码器中的隐藏状态(通常是一系列时间步的隐藏状态序列)。

2) 计算注意力分数: 通过将Query和Key进行非线性变换,再经过加性函数求得注意力分数。这个过程使用了一个可学习的权重矩阵,将查询和键分别映射到一个共同的表示空间,计算它们的相似性。

3) softmax归一化: 将上述得到的注意力分数通过softmax函数进行归一化,得到注意力权重。

4) 加权求和: 使用得到的注意力权重对值(Value)进行加权求和,生成最终的加权上下文向量。

公式如下:
在这里插入图片描述
这里, W q W_q Wq W k W_k Wk是可学习的权重矩阵, e i j e_{ij} eij 是注意力分数, v j v_j vj是Value。

3. 发展

加性注意力是最早被提出的注意力机制之一,并在神经机器翻译中取得了显著的成果。后来,随着注意力机制的发展,点积注意力(如Transformer中的缩放点积注意力)因其更高效的计算方式而逐渐取代了加性注意力。然而,加性注意力仍然在某些场景中被使用,尤其是在需要更细致的相似性计算的任务中。

在性能方面,加性注意力与点积注意力的主要区别在于计算复杂度。加性注意力通过一个可学习的神经网络计算注意力分数,计算复杂度为 O ( d ) O(d) O(d),而点积注意力直接计算点积,复杂度为 O ( d 2 ) O(d^2) O(d2),这使得加性注意力在某些场景下具有优势。

4. 代码实现

下面是一个使用加性注意力机制的简化实现,基于PyTorch框架。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass AdditiveAttention(nn.Module):def __init__(self, query_dim, key_dim, hidden_dim):super(AdditiveAttention, self).__init__()# 定义线性层,用于将查询和键映射到同一空间self.query_layer = nn.Linear(query_dim, hidden_dim)self.key_layer = nn.Linear(key_dim, hidden_dim)# 定义一个线性层,用于计算注意力分数self.energy_layer = nn.Linear(hidden_dim, 1)def forward(self, query, keys, values):# query: [batch_size, query_dim]# keys: [batch_size, seq_len, key_dim]# values: [batch_size, seq_len, value_dim]# 计算查询和键的投影query_proj = self.query_layer(query)  # [batch_size, hidden_dim]keys_proj = self.key_layer(keys)  # [batch_size, seq_len, hidden_dim]# 将查询扩展到和键的时间步相同的维度query_proj = query_proj.unsqueeze(1).expand_as(keys_proj)  # [batch_size, seq_len, hidden_dim]# 计算 e_ij = tanh(W_q q + W_k k)energy = torch.tanh(query_proj + keys_proj)  # [batch_size, seq_len, hidden_dim]# 计算注意力分数,并去掉最后一维attention_scores = self.energy_layer(energy).squeeze(-1)  # [batch_size, seq_len]# 通过softmax得到注意力权重attention_weights = F.softmax(attention_scores, dim=1)  # [batch_size, seq_len]# 加权求和值context = torch.bmm(attention_weights.unsqueeze(1), values).squeeze(1)  # [batch_size, value_dim]return context, attention_weights# 测试加性注意力
batch_size = 2
query_dim = 5
key_dim = 5
value_dim = 6
seq_len = 10
hidden_dim = 20# 随机生成查询、键和值
query = torch.randn(batch_size, query_dim)
keys = torch.randn(batch_size, seq_len, key_dim)
values = torch.randn(batch_size, seq_len, value_dim)# 实例化加性注意力
additive_attention = AdditiveAttention(query_dim, key_dim, hidden_dim)# 前向传播
context, attention_weights = additive_attention(query, keys, values)print("上下文向量:", context)
print("注意力权重:", attention_weights)

5. 代码逐句解释

1. 导入库:

import torch
import torch.nn as nn
import torch.nn.functional as F

导入PyTorch库,其中torch用于张量操作,nn包含神经网络模块,F提供常用函数如softmax。

2. 定义加性注意力类:

class AdditiveAttention(nn.Module):def __init__(self, query_dim, key_dim, hidden_dim):super(AdditiveAttention, self).__init__()# 定义线性层,用于将查询和键投影到同一维度self.query_layer = nn.Linear(query_dim, hidden_dim)self.key_layer = nn.Linear(key_dim, hidden_dim)# 定义计算注意力能量的线性层self.energy_layer = nn.Linear(hidden_dim, 1)

这里定义了AdditiveAttention类,继承自nn.Modulequery_layerkey_layer分别是将查询和键投影到同一维度的线性层,energy_layer用于计算注意力能量分数。

3. 前向传播函数:

def forward(self, query, keys, values):query_proj = self.query_layer(query)  # [batch_size, hidden_dim]keys_proj = self.key_layer(keys)  # [batch_size, seq_len, hidden_dim]# 扩展查询的维度,使其与键对齐query_proj = query_proj.unsqueeze(1).expand_as(keys_proj)# 计算注意力能量:e_ij = tanh(W_q q + W_k k)energy = torch.tanh(query_proj + keys_proj)  # [batch_size, seq_len, hidden_dim]# 通过线性层计算注意力分数,并去掉最后一维attention_scores = self.energy_layer(energy).squeeze(-1)  # [batch_size, seq_len]# 使用softmax归一化得到注意力权重attention_weights = F.softmax(attention_scores, dim=1)  # [batch_size, seq_len]# 计算上下文向量,通过加权求和值context = torch.bmm(attention_weights.unsqueeze(1), values).squeeze(1)  # [batch_size, value_dim]return context, attention_weights
  • forward函数负责计算加性注意力的前向传播过程。首先,将查询和键分别通过线性层映射到相同的维度。
  • 然后,计算注意力能量,并使用softmax进行归一化,得到注意力权重。
  • 最后,使用这些注意力权重对值进行加权求和,生成上下文向量。
    4. 测试模型:
# 测试加性注意力
query = torch.randn(batch_size, query_dim)
keys = torch.randn(batch_size, seq_len, key_dim)
values = torch.randn(batch_size, seq_len, value_dim)# 实例化加性注意力
additive_attention = AdditiveAttention(query_dim, key_dim, hidden_dim)# 前向传播
context, attention_weights = additive_attention(query, keys, values)print("上下文向量:", context)
print("注意力权重:", attention_weights)

在这里,使用随机生成的张量querykeysvalues来测试加性注意力的输出。

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/454447.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

kubernetes(三)

k8s之持久化存储pv&pvc 存储资源管理 在基于k8s容器云平台上,对存储资源的使用需求通常包括以下几方面: 1.应用配置文件、密钥的管理; 2.应用的数据持久化存储; 3.在不同的应用间共享数据存储; k8s支持Volume类…

Spring MVC文件请求处理-MultipartResolver

Spring Boot中的MultipartResolver是一个用于解析multipart/form-data类型请求的策略接口,通常用于文件上传。 对应后端使用MultipartFile对象接收。 RequestMapping("/upload")public String uploadFile(MultipartFile file) throws IOException {Strin…

十一、数据库配置

一、Navicat配置 这个软件需要破解 密码是:123456; 新建连接》新建数据库 创建一个表 保存出现名字设置 双击打开 把id设置为自动递增 这里就相当于每一次向数据库添加一个语句,会自动增长id一次 二、数据库的增删改查 1、Vs 建一个控…

磁编码器的工作原理和特点

目录 概述 1 磁编码器的构造 1.1 霍尔元件 1.2 永磁体 1.3 永磁体和霍尔元件的配置 2 磁编码器的工作原理 2.1 原理介绍 2.2 电气信号转换成角度 2.3 旋转角度传感器IC 3 磁编码器的特点和主要应用 概述 本文主要介绍磁编码器的构造原理,工作特性和应用特…

C/C++函数调用约定:__cdecl、__stdcall、__fastcall和__thiscall

目录 1.引言 2.常见函数调用约定 2.1.__cdecl 2.2.__stdcall 2.3.__fastcall 2.4.__thiscall 3.几种调用约定比较 4.注意事项 1.引言 在C和C编程中,函数调用约定(Calling Convention)定义了函数如何接收参数、如何返回值以及由谁来清…

【小沐学Golang】基于Go语言搭建静态文件服务器

文章目录 1、简介2、安装2.1 安装版2.2 压缩版 3、基本操作3.1 go run3.2 go build3.3 go install3.4 go env3.5 go module 4、文件服务器4.1 filebrowser4.2 gohttpserver4.3 goFile 5、FAQ5.1 go.mod 为空5.2 超时 结语 1、简介 https://golang.google.cn/ Go语言诞生于2007…

word表格跨页后自动生成的顶部横线【去除方法】

Hello World! Its been a long time. 这一年重心放在了科研、做事、追寻新的经历上,事有正事、琐事、幸事、哀事,内心与认知成长了一些,思想成熟了几分,技艺也有若干收获。不管怎样,来打个卡吧,纪念一下&…

Web前端高级工程师培训:使用 Node.js 构建一个 Web 服务端程序(3)

11、HTTP 协议 11-1、协议的定义 HTTP 是一种能够获取如 HTML 这样的网络资源的 protocol(通讯协议)。它是在 Web 上进行数据交换的基础,是一种 client-server 协议,也就是说,请求通常是由像浏览器这样的接受方发起的。一个完整的Web文档通…

Tailwind Starter Kit 一款极简的前端快速启动模板

Tailwind Starter Kit 是基于TailwindCSS实现的一款开源的、使用简单的极简模板扩展。会用Tailwincss就可以快速入手使用。Tailwind Starter Kit 是免费开源的。它不会在原始的TailwindCSS框架中更改或添加任何CSS。它具有多个HTML元素,并附带了ReactJS、Vue和Angul…

Docker安装Mysql5.7,解决无法访问DockerHub问题

Docker安装Mysql5.7,解决无法访问DockerHub问题 简介 Docker Hub 无法访问,应用安装失败,镜像拉取超时的解决方案。 摘要 : 当 Docker Hub 无法访问时,可以通过配置国内镜像加速来解决应用安装失败和镜像拉取超时的…

使用爬虫爬取Python中文开发者社区基础教程的数据

👨‍💻个人主页:开发者-曼亿点 👨‍💻 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍💻 本文由 曼亿点 原创 👨‍💻 收录于专栏&#xff1a…

微信小程序文本收起展开

这里写自定义目录标题 微信小程序文本收起展开常见问题的梯形背景框 微信小程序文本收起展开 参考 https://juejin.cn/post/6963904955262435336 <!-- 常见问题解答 --><view classcontentBottom><view classBottomFirst><text id0 data-id0 class&quo…

python + mitmproxy 爬手机app (1)

起因&#xff0c; 目的: 想爬手机上某鱼。 mitmproxy 简介: 一句话: mitmproxy 就是中间人攻击. (只不过&#xff0c; 你安装&#xff0c;就代表你愿意承担风险。)源码&#xff1a;https://github.com/mitmproxy/mitmproxy文档: https://mitmproxy.org/ 安装过程: 见聊天记…

eCAP超声波测距-ePWM电机调速

目录 eCAP超声波测距 整体框架 关键模块 实验效果 PWM电机调速 DRV8833基本介绍 整体框架 eCAP超声波测距 本实验所用的超声波HC-SR04模块如下图所示&#xff0c;左边为正面图&#xff0c;右边为反面图。 HC-SR04基本工作原理&#xff1a; &#xff08;1&#xff09;采…

spring源码中的,函数式接口,注解@FunctionalInterface

调用方 /org/springframework/beans/factory/support/AbstractBeanFactory.java:333sharedInstance getSingleton(beanName, () -> {try {return createBean(beanName, mbd, args);}catch (BeansException ex) {// Explicitly remove instance from singleton cache: It mi…

Kafka之消费者客户端

1、历史上的二个版本 与生产者客户端一样&#xff0c;在Kafka的发展过程当中&#xff0c;消费者客户端主要有两个大的版本&#xff1a; 旧消费者客户端&#xff08;Old Consumer&#xff09;&#xff1a;基于Scala语言开发的版本&#xff0c;又称为Scala消费者客户端。新消费…

rpm 命令

rpm&#xff08;Red Hat Package Manager&#xff09;是 Red Hat Linux 及其衍生发行版&#xff08;如 CentOS、Fedora&#xff09;中用于管理软件包的系统。它允许用户安装、卸载、升级、查询和验证软件包。 一、安装软件包 &#xff08;1&#xff09;安装一个 RPM 软件包&a…

高并发下如何保证接口的幂等性?

前言 接口幂等性问题,对于开发人员来说,是一个跟语言无关的公共问题。本文分享了一些解决这类问题非常实用的办法,绝大部分内容我在项目中实践过的,给有需要的小伙伴一个参考。 不知道你有没有遇到过这些场景: 有时我们在填写某些form表单时,保存按钮不小心快速点了两次…

十二、【智能体】深入剖析:大模型节点的全面解读,举例说明,教你如何在扣子中嵌入代码

大模型节点 大模型节点主要分为5部分&#xff1a; 处理类型 单次批处理 模型类型&#xff1a;目前可以选择的模型有 豆包、通义千问、智谱、MinMax和Kimi输入:此时的参数可以被下面的提示词所用提示词&#xff1a;给大模型使用的提示词输出&#xff1a;经过此大模型处理后的输…

Vehicle Spy3.9如何新建工程—总览

1&#xff1a;写作目的 学习和精通SPY的使用&#xff0c;对于spy&#xff0c;目前主要是通用系用的比较多&#xff0c;本身spy的生产厂家英特佩斯也是美国的公司&#xff0c;除了软件自带教程。中文网上很少能找到相关的中文教程。 故写下这篇文章&#xff0c;帮助自己和大家…