MOH: MULTI-HEAD ATTENTION AS MIXTURE-OFHEAD ATTENTION

当前的问题

多头注意力使用多个头部可以提高模型的精度。然而,并不是所有的注意力头都具有同样的重要性。一些研究表明,许多注意力头可以被修剪而不影响准确性。

此外,在多头注意中,每个注意头并行操作,最终输出是所有注意头的总和。鉴于这些注意头独立运作,有些可能是多余的。

动机

建立一个动态的注意头路由机制。这种机制可以使每个标记自适应地选择适当的注意头,在不影响准确性的情况下提高推理效率。

方法

图1:多头注意和我们提出的头部混合注意之间的高层次比较。子图(a)展示了具有 h h h个注意头的标准多头注意层,而子图(b)展示了头部混合注意(MoH)(me:包含了共享注意力和混合注意力)架构。值得注意的是,MoH不会增加注意头的数量,从而确保MoH的总参数与多头注意的总参数相当。

知识回顾:多头注意力

注意力机制

其中 X = X ′ X=X' X=X是为自注意力, X ≠ X ′ X\ne X' X=X为交叉注意力。

交叉注意力

混合多头注意力(MIXTURE-OF-HEAD ATTENTION)

把注意力头当作专家

受MoE的巨大成功启发,我们提出了头部混合注意(MoH),它将注意头视为专家。具体来说,MoH由 h h h个头组成 H = { H 1 , H 2 , … , H h } H=\{H^1,H^2,\ldots,H^h\} H={H1,H2,,Hh}和激活 Top-K \text{Top-K} Top-K头的路由器。形式上,给定输入令牌 X X X X ′ X' X, MoH的输出是 K K K个选定的正面输出的加权和

其中 g i g_i gi表示路由得分。只有当第 i i i个注意头被激活时, g i g_i gi才不为零。

共享注意力

在注意机制中,一些注意头可能在不同的语境中捕捉到共同的知识,如语言中的语法规则。受Dai等人(2024)的启发,我们将一个头像子集指定为始终保持激活状态的共享头像。通过在共享头部内整合公共知识,我们减少了其他动态路由头部之间的冗余

路由得分g的定义

其中, h s h_s hs表示共有正面的个数 W s ∈ R h s × d i n \bm W_s\in \mathbb{R}^{h_s\times d_{in}} WsRhs×din W r ∈ R ( h − h s ) × d i n \bm W_r\in \mathbb{R}^{(h-h_s)\times d_{in}} WrR(hhs)×din分别表示共享头和路由头的投影矩阵。系数 α 1 \alpha_1 α1 α 2 \alpha_2 α2平衡了共享头和路由头的贡献,定义为:

其中, W h ∈ R 2 × d i n \bm W_h\in \mathbb{R}^{2\times d_{in}} WhR2×din为可训练投影矩阵, d i n d_{in} din x t \bm x_t xt的隐藏大小。

负载平衡损失(使专家得到充分训练)

直接训练MoE层通常会导致大多数令牌被路由给少数专家,使剩余的专家没有得到充分的训练(Shazeer等人,2017)。为了避免拟议MoH中的不平衡负载,遵循先前的MoE方法(Lepikhin等人,2021;Wei等人,2024),我们应用负载平衡损失。具体来说,对于 X ∈ R T × d i n \bm{X}\in \mathbb{R}^{T\times d_{in}} XRT×din中的第 t t t个输入令牌 x t ∈ R d i n \bm{x}_t\in \mathbb{R}^{d_{in}} xtRdin,负载均衡损失 L b \mathcal{L}_b Lb表示为:

其中 T T T为令牌数量。 1 ( ∗ ) \mathbb{1}(*) 1()表示指示函数。

L t a s k \mathcal{L}_{task} Ltask指特定于任务的损失。

其中 β \beta β是减轻路由崩溃风险的权衡超参数。默认情况下,所有任务的负载均衡损失权重 β \beta β设置为0.01。

相关工作

多头注意力。Transformers(Vaswani et al ., 2017)在自然语言处理和计算机视觉方面都获得了极大的兴趣和成功。长期以来,变形金刚的成功归功于多头注意机制(Cordonnier et al, 2020)。多头注意机制由Vaswani等人(2017)提出,通过允许多个注意头在输入的不同低维投影上操作来增强注意层的表征能力。然后将这些头部的输出连接起来形成最终结果。或者,通过按行分解输出投影矩阵,多头注意力可以用求和形式表示。在求和形式中,每个头并行操作,最终输出是所有头的和。受此启发,我们提出了MoH,一种动态注意-头部路由机制,允许每个令牌自适应地选择适当的头部。

