【transformer】自注意力源码解读和复杂度计算

Self-attention

1

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

其中, Q Q Q为查询向量, K K K V V V为键向量和值向量, d k d_k dk为向量的维度。 Q Q Q K K K V V V在一般情况下是相同的。公式中的softmax函数将分数归一化为概率,得到加权的值向量。这里的注意力机制是通过计算查询向量 Q Q Q和键向量 K K K之间的相似性,来为值向量 V V V分配不同的权重。如果两个向量越相似,则它们之间的权重应该越大,反之则越小。

def attention(query, key, value, mask=None, dropout=None):"Compute 'Scaled Dot Product Attention'"d_k = query.size(-1)  # 获取文本嵌入维度大小# 按照注意力机制的公式计算注意力分数scores = torch.matmul(query, key.transpose(-2, -1)) \/ math.sqrt(d_k)# 是否使用掩码if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)# 使用softmax对最后一个维度获得注意力张量p_attn = F.softmax(scores, dim = -1)if dropout is not None:p_attn = dropout(p_attn)# 注意力张量与value相乘得到query的注意力表示return torch.matmul(p_attn, value), p_attn

一个形状为 N × M N\times M N×M 的矩阵,与另一个形状为 M × P M\times P M×P的矩阵相乘,其运算复杂度来源于乘法操作的次数,时间复杂度为 O ( N M P ) O(NMP) O(NMP)

Self-attention的公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V其中, Q Q Q为查询向量, K K K V V V为键向量和值向量, d k d_k dk为向量的维度。 Q Q Q K K K V V V在一般情况下是相同的。公式中的softmax函数将分数归一化为概率,得到加权的值向量。
Self-Attention的计算复杂度主要来自三个方面:查询矩阵、键矩阵和值矩阵的乘积、softmax 的计算、以及输出向量和值的加权平均。
对于一个由n个单词组成的输入序列,假设有d个维度的特征,那么查询矩阵、键矩阵和值矩阵的维度都将是 n × d。

  • 对于查询矩阵 Q 和键矩阵 K 的点积, n × d n\times d n×d d × n d\times n d×n计算复杂度是 O ( n 2 d ) O(n^2d) O(n2d)
  • 每行 softmax 的计算,计算复杂度为 O ( n ) O(n) O(n),对n行做softmax,复杂度为 O ( n 2 ) O(n^2) O(n2)
  • 对于值矩阵 V (维度 n × d n\times d n×d)和 softmax 后的结果(维度 n × n n\times n n×n)进行点积,得到每个查询向量的加权平均值,复杂度是 O ( n 2 d ) O(n^2d) O(n2d)

因此,总的计算复杂度是 O ( n 2 d ) + O ( n 2 ) + O ( n 2 d ) ≃ O ( n 2 d ) O(n^2d) + O(n^2) + O(n^2d) \simeq O(n^2d) O(n2d)+O(n2)+O(n2d)O(n2d)
由于这个复杂度是关于输入序列长度n的平方级别,因此Self-Attention在处理长序列时可能会面临计算上的挑战。

多头注意力

2
多头注意力的计算公式如下:
MultiHead ⁡ ( Q , K , V ) = Concat ⁡ ( head ⁡ 1 , … , head  h ) W O where  head  i = A ( Q W i Q , K W i K , V W i V ) \begin{aligned} \operatorname{MultiHead}(Q, K, V) & =\operatorname{Concat}\left(\operatorname{head}_1, \ldots, \text { head }_{\mathrm{h}}\right) W^O \\ \text { where } \text { head }_{\mathrm{i}} & =A\left(Q W_i^Q, K W_i^K, V W_i^V\right) \end{aligned} MultiHead(Q,K,V) where  head i=Concat(head1,, head h)WO=A(QWiQ,KWiK,VWiV)其中, Q , K , V Q,K,V Q,K,V 分别表示查询、键和值, h h h 表示头数, h e a d i head_i headi 表示第 i i i 个注意力头, W O W^O WO 表示输出层的权重矩阵。

