【大模型上下文长度扩展】FlashAttention:高效注意力计算的新纪元

FlashAttention:高效注意力计算的新纪元

    • 核心思想
      • 核心操作融合,减少高内存读写成本
      • 分块计算(Tiling),避免存储一次性整个矩阵
      • 块稀疏注意力,处理长序列时的效率问题
      • 利用快速 SRAM,处理内存与计算速度不匹配
      • 算术强度优化,处理计算与内存访问的不平衡
      • 重计算,解决后向传递中存储大型中间矩阵的需求
    • 当前FlashAttention实现的局限性,并提出了未来发展的方向
      • 低级语言编程的复杂性
      • IO-感知优化的普遍性
      • 多GPU并行计算的IO优化

 


论文:https://arxiv.org/pdf/2205.14135.pdf

 

核心思想

FlashAttention 提出的是为了解决 Transformers 在处理长序列时的速度慢和内存消耗大的问题。

这个问题主要是因为,自注意力模块在长序列上的时间和内存复杂度都是二次方的。

FlashAttention的本质是通过创新的算法设计,实现了对Transformer模型中注意力机制的高效计算。

  • FlashAttention通过减少HBM访问次数和避免存储大型中间矩阵,使BERT模型比MLPerf 1.1的速度记录快15%,GPT-2的训练速度提高了最高3倍。

  • 使用FlashAttention的GPT-2模型,在4K的上下文长度下训练比Megatron在1K上下文长度下训练还快,同时困惑度(perplexity)更低,说明模型质量提高。

  • FlashAttention在常见序列长度(最高2K)上比标准注意力实现快3倍,并且其内存占用随序列长度线性增长,证明了其在效率和内存使用上的优势。

  • 块稀疏FlashAttention通过仅计算重要的注意力块来减少计算量和内存使用,使得Transformer模型能够处理高达64K序列长度,且在Path-256任务上达到了63.1%的准确率,显示了其在处理长序列任务上的能力。

它通过以下核心方法和策略,解决了传统注意力计算在长序列处理时遇到的速度慢和内存消耗大的问题:

  1. IO-感知优化:FlashAttention深入考虑了GPU内存层次之间的交互,特别是高带宽内存(HBM)与片上SRAM之间的读写操作,通过优化这些操作来减少内存访问成本,从而提高计算效率。

  2. 分块计算(Tiling):通过将输入序列分成小块并逐块处理,FlashAttention避免了一次性加载整个序列到内存中,减轻了内存压力,并使得注意力计算更加高效。

  3. 重计算策略:为了减少后向传播时对大型中间矩阵的存储需求,FlashAttention采用了在需要时重新计算这些矩阵的策略,从而节省了大量的内存空间。

  4. 核心融合:FlashAttention通过将多个计算步骤融合到一个CUDA核心中执行,减少了内存访问次数,并提高了执行速度。

这些策略共同作用,使FlashAttention能够以更少的内存访问和更低的时间复杂度,准确地计算出注意力,从而在保持模型质量的同时,显著提高了训练速度和效率。

此外,FlashAttention的设计还支持块稀疏注意力,进一步提高了处理长序列能力,使得在资源有限的情况下,Transformer模型能够处理更长的上下文信息,这在自然语言处理和其他需要长序列处理的领域中尤为重要。

FlashAttention本质上是对传统Transformer注意力机制的一个高效、内存友好的改进,它通过深入挖掘和优化计算机内存和计算资源的使用方式,推动了深度学习模型在复杂任务上的应用和发展。

 

