FlashAttention v1 论文解读

论文标题:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

论文地址:https://arxiv.org/pdf/2205.14135

FlashAttention 是一种重新排序注意力计算的算法,它无需任何近似即可加速注意力计算并减少内存占用。所以作为目前LLM的模型加速它是一个非常好的解决方案,本文介绍经典的V1版本。

目前FlashAttention已经推出了V1~V3版本,遗憾的是,FlashAttention V3目前只支持Nvidia Hopper架构的GPU。目前transformers库已经集成了FlashAttention。

【注】穷人玩不起系列。

FlashAttention是用于在训练或推理时加速注意力计算的方法,参考其官方仓库可以看到对于训练精度和显卡还是有较大限制的:

https://github.com/Dao-AILab/flash-attention

带有 CUDA 的 FlashAttention-2 目前支持:

GPU架构 Ampere, Ada, or Hopper GPUs(例如 A100、RTX 3090、RTX 4090、H100)。对Turing GPU(T4、RTX 2080)的支持即将推出,目前请为Turing GPU 使用 FlashAttention 1.x。

数据类型 fp16 和 bf16(bf16 需要Ampere, Ada, or Hopper GPUs)。

标准注意力机制

在介绍FlashAttention前,一定要深入了解标准注意力机制计算原理。

在 Transformer 架构当中,Attention 是整个模型中最重要的运算,而这个 Attention 的运算示意图如下:

图1 标准注意力计算图示

首先我们把 Q Q Q K K K做矩阵相乘,接下来就是除以隐藏层维度的开根号 d \sqrt{d} d ,然后我们会把运算出来的结果 S(Score)丢进 Softmax 函数得到 P,最后 P 再和 V V V做矩阵相乘就会得到 Attention 的输出 O O O

但实际上我们会发现这一连串的运算非常的耗时间,且会使用到非常大量的内存。

在我们的 GPU 架构中,可以把内存简单地分成 HBM(高带宽内存)和 SRAM(静态随机存取存储器)两个部分。

HBM 的内存空间虽然很大,但是它的带宽比较低。

SRAM 的内存空间虽然很小,但是它的带宽非常高。

所以我们常常看到 GPU 的参数,像是 Nvidia RTX 4090 24GB,就是这张 GPU 有大约 24GB 大小的 HBM。而 SRAM 这块又贵又小的内存,就是拿来做运算的。

图2 GPU存储架构与FlashAttention计算示例

因此我们可以看到今天你在GPU 上运行 标准Attention 的流程如下(N:序列长度、d 是 隐藏层维度):

图3 标准注意力机制流程

首先我们会把 Q Q Q K K K从 HBM 拉到 SRAM 运算,接下来把算出来的结果 S S S写回 HBM,然后 GPU 又把 S S S拉到 SRAM 计算 S o f t m a x Softmax Softmax,算出来的 P P P又写回 HBM,最后 P P P V V V从 HBM 写到 SRAM 做矩阵运算,最后输出 O O O写回 HBM。

而实际情况当然没那么简单,我们知道 SRAM 这块内存又贵又小,所以当然不可能直接把整个 Q Q Q或是 K K K加载进 SRAM,而是一小块一小块地加载。所以这样大量的读写导致 Attention 运算速度很慢,而且会有内存碎片化问题。

【注】有了上面的背景之后,我们来看看FlashAttention V1是如何优化的,下面为大家带来FlashAttention V1论文精读。

Abstract

针对Transformer在处理长序列时速度慢、内存消耗大的问题,论文提出了FlashAttention,一种IO感知的精确注意力算法。该算法通过使用平铺(tiling)技术减少GPU内存(HBM)与SRAM之间的内存读写次数,从而降低计算复杂性。

分析显示,FlashAttention减少了HBM访问次数,并优化了SRAM使用。此外,本研究将FlashAttention扩展至块稀疏注意力,实现了比现有近似注意力方法更快的近似注意力算法,为长序列处理提供了高效解决方案。

