本文参考了链接添加链接描述
flash attention介绍
flash attention的介绍可以参考论文:FlashAttention: Fast and Memory-Efficient Exact Attention
with IO-Awareness,具体的数学公式参考下面这个图片:其中注意关于矩阵S有两个维度,softmax的操作维度是dim=1,用pytorch表示就是torch.softmax(S, dim=1)
对于flash attention来说,里面有两次矩阵乘法,对于这样的二维数组矩阵乘法,一般来说都会考虑使用二维线程块,但是我们之前实现的softmax都是以一维线程块来处理,其中专门用到了一个cub库的函数BlockReduce,经过本人测试,发现这个函数只能针对一维线程块做线程块内部的规约,不能用于二维线程块内部针对某个维度规约,因此在实现flash attention之前,我们需要编写一个二维线程块实现softmax的算法,其中注意BLOCK_DIM_x和BLOCK_DIM_y都必须要选取2的幂次方。
二维线程块实现softmax
之前我们实现一维线程块处理softmax的时候,参考链接添加链接描述