Einsum(Einstein summation convention)
笔记来源:
Permute和Reshape嫌麻烦?einsum来帮忙!
The Einstein summation convention is a notational shorthand used in tensor calculus, particularly in the fields of physics and mathematics, to simplify the representation of sums over indices in tensor equations. This convention is widely used in general relativity and other areas involving tensors.
爱因斯坦求和约定(Einstein summation convention)的一种函数,它通过指定的索引规则执行张量操作。爱因斯坦求和表示法使得高维张量运算更加直观灵活,比如可以用于更复杂的张量运算,而不仅限于简单的矩阵乘法
numpy.einsum(subscripts,*operands,out=None,dtype=None,order='K',casting='safe',optimize=False)
torch.einsum(equation,*operands)
tf.einsum(equation,*inputs,**kwargs)
1.1 什么是einsum?运算规则是什么?
在求和公式中,某些下标在等式两边都有出现(例如下标 i i i)而有些下标(被求和的维度)只出现在一侧(例如下标 j j j)
所以即便是省略求和符号也不会产生歧义,即我们仍然知道哪个维度被求和了
这种运算与变量是什么并没有关系,因此上式可以进一步简化
1.在不同输入之间重复出现的索引表示沿着这一维度进行乘法(例如 k k k)
2.只出现在输入中的索引表示在这一维度上求和(例如输出有 i , j i,j i,j,也就是说 k k k只出现在输入中)
C = torch.einsum('ik,kj->ij',A,B)
# 箭头和箭头右侧的可以省略
C = torch.einsum('ik,kj',A,B)
#等价于
C = torch.matmul(A, B)
3.输出中维度的顺序可以是任意的(例如 i j ij ij或 j i ji ji)
这里 C j i C_{ji} Cji就是 C i j C_{ij} Cij的转置
C = torch.einsum('ik,kj->ji',A,B)
# 省略号可以用于broadcasting,也就是忽略不关心的维度,只对最后两个维度进行计算
C = torch.einsum('...ik,...kj->ji',A,B)
1.2 Einsum怎么用?
向量外积
C = torch.einsum('i,j->ij',a,b)
提取对角元素
a = torch.einsum('kk->k',A)
1.3 Einum在多头注意力的应用
原版本
qkv:torch.Tensor self.qkv(x) #B,patches,3*dim
qkv = qkv.reshape(B, patches, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2,0,3,1,4)
q:torch.Tensor = qkv[0]
k:torch.Tensor = qkv[1]
v:torch.Tensor = qkv[2]
k_t = k.transpose(-2,-1)
attn = torch.softmax(q k_t self.scale,dim=-1)
attn = self.attn_drop(attn)
wa = attn @ v
wa = wa.transpose(1,2)
wa = wa.flatten(2)
使用Einum简化
q,k,v = map(lambda t:rearrange(t,'b n (h d)->b h n d',h=self.num_heads),qkv)
attn = torch.einsum('bijc,bikc -bijk',q,k)*self.scale
attn = attn.softmax(dim=-1)
x = torch.einsum('bijk,bikc -bijc',attn,v)
x = rearrange(x,'b i jc->b j (i c)')
(1)使用EINOPS库中的rearrange操作
q,k,v = map(lambda t:rearrange(t,'b n (h d)->b h n d',h=self.num_heads),qkv)
(2)q乘k转置除以缩放比例
attn = torch.einsum('bijc,bikc -bijk',q,k)*self.scale
(3)softmax得到attention数值
attn = attn.softmax(dim=-1)
(4)attention值对v加权
x = torch.einsum('bijk,bikc -bijc',attn,v)
(5)将x的维度还原为输入的形式
x = rearrange(x,'b i jc->b j (i c)')
[ B , h e a d , N , C / / h e a d ] − > [ B , N , C ] [B,head,N,C//head]->[B,N,C] [B,head,N,C//head]−>[B,N,C]
1.4 Einsum优缺点
优点:
- 一次调用、一个函数完成多个操作
- 有时比多个Permute和Transpose操作组合的可读性高
- 可以避免生成中间变量
缺点:
求和表达式复杂时耗费内存,导致性能问题