【注】标准自注意力机制的时间复杂度是 O ( n 2 ∗ d ) O(n^2*d) O(n2d),其中 n n n是序列长度, d d d是隐藏层维度。多头注意力只是把 d d d进行了多头拆分,单头的时间复杂度是 O ( n 2 ∗ d h ) O(n^2*d_h) O(n2dh),其中 d h d_h dh是单头的隐藏层维度,虽然多头之间可以并行计算,但是仍然没有解决平方量的复杂度。

Introduction

目前许多优化 attention 的方法旨在降低 attention 的计算和内存需求。这些方法专注于减少 FLOP,并且倾向于忽略内存访问 (IO) 的开销。

但是本文认为attention的一个优化方向是使算法具有 IO 感知能力

【注】也就是说,让求注意力的操作尽可能放在SRAM里,而不是频繁的让SRAM与HBM通信。

现代的GPU,计算速度超过了内存IO速度,当读取和写入数据可能占据运行时间的很大一部分时,IO 感知算法对于加速与降内存就变得很重要了。并且深度学习的常见 Python 库(如 PyTorch 和 Tensorflow)目前还不允许对内存访问进行精细控制。

因此,FlashAttention应运而生。

论文提到,为了实现计算注意力时多使用SRAM而少与HBM交换数据,需要克服两点:

  1. 在输入不完整的情况下,计算 S o f t m a x Softmax Softmax
  2. 不存储用于反向传播的中间结果;

FlashAttention

第一招:内核融合(Kernel Fusion)

相信聪明的朋友立刻就能明白,何必这样反复加载和卸载,一次性在SRAM中完成所有计算不就好了?没错,这就是FlashAttention的精髓之一。

FlashAttention就是直接将 Q K V QKV QKV一次性加载到SRAM中完成所有计算,然后再将 O O O写回HBM。

这样大大减少了读写次数,这种一次性完成所有计算的流程被称为内核融合(Kernel Fusion)。

图4 内核融合示意图

第二招:反向重计算(Backward Recomputation)

但是等一下,我们是不是忘了什么?我们直接计算出了 O O O,那么 P P P S S S难道就直接丢弃不存回HBM吗?在进行反向传播时,我们需要从 O O O推回 P P P,再从 P P P推回 S S S,它们都被我们丢弃了,怎么进行反向传播?没错,这就是FlashAttention的第二招,反向重计算(Backward Recomputation)。

因为 P P P S S S这两者实在太占用空间了,所以

在前向传播时, P P P S S S都不会被存储起来。当进行反向传播时,我们就会重新计算一次前向传播,重新计算出 P P P S S S,以便执行反向传播。

所以说:我们执行了2次前向传播和1次反向传播。

这里大家可能又会问:啊这样计算量不是更多了吗,怎么可能会更快?事实上,虽然我们重新计算了一次前向传播,但它不仅帮我们省下了存储P和S的内存空间,还省下了 P P P S S S在HBM和SRAM之间搬运的时间,让我们可以开启更大的batch size,所以总的来说,GPU每秒能处理的数据量依然是大幅增加的。

第三招:Softmax分块(Softmax Tiling)

最后是FlashAttention的最后一招分块(Tiling)。首先我们需要知道注意力机制中的最难搞的就是 S o f t m a x Softmax Softmax函数:

s o f t m a x ( { x 1 , . . . , x N } ) = { e x i ∑ j = 1 N e x j } i = 1 N (1) softmax(\{x_1, ..., x_N\}) = \left\{\frac{e^{x_i}}{\sum_{j=1}^N e^{x_j}}\right\}_{i=1}^N \tag1 softmax({x1,...,xN})={j=1Nexjexi}i=1N(1)

主要原因是在计算分母时,我们需要将所有位的exp值加总。但由于SRAM的大小限制,我们不可能一次性计算出所有数值的 S o f t m a x Softmax Softmax,一定是需要一块一块地丢进SRAM进行计算,所以需要将所有中间计算的数值存储在HBM中。

