大语言模型推理加速技术:计算加速篇

原文:大语言模型推理加速技术:计算加速篇 - 知乎

目录

简介

Transformer和Attention

瓶颈

优化目标

计算加速

计算侧优化

KVCache

Kernel优化和算子融合

分布式推理

内存IO优化

Flash Attention

Flash Decoding

Continuous Batching

Paged Attention

SplitFuse

总结


本文是《大语言模型推理加速技术》系列的第一篇

《大语言模型推理加速技术:计算加速篇》

《大语言模型推理加速技术:模型压缩篇》

《大语言模型推理加速技术:推理框架篇》

自从去年ChatGPT横空出世之后,业界对于大语言模型的热情也愈发高涨,随着模型规模越来越大,它们的计算需求也水涨船高,大模型部署和所需的资源量也让很多团队望而却步:毕竟可以拿社区开源的预训练模型跳过训练的过程,但是部署大模型推理是无法避开的流程。本系列旨在简单讨论几个业界生产环境可用的大模型推理技术,并分析对比几个主流的推理框架。

由于各大公司和学术团队都在“卷”大模型,大模型新技术层出不穷,本系列只能保证当前的信息有效性(2023年11月初)。另外由于本文是从工程角度出发,只会介绍工业界可落地的技术,一些前沿的学术成果可能并不包含在内,敬请谅解。

简介

Transformer和Attention

当前主流的大模型都是基于2017年谷歌团队提出的Transformer架构,其核心是注意力(Attention)机制,简单来说就是计算softmax(qk^T)*v:

def attention(q_input, k_input, v_input):q = self.Q(q_input)k = self.K(k_input)v = self.V(v_input)return softmax(q * k.transpose()) * v

其中Q,K,V是模型的三个矩阵。对当前主流的Decoder-only模型来说,推理过程分为两个阶段:

  1. context phase也叫prefill phase:需要计算整个prompt的自注意力,q_input, k_input, v_input大小都为[seq_len, emb_dim],即整个prompt的embedding,context phase只需要进行一次,生成第一个token。
  2. generation phase或decoding phase:每生成一个token就要计算一次,其中q_input为[1, emb_dim],代表当前token的embedding,k_input, v_input为[n, emb_dim]代表所有前文的embedding,这个阶段计算的是当前token和所有前文的注意力。

Attention计算占据了模型推理阶段的绝大部分资源,要了解本文介绍的优化技巧,只需要了解Attention的计算过程就足够了。由于本文更多关注工程细节,对模型的设计只是一笔带过,有兴趣的读者可自行了解完整的Transformer相关信息。

Attention计算,来自Raimi Karim的Towards Data Science文章

瓶颈

想知道如何优化,我们首先要知道模型的计算慢在哪里。

  1. 大模型通常需要处理很长的输入和输出,由于当前token需要和前面所有的token进行attention计算,随着seq_len和n的增加,模型需要的计算的矩阵尺寸也越来越大。
  2. 在生成阶段,每个token的生成都依赖前面所有的计算,只能一个一个token生成,无法并发计算。

简单来说,推理瓶颈就在于Attention的计算是一个O(N^2)的操作且无法并发。本文所介绍的技术都是针对以上两个瓶颈进行优化。

优化目标

加速优化一般有两个目标,一个是更低的延迟(Latency),一个是更高的吞吐量(Throughput),其中延迟是指单个请求返回的时间,而吞吐量是一定时间内处理请求的总量。这两者有时是不可兼得的。举个例子,如果我们把一个模型切分成多个小模型进行分布式计算(张量并行),我们可以把单个请求的速度提升数倍(延迟下降),但是由于有通信和聚合的成本,系统处理单个请求的资源消耗量变多了,导致系统的吞吐量也会下降。有些加速技术只会针对其中一种目标,我们后面会详细介绍。

在本系列中,推理优化技术分为两大类:计算加速模型压缩。计算加速,通过改进算法和硬件利用率来提高效率,而不影响模型的输出质量,本质上是让模型“算得更快”。而模型压缩则是改变模型结构,减少部分计算(比如稀疏Attention)或降低计算精度(比如量化),换来更快的推理速度和更低的资源消耗,但可能会影响模型的输出质量,本质上是让模型“算得更少”。本文作为系列的第一篇,只介绍计算加速技术,我们将在下一篇文章中介绍模型压缩技术。

