GiantPandaCV | FasterTransformer Decoding 源码分析(六)-CrossAttention介绍

本文来源公众号“GiantPandaCV”,仅用于学术分享,侵权删,干货满满。

原文链接:FasterTransformer Decoding 源码分析(六)-CrossAttention介绍

GiantPandaCV | FasterTransformer Decoding 源码分析(一)-整体框架介绍-CSDN博客

GiantPandaCV | FasterTransformer Decoding 源码分析(二)-Decoder框架介绍-CSDN博客

GiantPandaCV | FasterTransformer Decoding 源码分析(三)-LayerNorm介绍-CSDN博客

GiantPandaCV | FasterTransformer Decoding 源码分析(四)-SelfAttention实现介绍-CSDN博客

GiantPandaCV | FasterTransformer Decoding 源码分析(五)-AddBiasResidualLayerNorm介绍-CSDN博客

作者丨进击的Killua

来源丨https://zhuanlan.zhihu.com/p/670739629

编辑丨GiantPandaCV

本文是FasterTransformer Decoding源码分析的第六篇,笔者试图去分析CrossAttention部分的代码实现和优化。由于CrossAttention和SelfAttention计算流程上类似,所以在实现上FasterTransformer使用了相同的底层Kernel函数,因此会有大量重复的概念和优化点,重复部分本文就不介绍了,所以在阅读本文前务必先浏览进击的Killua:FasterTransformer Decoding 源码分析(四)-SelfAttention实现介绍这篇文章,一些共性的地方会在这篇文章中做统一介绍,本文着重介绍区别点。

一、模块介绍

如下图所示,CrossAttention模块位于DecoderLayer的第4个模块,输入为经过LayerNorm后的SelfAttention结果和encoder的outputs,经过该模块处理后进行残差连接再输入LayerNorm中。

CrossAttention在decoder中的位置

CrossAttention模块本质上还是要实现如下几个公式,主要的区别在于其中 CrossAttention 的K, V矩阵不是使用 上一个 Decoder block的输出或inputs计算的,而是使用Encoder 的编码信息矩阵计算的,这里还是把公式放出来展示下。

crossAttention 公式

二、设计&优化

整体Block和Thread的执行模型还是和SelfAttention的保持一致,这里不再赘述,主要介绍一下有一些区别的KV Cache。

1. KV Cache

由于在CrossAttention中K,V矩阵是来自于已经计算完成的Encoder输出,所以KV Cache的程度会更大,即第一次运算把KV计算出来之后,后续只要读取Cache即可,不需要用本step的输入再进行线性变换得到增量的部分K,V,如下图所示。

三、源码分析

1. 方法入口

CrossAttention的调用入口如下,解释下这里的输入和输出,具体逻辑在后面。

输入Tensor

  1. input_query:normalize之后的SelfAttention输出,大小是[batch_size,hidden_units_]

  2. encoder_output: encoder模块的输出,大小是[batch_size, mem_max_seq_len, memory_hidden_dimension]

  3. encoder_sequence_length:每个句子的长度,大小是[batch_size]

  4. finished: 解码是否结束的标记,大小是[batch_size]

  5. step: 当前解码的步数

输出Tensor

  1. hidden_features:CrossAttention的输出feature,大小是[batch_size,hidden_units_],和input_query大小一致。

  2. key_cache:CrossAttention中存储key的cache,用于后续step的计算。

  3. value_cache: CrossAttention中存储Value的cache,用于后续step的计算。

        TensorMap cross_attention_input_tensors{{"input_query", Tensor{MEMORY_GPU, data_type, {batch_size, hidden_units_}, normed_self_attn_output_}},{"encoder_output", input_tensors->at(1)},{"encoder_sequence_length", input_tensors->at(2)},{"finished", input_tensors->at(3)},{"step", input_tensors->at(4)}}; TensorMap cross_attention_output_tensors{{"hidden_features", Tensor{MEMORY_GPU, data_type, {batch_size, hidden_units_}, cross_attn_output_}},{"key_cache",Tensor{MEMORY_GPU,data_type,std::vector<size_t>(output_tensors->at(3).shape.begin() + 1, output_tensors->at(3).shape.end()),output_tensors->at(3).getPtrWithOffset<T>(mem_cache_offset)}},{"value_cache",Tensor{MEMORY_GPU,data_type,std::vector<size_t>(output_tensors->at(4).shape.begin() + 1, output_tensors->at(4).shape.end()),output_tensors->at(4).getPtrWithOffset<T>(mem_cache_offset)}}};cross_attention_layer_->forward(&cross_attention_output_tensors,&cross_attention_input_tensors,&decoder_layer_weight->at(l).cross_attention_weights);

2. 主体框架