核心操作融合,减少高内存读写成本

  • 子解法: IO-感知算法(IO-Awareness)
    • 解释: 传统的注意力算法没有考虑到 GPU 内存层次之间的读写成本,导致了大量的内存访问,进而增加了计算时间和内存消耗。
    • FlashAttention 通过考虑 IO,即输入/输出操作,特别是在 GPU 高带宽存储器(HBM)与 GPU 上的 SRAM 之间的读写操作,来降低这些成本。
    • 例子: 在传统的 Transformer 模型中,整个注意力矩阵需要从 HBM 读入到 SRAM 中进行计算,
    • 然后结果再写回 HBM,这个过程中的读写操作非常耗时和耗内存。
    • FlashAttention 通过减少这种读写操作的次数,来减少内存访问成本。

在标准注意力计算中,每个操作(如 softmax、矩阵乘法等)都需要从 HBM 读取输入,计算后再将结果写回 HBM,导致高内存访问成本。

如果我们可以将多个操作合并为一个操作(核心融合),那么输入只需从 HBM 加载一次,这样就减少了内存访问次数,从而降低了内存访问成本。
 

分块计算(Tiling),避免存储一次性整个矩阵

  • 子解法: 增量式 softmax 计算(Tiling)
    • 解释: 标准的注意力机制需要存储整个注意力矩阵以便于后向传播,这在长序列上是非常内存消耗的。
    • FlashAttention 通过将输入分块(tiling)并多次通过输入块逐步执行 softmax 减少(也称为 tiling),避免了一次性处理整个大矩阵。
    • 例子: 假设有一个很长的序列,传统方法需要一次性计算和存储整个序列的注意力矩阵。
    • FlashAttention 则将序列分成小块,每次只处理一个块,并逐步累积计算结果,从而不需要存储整个大矩阵。

在标准注意力机制中,整个注意力矩阵需要一次性计算并存储,导致对 HBM 的大量访问。

通过将输入矩阵 Q、K、V 分块并逐块计算,我们可以逐步生成注意力输出,减少了一次性对大量数据的访问需求。

一个大型矩阵乘法,通过将矩阵分为小块,每次只处理一部分数据,就可以减少内存的即时需求。

 

块稀疏注意力,处理长序列时的效率问题

  • 子解法: 块稀疏注意力(Block-sparse Attention)
    • 解释: 长序列上的注意力计算复杂度高,导致计算缓慢。
    • FlashAttention 引入了块稀疏技术,通过只计算序列中重要部分的注意力,忽略其他不重要的部分,从而减少计算量。
    • 例子: 在处理一个长文本时,可能只有部分词语之间存在强关联,而其他词语的关联性较弱。块稀疏注意力允许模型只关注那些重要的词语间的关联,忽略其他,从而加速计算并降低内存使用。

 

利用快速 SRAM,处理内存与计算速度不匹配

  • 子解法: 利用快速 SRAM
    • 原因: 现代 GPU 的计算速度相比内存速度增长得更快,使得大多数操作成为内存访问受限。
    • 例子: 通过更多地利用每个流式多处理器上的快速 SRAM(与 HBM 相比,SRAM 速度快得多但容量小得多),我们可以加速那些内存访问受限的操作,例如通过在 SRAM 中计算部分结果来减少对 HBM 的访问。

 

算术强度优化,处理计算与内存访问的不平衡

  • 子解法: 算术强度优化
    • 原因: 操作可以根据计算和内存访问之间的平衡被分类为计算密集型或内存访问密集型。
    • 标准注意力实现中,很多操作(如 softmax)是内存访问密集型的。
    • 例子: 通过优化算术强度,即每字节内存访问的算术操作数量,我们可以尽量将操作转变为计算密集型,从而减轻内存访问的瓶颈。

 

重计算,解决后向传递中存储大型中间矩阵的需求

  • 子解法: 重计算(Recomputation)
    • 原因: 标准实现中,后向传递需要访问前向传递计算时产生的大型中间矩阵(如 S 和 P 矩阵)。通过存储必要的统计量而非整个矩阵,并在需要时重计算这些矩阵,可以避免大量的内存使用。
    • 例子: 类似于梯度检查点技术,我们不存储整个计算过程中的中间状态,而是仅存储关键节点,需要时再重建整个状态。

 

