Moonshot AI 新突破:MoBA 为大语言模型长文本处理提效论文速读

前言

在自然语言处理领域,随着大语言模型(LLMs)不断拓展其阅读、理解和生成文本的能力,如何高效处理长文本成为一项关键挑战。近日,Moonshot AI Research 联合清华大学、浙江大学的研究人员提出了一种创新方法 —— 混合块注意力机制(Mixture of Block Attention,MoBA),它将专家混合(Mixture of Experts,MoE)原理应用于注意力机制,为解决长文本处理难题带来了新的思路。

在 Transformer 架构广泛应用的当下,其注意力机制存在明显弊端。在处理长文本时,传统注意力机制需将每个 token 与其他所有 token 进行比较,这使得计算成本随序列长度呈二次方增长。当模型处理长篇文档、多章书籍、法律简报或大型代码库等包含大量文本信息的任务时,这种计算成本会变得难以承受。此前,为解决这一问题,研究人员尝试过多种方法。例如,滑动窗口机制将 token 限制在局部邻域内,虽降低了计算量,但会忽略重要的全局关系;而一些彻底改变基本架构的方法,如用全新结构替代 softmax 注意力机制,往往需要从头开始重新训练模型,难以利用现有的预训练成果。

核心原理

MoBA 的出现有效弥补了上述方法的不足。它的核心在于将输入划分为易于管理的 “块”,并借助可训练的门控系统来确定每个查询 token 相关的块。这种设计遵循 “少结构” 原则,不预先定义哪些 token 应该相互作用,而是由学习到的门控网络做出决策。与固定结构或近似处理的方法不同,MoBA 能让模型自主学习注意力的聚焦点。而且,MoBA 可与现有的基于 Transformer 的模型无缝协作,它作为一种 “插件” 或替代方案,保持与原模型相同的参数数量,避免架构膨胀,同时保留因果掩码,确保自回归生成的准确性。在实际应用中,MoBA 能在稀疏注意力和全注意力之间灵活切换。处理超长输入时,稀疏注意力可提升速度;而在训练的某些层或阶段,若需要全注意力,模型也能切换回标准模式。

从技术细节来看,MoBA 将上下文划分为多个块,每个块包含连续的 token 序列。门控机制通过比较查询 token 与块的池化键表示,计算查询 token 与每个块之间的 “亲和度” 分数,然后选择得分最高的块。这样,只有最相关块中的 token 才会对最终的注意力分布产生影响。同时,包含查询 token 本身的块始终被纳入,以确保局部上下文信息可访问。并且,MoBA 执行因果掩码,防止 token 关注未来位置,维持从左到右的自回归属性。这种基于块的方法大幅减少了 token 比较次数,使计算规模低于二次方,随着上下文长度增加到数十万甚至数百万个 token,效率提升愈发显著。此外,MoBA 与现代加速器和专用内核兼容性良好。研究人员将 MoBA 与 FlashAttention(一种高性能的快速、内存高效的精确注意力库)相结合,根据所选块对查询 - 键 - 值操作进行精心分组,进一步优化了计算流程。实验数据显示,在处理一百万个 token 时,MoBA 相比传统全注意力机制速度提升约 6 倍,凸显了其在实际应用中的优势。

在性能测试方面,MoBA 表现出色。技术报告显示,在多种任务中,MoBA 的性能与全注意力机制相当,但在处理长序列时可显著节省计算资源。在语言建模数据测试中,当序列长度为 8192 或 32768 个 token 时,MoBA 的困惑度与全注意力 Transformer 相近。更为关键的是,当研究人员将上下文长度逐渐扩展到 128000 及更长时,MoBA 仍能保持强大的长上下文理解能力。在 “尾随 token” 评估中,MoBA 能够有效处理长提示末尾附近的 token 预测任务,且预测质量没有明显下降。研究人员还对 MoBA 的块大小和门控策略进行了敏感性探索。实验表明,细化粒度(使用更小的块但选择更多的块)有助于模型更接近全注意力的效果。即使在忽略大部分上下文的情况下,自适应门控也能识别与查询真正相关的块。此外,“混合” 模式展现出一种平衡策略:部分层继续使用 MoBA 提升速度,少数层则恢复全注意力。这种混合方法在监督微调任务中尤为有益,例如当输入中的某些位置在训练目标中被屏蔽时,保留少数上层的全注意力,可使模型保持广泛的上下文覆盖,有助于需要全局视角的任务。