Mixture-of-Experts模型。混合专家(MoE)方法(Du et al, 2022;Lewis et al, 2021;Rajbhandari等人,2022;Roller等,2021;Zhou et al ., 2022;Jin等人,2024b)的引入是为了在不增加计算成本的情况下扩展深度神经网络的容量。在这种方法中,对于每个输入,只有一个被称为专家的参数子集被激活。Shazeer等人(2017)首先在LSTM层之间引入了MoE层。Switch Transformer (Fedus et al, 2022)通过每个令牌只选择Top-1专家进一步简化了门控机制。Gshard (Lepikhin et al, 2021)改进了Top-2专家路由策略。MoE强调有效的参数缩放,同时保持可管理的计算成本,而MoH侧重于在不增加参数数量的情况下减少冗余注意头的激活

参考资料

论文下载(arixv,15 Oct 2024)

https://arxiv.org/abs/2410.11842

代码地址

https://github.com/SkyworkAI/MoH

基于MOE的ViT注意力代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Attention(nn.Module):LOAD_BALANCING_LOSSES = []def __init__(self, dim, input_resolution, num_heads=8, qkv_bias=True, attn_drop=0.,proj_drop=0., shared_head=0, routed_head=0):super().__init__()assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."self.dim = dimself.num_heads = num_headsself.head_dim = dim // num_headsself.temperature = nn.Parameter(torch.log((torch.ones(num_heads, 1, 1) / 0.24).exp() - 1))  # Initialize softplus(temperature) to 1/0.24.# Generate sequnce length scaleself.register_buffer("seq_length_scale", torch.as_tensor(np.log(input_resolution[0] * input_resolution[1])),persistent=False)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.query_embedding = nn.Parameter(nn.init.trunc_normal_(torch.empty(self.num_heads, 1, self.head_dim), mean=0, std=0.02))self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)# mlp to generate continuous relative position biasself.cpb_fc1 = nn.Linear(2, 512, bias=True)self.cpb_act = nn.ReLU(inplace=True)self.cpb_fc2 = nn.Linear(512, num_heads, bias=True)self.shared_head = shared_headself.routed_head = routed_headif self.routed_head > 0:self.wg = torch.nn.Linear(dim, num_heads - shared_head, bias=False)if self.shared_head > 0:self.wg_0 = torch.nn.Linear(dim, 2, bias=False)if self.shared_head > 1:self.wg_1 = torch.nn.Linear(dim, shared_head, bias=False)def forward(self, x, H, W, relative_pos_index, relative_coords_table):B, N, C = x.shape_x = x.reshape(B * N, C)if self.routed_head > 0:logits = self.wg(_x)gates = F.softmax(logits, dim=1)num_tokens, num_experts = gates.shape_, indices = torch.topk(gates, k=self.routed_head, dim=1)mask = F.one_hot(indices, num_classes=num_experts).sum(dim=1)if self.training:me = gates.mean(dim=0)ce = mask.float().mean(dim=0)l_aux = torch.mean(me * ce) * num_experts * num_expertsAttention.LOAD_BALANCING_LOSSES.append(l_aux)routed_head_gates = gates * maskdenom_s = torch.sum(routed_head_gates, dim=1, keepdim=True)denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)routed_head_gates /= denom_srouted_head_gates = routed_head_gates.reshape(B, N, -1) * self.routed_headqkv = self.qkv(x).reshape(B, -1, 3 * self.num_heads, self.head_dim).permute(0, 2, 1, 3)q, k, v = qkv.chunk(3, dim=1)# Use MLP to generate continuous relative positional biasrel_bias = self.cpb_fc2(self.cpb_act(self.cpb_fc1(relative_coords_table))).transpose(0, 1)[:,relative_pos_index.view(-1)].view(-1, N, N)# Calculate attention map using sequence length scaled cosine attention and query embeddingattn = ((F.normalize(q, dim=-1) + self.query_embedding) * F.softplus(self.temperature) * self.seq_length_scale) @ F.normalize(k, dim=-1).transpose(-2, -1) + rel_biasattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)if self.routed_head > 0:x = (attn @ v).transpose(1, 2)  # B, N, head, dimif self.shared_head > 1:shared_head_weight = self.wg_1(_x)shared_head_gates = F.softmax(shared_head_weight, dim=1).reshape(B, N, -1) * self.shared_headelse:shared_head_gates = torch.ones((B, N, self.shared_head)).to(_x.device).to(_x.dtype) * self.shared_headif self.shared_head == 0:masked_gates = routed_head_gateselse:weight_0 = self.wg_0(_x)weight_0 = F.softmax(weight_0, dim=1).reshape(B, N, 2) * 2shared_head_gates = torch.einsum("bn,bne->bne", weight_0[:,:,0], shared_head_gates)routed_head_gates = torch.einsum("bn,bne->bne", weight_0[:,:,1], routed_head_gates)masked_gates = torch.cat([shared_head_gates, routed_head_gates], dim=2)x = torch.einsum("bne,bned->bned", masked_gates, x)x = x.reshape(B, N, C)else:shared_head_weight = self.wg_1(_x)masked_gates = F.softmax(shared_head_weight, dim=1).reshape(B, N, -1) * self.shared_headx = (attn @ v).transpose(1, 2)  # B, N, head, dimx = torch.einsum("bne,bned->bned", masked_gates, x)x = x.reshape(B, N, C)x = self.proj(x)x = self.proj_drop(x)return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/482483.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring boot之BeanDefinition介绍