在FP16精度下,最大可以表示65536,而

e 12 = 162754 e^{12} = 162754 e12=162754

为了防止在计算 S o f t m a x Softmax Softmax产生数值溢出,引入了 S a f e − s o f t m a x Safe-softmax Safesoftmax概念,其公式如下:

S a f e − s o f t m a x ( { x 1 , . . . , x N } ) = { e x i − m ∑ j = 1 N e x j − m } i = 1 N (2) Safe-softmax(\{x_1, ..., x_N\}) = \left\{\frac{e^{x_i-m}}{\sum_{j=1}^N e^{x_j-m}}\right\}_{i=1}^N \tag2 Safesoftmax({x1,...,xN})={j=1Nexjmexim}i=1N(2)

在公式(2)中,有如下定义:

x = [ x 1 , . . . , x N ] (3) x=[x_1,...,x_N] \tag3 x=[x1,...,xN](3)

m ( x ) : = m a x ( x ) (4) m(x):=max(x) \tag4 m(x):=max(x)(4)

p ( x ) : = [ e x 1 − m ( x ) , . . . , e x N − m ( x ) ] (5) p(x):=[e^{x_1-m(x)},...,e^{x_N-m(x)}] \tag5 p(x):=[ex1m(x),...,exNm(x)](5)

l ( x ) : = ∑ i p ( x ) i (6) l(x):=\sum_ip(x)_i \tag6 l(x):=ip(x)i(6)

s o f t m a x ( x ) : = p ( x ) l ( x ) (7) softmax(x):=\frac{p(x)}{l(x)} \tag7 softmax(x):=l(x)p(x)(7)

其原理就是,从 x x x中找出最大值 m m m,在计算 S o f t m a x Softmax Softmax时,分子分母同除以 e m e^m em,这样既可以防止数据溢出,也能保证 S o f t m a x Softmax Softmax值保持不变。

【注】类似于归一化。

x = [ x 1 , … , x N , … , x 2 N ] x 1 = [ x 1 , … , x N ] x 2 = [ x N + 1 , … , x 2 N ] m ( x 1 ) p ( x 1 ) l ( x 1 ) m ( x 2 ) p ( x 2 ) l ( x 2 ) m ( x ) : = max ⁡ ( m ( x 1 ) , m ( x 2 ) ) p ( x ) : = [ e m ( x 1 ) − m ( x ) p ( x 1 ) , e m ( x 2 ) − m ( x ) p ( x 2 ) ] l ( x ) : = e m ( x 1 ) − m ( x ) l ( x 1 ) + e m ( x 2 ) − m ( x ) l ( x 2 ) s o f t m a x ( x ) : = p ( x ) l ( x ) (8) \begin{align*} & x = [x_1, \ldots, x_N, \ldots, x_{2N}] \\ & x^1 = [x_1, \ldots, x_N] \\ & x^2 = [x_{N+1}, \ldots, x_{2N}] \\ & m(x^1) \ p(x^1) \ l(x^1) \ m(x^2) \ p(x^2) \ l(x^2) \\ & m(x) := \max(m(x^1), m(x^2)) \\ & p(x) := [e^{m(x^1)-m(x)} p(x^1), e^{m(x^2)-m(x)} p(x^2)] \\ & l(x) := e^{m(x^1)-m(x)} l(x^1) + e^{m(x^2)-m(x)} l(x^2) \\ & softmax(x) := \frac{p(x)}{l(x)} \end{align*}\tag8 x=[x1,,xN,,x2N]x1=[x1,,xN]x2=[xN+1,,x2N]m(x1) p(x1) l(x1) m(x2) p(x2) l(x2)m(x):=max(m(x1),m(x2))p(x):=[em(x1)m(x)p(x1),em(x2)m(x)p(x2)]l(x):=em(x1)m(x)l(x1)+em(x2)m(x)l(x2)softmax(x):=l(x)p(x)(8)

