(2020|ICML PMLR,线性 Transformer,核函数,RNN)Transformer 是 RNN

Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

公众号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)

目录

0. 摘要

3. 线性 Transformers

3.1. Transformer

3.2. 线性注意力机制

3.2.1. 特征映射与计算成本

3.3. 因果掩码

3.3.1. 梯度计算

3.3.2. 训练和推理

3.4. transformer 是 RNN

4. 实验


0. 摘要

Transformer 在多项任务中表现出色,但由于其对输入长度的二次复杂度,对于非常长的序列来说,速度极慢。为了解决这一限制,我们将自注意力表示为核特征映射(kernel feature maps)的线性点积,并利用矩阵乘积的结合性将复杂度从 O(N^2) 降低到 O(N),其中 N 是序列长度。我们证明了这种表达方式允许一种迭代实现,大大加速了自回归 Transformer,并揭示了它们与递归神经网络的关系。我们的线性 Transformer 在性能上与普通 Transformer 相似,并且在非常长序列的自回归预测中速度快达 4000 倍。 

3. 线性 Transformers

在本节中,我们提出了线性 Transformer。我们展示了将传统的 softmax 注意力机制改为基于特征映射的点积注意力,可以改善时间和内存复杂度,并且可以实现类似于 RNN 的线性时间序列生成模型。

3.1. Transformer

3.2. 线性注意力机制

公式 2 中的注意力定义是通用的,可以用于定义多种其他注意力实现,例如多项式注意力或 RBF 核注意力(Tsai等人,2019)。注意,为了使公式 3 定义的注意力函数有效,我们需要对 sim(·) 施加的唯一约束是非负性。这包括所有核函数 k(x, y): R^(2 × F) → R_+。

给定具有特征表示 ϕ(x) 的核函数,我们可以将公式 2 重写为:

然后利用矩阵乘法的结合性进一步简化为:

当分子以向量形式书写时,上述公式更容易理解,如下所示:

注意,特征映射 ϕ(·) 是逐行应用于矩阵 Q 和 K 的。

从公式 2 可以看出,softmax 注意力的计算成本随 O(N^2) 缩放,其中 N 表示序列长度。内存需求也是如此,因为必须存储完整的注意力矩阵以计算查询、键和值的梯度。相比之下,我们在公式 5 中提出的线性 transformer 具有 O(N) 的时间和内存复杂度,因为我们可以计算

一次,并在每个查询中重复使用它们。

3.2.1. 特征映射与计算成本

对于 softmax 注意力,就乘法和加法的总成本而言,随着 O(N^2·max(D, M)) 缩放,其中 D 是查询和键的维度,M 是值的维度。相反,对于线性注意力,我们首先计算维度为 C 的特征映射。随后,计算新值需要 O(NCM) 次加法和乘法。

上述分析未考虑核函数和特征函数的选择。需要注意的是,对应于指数核的特征函数是无限维的,这使得精确 softmax 注意力的线性化不可行。另一方面,例如多项式核具有精确的有限维特征映射,并且已证明与指数或 RBF 核(Tsai等人,2019)同样有效。线性化多项式 transformer 的计算成本为 O(N·D^2·M)。当 N > D^2 时,这使得计算复杂度更具优势。实际上,由于我们希望能够处理成千上万元素的序列,这一情况是成立的。

对于我们的实验,处理较小的序列,我们采用了一个结果为正相似函数的特征映射,如下定义:

其中 elu(·) 表示指数线性单元(Clevert等人,2015)的激活函数。我们更喜欢 elu(·) 而不是relu(·),以避免在 x 为负时将梯度设置为 0。这种特征映射导致的注意力函数需要 O(NDM) 次乘法和加法。在我们的实验部分,我们展示了公式 7 的特征映射在性能上与完整 transformer 相当,同时显著减少了计算和内存需求。

3.3. 因果掩码

transformer  架构可以通过掩蔽(masking)注意力计算来高效地训练自回归模型,使得第 i 个位置只能被第 j 个位置影响当且仅当 j ≤ i,即一个位置不能被后续位置影响。形式上,这种因果掩码将公式 3 修改如下:

按照3.2节的推理,我们如下所述对掩码注意力进行线性化:

通过引入 Si 和 Zi 如下所示:

我们可以将公式 9 简化为:

注意,Si 和 Zi 可以从 S_(i-1) 和 Z_(i-1) 在固定时间内计算得出,因此使得具有因果掩码的线性 transformer 的计算复杂度相对于序列长度为线性。

3.3.1. 梯度计算

在任何深度学习框架中,公式 12 的朴素实现需要存储所有中间值 Si,以计算梯度。这会增加max(D, M) 倍的内存消耗,从而阻碍因果线性注意力在更长序列或更深模型中的应用。为了解决这个问题,我们将公式 9 中的分子(numerator)的梯度导出为累积和。这使我们能够在线性时间和固定内存中计算因果线性注意力的前向和后向传播。详细推导见附录材料。

给定分子 ¯V_i 和标量损失函数相对于分子的梯度

推导可得:

累计和项在公式 9 和 13-15 中以线性时间计算,并且相对于序列长度需要常量内存。这导致的算法在给定维度为 C 的特征映射下,其计算复杂度为 O(NCM),内存复杂度为 O(N·max (C, M))。算法 1 是分子部分前向和后向传播的伪代码实现。

3.3.2. 训练和推理

在训练自回归 transformer 模型时,可以使用完整的真实序列。这使得公式 1 中的函数 φ(·) 和注意力计算都可以进行分层并行化。因此,transformer 比 RNN 更高效地进行训练。然而,在推理过程中,时间步 i 的输出是时间步 i + 1 的输入。这使得自回归模型无法并行化。此外,transformer 每个时间步的成本不是常量,而是随着当前序列长度的平方增长,因为必须为所有先前的时间步计算注意力。

我们提出的线性 transformer 模型结合了这两者的优点。在训练时,计算可以并行化并充分利用 GPU 或其他加速器。在推理时,我们模型的每次预测在时间和内存上的成本是常量的。这意味着我们可以简单地将

矩阵存储为内部状态,并在每个时间步像递归神经网络一样更新它。这使得推理速度比其他 transformer 模型快数千倍。

3.4. transformer 是 RNN

在文献中,transformer 模型被认为是一种与递归神经网络(RNN)根本不同的方法。然而,从 3.3 节中的因果掩码公式和前一节的讨论可以看出,任何具有因果掩码的 transformer 层都可以被表示为一种模型,该模型在给定输入后修改内部状态,然后预测输出,即 RNN。注意,与通用变压器(Universal Transformers)(Dehghani等人,2018)不同,我们考虑的是时间上的递归,而不是深度上的递归。

在以下公式中,我们将公式 1 的 Transformer 层形式化为 RNN。所得的 RNN 有两个隐藏状态,即注意力记忆 s 和归一化记忆 z。我们用下标表示递归中的时间步。

在上述公式中,x_i 表示特定 Transformer 层的第 i 个输入,y_i 表示第 i 个输出。需要注意的是,我们的公式对特征函数没有任何约束,因此可以用于表示任何 Transformer 模型,理论上甚至包括使用 softmax 注意力的模型。这一公式是更好理解 Transformer 与流行的 RNN(Hochreiter & Schmidhuber, 1997)及其存储和检索信息过程之间关系的第一步。 

4. 实验

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/334468.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

备战秋招—模拟版图面试题来了

随着暑期的脚步逐渐临近,电子工程和集成电路设计领域的毕业生们,也将迎来了另一个求职的黄金期——秋招。我们总说机会是留给有准备的人。对于有志于投身于模拟版图设计的学子们来说,为了在众多求职者中脱颖而出,充分备战模拟版图…

网络——多区域OSPF配置(OSPF系列第1篇)

简介 路由协议OSPF全称为Open Shortest Path First,也就开放是的最短路径优先协议,使用链路状态路由算法,isis协议也是使用链路状态路由算法。而RIP协议使用距离矢量路由算法。 区域 为了能够降低OSPF计算的复杂程度,OSPF采用分…

C++的AVL树

目录 基本概念 插入的语言分析 LL右旋 RR左旋 额外结论及问题1 LR左右旋 RL右左旋 额外结论及问题2 插入结点 更新bf与判断旋转方式 旋转代码实现 准备工作一 LL右旋的实现 RR左旋的实现 准备工作二 LR左右旋的实现 RL右左旋的实现 完整代码 基本概念 1、…

Android Studio 所有历史版本下载

一、官网链接 https://developer.android.google.cn/studio/archive 操作 二、AndroidDevTools地址 https://www.androiddevtools.cn/ 参考 https://blog.csdn.net/qq_27623455/article/details/103008937

Golang | Leetcode Golang题解之第102题二叉树的层序遍历

题目&#xff1a; 题解&#xff1a; func levelOrder(root *TreeNode) [][]int {ret : [][]int{}if root nil {return ret}q : []*TreeNode{root}for i : 0; len(q) > 0; i {ret append(ret, []int{})p : []*TreeNode{}for j : 0; j < len(q); j {node : q[j]ret[i] …

【加密与解密(第四版)】第十五章笔记

第十五章 专用加密软件 15.1 认识壳 15.2 压缩壳 UPX、ASPack、PECompact 15.3 加密壳 ASProtect(压缩、加密、反跟踪代码、CRC校验、花指令)、Armadillo(穿山甲)、EXECryptor、Themida 15.4 虚拟机保护软件 虚拟机引擎&#xff08;编译器解释器虚拟CPU环境指令系统&#xff…

Python代码:十七、生成列表

1、题目 描述&#xff1a; 一串连续的数据用什么记录最合适&#xff0c;牛牛认为在Python中非列表&#xff08;list&#xff09;莫属了。现输入牛牛朋友们的名字&#xff0c;请使用list函数与split函数将它们封装成列表&#xff0c;再整个输出列表。 输入描述&#xff1a; …

关于智慧校园安全用电监测系统的设计