在spring框架中IOC容器进行bean的创建和管理。Bean的创建是一个比较复杂的过程,它并不像我们创建对象一样只是直接new一下就行,虽然有些bean确实就是New一下。但在Spring中可以通过一些途径对bean进行增强扩展。在这个过程中,BeanDefinition作…

Ubuntu 服务器部署 Tomcat 并配置 SSL/TLS 证书

本文目录 准备登陆云服务器安装 Java下载 tomcat 包配置防火墙浏览器访问 Tomcat 默认页面以服务的形式运行 Tomcat创建 Tomcat 用户和组创建 systemd 服务文件启动 tomcat 服务 Tomcat webapps 文件目录部署一个静态网站tomcat 的配置文件 将域名解析到服务器Tomcat 配置 SSL/…

C++小问题

怎么分辨const修饰的是谁 是限定谁不能被改变的? 在C中,const关键字的用途和位置非常关键,它决定了谁不能被修改。const可以修饰变量、指针、引用等不同的对象,并且具体的作用取决于const的修饰位置。理解const的规则能够帮助我们…

PPT不能编辑,按钮都是灰色,怎么办?

PPT文件打开之后,发现无法编辑,再仔细查看发现工具栏中的功能按钮都是灰色的,无法使用,这是什么原因?该如何解决? 原因:无法编辑PPT文件,并且功能按钮都是灰色,这是因为…

相交链表和环形链表

(一)相交链表 相交链表 思路:先分别计算出A列表和B列表的长度,判断它们的尾节点是否相等,如果不相等就不相交,直接返回空。然后让两个列表中的长的列表先走它们的差距步,然后再一起走&#xff…

ARM架构下安装新版docker及docker-compose

一、常见CPU 架构: 二、环境信息 CPU架构操作系统配置HUAWEI Kunpeng 920 5220 aarch64openEuler 22.03 (LTS-SP3)64C128g15T 三、安装docker 3.1 二进制包下载 docker-ce 社区下载地址: wget https://mirrors.nju.edu.cn/docker-ce/linux/static/s…

LeetCode-315. Count of Smaller Numbers After Self

目录 题目描述 解题思路 【C】 【Java】 复杂度分析 LeetCode-315. Count of Smaller Numbers After Selfhttps://leetcode.com/problems/count-of-smaller-numbers-after-self/description/ 题目描述 Given an integer array nums, return an integer array counts whe…

【NLP 4、数学基础】

此去经年,应是良辰美景虚设 —— 24.11.28 一、线性代数 1.标量和向量 ① 标量 Scalar 一个标量就是一个单独的数 ② 向量 Vector 一个向量是一列数 可以把向量看作空间中的点,每个元素是不同坐标轴上的坐标 向量中有几个数,就叫作几维…

VideoBooth: Diffusion-based Video Generation with Image Prompts

VideoBooth: Diffusion-based Video Generation with Image Prompts 概括 文章提出了一个视频生成模型VideoBooth,输入一张图片和一个文本提示词,即可输出保持图片中物体且符合文本提示词要求的视频。 方法 粗-细两阶段设计:1)…

