《动手学深度学习 Pytorch版》 10.3 注意力评分函数

上一节使用的高斯核的指数部分可以视为注意力评分函数(attention scoring function),简称评分函数(scoring function)。

后续把评分函数的输出结果输入到softmax函数中进行运算。最后,注意力汇聚的输出就是基于这些注意力权重的值的加权和。该过程可描述为下图:

在这里插入图片描述

用数学语言描述为:

f ( q , ( k 1 , v 1 ) , … , ( k m , v m ) ) = ∑ i = 1 m α ( q , k i ) v i ∈ R v f(\boldsymbol{q},(\boldsymbol{k}_1,\boldsymbol{v}_1),\dots,(\boldsymbol{k}_m,\boldsymbol{v}_m))=\sum^m_{i=1}{\alpha(\boldsymbol{q},\boldsymbol{k}_i)\boldsymbol{v}_i}\in\R^v f(q,(k1,v1),,(km,vm))=i=1mα(q,ki)viRv

其中查询 q \boldsymbol{q} q 和键 k i \boldsymbol{k}_i ki 的注意力权重(标量)是通过注意力评分函数 a a a 将两个向量映射成标量,再经过softmax运算得到的:

α ( q , k i ) = s o f t m a x ( a ( q , k i ) ) = a ( q , k i ) ∑ j = 1 m exp ⁡ a ( q , k i ) ∈ R \alpha(\boldsymbol{q},\boldsymbol{k}_i)=\mathrm{softmax}(a(\boldsymbol{q},\boldsymbol{k}_i))=\frac{a(\boldsymbol{q},\boldsymbol{k}_i)}{\sum^m_{j=1}{\exp{a(\boldsymbol{q},\boldsymbol{k}_i)}}}\in\R α(q,ki)=softmax(a(q,ki))=j=1mexpa(q,ki)a(q,ki)R

import math
import torch
from torch import nn
from d2l import torch as d2l

以下介绍的是两个流行的评分函数。

10.3.1 遮蔽 softmax 操作

并非所有的值都应该被纳入到注意力汇聚中。下面的 masked_softmax 函数实现了这样的掩蔽softmax操作(masked softmax operation),其中任何超出有效长度的位置都被掩蔽并置为0。

#@save
def masked_softmax(X, valid_lens):"""通过在最后一个轴上掩蔽元素来执行softmax操作"""# X:3D张量,valid_lens:1D或2D张量if valid_lens is None:return nn.functional.softmax(X, dim=-1)else:shape = X.shapeif valid_lens.dim() == 1:valid_lens = torch.repeat_interleave(valid_lens, shape[1])else:valid_lens = valid_lens.reshape(-1)# 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,value=-1e6)return nn.functional.softmax(X.reshape(shape), dim=-1)
print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3])))  # 两样本有效长度分别为 2 和 3
print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]])))  # 也可以给每一行指定有效长度
tensor([[[0.4297, 0.5703, 0.0000, 0.0000],[0.6186, 0.3814, 0.0000, 0.0000]],[[0.2413, 0.3333, 0.4254, 0.0000],[0.4165, 0.2801, 0.3034, 0.0000]]])
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],[0.3277, 0.4602, 0.2121, 0.0000]],[[0.5026, 0.4974, 0.0000, 0.0000],[0.2684, 0.2599, 0.2613, 0.2103]]])

10.3.2 加性注意力

当查询和键是不同长度的矢量时,可以使用加性注意力作为评分函数。加性注意力(additive attention)的评分函数为:

a ( q , k i ) = w v T tanh ⁡ ( W q q + W k k ) ∈ R a(\boldsymbol{q},\boldsymbol{k}_i)=\boldsymbol{\mathrm{w}}_v^T\tanh{(\boldsymbol{\mathrm{W}}_q\boldsymbol{q}+\boldsymbol{\mathrm{W}}_k\boldsymbol{k})}\in\R a(q,ki)=wvTtanh(Wqq+Wkk)R

参数字典:

  • q ∈ R q \boldsymbol{q}\in\R^q qRq 表示查询

  • k ∈ R k \boldsymbol{k}\in\R^k kRk 表示键

  • W q ∈ R h × q \boldsymbol{\mathrm{W}}_q\in\R^{h\times q} WqRh×q W k ∈ R h × k \boldsymbol{\mathrm{W}}_k\in\R^{h\times k} WkRh×k W v ∈ R h \boldsymbol{\mathrm{W}}_v\in\R^h WvRh 均为可学习参数。