# 用于深度拷贝的copy工具包
import copy# 首先需要定义克隆函数, 因为在多头注意力机制的实现中, 用到多个结构相同的线性层.
# 我们将使用clone函数将他们一同初始化在一个网络层列表对象中. 之后的结构中也会用到该函数.
def clones(module, N):"""用于生成相同网络层的克隆函数, 它的参数module表示要克隆的目标网络层, N代表需要克隆的数量"""# 在函数中, 我们通过for循环对module进行N次深度拷贝, 使其每个module成为独立的层,# 然后将其放在nn.ModuleList类型的列表中存放.return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])# 我们使用一个类来实现多头注意力机制的处理
class MultiHeadedAttention(nn.Module):def __init__(self, head, embedding_dim, dropout=0.1):"""在类的初始化时, 会传入三个参数,head代表头数,embedding_dim代表词嵌入的维度, dropout代表进行dropout操作时置0比率,默认是0.1."""super(MultiHeadedAttention, self).__init__()# 在函数中,首先使用了一个测试中常用的assert语句,判断h是否能被d_model整除,# 这是因为我们之后要给每个头分配等量的词特征.也就是embedding_dim/head个.assert embedding_dim % head == 0# 得到每个头获得的分割词向量维度d_kself.d_k = embedding_dim // head# 传入头数hself.head = head# 然后获得线性层对象,通过nn的Linear实例化,它的内部变换矩阵是embedding_dim x embedding_dim,然后使用clones函数克隆四个,# 为什么是四个呢,这是因为在多头注意力中,Q,K,V各需要一个,最后拼接的矩阵还需要一个,因此一共是四个.self.linears = clones(nn.Linear(embedding_dim, embedding_dim), 4)# self.attn为None,它代表最后得到的注意力张量,现在还没有结果所以为None.self.attn = None# 最后就是一个self.dropout对象,它通过nn中的Dropout实例化而来,置0比率为传进来的参数dropout.self.dropout = nn.Dropout(p=dropout)def forward(self, query, key, value, mask=None):"""前向逻辑函数, 它的输入参数有四个,前三个就是注意力机制需要的Q, K, V,最后一个是注意力机制中可能需要的mask掩码张量,默认是None. """# 如果存在掩码张量maskif mask is not None:# 使用unsqueeze拓展维度mask = mask.unsqueeze(0)# 接着,我们获得一个batch_size的变量,他是query尺寸的第1个数字,代表有多少条样本.batch_size = query.size(0)# 之后就进入多头处理环节# 首先利用zip将输入QKV与三个线性层组到一起,然后使用for循环,将输入QKV分别传到线性层中,# 做完线性变换后,开始为每个头分割输入,这里使用view方法对线性变换的结果进行维度重塑,多加了一个维度h,代表头数,# 这样就意味着每个头可以获得一部分词特征组成的句子,其中的-1代表自适应维度,# 计算机会根据这种变换自动计算这里的值.然后对第二维和第三维进行转置操作,# 为了让代表句子长度维度和词向量维度能够相邻,这样注意力机制才能找到词义与句子位置的关系,# 从attention函数中可以看到,利用的是原始输入的倒数第一和第二维.这样我们就得到了每个头的输入.query, key, value = \[model(x).view(batch_size, -1, self.head, self.d_k).transpose(1, 2)for model, x in zip(self.linears, (query, key, value))]# 得到每个头的输入后,接下来就是将他们传入到attention中,# 这里直接调用我们之前实现的attention函数.同时也将mask和dropout传入其中.x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)# 通过多头注意力计算后,我们就得到了每个头计算结果组成的4维张量,我们需要将其转换为输入的形状以方便后续的计算,# 因此这里开始进行第一步处理环节的逆操作,先对第二和第三维进行转置,然后使用contiguous方法,# 这个方法的作用就是能够让转置后的张量应用view方法,否则将无法直接使用,# 所以,下一步就是使用view重塑形状,变成和输入形状相同.x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.head * self.d_k)# 最后使用线性层列表中的最后一个线性层对输入进行线性变换得到最终的多头注意力结构的输出.return self.linears[-1](x)

在多头注意力中,假设有 h h h 个头,每个头的查询、键和值的维度是 d k d_k dk d k d_k dk d v d_v dv,一般情况 d q = d k = d v = d h d_q=d_k=d_v=\frac{d}{h} dq=dk=dv=hd, 输入序列的长度为 N N N

  • 输入线性映射的复杂度: n × d n\times d n×d d × d h d \times \frac{d}{h} d×hd,计算复杂度是 O ( n d 2 h ) O(\frac{nd^2 }{h}) O(hnd2)
  • 注意力计算:输入线性映射后的维度 n × d h n \times \frac{d}{h} n×hd n × d h n \times \frac{d}{h} n×hd d h × n \frac{d}{h}\times n hd×n计算复杂度是 O ( n 2 d h ) O(n^2\frac{d}{h}) O(n2hd)
  • 输出线性映射: 多个头的结果concat成一个 n × d n\times d n×d矩阵, n × d n\times d n×d d × d d \times d d×d,计算复杂度是 O ( n d 2 ) O(nd^2) O(nd2)