计算加速

任何计算的本质都是CPU/GPU执行一系列的指令,在模型结构固定的情况下,我们能做的优化无非就是以下三种:

  1. 减少需要执行的指令数量,即减少不必要的或重复的运算。
  2. 充分利用硬件的并发度,要么是让单条指令可以一次处理多条数据(SIMD),要么是利用CPU和GPU的多核心机制,同时执行多条指令。
  3. 加速内存IO速度。利用缓存局部性加速内存读取指令的执行速度,或者减少不必要的内存读写。

前两种我们称之为计算侧优化,后一种我们称之为内存IO优化

计算侧优化

KVCache

在每一个decoding phase中,我们需要计算当前token和之前所有已生成token的attention,因此需要计算所有token的k和v向量,但是前面的token的kv值在每轮decoding中都被重复计算了,因此我们可以把它们存下来,存成两个[seq_len-1, inner_dim]的Tensor,在每轮计算中只需要计算当前token的kv值即可。

KVCache是最简单直接的优化手段,一般模型的默认实现都会自带KVCache,因此并不需要我们额外实现,以Huggingface Transformers库为例,我们只需要在配置中设置use_cache=True即可。

这里还想吐槽一下,一般我们在工程中要到的cache也都是基于hash_map这种kv map的,最开始我还以为KVCache也是基于map的复杂数据结构,没想到只是简单的两个Tensor就实现了。

KVCache图解,来自Joao Lages的Medium文章

Kernel优化和算子融合

在NVIDIA GPU环境上,我们通过CUDA Kernel来执行大部分运算,矩阵乘法(GEMM),激活函数,softmax等,一般老说每个操作都对应一次kernel调用。但是每次kernel调用都有一些额外开销,比如gpu和cpu之间的通信,内存拷贝等,因此我们可以将整个Attention的计算放进同一个kernel实现,省略这些开销,在kernel中也可以实现一些Attention专用的优化。比如Facebook的xformers库就为各种Attention变种提供了高效的CUDA实现。主流的推理库也基本都自带了高效的Kernel实现。

除了手写的Kernel外,模型编译器也可以提供类似的优化机制,编译器将模型结构转换为其中间格式然后进行一系列的优化,比如我们提到的算子融合。虽然在CUDA这种主流平台上编译器的优化效果不一定比得上专用的Kernel实现,但是多平台的通用性是编译器的一大优势,比如TVM团队的MLC-LLM和微软主导的ONNX常被用来在手机等边缘设备上运行大模型。

分布式推理

在大模型的训练和推理过程中,我们有以下几种主流的分布式并行方式:

  • 数据并行(DataParallel):将模型放在多个GPU上,每个GPU都包含完整的模型,将数据集切分成多份,每个GPU负责推理一部分数据。
  • 流水线并行(PipelineParallel):将模型纵向拆分,每个GPU只包含模型的一部分层,数据在一个GPU完成运算后,将输出传给下一个GPU继续计算。
  • 张量并行(TensorParallel):将模型横向拆分,将模型的每一层拆分开,放到不同的GPU上,每一层的计算都需要多个GPU合作完成。

流水线并行一般是一台GPU内存无法放下整个模型的妥协之举,各层之间仍然是顺序运行的,并不能加速模型的计算。而另外两种并行则可以加速模型的推理。推理阶段的数据并行非常简单,因为不需要像训练一样聚合推理结果更新参数,因此我们只需要将模型单独部署在多个GPU上推理即可。而张量并行一般使用NVIDIA的Megatron库,模型内部结构改为使用Megatron的ColumnParallelLinear、RowParallelLinear、ParallelMLP和ParallelAttention等结构实现。

流水线并行和张量并行,来自Sequence Parallelism的paper

内存IO优化