通过子解法的组合,FlashAttention 成功地解决了 Transformers 在处理长序列时速度慢和内存消耗大的问题。

FlashAttention 提出了一种计算精确注意力的算法,其关键在于通过减少对高带宽内存(HBM)的读写操作以及避免在后向传递中存储大型中间矩阵,从而实现了既节省内存又加速计算的目标。

在探索传统注意力机制在现代硬件(尤其是 GPU)上的执行效率时,遇到了一系列的具体问题,这些问题导致了处理速度慢和高内存消耗。

每种解决方案都直接针对了标准注意力实现中的效率瓶颈,通过改善内存访问模式、减少不必要的内存写入和读取、以及优化计算流程来提高整体性能。

 
在这里插入图片描述

左侧:展示了在GPU中的内存层次结构和FlashAttention如何在这种结构中工作。

它说明了:

  • GPU的不同内存层次及其带宽和大小,包括片上SRAM(20MB, 19TB/s),高带宽内存HBM(40GB, 1.5TB/s),以及主内存DRAM(12.8GB/s, 大于1TB)。
  • FlashAttention使用分块计算(Tiling)来避免实现大型 N×N 注意力矩阵。
  • 在外部循环(红色箭头)中,FlashAttention遍历K和V矩阵的块,并将它们加载到快速的片上SRAM中。
  • 在每个块中,FlashAttention遍历Q矩阵的块(蓝色箭头),加载到SRAM中,并将注意力计算的输出写回到HBM。

右侧:显示了使用PyTorch实现的注意力计算与FlashAttention实现在GPT-2模型上的速度对比。

它说明了:

  • FlashAttention与PyTorch实现相比在各个组件(矩阵乘法、Dropout、Softmax、Mask和Fused Kernel)上的时间消耗。
  • FlashAttention没有读写大型 N×N 注意力矩阵到HBM,因此在注意力计算上得到了约7.6倍的加速。

 


当前FlashAttention实现的局限性,并提出了未来发展的方向

 

低级语言编程的复杂性

  • 子解法1: 高级语言到CUDA的自动编译
    • 原因: 目前,IO-感知的注意力实现需要在CUDA中手动编写新的核函数,这不仅需要在比PyTorch这样的高级语言更低级的语言中编程,而且还需要大量的工程努力。
    • 例子: 类似于图像处理领域的Halide工具,可以让研究人员用高级语言编写算法,然后自动编译成优化的CUDA代码,减少直接使用CUDA编程的复杂性。

 

IO-感知优化的普遍性

  • 子解法2: 扩展IO-感知实现到其他模块
    • 原因: 虽然注意力计算是Transformer模型中最耗内存的部分,但模型的每一层都需要与GPU的高带宽内存(HBM)交互。
    • 例子: 在深度学习模型的其他组件,如卷积层或循环层,也采用IO-感知的实现方法,可以进一步提高整个模型的效率。

 

多GPU并行计算的IO优化

  • 子解法3: 多GPU间的IO-感知方法
    • 原因: FlashAttention的当前实现在单GPU上是最优的,但注意力计算可以跨多GPU并行化,这引入了考虑GPU间数据传输的额外IO分析层。
    • 例子: 通过设计能够优化GPU间数据传输的IO-感知算法,可以在不牺牲性能的前提下,实现更大规模的模型训练和更高效的并行计算。

从提高开发效率、扩展IO-感知优化的应用范围,到优化多GPU并行计算的效率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/252452.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

二分查找------蓝桥杯

题目描述&#xff1a; 请实现无重复数字的升序数组的二分查找 给定一个元素升序的、无重复数字的整型数组 nums 和一个目标值 target&#xff0c;写一个函数搜索 nums 中的target&#xff0c;如果目标值存在返回下标 (下标从0 开始)&#xff0c;否则返回-1 数据范围: 0 < l…