主体框架代码由三部分构成,分别是该step的QKV生成、output生成和Linear输出。其中第一部分和第三部分都使用了cublas的封装矩阵乘方法gemm,这里就不多介绍了,主要功能逻辑在第二部分output生成。

第一部分:QKV生成

如上所述,代码中Q矩阵是需要每个step生成的,而KV矩阵只有第一个step需要生成,后续步骤读取cache即可。

    cublas_wrapper_->Gemm(CUBLAS_OP_N,CUBLAS_OP_N,hidden_units_,  // n                          batch_size,d_model_,  // k                          attention_weights->query_weight.kernel,hidden_units_,  // n                          attention_input,d_model_,  // k                          q_buf_,hidden_units_ /* n */);if (step == 1) {cublas_wrapper_->Gemm(CUBLAS_OP_N,CUBLAS_OP_N,hidden_units_,batch_size * mem_max_seq_len,encoder_output_tensor.shape[2],attention_weights->key_weight.kernel,hidden_units_,encoder_output_tensor.getPtr<T>(),encoder_output_tensor.shape[2],key_mem_cache,hidden_units_);cublas_wrapper_->Gemm(CUBLAS_OP_N,CUBLAS_OP_N,hidden_units_,batch_size * mem_max_seq_len,encoder_output_tensor.shape[2],attention_weights->value_weight.kernel,hidden_units_,encoder_output_tensor.getPtr<T>(),encoder_output_tensor.shape[2],value_mem_cache,hidden_units_);}

第二部分:output生成

核心函数调用,这里参数较多不一一介绍了,非常多(像一些has_ia3等参数应该是在不断迭代的过程中加入的),在后面函数实现中会将重点参数进行阐述。

    cross_attention_dispatch<T>(q_buf_,attention_weights->query_weight.bias,key_mem_cache,attention_weights->key_weight.bias,value_mem_cache,attention_weights->value_weight.bias,memory_sequence_length,context_buf_,finished,batch_size,batch_size,head_num_,size_per_head_,step,mem_max_seq_len,is_batch_major_cache_,q_scaling_,output_attention_param,has_ia3 ? input_tensors->at("ia3_tasks").getPtr<const int>() : nullptr,has_ia3 ? attention_weights->ia3_key_weight.kernel : nullptr,has_ia3 ? attention_weights->ia3_value_weight.kernel : nullptr,stream_);

第三部分:Linear输出

这里就是简单地对上步输出结果乘以一个权重矩阵。

    cublas_wrapper_->Gemm(CUBLAS_OP_N,CUBLAS_OP_N,d_model_,  // nbatch_size,hidden_units_,  // kattention_weights->attention_output_weight.kernel,d_model_,  // ncontext_buf_,hidden_units_,  // kattention_out,d_model_ /* n */);

3. kernel函数调用

上述output生成步骤中会调用如下代码,这里针对每个head中需要处理的层数进行了分类,这个也是大量优化中的常用方案,针对不同的入参大小选择不同size和配置的kernel函数进行处理,这里有经验的一些成分在里面,我们常用的case是hidden_size_per_head=64(head=8)的情况。

template<typename T, typename KERNEL_PARAMS_TYPE>void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream){switch (params.hidden_size_per_head) {case 32:mmha_launch_kernel<T, 32, 32, KERNEL_PARAMS_TYPE>(params, stream);break;case 48:mmha_launch_kernel<T, 48, 64, KERNEL_PARAMS_TYPE>(params, stream);break;case 64:mmha_launch_kernel<T, 64, 64, KERNEL_PARAMS_TYPE>(params, stream);break;case 80:mmha_launch_kernel<T, 80, 128, KERNEL_PARAMS_TYPE>(params, stream);break;case 96:mmha_launch_kernel<T, 96, 128, KERNEL_PARAMS_TYPE>(params, stream);break;case 112:mmha_launch_kernel<T, 112, 128, KERNEL_PARAMS_TYPE>(params, stream);break;case 128:mmha_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);break;case 144:mmha_launch_kernel<T, 144, 256, KERNEL_PARAMS_TYPE>(params, stream);break;case 160:mmha_launch_kernel<T, 160, 256, KERNEL_PARAMS_TYPE>(params, stream);break;case 192:mmha_launch_kernel<T, 192, 256, KERNEL_PARAMS_TYPE>(params, stream);break;case 224:mmha_launch_kernel<T, 224, 256, KERNEL_PARAMS_TYPE>(params, stream);break;case 256:mmha_launch_kernel<T, 256, 256, KERNEL_PARAMS_TYPE>(params, stream);break;default:assert(false);}}

4. kernel函数实现

这个函数和SelfAttention中的kernel函数是同一个,流程如图所示,这里只介绍下区别点。