#@save
class AdditiveAttention(nn.Module):"""加性注意力"""def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):super(AdditiveAttention, self).__init__(**kwargs)self.W_k = nn.Linear(key_size, num_hiddens, bias=False)self.W_q = nn.Linear(query_size, num_hiddens, bias=False)self.w_v = nn.Linear(num_hiddens, 1, bias=False)self.dropout = nn.Dropout(dropout)  # 使用了暂退法进行模型正则化def forward(self, queries, keys, values, valid_lens):# 初始 q 和 k 的形状如下,不好直接加# queries 的形状:(batch_size,查询的个数,num_hidden)# key 的形状:(batch_size,“键-值”对的个数,num_hiddens)queries, keys = self.W_q(queries), self.W_k(keys)# 在维度扩展后,# queries 的形状:(batch_size,查询的个数,1,num_hidden)# key 的形状:(batch_size,1,“键-值”对的个数,num_hiddens)features = queries.unsqueeze(2) + keys.unsqueeze(1)  # 优雅,实在优雅 使用广播方式进行求和features = torch.tanh(features)# self.w_v 仅有一个输出,因此从形状中移除最后那个维度。# scores 的形状:(batch_size,查询的个数,“键-值”对的个数)scores = self.w_v(features).squeeze(-1)  # 把最后一个维度去掉self.attention_weights = masked_softmax(scores, valid_lens)# values的形状:(batch_size,“键-值”对的个数,值的维度)return torch.bmm(self.dropout(self.attention_weights), values)
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))  # 查询、键和值的形状为(批量大小,步数或词元序列长度,特征大小)
# values的小批量,两个值矩阵是相同的
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
valid_lens = torch.tensor([2, 6])attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,dropout=0.1)
attention.eval()
attention(queries, keys, values, valid_lens)  # 注意力汇聚输出的形状为(批量大小,查询的步数,值的维度)
tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],[[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),  # 本例子中每个键都是相同的,所以注意力权重是均匀的,由指定的有效长度决定。xlabel='Keys', ylabel='Queries')


在这里插入图片描述

10.3.3 缩放点积注意力

使用点积可以得到计算效率更高的评分函数,缩放点积注意力(scaled dot-product attention)评分函数为:

a ( q , k ) = q T k / d a(\boldsymbol{q},\boldsymbol{k})=\boldsymbol{q}^T\boldsymbol{k}/\sqrt{d} a(q,k)=qTk/d

在实践中,我们通常从小批量的角度来考虑提高效率:

s o f t m a x ( Q K T d ) V ∈ R n × v \mathrm{softmax}\left(\frac{\boldsymbol{Q}\boldsymbol{K}^T}{\sqrt{d}}\right)\boldsymbol{V}\in\R^{n\times v} softmax(d QKT)VRn×v

实际上就是用两个向量内积作为 Q 和 K 的相似度(越不像越正交,越正交内积越接近零),最后来个 softmax 归到概率上。中间除个 d \sqrt{d} d 是为了在向量尺寸比较大的时候使概率分布更平滑。

#@save
class DotProductAttention(nn.Module):"""缩放点积注意力"""def __init__(self, dropout, **kwargs):super(DotProductAttention, self).__init__(**kwargs)self.dropout = nn.Dropout(dropout)# queries的形状:(batch_size,查询的个数,d)# keys的形状:(batch_size,“键-值”对的个数,d)# values的形状:(batch_size,“键-值”对的个数,值的维度)# valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)def forward(self, queries, keys, values, valid_lens=None):d = queries.shape[-1]# 设置transpose_b=True为了交换keys的最后两个维度scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)self.attention_weights = masked_softmax(scores, valid_lens)return torch.bmm(self.dropout(self.attention_weights), values)
queries = torch.normal(0, 1, (2, 1, 2))  # 点积操作需要查询的特征维度与键的特征维度大小相同
attention = DotProductAttention(dropout=0.5)  # 使用了暂退法进行模型正则化
attention.eval()
attention(queries, keys, values, valid_lens)
tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],[[10.0000, 11.0000, 12.0000, 13.0000]]])
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),xlabel='Keys', ylabel='Queries')


在这里插入图片描述

练习

(1)修改小例子中的键,并且可视化注意力权重。可加性注意力和缩放的“点-积”注意力是否仍然产生相同的结果?为什么?

不一样,评分函数不一样,键值不同的话那注意力汇聚肯定不一样的。

queries_new, keys_rand = torch.normal(0, 1, (2, 1, 2)), torch.rand((2, 10, 2))
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
valid_lens = torch.tensor([2, 6])attention_rand = AdditiveAttention(key_size=2, query_size=2, num_hiddens=8,dropout=0.1)
attention_rand.eval()
attention_rand(queries_new, keys_rand, values, valid_lens)d2l.show_heatmaps(attention_rand.attention_weights.reshape((1, 1, 2, 10)),xlabel='Keys', ylabel='Queries')