而本文softmax分块的做法如公式(8)所示。

我们首先将一块数据 x x x中的第一块 x 1 x_1 x1丢进去计算出softmax,这里的 m 1 m_1 m1代表的是这一块加载到SRAM的最大值,所以我们称之为局部最大值。接下来,我们可以根据 m 1 m_1 m1计算出局部softmax。

接下来第二块数据进来时,我们将第一块的最大值 m 1 m_1 m1和第二块的最大值 m 2 m_2 m2取最大值,就可以得到这两块数据的最大值 m ( x ) m(x) m(x)。这个时候定义 p ( x ) : = [ e m ( x 1 ) − m ( x ) p ( x 1 ) , e m ( x 2 ) − m ( x ) p ( x 2 ) ] p(x) := [e^{m(x^1)-m(x)} p(x^1), e^{m(x^2)-m(x)} p(x^2)] p(x):=[em(x1)m(x)p(x1),em(x2)m(x)p(x2)],再与公式(5)结合,只会出现两种情况:

  1. m ( x 1 ) m(x^1) m(x1)最大,最后可化简为 p ( x ) : = [ e x 1 − m ( x 1 ) , . . . , e x N − m ( x 1 ) ] p(x) := [e^{x_1-m(x^1)},...,e^{x_N-m(x^1)}] p(x):=[ex1m(x1),...,exNm(x1)]
  2. m ( x 2 ) m(x^2) m(x2)最大,最后可化简为 p ( x ) : = [ e x 1 − m ( x 2 ) , . . . , e x N − m ( x 2 ) ] p(x) := [e^{x_1-m(x^2)},...,e^{x_N-m(x^2)}] p(x):=[ex1m(x2),...,exNm(x2)]

l ( x ) l(x) l(x)的计算化简也同理,所以我们只需要将第一块的局部softmax乘上这次更新的数值。如此一来,我们就得到了这两块的局部softmax。

没错!接下来依此类推,我们就可以将整个softmax计算完。而通过这种方式:

我们就不需要将每块计算出来的数值存储在HBM中,我们只需要存储当前的最大值 m ( x ) m(x) m(x)和分母加总值 l ( x ) l(x) l(x)就可以了。

而这两者都非常小,所以可以进一步帮我们节省更多内存空间。

另外,这里还有一个小细节,就是由于softmax计算出来后需要与value state进行矩阵相乘,但同样由于SRAM有限,我们一次只能加载一块进行内核融合运算,所以第一块QKV进去后,它计算出来的O是不准确的。但由于矩阵相乘就是数字相乘,所以同样道理,我们只要在计算到下一块时,使用l和m更新O就可以了。

我们可以看到实际的流程就是这样,蓝色的区域就是HBM,橙色虚线的区域就是SRAM。每次运算时,由于SRAM大小有限,所以我们只加载一部分的Key和Value。红色的字就是我们的第一个block的计算,蓝色的字就是我们的第二个block的计算。

图5 block计算演示图

这边我们可以更深入探讨算法和实现部分。静态随机存取存储器(SRAM)容量较小,当序列长度很长时,根本不可能一次性将如此庞大的查询(query)、键(key)、值(value)状态全部塞进SRAM。

一开始我们会把查询状态(Query State)切成 T r T_r Tr块,键/值状态(Key/Value State)切成 T c T_c Tc块,查询状态块的大小为 ( B r , d ) (B_r, d) (Br,d),键/值状态块的大小为 ( B c , d ) (B_c, d) (Bc,d)。切好的这些块再放入SRAM进行Flash Attention运算。 你可能会好奇 B r B_r Br B c B_c Bc是什么神奇的数字,其实非常简单, M M M是我们SRAM的大小,并且查询(Q)、键(K)、值(V)、输出(O)这四个矩阵大小完全相同,所以当然是 M / 4 d M/4d M/4d啦,这样Q、K、V、O四个矩阵的块加起来不就刚好是 M M M嘛,也就是说刚好填满SRAM。