在做性能优化时,我们除了要考虑单纯计算的速度,也要考虑内存访问的速度。一方面,和CPU缓存一样,GPU也有类似L1、L2这样的分级缓存(SRAM和HBM),级数越低的缓存大小越小,访问速度越快,因此我们在优化模型推理时也要考虑内存访问的局部性(Cache Locality)。另一方面,KVCache随着batch size和seq len的增加而扩张,在推理过程中会占据超过30%的内存,可能会出现因为内存不够用而限制最大并发度的问题。在并发度较高或者输入输出长度较大时,内存访问反而可能成为计算的瓶颈,而非CPU/GPU的计算量。

GPU的分级缓存,来自Flash Attention的paper

推理时的内存占用,来自vllm的paper

Flash Attention

在进行Attention计算时,QKV都是非常大的矩阵,直接进行矩阵乘法运算是非常缓存不友好的。因此我们可以考虑对矩阵进行分块乘法,每次只计算一个小的block,保证block可以放进SRAM而非HBM中。实际上这是一个很经典的思路,大部分的矩阵乘法kernel也是这样实现的。

而FlashAttention则更进一步,我们观察到Attention计算分为三步:

  1. 从HBM读取QK,计算S = QK^T,将S写回HBM
  2. 从HBM读出S,计算P = softmax(S),将P写回HBM
  3. 从HBM读出P和V,计算O=PV,将O写回HBM

我们然需要在整个计算过程中HBM读三次写三次,有没有办法只读写一次呢?如果Attention只是简单的矩阵乘法,可以通过分块计算的方法避免写回HBM,但是由于softmax的存在,我们无法直接这样做。因为softmax需要计算矩阵中每一行元素的最大值,所以我们必须等待所有分块遍历完成后才能计算下一步。

FlashAttention巧妙地利用了类似于动态规划的技巧,实现了online softmax,可以在一个循环中计算出一个分块的最终结果。FlashAttention算法在遍历过程中需要不断地对中间结果进行重新计算,但是得益于整个过程不需要读HBM,再增大了计算量的情况下仍然可以提升运算速度。

关于Flash Attention的完整介绍和数学推导,我推荐华盛顿大学的这个课件,里面非常直观地解释了Flash Attention背后的想法和推导过程。用户可以通过使用Flash Attention的库调用它的kernel,PyTorch官方也对Flash Attention提供了官方支持。

Flash Decoding

由于Flash Attention优化的是大矩阵乘法,矩阵越大优化效果应当越好。但是在在线推理的场景中,输入的batch size为1,Q矩阵实际上是一个向量而非矩阵,在这种场景下,Flash Attention无法充分地利用GPU的并发能力。而Flash Decoding通过以seq_len为维度并发,即将K和V分成多个部分,并发地与Q相乘而解决了这个问题。

与Flash Attention适合离线训练和批量推理不同,Flash Decoding在在线单次推理且上下文长度较长时效果更好。用户可以通过FlashAttention库或者xFormers的attention kernel来使用Flash Decoding。

动图封面

Flash decoding示例

Continuous Batching

在批量推理过程中我们一般使用固定的Batch Size,将多个请求Batch起来一起推理。在分配KVCache时,我们需要分配两个shape为[batch_size, seq_len, inner_dim]的tensor,但是不同的请求可能有不同输入和输出长度,而且我们无法预知最终的输出长度,无法固定seq_len,因此我们通常分配[batch_size, max_seq_len, inner_dim]这样的shape,保证所有请求的cache都放得下。

但是这样的分配策略有两个问题:

  1. 不是每个请求都可以达到max_seq_len,因此KVCache中很多的内存都被浪费掉了
  2. 即使一些请求输出长度很短,它们仍然需要等待输出较长的请求结束后才能返回

这种固定的Batch策略叫做静态Batching(Static Batching),为了解决这个问题,Orca提出了Continuous Batching策略,也叫Dynamic Batching或Inflight Batching。Continuous Batching允许输出较短的请求提前结束,并由新请求占用已结束请求的KVCache空间。

Continuous Batching示例,来自Anyscale官网,黄色为prompt,蓝色为生成的token

在批量推理场景中,Continuous Batching可以将模型的吞吐量提升两到三倍。当前主流的推理框架比如Huggingface TGI, Ray serve, vllm, TensorRT-LLM等都支持Continuous Batching策略。

Paged Attention