【Java 数据结构】反射

反射 1 定义2 用途(了解)3 反射基本信息4 反射相关的类&#xff08;重要&#xff09;4.1 Class类(反射机制的起源 )4.1.1 Class类中的相关方法(方法的使用方法在后边的示例当中) 4.2 反射示例4.2.1 获得Class对象的三种方式4.2.2 反射的使用 5、反射优点和缺点 1 定义 Java的反…

网络编程套接字

目录 本节重点一、预备知识1.1 理解源IP地址和目的IP地址1.2 认识端口号1.3 理解 "端口号" 和 "进程ID"1.4 理解源端口号和目的端口号1.5 认识TCP协议1.6 认识UDP协议1.7 网络字节序 二、socket编程接口2.1 socket常见的API2.2 sockaddr结构2.3 in_addr结构…

SpringBoot集成axis发布WebService服务

文章目录 1、使用maven-web项目生成server-config.wsdd文件1.1、新建maven-web项目1.1.1、新建项目1.1.2、添加依赖 1.2、编写服务接口和实现类1.2.1、OrderService接口1.2.2、OrderServiceImpl实现类 1.3、配置deploy.wsdd文件deploy.wsdd文件 1.4、配置tomcat1.4.1、配置tomc…

MySQL数据库练习【一】

MySQL数据库练习【一】 一、建库建表-数据准备二、习题2.1. 查询部门编号为30的部门的员工详细信息2.2.查询从事clerk工作的员工的编号、姓名以及其部门号2.3.查询奖金多于基本工资的员工的信息、查询奖金小于基本工资的员工的信息2.4.查询奖金多于基本工资60%的员工的信息2.5.…

宠物空气净化器适合养猫家庭吗?除猫毛好的猫用空气净化器推荐

宠物掉毛是一个普遍存在的问题&#xff0c;尤其在脱毛季节&#xff0c;毛发似乎无处不在。这给家中的小孩和老人带来了很多麻烦&#xff0c;他们容易流鼻涕、过敏等不适。此外&#xff0c;宠物有时还会不规矩地拉扯和撒尿&#xff0c;这股气味实在是难以忍受。家人们对宠物的存…

【DC渗透系列】DC-4靶场

主机发现 arp-scan -l┌──(root㉿kali)-[~] └─# arp-scan -l Interface: eth0, type: EN10MB, MAC: 00:0c:29:6b:ed:27, IPv4: 192.168.100.251 Starting arp-scan 1.10.0 with 256 hosts (https://github.com/royhills/arp-scan) 192.168.100.1 00:50:56:c0:00:08 …

掌握Go的加密技术:crypto/rsa库的高效使用指南

掌握Go的加密技术&#xff1a;crypto/rsa库的高效使用指南 引言crypto/rsa 库概览RSA 加密算法基本原理crypto/rsa 库的功能和应用 安装和基本设置在 Go 项目中引入 crypto/rsa 库基本环境设置和配置 密钥生成与管理生成 RSA 密钥对密钥存储和管理 加密和解密操作使用 RSA 加密…

Kafka零拷贝技术与传统数据复制次数比较

读Kafka技术书遇到困惑: "对比传统的数据复制和“零拷贝技术”这两种方案。假设有10个消费者&#xff0c;传统复制方式的数据复制次数是41040次&#xff0c;而“零拷贝技术”只需110 11次&#xff08;一次表示从磁盘复制到页面缓存&#xff0c;另外10次表示10个消费者各自…

加固平板电脑丨三防智能平板丨工业加固平板丨智能城市管理

随着智能城市的不断发展&#xff0c;人们对于城市管理的要求也在不断提高&#xff0c;这就需要高效、智能的城市管理平台来实现。而三防平板就是一款可以满足这一需求的智能设备。 三防平板是一种集防水、防尘、防摔于一体的智能平板电脑&#xff0c;它可以在复杂的环境下稳定运…