比如说,假设M = 1000, d = 5。那么块大小为(1000/4*5)= 50。所以一次加载50个q, k, v, o个向量的块,这样可以减少HBM/SRAM之间的读/写次数。

性能

图6 FlashAttention V1性能

我们可以看到 FlashAttention 大大地加速了运算,达到 3 倍以上。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/11569.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

stm32硬件实现与w25qxx通信

使用的型号为stm32f103c8t6与w25q64。 STM32CubeMX配置与引脚衔接 根据stm32f103c8t6引脚手册,采用B12-B15四个引脚与W25Q64连接,实现SPI通信。 W25Q64SCK(CLK)PB13MOSI(DI)PB15MISO(DO)PB14CS&#xff08…

软件工程概论试题五

一、多选 1.好的软件的基本属性包括()。 A. 效率 B. 可依赖性和信息安全性 C. 可维护性 D.可接受性 正答:ABCD 2.软件工程的三要素是什么()? A. 结构化 B. 工具 C.面向对象 D.数据流! E.方法 F.过程 正答:BEF 3.下面中英文术语对照哪些是正确的、且是属…

FBX SDK的使用:基础知识

Windows环境配置 FBX SDK安装后,目录下有三个文件夹: include 头文件lib 编译的二进制库,根据你项目的配置去包含相应的库samples 官方使用案列 动态链接 libfbxsdk.dll, libfbxsdk.lib是动态库,需要在配置属性->C/C->预…

知识库管理在提升企业决策效率与知识共享中的应用探讨

内容概要 知识库管理是指企业对内部知识、信息进行系统化整理和管理的过程,其重要性在于为企业决策提供了坚实的数据支持与参考依据。知识库管理不仅能够提高信息的获取速度,还能有效减少重复劳动,提升工作效率。在如今快速变化的商业环境中…

基于vue船运物流管理系统设计与实现(源码+数据库+文档)

船运物流管理系统目录 目录 基于springboot船运物流管理系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、管理员登录 2、货运单管理 3、公告管理 4、公告类型管理 5、新闻管理 6、新闻类型管理 四、数据库设计 1、实体ER图 五、核心代码 六、论文参考…

【自然语言处理(NLP)】深度学习架构:Transformer 原理及代码实现

文章目录 介绍Transformer核心组件架构图编码器(Encoder)解码器(Decoder) 优点应用代码实现导包基于位置的前馈网络残差连接后进行层规范化编码器 Block编码器解码器 Block解码器训练预测 个人主页:道友老李 欢迎加入社…

Spring Boot 实例解析:配置文件

SpringBoot 的热部署&#xff1a; Spring 为开发者提供了一个名为 spring-boot-devtools 的模块来使用 SpringBoot 应用支持热部署&#xff0c;提高开发者的效率&#xff0c;无需手动重启 SpringBoot 应用引入依赖&#xff1a; <dependency> <groupId>org.springfr…

Linux网络 HTTPS 协议原理

概念 HTTPS 也是一个应用层协议&#xff0c;不过 是在 HTTP 协议的基础上引入了一个加密层。因为 HTTP的内容是明文传输的&#xff0c;明文数据会经过路由器、wifi 热点、通信服务运营商、代理服务器等多个物理节点&#xff0c;如果信息在传输过程中被劫持&#xff0c;传输的…

java练习(5)

ps:题目来自力扣 给你两个 非空 的链表&#xff0c;表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的&#xff0c;并且每个节点只能存储 一位 数字。 请你将两个数相加&#xff0c;并以相同形式返回一个表示和的链表。 你可以假设除了数字 0 之外&#xff0c;这…

深入 Rollup:从入门到精通(三)Rollup CLI命令行实战

准备阶段&#xff1a;初始化项目 初始化项目&#xff0c;这里使用的是pnpm&#xff0c;也可以使用yarn或者npm # npm npm init -y # yarn yarn init -y # pnpm pnpm init安装rollup # npm npm install rollup -D # yarn yarn add rollup -D # pnpm pnpm install rollup -D在…