总时间复杂度 O ( n d 2 h + n 2 d h + n d 2 ) O(\frac{nd^2}{h}+n^2\frac{d}{h}+nd^2) O(hnd2+n2hd+nd2)


参考:
传智博客-Transformer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/121757.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

油田钻井平台三维应急仿真系统降低事故发生概率

海上钻井是供海上作业人员进行生产作业或者其他活动使用的仿陆地区域,被称为“流动的国土”,主要由上部平台、下浮体(沉垫浮箱)和中部立柱三部分组成,平台上安装钻井、动力、通讯、导航等设备,以及安全救生和人员生活等设施&#…

OB Cloud 初体验

文章来源:韩锋频道 韩锋 数据库行业资深从业者,著有《SQL 优化最佳实践》、《数据库高效优化》等数据库相关著作。 OceanBase(下文简称OB) 作为国内一款优秀的分布式数据库,这些年来发展很快,在金融、电商…

Java之文件操作与IO

目录 一.认识文件 1.1文件是什么? 1.2文件的组织 1.3文件路径 1.4文件的分类 二.文件操作 2.1File概述 三.文件内容操作--IO 3.1JavaIO的认识 3.2Reader和Writer ⭐Reader类 ⭐Writer类 3.2FileInputStream和FileOutputStream ⭐FileInputStream类 …

继承(个人学习笔记黑马学习)

1、基本语法 #include <iostream> using namespace std; #include <string>//普通实现页面//Java页面 //class Java { //public: // void header() { // cout << "首页、公开课、登录、注册...(公共头部)" << endl; // } // void footer() …

【广州华锐互动】VR全景工厂虚拟导览,虚拟现实技术提升企业数字化信息管理水平

随着工业4.0的到来&#xff0c;VR工厂全景制作成为了越来越多工业企业的选择。传统的工厂管理方式往往存在诸多问题&#xff0c;如信息不对称、安全隐患等。为了解决这些问题&#xff0c;VR工厂全景制作应运而生&#xff0c;它通过结合虚拟现实现实技术和数据采集技术&#xff…

设计模式-迭代器

文章目录 1. 引言1.1 概述1.2 设计模式1.3 迭代器模式的应用场景1.4 迭代器模式的作用 2. 基本概念2.1 迭代器 Iterator2.2 聚合 Aggregate2.3 具体聚合 ConcreteAggregate 3. Java 实现迭代器模式3.1 Java 集合框架3.2 Java 迭代器接口3.3 Java 迭代器模式实现示例 4. 迭代器模…

Android 音频框架 基于android 12

文章目录 前言音频服务audioserver音频数据链路hal 提供什么样的作用 前言 Android 的音频是一个相当复杂的部分。从应用到框架、hal、kernel、最后到硬件&#xff0c;每个部分的知识点都相当的多。而android 这部分代码在版本之间改动很大、其中充斥着各种workaround的处理&a…

介绍GitHub

GitHub 是一个基于互联网的源代码托管平台&#xff0c;可以帮助软件开发者存储和管理源代码&#xff0c;方便团队协作和版本控制。GitHub 的主要功能包括&#xff1a; 代码托管&#xff1a;开发者可以在 GitHub 上创建远程代码仓库&#xff0c;存储和管理他们的源代码。 版本控…

聚观早报|多邻国推出进阶英文课程;电动汽车成本将高于燃油车

【聚观365】9月5日消息 多邻国即将推出进阶英文课程 未来电动汽车成本仍将高于燃油车 戴尔科技2024财年第二季度营收229亿美元 现代汽车电动汽车销量在8月份环比继续下滑 马斯克称将用X数据训练AI 多邻国即将推出进阶英文课程 语言学习平台多邻国宣布&#xff0c;为了满…

TCP机制之连接管理(三次握手和四次挥手详解)