vLLM团队分析了推理时的内存浪费问题,认为推理中存在三种内存浪费

  1. Reservation:由于不确定每个请求的输出长度,我们需要给每个请求预留max_seq_len的空间。
  2. Internal Fragmentation:在Static Batching策略下,一个请求结束时,其剩余的空间旧被浪费掉了。
  3. External Fragmentation:由于KVCache是一个巨大的矩阵,且必须占用连续内存,操作系统如果只分配大的连续内存,势必有很多小的内存空间被浪费掉。

请求中的内存浪费,来自vLLM paper

vLLM团队认为,Continuous Batching可以部分解决Internal Fragmentation问题,但是Reservation和External Fragmentation的浪费仍然存在。因此他们提出了Paged Attention,其借鉴了操作系统中通过Page管理虚拟内存的思想:将KVCache分割为固定大小的Block,这些block不需要存储在连续内存中,由一个统一的内存分配器管理。请求按需申请内存,不需要预先留好max_seq_len大小的内存,解决了Reservation的浪费,请求结束后释放掉自己的blocks,解决了Internal Fragmentation,而系统只需要分配小的block,解决了External Fragmentation的问题。

vLLM团队的benchmark显示,使用PagedAttention可以将模型批量推理的吞吐量再提升3倍以上,达到Static Batching的6倍,而Paged Attention的另一个好处是不同的请求可以共享cache block,比如在beam search场景中,我们需要对同一个prompt生成多个结果,这些子请求就可以共享同一批prompt cache,PageAttention可以将beam search的吞吐量提升10倍以上。

用户可以通过官方的vLLM库使用Paged Attention,英伟达的TensorRT-LLM库和微软的Deepspeed-MII库也对部分模型提供了支持。

SplitFuse

微软的DeepSpeed团队观察到:如果我们要在F个前向推理中处理P个token,最高效的分配策略是将它们均分,即每个前向推理处理P/F个Token。但是在模型推理过程中,我们需要先在context phase一次性处理整个prompt,然后在genration phase一个一个生成token,是不符合最优策略的,因此deepspeed提出了SplitFuse:

  1. 将长prompt分割成多个短的输入,分在多个前向推理处理,只有最后一次推理会生成新token。
  2. 短的prompt会被合并在一起进行前向推理。
  3. 保证每次前向推理的输入token数是固定的。

在vLLM中,不同请求的prompt和generation需要分开处理,而Deepspeed SplitFuse可以把它们混合处理,图片来自deepseed blog

SplitFuse在输入长度越长时效果越明显,在Deepspeed自己的benchmark中,输入长度为2600时吞吐量可以达到vLLM的2.3倍,不过vLLM团队不久前推出了PagedAttentionV2,专门对长prompt进行了优化,不知道deepspeed的benchmark有没有对比最新的vLLM。

由于SplitFuse刚出不久,我对它的理解还不是很透彻,它声称在generation阶段也是每次输入一个>1的固定长度n,但是模型generation阶段一次只能生成一个token,是怎么做到输入n个token的,难道是直接pad成n个token,这样真的不会导致generation很浪费计算资源吗?我看到Deepspeed的benchmark里都是prompt length=2600,generation length=60,我怀疑它只适合prompt很长generation很短的情况,欢迎了解Deepspeed的读者在评论区解读。

SplitFuse目前被包含在DeepSpeed MII框架里,暂时只支持LLAMA,Mistral,OPT三种模型架构。

总结

本文是《大语言模型推理加速技术》系列的第一篇,简单介绍了大模型的计算过程和一些主流的推理加速技术。本篇所介绍的技术都是不改变模型结构和精度的前提下,以目标为最大化硬件利用率的优化技术。在后续的两篇文章中,我们将继续探讨模型压缩技术和当前主流的推理框架。

由于我并不是大模型方面的专家,本篇文章中可能出现了很多不准确的解读,也欢迎各位读者赐教~

Reference:

[1] Attention介绍:https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a

[2] KVCache介绍:https://medium.com/@joaolages/kv-caching-explained-276520203249

[3] SequenceParallel: https://arxiv.org/pdf/2105.13120.pdf