1. CrossAttention中只有第一个step需要将KV存入Cache,其他step不需要。

        const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0);if (handle_kv) {// Trigger the stores to global memory.            if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {*reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k);}}

2. 处理本轮step的KV时,也是从cache中取得KV,无需进行本轮计算得到增量KV。

    if (DO_CROSS_ATTENTION) {// The 16B chunk written by the thread.        int co = tidx / QK_VECS_IN_16B;// The position of the thread in that 16B chunk.        int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.        int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +// params.timestep*QK_ELTS_IN_16B +                     tlength * QK_ELTS_IN_16B + ci;k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k_cache[offset])) :k;}else {if (params.int8_mode == 2) {using Packed_Int8_t  = typename packed_type<int8_t, num_elems<Qk_vec_m>::value>::type;using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec_m>::value>::type;const auto k_scaling = params.qkv_scale_out[1];const auto k_quant =*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[qk_offset]);convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));}else {k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k[qk_offset])) :k;}}if (DO_CROSS_ATTENTION) {v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&v_cache[tlength * Dh]));}

四、总结

本文相对简单,分析了FasterTransformer中CrossAttention模块的设计方法和代码实现,和SelfAttention基本一致,只是对KV Cache的处理细节上有一点区别,整体上看缓存的使用会比SelfAttention多一些,所以速度应该还会快一点。

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/337884.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HttpSecurity 是如何组装过滤器链的

有小伙伴们问到这个问题&#xff0c;简单写篇文章和大伙聊一下。 一 SecurityFilterChain 首先大伙都知道&#xff0c;Spring Security 里边的一堆功能都是通过 Filter 来实现的&#xff0c;无论是认证、RememberMe Login、会话管理、CSRF 处理等等&#xff0c;各种功能都是通…

数字信号处理实验四:IIR数字滤波器设计及软件实现

一、实验目的 1. 掌握MATLAB中进行IIR模拟滤波器的设计的相关函数的应用&#xff1b; 2. 掌握MATLAB的工具箱中提供的常用IIR数字滤波器的设计函数的应用&#xff1b; 3.掌握MATLAB的工具箱中提供的模拟滤波器转数字滤波器的相关的设计函数的应用。 二、实验内容 本实验为…

微软远程连接工具:Microsoft Remote Desktop for Mac 中文版

Microsoft Remote Desktop 是一款由微软开发的远程桌面连接软件&#xff0c;它允许用户从远程地点连接到远程计算机或虚拟机&#xff0c;并在远程计算机上使用桌面应用程序和文件。 下载地址&#xff1a;https://www.macz.com/mac/5458.html?idOTI2NjQ5Jl8mMjcuMTg2LjEyNi4yMz…

AI网络爬虫:无限下拉滚动页面的另类爬取方法

现在很多网页都是无限下拉滚动的。可以拉动到底部&#xff0c;然后保存网页为mhtml格式文件。 接着&#xff0c;在ChatGPT中输入提示词&#xff1a; 你是一个Python编程高手&#xff0c;要完成一个关于爬取网页内容的Python脚本的任务&#xff0c;下面是具体步骤&#xff1a; …

vs - 在win10中安装vs2013update5

文章目录 vs - 在win10中安装vs2013update5概述笔记直接安装vs2013-update5报错先安装vs2013原版安装 vs2013 update5测试备注END vs - 在win10中安装vs2013update5 概述 用VS2019写的程序&#xff0c;在早期windows(e.g. win7, win8.1)上安装时&#xff0c;需要UCRT。 UCRT是…

unity2020打包webGL时卡进程问题

我使用的2020.3.0f1c1&#xff0c;打包发布WEB版的时候会一直卡到asm2wasm.exe这个进程里&#xff0c;而且CPU占用率90%以上。 即使是打包一个新建项目的空场景也是同样的问题&#xff0c;我尝试过一直卡在这里会如何&#xff0c;结果还真打包成功了。只是打包一个空场景需要20…

latex bib引参考文献

1.bib内容 2.sn-mathphys-num是官方的参考文献格式 3.不用导cite包&#xff0c;文中这么写 4.end document前ckwx是自己命名的bib的名字

【自动化运维】不要相信人,把所有的东西都交给机器去处理

不积跬步&#xff0c;无以至千里&#xff1b;不积小流&#xff0c;无以成江海。 大家好&#xff0c;我是闲鹤&#xff0c;十多年开发、架构经验&#xff0c;先后在华为、迅雷服役过&#xff0c;也在高校从事教学3年&#xff1b;目前已创业了7年多&#xff0c;主要从事物联网/车…

【运维项目经历|023】Docker自动化部署与监控项目

目录 项目名称 项目背景 项目目标 项目成果 我的角色与职责 我主要完成的工作内容 本次项目涉及的技术 本次项目遇到的问题与解决方法 本次项目中可能被面试官问到的问题 问题1&#xff1a;项目周期是多久&#xff1f; 问题2&#xff1a;服务器部署架构方式及数量配置…

【SpringMVC】_SpringMVC实现用户登录

目录 1、需求分析 2、接口定义 2.1 校验接口 请求参数 响应数据 2.2 查询登录用户接口 请求参数 响应数据 4、服务器代码 5、前端代码 5.1 登录页面login.html 5.2 首页页面index.html 6、运行测试 1、需求分析 用户输入账号与密码&#xff0c;后端校验密码是否正确&a…

FineBi导出Excel后台版实现

就是不通过浏览器,在后台运行的导出 参考文档在:仪表板查看接口- FineBI帮助文档 FineBI帮助文档 我这里是将这个帮助文档中导出的excel文件写到服务器某个地方后,对excel进行其他操作后再下载。由于原有接口耦合了HttpServletRequest req, HttpServletResponse res对象,…

可变参数

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 在Python中&#xff0c;还可以定义可变参数。可变参数也称不定长参数&#xff0c;即传入函数中的实际参数可以是任意多个。 定义可变参数时&#xf…

SRS视频服务器应用研究

1.SRS尝试从源码编译启动 1.1.安装ubuntu 下载镜像文件 使用VMWare安装&#xff0c;过程中出现蓝屏&#xff0c;后将VM的软件版本从15.5升级到17&#xff0c;就正常了。 1.2.更新ubuntu依赖 1.3.下载源码 官方推荐下载develop 切换到用户目录&#xff0c;开始安装 安装后 突然…

[AI OpenAI] 为非营利组织推出OpenAI

我们正在启动一项新计划&#xff0c;以增强非营利组织对我们工具的可访问性&#xff0c;包括ChatGPT Team和Enterprise的折扣优惠。 今天&#xff0c;我们推出了OpenAI for Nonprofits&#xff0c;这是一项旨在增强非营利组织对我们工具的可访问性的新计划。 非营利组织已经在…

5G专网驻网失败分析(suci无效)

suci 5G终端第一次驻网时&#xff0c;注册消息Registartion request中携带的5GS mobile identity要携带suci类型的mobile identity。 注册消息协议规范见5G NAS 协议3gpp TS24.501 8.2.6 Registration request。 suci协议规范参见3gpp TS24.501 9.11.3.4 5GS mobile identity …

python zip()函数(将多个可迭代对象的元素配对,创建一个元组的迭代器)zip_longest()

文章目录 Python zip() 函数深入解析基本用法函数原型基础示例 处理不同长度的迭代器高级用法多个迭代器使用 zip() 与 dict()解压序列 注意事项内存效率&#xff1a;zip() 返回的是一个迭代器&#xff0c;这意味着直到迭代发生前&#xff0c;元素不会被消耗。这使得 zip() 特别…

Mysql | select语句导入csv后再导入excel表格

需求 从mysql数据库中导出数据到excel 解决方案 sql导出csv文件 sql SELECT col1,col2 FROM tab_01 WHERE col3 xxx INTO OUTFILE /tmp/result.csv FIELDS TERMINATED BY , ENCLOSED BY " LINES TERMINATED BY \n;csv文件导出excel文件 1、【数据】-【导入数据】 …

【redis】宝塔,线上环境报Redis error: ERR unknown command del 错误

两种方式&#xff1a; 1.打开宝塔上的redis&#xff0c;通过配置文件修改权限&#xff0c;注释&#xff1a;#rename-command DEL “” 2.打开服务器&#xff0c;宝塔中默认redis安装位置是&#xff1a;cd /www/server/redis 找到redis.conf,拉到最后&#xff0c;注释#rename-co…

『 Linux 』文件系统

文章目录 磁盘构造磁盘抽象化 磁盘的寻址方式磁盘控制器磁盘数据传输文件系统Inode数据块(Data Blocks)超级块(SuperBlock)块组描述符(Group Descriptor) 磁盘构造 磁盘内部构造由磁头臂,磁头,主轴,盘片,盘面,磁道,柱面,扇区构成; 磁头臂&#xff1a;控制磁头的移动,可以精确地…

测试工具fio

一、安装部署 fio是一款优秀的磁盘IO测试工具&#xff0c;在Linux中比较常用于测试磁盘IO 其下载地址&#xff1a;https://brick.kernel.dk/snaps/fio-2.1.10.tar.gz 或者登录其官网&#xff1a;http://freshmeat.sourceforge.net/projects/fio/ 进行下载。 tar -zxvf fio-…