TCP的连接管理机制描述了连接如何创建以及如何断开! 建立连接(三次握手) 三次握手的过程 所谓建立连接就是通信双方各自要记录对方的信息,彼此之间要相互认同;这里以A B双方确立男女朋友关系为例: 从图中可以看出,通信双方各自向对方发起一个"建立连接"的请求,同时…

SOME/IP TTL 在各种Entry 中各是什么意思?有什么限制?

1 服务发现 SOME/IP SD 服务发现主要用于 定位服务实例检测服务实例状态是否在运行发布/订阅行为管理SOME/IP SD 也是 SOME/IP 消息,遵循 SOME/IP 消息格式,有固定的 Message ID、Request ID 以及 Message Type 等。并对 SOME/IP Payload 进行了详细的定义。 SOME/IP SD …

STM32CUBEMX_创建时间片轮询架构的软件框架

STM32CUBEMX_创建时间片轮询架构的软件框架 说明&#xff1a; 1、这种架构避免在更新STM32CUBEMX配置后把用户代码清除掉 2、利用这种时间片的架构可以使得代码架构清晰易于维护 创建步骤&#xff1a; 1、使用STM32CUBEMX创建基础工程 2、新建用户代码目录 3、构建基础的代码框…

Unity生命周期函数

1、Awake 当对象&#xff08;自己这个类对象&#xff0c;就是这个脚本&#xff09;被创建时 才会调用该生命周期函数 类似构造函数的存在 我们可以在一个类对象创建时进行一些初始化操作 2、OnEnable 失活激活&#xff08;这个勾&#xff09; 想要当一个对象&#xff08;游戏…

C语言基础知识理论版(很详细)

文章目录 前述一、数据1.1 数据类型1.2 数据第一种数据&#xff1a;常量第二种数据&#xff1a;变量第三种数据&#xff1a;表达式1、算术运算符及算术表达式2、赋值运算符及赋值表达式3、自增、自减运算符4、逗号运算符及其表达式&#xff08;‘顺序求值’表达式&#xff09;5…

光电耦合器市场在预测期内预计将以8.99%的复合年增长率增长!

光耦合器是一种用于传输光信号的电子器件。它具有以下特性&#xff1a; 1. 传输性能&#xff1a;光耦合器能够实现光电转化和信号传输&#xff0c;具有良好的传输性能和抗干扰能力&#xff0c;可以避免外部环境的干扰。 2. 隔离性能&#xff1a;光耦合器能够实现电路之间的隔…

学习MATLAB

今日&#xff0c;在大学慕课上找了一门关于MATLAB学习的网课&#xff0c;MATLAB对于我们这种自动化的学生应该是很重要的&#xff0c;之前也是在大三的寒假做自控的课程设计时候用到过&#xff0c;画一些奈奎斯特图&#xff0c;根轨迹图以及伯德图&#xff0c;但那之后也就没怎…

【构造】CF Edu 12 D

Problem - D - Codeforces 题意&#xff1a; 思路&#xff1a; 这种题一定要从小数据入手&#xff0c;不然很有可能走歪思路 先考虑n 1的情况&#xff0c;直接输出即可 然后是n 2的情况&#xff0c;如果相加是质数&#xff0c;就输出2个&#xff0c;否则就输出一个 然后…

死锁是什么?死锁的字节码指令了解?

用幽默浅显的言语来说死锁 半生&#xff1a;我已经拿到了机考的第一名&#xff0c;就差笔试第一名了 小一&#xff1a;我已经拿到了笔试的第一名&#xff0c;就差机考第一名了 面试官&#xff1a;我很看好你俩&#xff0c;继续"干", 同时拿到2个的第一名才能拿到offe…

算法训练 第一周

一、合并两个有序数组 本题给出了两个整数数组nums1和nums2&#xff0c;这两个数组均是非递减排列&#xff0c;要求我们将这两个数组合并成一个非递减排列的数组。题目中还要求我们把合并完的数组存储在nums1中&#xff0c;并且为了存储两个数组中全部的数据&#xff0c;nums1中…

签到系统怎么设计

背景 相信签到系统大家都有接触过&#xff0c;更多的是使用。但是有思考过这种系统是怎么设计的吗&#xff1f;比方说我统计一下每个月中每天的签到情况&#xff0c;怎么设计呢&#xff1f;今天一篇文章告诉你。 首先&#xff0c;我们熟悉的思维是&#xff1a;我设计一个数据…