Graphy 是一款终极、易于使用、功能齐全的 FPS 计数器、统计监视器和调试器,适用于您的 Unity 项目。

主要特点: Graph & Text: 图文: FPSMemory 记忆Audio 声音的Advanced device information 高级设备信息Debugging tools 调试工具 GitHub - Tayx94/graphy:Graphy 是适用于 Unity 项目的终极、易于使用、功能丰富的 FPS 计数器、统计监视器和调试…

ASP.NET Core 负载/压力测试

文章目录 一、第三方工具二、使用发布版本进行负载测试和压力测试 负载测试和压力测试对于确保 web 应用的性能和可缩放性非常重要。 尽管负载测试和压力测试的某些测试相似,但它们的目标不同。 负载测试:测试应用是否可以在特定情况下处理指定的用户负…

008静态路由-特定主机路由

按照如上配置,用192.168.0.1 电脑ping 192.168.1.1 发现能够ping通 用192.168.0.1 电脑ping 192.168.2.1 发现不能ping通 这是因为192.168.0.1 和 192.168.1.1 使用的是同一个路由器R1。 192.168.0.1 和 192.168.2.1 通信需要先经过R1,再经过R2 &#xf…

基于yolov4深度学习网络的排队人数统计系统matlab仿真,带GUI界面

目录 1.算法仿真效果 2.算法涉及理论知识概要 3.MATLAB核心程序 4.完整算法代码文件获得 1.算法仿真效果 matlab2022a仿真结果如下(完整代码运行后无水印): 仿真操作步骤可参考程序配套的操作视频。 2.算法涉及理论知识概要 在现代社会…

周鸿祎再次“创业”,盯上百度

周鸿祎特地拍了部短剧来推广的新产品,终于上线了。 11月27日晚间,360正式发布多模态内容创作引擎“纳米搜索”。 作为当前AI应用最红的赛道之一,AI搜索已经有腾讯、秘塔、商汤、抖音等公司入局。传统搜索老大百度也在发力。竞争不妨碍有搜索…

pytorch中一个tensor经过多次softmax会有什么变化?

在 PyTorch 中,一个 Tensor 经过多次 softmax 操作时,其值会逐渐趋向于某种分布,但并不会无限变化。以下是具体的行为与原因分析: 1. Softmax 的作用: Softmax 将输入张量的值转换为一个概率分布,满足以下…

汽车轮毂结构分析有哪些?国产3D仿真分析实现静力学+模态分析

本文为CAD芯智库原创,未经允许请勿复制、转载! 之前分享了如何通过国产三维CAD软件如何实现「汽车/汽配行业产品设计」,兼容NX(UG)、Creo(Proe),轻松降低企业上下游图纸交互成本等。…

深度学习中的生成对抗网络(GAN)原理与应用

引言 生成对抗网络(Generative Adversarial Network,简称GAN)是由Ian Goodfellow等人在2014年提出的一种深度学习模型,它通过对抗训练的方式生成与真实数据分布相似的假数据。GAN的出现极大地推动了深度学习和生成模型的研究&…

前端学习笔记之FileReader

概念 FileReader接口允许网页应用程序异步读取用户计算机上存储的文件&#xff08;或原始数据缓冲区&#xff09;的内容&#xff0c;使用File或Blob对象来制定要读取的文件或数据。 File对象可以通过用户使用<input>元素选择文件后返回的FileList对象获得&#xff0c;或…

Unity类银河战士恶魔城学习总结(P149 Screen Fade淡入淡出菜单)

【Unity教程】从0编程制作类银河恶魔城游戏_哔哩哔哩_bilibili 教程源地址&#xff1a;https://www.udemy.com/course/2d-rpg-alexdev/ 本章节实现了进入游戏和死亡之后的淡入淡出动画效果 UI_FadeScreen.cs 1. Animator 组件的引用 (anim) 该脚本通过 Animator 控制 UI 元…

win10系统部署RAGFLOW+Ollama教程

本篇主要基于linux服务器部署ragflowollama&#xff0c;其他操作系统稍有差异但是大体一样。 一、先决条件 CPU ≥ 4核&#xff1b; RAM ≥ 16 GB&#xff1b; 磁盘 ≥ 50 GB&#xff1b; Docker ≥ 24.0.0 & Docker Compose ≥ v2.26.1。 如果尚未在本地计算机&#xff…