[4] From Online Softmax to FlashAttention: https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf

[5] Flash Decoding: https://pytorch.org/blog/flash-decoding/

[6] How continuous batching enables 23x throughput in LLM inference while reducing p50 latency: Achieve 23x LLM Inference Throughput & Reduce p50 Latency

[7] vLLM: Efficient Memory Management for Large Language Model Serving with PagedAttention

[8] SplitFuse: https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen

[9] xFormers: GitHub - facebookresearch/xformers: Hackable and optimized Transformers building blocks, supporting a composable construction.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/264699.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

车载测试面试:题库+项目

车载测试如何面试(面试技巧)https://blog.csdn.net/2301_79031315/article/details/136229809 入职车载测试常见面试题(附答案)https://blog.csdn.net/2301_79031315/article/details/136229946 各大车企面试题汇总(含答案&…

深搜+素数筛,LeetCode 2867. 统计树中的合法路径数目

一、题目 1、题目描述 给你一棵 n 个节点的无向树,节点编号为 1 到 n 。给你一个整数 n 和一个长度为 n - 1 的二维整数数组 edges ,其中 edges[i] [ui, vi] 表示节点 ui 和 vi 在树中有一条边。 请你返回树中的 合法路径数目 。 如果在节点 a 到节点 …

请求包的大小会影响Redis每秒处理请求数量

文章目录 🔊博主介绍🥤本文内容压测规划客户端长连接数量对性能的影响请求包大小的影响Pipleline模式对Redis的影响 📢文章总结📥博主目标 🔊博主介绍 🌟我是廖志伟,一名Java开发工程师、Java领…

rtthread stm32h743的使用(二)DAPLINK下载程序不能启动

我们要在rtthread studio 开发环境中建立stm32h743xih6芯片的工程。我们使用一块stm32h743及fpga的核心板完成相关实验。 但是发现程序编译下载后不能启动,换为stlink后能够启动,反复试了好多次DAPLINK都不行,发送指令也没有反馈&#xff1a…

关于python的数据可视化与可视化:数据读取

带着问题寻找答案可以使自己不再迷茫或者不知所措! 了解什么python的数据可视化? 数据的读取(一般伴随着课程文件中会进行提供和利用) 数据可视化是将Python应用于大气海洋科学中数据处理及分析过程的重要环节,它可以…

Python爬虫实战:从API获取数据

引言 在现代软件开发中,API已经成为获取数据的主要方式之一。API允许不同的软件应用程序相互通信,共享数据和功能。在本文中,我们将学习如何使用Python从API获取数据,并探讨其在实际应用中的价值。 目录 引言 二、API基础知识 …

tigramite教程(二)生物地球科学案例研究

文章目录 数据生成与绘图因果发现分析平稳性假设、确定性、潜在混杂因素结构假设参数假设使用PCMCIplus的滑动窗口分析聚合因果图非参数因果效应估计假设的图形和调整集干预的真实情况假设的参数模型和因果效应的估计使用关于图的不同假设进行估计非因果估计项目地址 这个文件…

目标检测-Transformer-ViT和DETR

文章目录 前言一、ViT应用和结论结构及创新点 二、DETR应用和结论结构及创新点 总结 前言 随着Transformer爆火以来,NLP领域迎来了大模型时代,成为AI目前最先进和火爆的领域,介于Transformer的先进性,基于Transformer架构的CV模型…

FL Studio20.8.3.2304 中文破解版功能更新(附2024最新图文激活教程)

FL Studio20.8中文开心版下载 (有能力的话还是建议支持正版) 绝对能用,我现在就在用着 版本是FL Studio 20.8.3.230420.8.3.2304 Image-Line公司成立20周年而发布的一个版本,FL Studio中文开心版是目前互联网上最优秀的完整的软件音乐制作环境或数字…

SpringBoot使用classfinal-maven-plugin插件加密Jar包

jar包加密 1、在启动类的pom.xml中加入classfinal-maven-plugin插件 <build><plugins><plugin><groupId>org.springframework.boot</groupId><artifactId>spring-boot-maven-plugin</artifactId></plugin><plugin><…

from tensorflow.keras.layers import Dense,Flatten,Input报错无法引用