在这里插入图片描述

attention_rand = DotProductAttention(dropout=0.5)
attention_rand.eval()
attention_rand(queries_new, keys_rand, values, valid_lens)d2l.show_heatmaps(attention_rand.attention_weights.reshape((1, 1, 2, 10)),xlabel='Keys', ylabel='Queries')


在这里插入图片描述


(2)只使用矩阵乘法,能否为具有不同矢量长度的查询和键设计新的评分函数?

可以想办法把他俩映射到一个长度。


(3)当查询和键具有相同的矢量长度时,矢量求和作为评分函数是否比“点-积”更好?为什么?

不会,略。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/172931.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

表的约束【MySQL】

文章目录 什么是约束DEFAULT&#xff08;默认约束&#xff09;NULL 与 NOT NULL&#xff08;非空约束&#xff09;COMMENT&#xff08;注释约束&#xff09;ZEROFILL&#xff08;零填充约束&#xff09;UNIQUE&#xff08;唯一键约束&#xff09;*PRIMARY KEY&#xff08;主键约…

Linux常用命令——chown命令

在线Linux命令查询工具 chown 用来变更文件或目录的拥有者或所属群组 补充说明 chown命令改变某个文件或目录的所有者和所属的组&#xff0c;该命令可以向某个用户授权&#xff0c;使该用户变成指定文件的所有者或者改变文件所属的组。用户可以是用户或者是用户D&#xff0…

Vuex模块化(modules)与namespaced(命名空间)的搭配

Vuex模块化&#xff08;modules&#xff09;与namespaced&#xff08;命名空间&#xff09;的搭配 Vuex模块化&#xff08;modules&#xff09;格式 原理&#xff1a;可以对Vuex的actions&#xff0c;mutations&#xff0c;state&#xff0c;getters四个属性综合成一个部分&a…

对Happens-Before的理解

Happens-Before Happens-Before 是一种可见性模型&#xff0c;也就是说&#xff0c;在多线程环境下。原本因为指令重排序的存在会导致数据的可见性问题&#xff0c;也就是 A 线程修改某个共享变量对 B 线程不可见。因此&#xff0c;JMM 通过 Happens-Before 关系向开发人员提供…

2023年香水行业数据分析:国人用香需求升级,高端香水高速增长

在人口结构变迁的背景下&#xff0c;“Z世代”作为当下我国的消费主力&#xff0c;正在将“悦己”消费推动成为新潮流。具备经济基础的“Z世代”倡导“高颜值”、“个性化”、“精致主义”&#xff0c;这和香水、香氛为代表的“嗅觉经济”的特性充分契合&#xff0c;因此&#…

【Docker从入门到入土 6】Consul详解+Docker https安全认证(附证书申请方式)

Part 6 一、服务注册与发现的概念1.1 cmp问题1.2 服务注册与发现 二、Consul ----- 服务自动发现和注册2.1 简介2.2 为什么要用consul&#xff1f;2.3 consul的架构2.3 Consul-template 三、consul架构部署3.1 Consul服务器Step1 建立 Consul 服务Step2 查看集群信息Step3 通过…

Flutter笔记:完全基于Flutter绘图技术绘制一个精美的Dash图标(中)

Flutter笔记 完全基于Flutter绘图技术绘制一个精美的Dart语言吉祥物Dash&#xff08;中&#xff09; 作者&#xff1a;李俊才 &#xff08;jcLee95&#xff09;&#xff1a;https://blog.csdn.net/qq_28550263 邮箱 &#xff1a;291148484163.com 本文地址&#xff1a;https://…

Android framework服务命令行工具框架 - Android13

Android framework服务命令行工具框架 - Android13 1、framework服务命令行工具简介2、cmd 执行程序2.1 目录和Android.bp2.2 cmdMain 执行入口2.3 cmd命令 3、am命令工具&#xff0c;实质脚本执行cmd activity3.1 sh脚本3.2 activity服务注册3.3 onShellCommand执行 4、简易时…

《从零开始大模型开发与微调 :基于PyTorch与ChatGLM》简介

内 容 简 介 大模型是深度学习自然语言处理皇冠上的一颗明珠&#xff0c;也是当前AI和NLP研究与产业中最重要的方向之一。本书使用PyTorch 2.0作为学习大模型的基本框架&#xff0c;以ChatGLM为例详细讲解大模型的基本理论、算法、程序实现、应用实战以及微调技术&#xff0c;…