【EI会议征稿通知】第三届智能控制与应用技术国际学术会议(AICAT 2024)

第三届智能控制与应用技术国际学术会议&#xff08;AICAT 2024&#xff09; 2024 3rd International Symposium on Artificial Intelligence Control and Application Technology 2024年第三届智能控制与应用技术国际学术会议&#xff08;AICAT 2024&#xff09;定于2024年5月…

Leetcode24:两两交换链表中的节点

一、题目 给你一个链表&#xff0c;两两交换其中相邻的节点&#xff0c;并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题&#xff08;即&#xff0c;只能进行节点交换&#xff09;。 示例&#xff1a; 输入&#xff1a;head [1,2,3,4] 输出&#xff…

[Linux 进程控制(二)] 写时拷贝 - 进程终止

文章目录 1、写时拷贝2、进程终止2.1 进程退出场景2.1.1 退出码2.1.2 错误码错误码 vs 退出码2.1.3 代码异常终止引入 2.2 进程常见退出方法2.2.1 exit函数2.2.2 _exit函数 本片我们主要来讲进程控制&#xff0c;讲之前我们先把写时拷贝理清&#xff0c;然后再开始讲进程控制。…

JAVA面试汇总总结更新中ing

本人面试积累面试题 多线程微服务JVMKAFKAMYSQLRedisSpringBoot/Spring 1.面向对象的三个特征 封装&#xff0c;继承&#xff0c;多态&#xff0c;有时候也会加上抽象。 2.多态的好处 允许不同类对象对同一消息做出响应&#xff0c;即同一消息可以根据发送对象的不同而采用多种…

软考21-上午题-数组、矩阵

数组&#xff1a;一组地址连续的空间。 数组是定长线性表在维数上的扩展&#xff0c;即&#xff0c;线性表中的元素又是一个线性表。 一、数组 数组的特点&#xff1a; 数组数目固定&#xff0c;一旦定义了数组结构&#xff0c;不再有元素个数的增减变化。因此&#xff0c;数…

C# Onnx GroundingDINO 开放世界目标检测

目录 介绍 效果 模型信息 项目 代码 下载 介绍 地址&#xff1a;https://github.com/IDEA-Research/GroundingDINO Official implementation of the paper "Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection" 效果 …

MATLAB矩阵的操作(第二部分)

师从清风 矩阵的创建方法 在MATLAB中&#xff0c;矩阵的创建方法主要有三种&#xff0c;分别是&#xff1a;直接输入法、函数创建法和导入本地文件中的数据。 直接输入法 输入矩阵时要以中括号“[ ]”作为标识符号&#xff0c;矩阵的所有元素必须都在中括号内。 矩阵的同行元…

openssl3.2 - use openssl cmd create ca and p12

文章目录 openssl3.2 - use openssl cmd create ca and p12概述笔记实验的openssl环境建立CA生成私钥和证书请求生成CA证书用CA签发应用证书用CA对应用证书进行签名将已经签名好的PEM证书封装为P12证书验证P12证书是否可用END openssl3.2 - use openssl cmd create ca and p12 …

Redis(三)(实战篇)

查漏补缺 1.spring 事务失效 有时候我们需要在某个 Service 类的某个方法中&#xff0c;调用另外一个事务方法&#xff0c;比如&#xff1a; Service public class UserService {Autowiredprivate UserMapper userMapper;public void add(UserModel userModel) {userMapper.…

睿尔曼超轻量仿人机械臂—外置按钮盒使用说明

睿尔曼RM系列机械臂的控制方式有很多种&#xff0c;包括&#xff1a;示教器、JSON、API等。在此为大家介绍外置按钮盒的使用方法。 按钮盒接线安装 按钮盒外观如下图所示&#xff0c;有&#xff1a;急停、暂停、开始、继续。四个功能按钮。用户可通过这四个按钮来实现对机械臂运…