关键代码分析:

以下是对 MoBA 库关键代码 MixedAttention 类的分析以及关键代码的摘录与注释:

整体分析

MixedAttention 类是一个自定义的 torch.autograd.Function,用于实现混合块注意力机制。这个类主要包含两个静态方法:forward 和 backward,分别用于前向传播和反向传播。

class MixedAttention(torch.autograd.Function):# 前向传播函数@staticmethoddef forward(ctx,q,  # 查询张量k,  # 键张量v,  # 值张量self_attn_cu_seqlen,  # 自注意力累积序列长度moba_q,  # MoBA 查询张量moba_kv,  # MoBA 键值张量moba_cu_seqlen_q,  # MoBA 查询累积序列长度moba_cu_seqlen_kv,  # MoBA 键值累积序列长度max_seqlen,  # 最大序列长度moba_chunk_size,  # MoBA 块大小moba_q_sh_indices,  # MoBA 查询块索引):# 保存一些参数,用于后续的反向传播ctx.max_seqlen = max_seqlenctx.moba_chunk_size = moba_chunk_sizectx.softmax_scale = softmax_scale = q.shape[-1] ** (-0.5)# 自注意力计算_, _, _, _, self_attn_out_sh, self_attn_lse_hs, _, _ = (_flash_attn_varlen_forward(q=q,k=k,v=v,cu_seqlens_q=self_attn_cu_seqlen,cu_seqlens_k=self_attn_cu_seqlen,max_seqlen_q=max_seqlen,max_seqlen_k=max_seqlen,softmax_scale=softmax_scale,causal=True,dropout_p=0.0,))# MoBA 注意力计算_, _, _, _, moba_attn_out, moba_attn_lse_hs, _, _ = _flash_attn_varlen_forward(q=moba_q,k=moba_kv[:, 0],v=moba_kv[:, 1],cu_seqlens_q=moba_cu_seqlen_q,cu_seqlens_k=moba_cu_seqlen_kv,max_seqlen_q=max_seqlen,max_seqlen_k=moba_chunk_size,softmax_scale=softmax_scale,causal=False,dropout_p=0.0,)# 转换 lse 形状,从 hs 转换为 sh(遵循传统混合注意力逻辑)self_attn_lse_sh = self_attn_lse_hs.t().contiguous()moba_attn_lse = moba_attn_lse_hs.t().contiguous()# 初始化输出缓冲区,形状与 q 相同output = torch.zeros((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)# 将输出张量展平为二维,便于后续索引操作output_2d = output.view(-1, q.shape[2])# 计算混合 lse# 减去最大 lse 以避免指数爆炸max_lse_1d = self_attn_lse_sh.view(-1)max_lse_1d = max_lse_1d.index_reduce(0, moba_q_sh_indices, moba_attn_lse.view(-1), "amax")self_attn_lse_sh = self_attn_lse_sh - max_lse_1d.view_as(self_attn_lse_sh)moba_attn_lse = (moba_attn_lse.view(-1).sub(max_lse_1d.index_select(0, moba_q_sh_indices)).reshape_as(moba_attn_lse))# 计算自注意力和 MoBA 注意力的 softmax 结果mixed_attn_se_sh = self_attn_lse_sh.exp()moba_attn_se = moba_attn_lse.exp()# 将 MoBA 注意力结果累加到自注意力结果上mixed_attn_se_sh.view(-1).index_add_(0, moba_q_sh_indices, moba_attn_se.view(-1))mixed_attn_lse_sh = mixed_attn_se_sh.log()# 加权自注意力输出factor = (self_attn_lse_sh - mixed_attn_lse_sh).exp()  # [ vS, H ]self_attn_out_sh = self_attn_out_sh * factor.unsqueeze(-1)output_2d += self_attn_out_sh.reshape_as(output_2d)# 加权 MoBA 输出mixed_attn_lse = (mixed_attn_lse_sh.view(-1).index_select(0, moba_q_sh_indices).view_as(moba_attn_lse))factor = (moba_attn_lse - mixed_attn_lse).exp()  # [ vS, H ]moba_attn_out = moba_attn_out * factor.unsqueeze(-1)raw_attn_out = moba_attn_out.view(-1, moba_attn_out.shape[-1])output_2d.index_add_(0, moba_q_sh_indices, raw_attn_out)# 将输出转换为与输入相同的数据类型output = output.to(q.dtype)# 恢复最大 lsemixed_attn_lse_sh = mixed_attn_lse_sh + max_lse_1d.view_as(mixed_attn_se_sh)# 保存中间结果,用于反向传播ctx.save_for_backward(output,mixed_attn_lse_sh,q,k,v,self_attn_cu_seqlen,moba_q,moba_kv,moba_cu_seqlen_q,moba_cu_seqlen_kv,moba_q_sh_indices,)return output# 反向传播函数@staticmethoddef backward(ctx, d_output):# 从上下文中获取保存的参数max_seqlen = ctx.max_seqlenmoba_chunk_size = ctx.moba_chunk_sizesoftmax_scale = ctx.softmax_scale(output,mixed_attn_vlse_sh,q,k,v,self_attn_cu_seqlen,moba_q,moba_kv,moba_cu_seqlen_q,moba_cu_seqlen_kv,moba_q_sh_indices,) = ctx.saved_tensors# 确保输入梯度连续d_output = d_output.contiguous()# 计算自注意力的梯度dq, dk, dv, _ = _flash_attn_varlen_backward(dout=d_output,q=q,k=k,v=v,out=output,softmax_lse=mixed_attn_vlse_sh.t().contiguous(),dq=None,dk=None,dv=None,cu_seqlens_q=self_attn_cu_seqlen,cu_seqlens_k=self_attn_cu_seqlen,max_seqlen_q=max_seqlen,max_seqlen_k=max_seqlen,softmax_scale=softmax_scale,causal=True,dropout_p=0.0,window_size=(-1, -1),softcap=0.0,alibi_slopes=None,deterministic=True,)# 计算 MoBA 注意力的梯度headdim = q.shape[-1]d_moba_output = (d_output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1))moba_output = (output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1))mixed_attn_vlse = (mixed_attn_vlse_sh.view(-1).index_select(0, moba_q_sh_indices).view(1, -1))dmq, dmk, dmv, _ = _flash_attn_varlen_backward(dout=d_moba_output,q=moba_q,k=moba_kv[:, 0],v=moba_kv[:, 1],out=moba_output,softmax_lse=mixed_attn_vlse,dq=None,dk=None,dv=None,cu_seqlens_q=moba_cu_seqlen_q,cu_seqlens_k=moba_cu_seqlen_kv,max_seqlen_q=max_seqlen,max_seqlen_k=moba_chunk_size,softmax_scale=softmax_scale,causal=False,dropout_p=0.0,window_size=(-1, -1),softcap=0.0,alibi_slopes=None,deterministic=True,)# 合并 MoBA 的键和值的梯度dmkv = torch.stack((dmk, dmv), dim=1)return dq, dk, dv, None, dmq, dmkv, None, None, None, None, None