MySQL数据库环境搭建

下载MySQL 官网&#xff1a;https://downloads.mysql.com/archives/installer/ 下载社区版就行了。 安装流程 看b站大佬的视频吧&#xff1a;https://www.bilibili.com/video/BV12q4y1477i/?spm_id_from333.337.search-card.all.click&vd_source37dfd298d2133f3e1f3e3c…

松灵机器人 scout ros2 驱动 安装

必须使用 ubuntu22 必须使用 链接的humble版本 #打开can 口 sudo modprobe gs_usbsudo ip link set can0 up type can bitrate 500000sudo ip link set can0 up type can bitrate 500000sudo apt install can-utilscandump can0mkdir -p ~/ros2_ws/srccd ~/ros2_ws/src git cl…

【最长上升子序列Ⅱ——树状数组,二分+DP,纯DP】

题目 代码&#xff08;只给出树状数组的&#xff09; #include <bits/stdc.h> using namespace std; const int N 1e510; int n, m; int a[N], b[N], f[N], tr[N]; //f[i]表示以a[i]为尾的LIS的最大长度 void init() {sort(b1, bn1);m unique(b1, bn1) - b - 1;for(in…

Linux安装zookeeper

1, 下载 Apache ZooKeeperhttps://zookeeper.apache.org/releases.htmlhttps://zookeeper.apache.org/releases.htmlhttps://zookeeper.apache.org/releases.htmlhttps://zookeeper.apache.org/releases.htmlhttps://zookeeper.apache.org/releases.htmlhttps://zookeeper.apa…

day6手机摄影社区,可以去苹果摄影社区学习拍摄技巧

逛自己手机的社区&#xff1a;即&#xff08;手机牌子&#xff09;摄影社区 拍照时防止抖动可以控制自己的呼吸&#xff0c;不要大喘气 拍一张照片后&#xff0c;如何简单的用手机修图&#xff1f; HDR模式就是让高光部分和阴影部分更协调&#xff08;拍风紧时可以打开&…

linux本地部署deepseek-R1模型

国产开源大模型追平甚至超越了CloseAI的o1模型&#xff0c;大国崛起时刻&#xff01;&#xff01;&#xff01; DeepSeek R1 本地部署指南   在人工智能技术飞速发展的今天&#xff0c;本地部署AI模型成为越来越多开发者和企业关注的焦点。本文将详细介绍如何在本地部署DeepS…

小程序-基础加强-自定义组件

前言 这次讲自定义组件 1. 准备今天要用到的项目 2. 初步创建并使用自定义组件 这样就成功在home中引入了test组件 在json中引用了这个组件才能用这个组件 现在我们来实现全局引用组件 在app.json这样使用就可以了 3. 自定义组件的样式 发现页面里面的文本和组件里面的文…

c语言(关键字)

前言&#xff1a; 感谢b站鹏哥c语言 内容&#xff1a; 栈区&#xff08;存放局部变量&#xff09; 堆区 静态区&#xff08;存放静态变量&#xff09; rigister关键字 寄存器&#xff0c;cpu优先从寄存器里边读取数据 #include <stdio.h>//typedef&#xff0c;类型…

【最长不下降子序列——树状数组、线段树、LIS】

题目 代码 #include <bits/stdc.h> using namespace std; const int N 1e510; int a[N], b[N], tr[N];//a保存权值&#xff0c;b保存索引,tr保存f&#xff0c;g前缀属性最大值 int f[N], g[N]; int n, m; bool cmp(int x, int y) {if(a[x] ! a[y]) return a[x] < a[…

springboot 启动原理

目标&#xff1a; SpringBootApplication注解认识了解SpringBoot的启动流程 了解SpringFactoriesLoader对META-INF/spring.factories的反射加载认识AutoConfigurationImportSelector这个ImportSelector starter的认识和使用 目录 SpringBoot 启动原理SpringBootApplication 注…