from tensorflow.keras.layers import Dense,Flatten,Input 打印一下路径&#xff1a; import tensorflow as tf import keras print(tf.__path__) print(keras.__path__) [E:\\开发工具\\pythonProject\\studyLL\\venv\\lib\\site-packages\\keras\\api\\_v2, E:\\开发工具\\…

小程序应用、页面、组件生命周期

引言 微信小程序生命周期是指在小程序运行过程中&#xff0c;不同阶段触发的一系列事件和函数。这一概念对于理解小程序的整体架构和开发流程非常重要。本文将介绍小程序生命周期的概念以及在不同阶段触发的关键事件&#xff0c;帮助开发者更好地理解和利用小程序的生命周期。 …

Docker基础(一)

文章目录 1. 基础概念2. 安装docker3. docker常用命令3.1 帮助命令3.2 镜像命令3.3 容器命令3.4 其他命令 4. 使用案例 1. 基础概念 镜像&#xff08;Image&#xff09;&#xff1a;Docker 镜像&#xff08;Image&#xff09;&#xff0c;就相当于是一个 root 文件系统。比如官…

ShardingSphere 5.x 系列【15】分布式主键生成器

有道无术,术尚可求,有术无道,止于术。 本系列Spring Boot 版本 3.1.0 本系列ShardingSphere 版本 5.4.0 源码地址:https://gitee.com/pearl-organization/study-sharding-sphere-demo 文章目录 1. 概述2. 配置3. 内置算法3.1 UUID3.2 Snowflake3.3 NanoId3.4 CosId3.5 Co…

深度学习目标检测】二十、基于深度学习的雾天行人车辆检测系统-含数据集、GUI和源码(python,yolov8)

雾天车辆行人检测在多种场景中扮演着至关重要的角色。以下是其作用的几个主要方面&#xff1a; 安全性提升&#xff1a;雾天能见度低&#xff0c;视线受阻&#xff0c;这使得驾驶者和行人在道路上的感知能力大大降低。通过车辆行人检测技术&#xff0c;可以在雾天条件下及时发现…

等保测评与商用密码共铸工控安全“双评合规”新篇章

最近听说了一个段子&#xff1a;“网络安全就像美女的内衣&#xff0c;等保和密评就是最贴身的内衣两件套&#xff0c;上下身一件都不能少。否则你的魔鬼身材&#xff08;核心数据&#xff09;就有可能被色狼&#xff08;黑客&#xff09;一览无余&#xff08;数据泄漏&#xf…

什么是nginx 、安装nginx、nginx调优

一、 什么是nginx 1.1 nginx的概念 一款高新能、轻量级Web服务软件系统资源消耗低对HTTP并发连接的处理能力高单台物理服务器可支持30 000&#xff5e;50 000个并发请求。 1.2 nginx模块与作用 核心模块&#xff1a;是 Nginx 服务器正常运行必不可少的模块&#xff0c;提供错…

配置MMDetection的solov2攻略整理

目录 一、MMDetection 特性 常见用法 二、ubuntu20.04配置solov2 三、Windows11配置solov2 一、MMDetection MMDetection是一个用于目标检测的开源框架&#xff0c;由OpenMMLab开发和维护。它提供了丰富的预训练模型和模块&#xff0c;可以用于各种目标检测任务&#xff…

ChatGPT在数据处理中的应用

ChatGPT在数据处理中的应用 今天的这篇文章&#xff0c;让我不断体会AI的强大&#xff0c;愿人类社会在AI的助力下走向更加灿烂辉煌的明天。 扫描下面二维码注册 ​ 数据处理是贯穿整个数据分析过程的关键步骤&#xff0c;主要是对数据进行各种操作&#xff0c;以达到最终的…

亿道丨三防平板丨如何从多方面选择合适的三防加固平板?

在如今这个信息爆炸的时代&#xff0c;移动设备已经成为我们生活和工作的必备工具。然而&#xff0c;在一些特殊的场合中&#xff0c;普通的平板电脑可能无法满足需求&#xff0c;比如工厂车间、野外作业、极端天气等环境下。此时&#xff0c;三防平板就成了不二之选。那么&…