代码关键部分解释

  • 前向传播 (forward)

    • 分别计算自注意力和 MoBA 注意力的结果。
    • 对注意力分数进行处理,包括形状转换、归一化等操作,以避免指数爆炸。
    • 将自注意力和 MoBA 注意力的结果进行加权合并,得到最终的输出。
    • 保存中间结果,用于后续的反向传播。
  • 反向传播 (backward)

    • 根据前向传播保存的中间结果,计算自注意力和 MoBA 注意力的梯度。
    • 最终返回各个输入张量的梯度。

小结

通过这种方式,MixedAttention 类实现了 MoBA 混合块注意力机制,通过将上下文划分为块并进行选择性的注意力计算,有效减少了计算量,提升了处理长文本的效率。

总结

总体而言,MoBA 非常适合处理涉及大量上下文的任务,如长篇文档阅读理解、大规模代码补全以及需要完整对话历史的多轮对话系统。它在提高效率的同时,性能损失极小,为大规模训练大语言模型提供了一种极具吸引力的方法。虽然目前 MoBA 主要应用于文本领域,但研究人员认为,其底层机制在其他数据模态中也具有应用潜力。只要序列长度足够长,引发计算或内存问题,将查询分配给块 “专家” 的思路就有望缓解瓶颈,同时保持处理关键全局依赖关系的能力。随着语言应用中的序列长度持续增长,像 MoBA 这样的方法可能会在推动神经语言建模的可扩展性和成本效益方面发挥关键作用,为人工智能的发展注入新的活力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/22329.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