Chapter1:C++概述

此专栏为移动机器人知识体系的 C {\rm C} C基础&#xff0c;基于《深入浅出 C {\rm C} C》(马晓锐)的笔记&#xff0c; g i t e e {\rm gitee} gitee链接: 移动机器人知识体系. 1.C概述 1.1 C概述 计算机系统分为硬件系统和软件系统。 硬件系统&#xff1a;指组成计算机的电子…

通过阿里云创建accessKeyId和accessKeySecret

我们想实现服务端向个人发送短信验证码 需要通过accessKeyId和accessKeySecret 这里可以白嫖阿里云的 这里 我们先访问阿里云官网 阿里云地址 进入后搜索并进入短信服务 如果没登录 就 登录一下先 然后在搜索框搜索短信服务 点击进入 因为我也是第一次操作 我们一起点免费开…

2017年上半年上午易错题(软件设计师考试)

CPU 执行算术运算或者逻辑运算时&#xff0c;常将源操作数和结果暂存在&#xff08; &#xff09;中。 A &#xff0e; 程序计数器 (PC) B. 累加器 (AC) C. 指令寄存器 (IR) D. 地址寄存器 (AR) 某系统由下图所示的冗余部件构成。若每个部件的千小时可靠度都为 R &…

深度学习之基于YoloV8的行人跌倒目标检测系统

欢迎大家点赞、收藏、关注、评论啦 &#xff0c;由于篇幅有限&#xff0c;只展示了部分核心代码。 文章目录 一项目简介 二、功能三、行人跌倒目标检测系统四. 总结 一项目简介 世界老龄化趋势日益严重&#xff0c;现代化的生活习惯又使得大多数老人独居&#xff0c;统计数据表…

美术如何创建 skybox 贴图资源?

文章目录 目的PS手绘Panorama To CubemapPS手绘Pano2VRSkybox & Cubemap Tutorial (Maya & Photoshop)Unity 中使用 ReflectionProbe 生成 Cubemap 然后再 PS 调整PS直接手绘 cubemapBlender 导入 Panorama&#xff0c;然后烘焙到 cubemap&#xff0c;再导入unity中使用…

【ARMv8 SIMD和浮点指令编程】NEON 通用数据处理指令——复制、反转、提取、转置...

NEON 通用数据处理指令包括以下指令(不限于): • DUP 将标量复制到向量的所有向量线。 • EXT 提取。 • REV16、REV32、REV64 反转向量中的元素。 • TBL、TBX 向量表查找。 • TRN 向量转置。 • UZP、ZIP 向量交叉存取和反向交叉存取。 1 DUP (element) 将…

基于计算机视觉的坑洼道路检测和识别-MathorCup A(深度学习版本)

1 2023 年 MathorCup 高校数学建模挑战赛——大数据竞赛 赛道 A&#xff1a;基于计算机视觉的坑洼道路检测和识别 使用深度学习模型&#xff0c;pytorch版本进行图像训练和预测&#xff0c;使用ResNet50模型 2 文件夹预处理 因为给定的是所有图片都在一个文件夹里面&#xf…

前端将图片储存table表格中,页面回显

<el-table :data"tableData" v-loading"loading" style"width: 100%" height"calc(100vh - 270px)" :size"tableSize"row-dblclick"enterClick"><el-table-column prop"name" label"文档…

图像数据噪音种类以及Python生成对应噪音

前言 当涉及到图像处理和计算机视觉任务时&#xff0c;噪音是一个不可忽视的因素。噪音可以由多种因素引起&#xff0c;如传感器误差、通信干扰、环境光线变化等。这些噪音会导致图像质量下降&#xff0c;从而影响到后续的图像分析和处理过程。因此&#xff0c;对于从图像中获…

数据结构时间复杂度(补充)和空间复杂度

Hello&#xff0c;今天事10月27日&#xff0c;距离刚开始写博客已经过去挺久了&#xff0c;我也不知道是什么让我坚持这么久&#xff0c;但是学校的课真的很多&#xff0c;很少有时间多出来再学习&#xff0c;有些科目马上要考试了&#xff0c;我还不知道我呢不能过哈哈哈&…

新的iLeakage攻击从Apple Safari窃取电子邮件和密码

图片 导语&#xff1a;学术研究人员开发出一种新的推测性侧信道攻击&#xff0c;名为iLeakage&#xff0c;可在所有最新的Apple设备上运行&#xff0c;并从Safari浏览器中提取敏感信息。 攻击概述 iLeakage是一种新型的推测性执行攻击&#xff0c;针对的是Apple Silicon CPU和…