人生人身安全是大家关注的话题&#xff0c;2019年12月中国消防统计近五年发生在全国学生宿舍的火灾2314起&#xff08;中国消防2019.12.应急管理部消防救援局官方微博&#xff09;&#xff0c;违规电器是引发火灾的主因。如果在各寝室安装智能用电监测器实时监督线路参数&#…

python-绘制五星红旗(非标准)

完整代码如下&#xff1a; #五星红旗&#xff08;非标准版&#xff09; from turtle import* import math from random import* tracer(0) penup() goto(-640,220) pendown() color(gold,gold) begin_fill() for i in range(5): fd(150) right(144) # 大五角星 penup(…

[9] CUDA性能测量与错误处理

CUDA性能测量与错误处理 讨论如何通过CUDA事件来测量它的性能如何通过CUDA代码进行调试 1.测量CUDA程序的性能 1.1 CUDA事件 CPU端的计时器可能无法给出正确的内核执行时间CUDA事件等于是在你的CUDA应用运行的特定时刻被记录的时间戳&#xff0c;通过使用CUDA事件API&#…

JVM内存模型详解

Java虚拟机&#xff08;JVM&#xff09;是Java程序运行的基础环境&#xff0c;它负责将Java代码转换为机器码并执行。在JVM中&#xff0c;内存管理是一个核心部分&#xff0c;它决定了Java程序如何分配、使用和回收内存。了解JVM的内存模型对于编写高效、健壮的Java程序至关重要…

生成式AI模型大PK——GPT-4、Claude 2.1和Claude 3.0 Opus

RAG(检索增强生成)系统的新评估似乎每天都在发布&#xff0c;其中许多都集中在有关框架的检索阶段。然而&#xff0c;生成方面——模型如何合成和表达这些检索到的信息&#xff0c;在实践中可能具有同等甚至更大的意义。许多实际应用中的案例证明&#xff0c;系统不仅仅要求从上…

centos下给es7.12.1设置密码

安装可参考&#xff1a; centos7下安装elasticsearch7.8.1并配置远程连接_在一台服务器centos7上安装和配置elasticsearch。-CSDN博客 1、先停掉es进程 2、设置输入密码后访问配置 cd /home/soft/elasticsearch-7.12.1/config vim elasticsearch.yml 3、启动es服务 cd /home/…

echarts全局设置饼图的颜色

&#x1f337;第一步 在js文件中写入你需要的颜色 这里的颜色也可以写渐变的 &#x1f337;下一步 在main.is中引用全局挂载 &#x1f337;最后一步 在初始化的时候加一个macarons即可 &#x1f337;第一步 在js文件中写入你需要的颜色 这里的颜色也可以写渐变的 (functi…

LeetCode199二叉树的右视图

题目描述 给定一个二叉树的 根节点 root&#xff0c;想象自己站在它的右侧&#xff0c;按照从顶部到底部的顺序&#xff0c;返回从右侧所能看到的节点值。 解析 这一题的关键其实就是找到怎么去得到当前是哪一层级&#xff0c;可以利用队列对二叉树进行层次遍历&#xff0c;但…

人才测评的应用:人才选拔,岗位晋升,面试招聘测评

人才测评自诞生以来&#xff0c;就被广泛应用在各大方面&#xff0c;不仅是我们熟悉的招聘上&#xff0c;还有其他考核和晋升&#xff0c;都会需要用到人才测评。不知道怎么招聘&#xff1f;或者不懂得如何实现人才晋升&#xff1f;都可以参考人才测评&#xff0c;利用它帮我们…

linux 定时执行shell、python脚本

在linux里设置定时执行一般是用crontab&#xff0c;如果没有的话&#xff0c;可以先安装&#xff1a; 安装 查看是否安装 cron -v # 对于基于Debian的系统&#xff08;如Ubuntu&#xff09; sudo apt-get install cron# 对于基于RedHat的系统&#xff08;如CentOS&#xff…

FL Studio v21.2.3.4004中文破解版百度网盘下载

FL Studio v21.2.3.4004中文破解版是一款完整的软件音乐制作环境或数字音频工作站 (DAW)。代表了超过 18 年的创新发展&#xff0c;它在一个软件包中提供了您创作、编曲、录制、编辑、混音和掌握专业品质音乐所需的一切。FL Studio v21.2.3.4004中文破解版现在是世界上最受欢迎…

解决LabVIEW通过OPC Server读取PLC地址时的错误180121602

在使用LabVIEW通过OPC Server读取PLC地址时&#xff0c;若遇到错误代码180121602&#xff0c;建议检查网络连接、OPC Server和PLC配置、用户权限及LabVIEW设置。确保网络畅通&#xff0c;正确配置OPC变量&#xff0c;取消缓冲设置以实时读取数据&#xff0c;并使用诊断工具验证…

蓝桥杯—SysTick中断精准定时实现闪烁灯

在嵌入式系统中&#xff0c;SysTick_Handler 是一个中断服务例程&#xff08;Interrupt Service Routine, ISR&#xff09;&#xff0c;用于处理 SysTick 定时器的中断。SysTick 定时器通常用于提供一个周期性的定时中断&#xff0c;可以用来实现延时或者周期性任务。 SysTick…