当前的问题
多头注意力使用多个头部可以提高模型的精度。然而,并不是所有的注意力头都具有同样的重要性。一些研究表明,许多注意力头可以被修剪而不影响准确性。
此外,在多头注意中,每个注意头并行操作,最终输出是所有注意头的总和。鉴于这些注意头独立运作,有些可能是多余的。
动机
建立一个动态的注意头路由机制。这种机制可以使每个标记自适应地选择适当的注意头,在不影响准确性的情况下提高推理效率。
方法
图1:多头注意和我们提出的头部混合注意之间的高层次比较。子图(a)展示了具有 h h h个注意头的标准多头注意层,而子图(b)展示了头部混合注意(MoH)(me:包含了共享注意力和混合注意力)架构。值得注意的是,MoH不会增加注意头的数量,从而确保MoH的总参数与多头注意的总参数相当。
知识回顾:多头注意力
注意力机制:
其中 X = X ′ X=X' X=X′是为自注意力, X ≠ X ′ X\ne X' X=X′为交叉注意力。
交叉注意力:
混合多头注意力(MIXTURE-OF-HEAD ATTENTION)
把注意力头当作专家
受MoE的巨大成功启发,我们提出了头部混合注意(MoH),它将注意头视为专家。具体来说,MoH由 h h h个头组成 H = { H 1 , H 2 , … , H h } H=\{H^1,H^2,\ldots,H^h\} H={H1,H2,…,Hh}和激活 Top-K \text{Top-K} Top-K头的路由器。形式上,给定输入令牌 X X X和 X ′ X' X′, MoH的输出是 K K K个选定的正面输出的加权和:
其中 g i g_i gi表示路由得分。只有当第 i i i个注意头被激活时, g i g_i gi才不为零。
共享注意力
在注意机制中,一些注意头可能在不同的语境中捕捉到共同的知识,如语言中的语法规则。受Dai等人(2024)的启发,我们将一个头像子集指定为始终保持激活状态的共享头像。通过在共享头部内整合公共知识,我们减少了其他动态路由头部之间的冗余。
路由得分g的定义
其中, h s h_s hs表示共有正面的个数。 W s ∈ R h s × d i n \bm W_s\in \mathbb{R}^{h_s\times d_{in}} Ws∈Rhs×din和 W r ∈ R ( h − h s ) × d i n \bm W_r\in \mathbb{R}^{(h-h_s)\times d_{in}} Wr∈R(h−hs)×din分别表示共享头和路由头的投影矩阵。系数 α 1 \alpha_1 α1和 α 2 \alpha_2 α2平衡了共享头和路由头的贡献,定义为:
其中, W h ∈ R 2 × d i n \bm W_h\in \mathbb{R}^{2\times d_{in}} Wh∈R2×din为可训练投影矩阵, d i n d_{in} din为 x t \bm x_t xt的隐藏大小。
负载平衡损失(使专家得到充分训练)
直接训练MoE层通常会导致大多数令牌被路由给少数专家,使剩余的专家没有得到充分的训练(Shazeer等人,2017)。为了避免拟议MoH中的不平衡负载,遵循先前的MoE方法(Lepikhin等人,2021;Wei等人,2024),我们应用负载平衡损失。具体来说,对于 X ∈ R T × d i n \bm{X}\in \mathbb{R}^{T\times d_{in}} X∈RT×din中的第 t t t个输入令牌 x t ∈ R d i n \bm{x}_t\in \mathbb{R}^{d_{in}} xt∈Rdin,负载均衡损失 L b \mathcal{L}_b Lb表示为:
其中 T T T为令牌数量。 1 ( ∗ ) \mathbb{1}(*) 1(∗)表示指示函数。
L t a s k \mathcal{L}_{task} Ltask指特定于任务的损失。
其中 β \beta β是减轻路由崩溃风险的权衡超参数。默认情况下,所有任务的负载均衡损失权重 β \beta β设置为0.01。
相关工作
多头注意力。Transformers(Vaswani et al ., 2017)在自然语言处理和计算机视觉方面都获得了极大的兴趣和成功。长期以来,变形金刚的成功归功于多头注意机制(Cordonnier et al, 2020)。多头注意机制由Vaswani等人(2017)提出,通过允许多个注意头在输入的不同低维投影上操作来增强注意层的表征能力。然后将这些头部的输出连接起来形成最终结果。或者,通过按行分解输出投影矩阵,多头注意力可以用求和形式表示。在求和形式中,每个头并行操作,最终输出是所有头的和。受此启发,我们提出了MoH,一种动态注意-头部路由机制,允许每个令牌自适应地选择适当的头部。
Mixture-of-Experts模型。混合专家(MoE)方法(Du et al, 2022;Lewis et al, 2021;Rajbhandari等人,2022;Roller等,2021;Zhou et al ., 2022;Jin等人,2024b)的引入是为了在不增加计算成本的情况下扩展深度神经网络的容量。在这种方法中,对于每个输入,只有一个被称为专家的参数子集被激活。Shazeer等人(2017)首先在LSTM层之间引入了MoE层。Switch Transformer (Fedus et al, 2022)通过每个令牌只选择Top-1专家进一步简化了门控机制。Gshard (Lepikhin et al, 2021)改进了Top-2专家路由策略。MoE强调有效的参数缩放,同时保持可管理的计算成本,而MoH侧重于在不增加参数数量的情况下减少冗余注意头的激活。
参考资料
论文下载(arixv,15 Oct 2024)
https://arxiv.org/abs/2410.11842
代码地址
https://github.com/SkyworkAI/MoH
基于MOE的ViT注意力代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Attention(nn.Module):LOAD_BALANCING_LOSSES = []def __init__(self, dim, input_resolution, num_heads=8, qkv_bias=True, attn_drop=0.,proj_drop=0., shared_head=0, routed_head=0):super().__init__()assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."self.dim = dimself.num_heads = num_headsself.head_dim = dim // num_headsself.temperature = nn.Parameter(torch.log((torch.ones(num_heads, 1, 1) / 0.24).exp() - 1)) # Initialize softplus(temperature) to 1/0.24.# Generate sequnce length scaleself.register_buffer("seq_length_scale", torch.as_tensor(np.log(input_resolution[0] * input_resolution[1])),persistent=False)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.query_embedding = nn.Parameter(nn.init.trunc_normal_(torch.empty(self.num_heads, 1, self.head_dim), mean=0, std=0.02))self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)# mlp to generate continuous relative position biasself.cpb_fc1 = nn.Linear(2, 512, bias=True)self.cpb_act = nn.ReLU(inplace=True)self.cpb_fc2 = nn.Linear(512, num_heads, bias=True)self.shared_head = shared_headself.routed_head = routed_headif self.routed_head > 0:self.wg = torch.nn.Linear(dim, num_heads - shared_head, bias=False)if self.shared_head > 0:self.wg_0 = torch.nn.Linear(dim, 2, bias=False)if self.shared_head > 1:self.wg_1 = torch.nn.Linear(dim, shared_head, bias=False)def forward(self, x, H, W, relative_pos_index, relative_coords_table):B, N, C = x.shape_x = x.reshape(B * N, C)if self.routed_head > 0:logits = self.wg(_x)gates = F.softmax(logits, dim=1)num_tokens, num_experts = gates.shape_, indices = torch.topk(gates, k=self.routed_head, dim=1)mask = F.one_hot(indices, num_classes=num_experts).sum(dim=1)if self.training:me = gates.mean(dim=0)ce = mask.float().mean(dim=0)l_aux = torch.mean(me * ce) * num_experts * num_expertsAttention.LOAD_BALANCING_LOSSES.append(l_aux)routed_head_gates = gates * maskdenom_s = torch.sum(routed_head_gates, dim=1, keepdim=True)denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)routed_head_gates /= denom_srouted_head_gates = routed_head_gates.reshape(B, N, -1) * self.routed_headqkv = self.qkv(x).reshape(B, -1, 3 * self.num_heads, self.head_dim).permute(0, 2, 1, 3)q, k, v = qkv.chunk(3, dim=1)# Use MLP to generate continuous relative positional biasrel_bias = self.cpb_fc2(self.cpb_act(self.cpb_fc1(relative_coords_table))).transpose(0, 1)[:,relative_pos_index.view(-1)].view(-1, N, N)# Calculate attention map using sequence length scaled cosine attention and query embeddingattn = ((F.normalize(q, dim=-1) + self.query_embedding) * F.softplus(self.temperature) * self.seq_length_scale) @ F.normalize(k, dim=-1).transpose(-2, -1) + rel_biasattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)if self.routed_head > 0:x = (attn @ v).transpose(1, 2) # B, N, head, dimif self.shared_head > 1:shared_head_weight = self.wg_1(_x)shared_head_gates = F.softmax(shared_head_weight, dim=1).reshape(B, N, -1) * self.shared_headelse:shared_head_gates = torch.ones((B, N, self.shared_head)).to(_x.device).to(_x.dtype) * self.shared_headif self.shared_head == 0:masked_gates = routed_head_gateselse:weight_0 = self.wg_0(_x)weight_0 = F.softmax(weight_0, dim=1).reshape(B, N, 2) * 2shared_head_gates = torch.einsum("bn,bne->bne", weight_0[:,:,0], shared_head_gates)routed_head_gates = torch.einsum("bn,bne->bne", weight_0[:,:,1], routed_head_gates)masked_gates = torch.cat([shared_head_gates, routed_head_gates], dim=2)x = torch.einsum("bne,bned->bned", masked_gates, x)x = x.reshape(B, N, C)else:shared_head_weight = self.wg_1(_x)masked_gates = F.softmax(shared_head_weight, dim=1).reshape(B, N, -1) * self.shared_headx = (attn @ v).transpose(1, 2) # B, N, head, dimx = torch.einsum("bne,bned->bned", masked_gates, x)x = x.reshape(B, N, C)x = self.proj(x)x = self.proj_drop(x)return x