cs224w课程学习笔记-第2课

cs224w课程学习笔记-第2课 传统图学习 前言一、节点任务1、任务背景2、特征节点度3、特征节点中心性3.1 特征向量中心性(Eigenvector Centrality)3.2 中介中心性(Betweenness Centrality)3.3 接近中心性(Closeness Cen…

Centos虚拟机扩展磁盘空间

Centos虚拟机扩展磁盘空间 扩展前后效果1 虚拟机vmware关机后,编辑2 扩展2.1 查看2.2 新建分区2.3 格式化新建分区ext42.3.1 格式化2.3.2 创建2.3.3 修改2.3.4 查看 2.4 扩容2.4.1 扩容2.4.1 查看 扩展前后效果 df -h1 虚拟机vmware关机后,编辑 2 扩展 …

1.13作业

1 if(!preg_match("/[0-9]|\~|\|\|\#|\\$|\%|\^|\&|\*|\&#xff08;|\&#xff09;|\-|\|\|\{|\[|\]|\}|\:|\|\"|\,|\<|\.|\>|\/|\?|\\\\/i", $c)){eval($c); 构造数组rce ?ceval(array_pop(next(get_defined_vars()))); post传参:asystem("c…

如何在 SpringBoot 项目使用 Redis 的 Pipeline 功能

本文是博主在批量存储聊天中用户状态和登陆信息到 Redis 缓存中时&#xff0c;使用到了 Pipeline 功能&#xff0c;并对此做出了整理。 一、Redis Pipeline 是什么 Redis 的 Pipeline 功能可以显著提升 Redis 操作的性能&#xff0c;性能提升的原因在于可以批量执行命令。当我…

力扣LeetCode: 2209 用地毯覆盖后的最少白色砖块

题目&#xff1a; 给你一个下标从 0 开始的 二进制 字符串 floor &#xff0c;它表示地板上砖块的颜色。 floor[i] 0 表示地板上第 i 块砖块的颜色是 黑色 。floor[i] 1 表示地板上第 i 块砖块的颜色是 白色 。 同时给你 numCarpets 和 carpetLen 。你有 numCarpets 条 黑…

RabbitMQ 消息队列

1. 消息队列是什么&#xff1f; 当用户注册成功后&#xff0c;就发送邮件。当邮件发送成功了&#xff0c;接口才会提示注册成功信息。但由于发送邮件&#xff0c;依赖于其他厂商的服务&#xff0c;有可能他们的接口会非常耗时。那么用户就一直要等着邮件发送成功了&#xff0c;…

【SQL实验】触发器

下载素材文件”tsgl”、“成绩管理”,将tsgl.bak和成绩管理.bak数据库还原到库中【导入操作在之前的文章中详细讲过】 触发器 1、为图书表设置更新触发器&#xff0c;根据总编号来更新书名、作者、出版社、分类号和单价(根据总编号找到相应记录&#xff0c;然后更新书名、作者…

Win10系统Docker+DeepSeek+ragflow搭建本地知识库

文章目录 1、安装ollama1.1 下载1.2 安装1.3 cmd命令行测试安装成功1.4 拉取模型2、安装ragflow2.1 下载项目2.2 通过docker拉取镜像安装2.3 查看docker日志是否安装成功3、模型配置3.1 第一次登录需要注册3.2 模型添加4、知识库配置4.1 创建知识库4.2 上传文档4.3 解析5、聊天…

redis的应用,缓存,分布式锁

1.应用 1.1可以用作缓存 作用&#xff1a;提交数据的查询效率&#xff0c;减少对数据库的访问频率 什么数据适合放入缓存 1.查询频率高&#xff0c;修改频率低 2.对安全系数比较低 如何实现 Service public class DeptServer {Autowiredprivate DeptMapper deptMapper;Auto…

springboot整合 xxl-job

文章目录 一、xxl-job是什么二、使用步骤 1. 下载并运行管理端代码2. 访问管理页面&#xff0c;确认是否启动成功3. 配置执行器【在自己的springboot项目中配置】4. 在页面上创建执行器和任务&#xff0c;与项目中绑定 总结参考 一、xxl-job是什么 XXL-JOB 是一个分布式任务调…

Jenkins 环境搭建---基于 Docker

前期准备 提前安装jdk、maven、nodeJs&#xff08;如果需要的话&#xff09; 创建 jenkins 环境目录&#xff0c;用来当做挂载卷 /data/jenkins/ 一&#xff1a;拉取 Jenkins 镜像 docker pull jenkins/jenkins:lts 二&#xff1a;设置 Jenkins挂载目录 mkdir -p ~/jen…

小米路由器 AX3000T 降级后无法正常使用,解决办法

问题描述 买了个 AX3000T 路由器&#xff0c;想安装 OpenWRT 或者 安装 Clash 使用&#xff0c;看教程说是需要降级到 v1.0.47 版本。 结果刷机之后路由器无法打开了&#xff0c;一直黄灯亮&#xff0c;中间灭一下&#xff0c;又是黄灯长亮&#xff0c;没有 WIFI 没有连接。以…

金融学-金融机构

前言 金融机构在金融体系运行体系运营中起着不可获缺的关键作用.如规则的制定与监管-中央银行,体系的运营证券公司,体系的供贷的参与者金融中介.本章将用一种说明我们的金融体系是怎样改进经济效率的经济分析,来讲述相关金融机构 金融结构的经济学分析 世界各国的金融体系在…

公网远程家里局域网电脑过程详细记录,包含设置路由器。

由于从校内迁居小区,校内需要远程控制访问小区内个人电脑,于是早些时间刚好自己是电信宽带,可以申请公网ipv4不需要花钱,所以就打电话直接申请即可,申请成功后访问光猫设备管理界面192.168.1.1,输入用户名密码登录超管(密码是网上查下就有了)设置了光猫为桥接模式,然后…

002 SpringCloudAlibaba整合 - Feign远程调用、Loadbalancer负载均衡

前文地址&#xff1a; 001 SpringCloudAlibaba整合 - Nacos注册配置中心、Sentinel流控、Zipkin链路追踪、Admin监控 文章目录 8.Feign远程调用、loadbalancer负载均衡整合1.OpenFeign整合1.引入依赖2.启动类添加EnableFeignClients注解3.yml配置4.日志配置5.远程调用测试6.服务…

基于javaweb的SpringBoot校园二手商品系统设计和实现(源码+文档+部署讲解)

技术范围&#xff1a;SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、小程序、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容&#xff1a;免费功能设计、开题报告、任务书、中期检查PPT、系统功能实现、代码编写、论文编写和辅导、论…

国产开源PDF解析工具MinerU

前言 PDF的数据解析是一件较困难的事情&#xff0c;几乎所有商家都把PDF转WORD功能做成付费产品。 PDF是基于PostScript子集渲染的&#xff0c;PostScript是一门图灵完备的语言。而WORD需要的渲染&#xff0c;本质上是PDF能力的子集。大模型领域&#xff0c;我们的目标文件格…

stm32单片机个人学习笔记16(SPI通信协议)

前言 本篇文章属于stm32单片机&#xff08;以下简称单片机&#xff09;的学习笔记&#xff0c;来源于B站教学视频。下面是这位up主的视频链接。本文为个人学习笔记&#xff0c;只能做参考&#xff0c;细节方面建议观看视频&#xff0c;肯定受益匪浅。 STM32入门教程-2023版 细…

Springboot + Ollama + IDEA + DeepSeek 搭建本地deepseek简单调用示例

1. 版本说明 springboot 版本 3.3.8 Java 版本 17 spring-ai 版本 1.0.0-M5 deepseek 模型 deepseek-r1:7b 需要注意一下Ollama的使用版本&#xff1a; 2. springboot项目搭建 可以集成在自己的项目里&#xff0c;也可以到 spring.io 生成一个项目 生成的话&#xff0c;如下…

Ubuntu 的RabbitMQ安装

目录 1.安装Erlang 查看erlang版本 退出命令 2. 安装 RabbitMQ 3.确认安装结果 4.安装RabbitMQ管理界面 5.启动服务并访问 1.启动服务 2.查看服务状态 3.通过IP:port 访问界面 4.添加管理员用户 a&#xff09;添加用户名&#xff1a;admin&#xff0c;密码&#xff1…