【PyTorch][chapter-35][MLA]

前言:

      MLA(Multi-head Latent Attention,多头潜在注意力)旨在提高推理效率和降低计算资源的消。MLA的核心思想在于通过信息转移来优化KV缓存的使用

      MLA的技术特点主要包括:

  1. KV压缩与潜在变量:将键(Key)和值(Value)联合压缩为低维潜在向量,显著减少推理时的KV缓存,降低内存占用。计算时通过升维恢复原始信息,平衡压缩效率与计算精度。
  2. 低秩降维技术:对查询(Queries)进行低秩压缩(降维后再升维),减少训练中的激活内存(activation memory),但需注意此操作不影响KV缓存。
  3. 动态序列处理:针对可变长度输入序列优化,支持高效处理不同长度的句子(如长文本对话场景 ROPE)。

目录

  1.     KV-cache
  2.     MLA 模型简介
  3.     MLA+ROPE 
  4.     MLA 数学原理
  5.     PyTorh 代码

   


一   KV-cache

       1.1 MHA (多头注意力)

      

       

1.2 KV-cache

      

       在自回归生成过程中,每个新生成的token都会依赖于之前所有token的信息,这就需要在生成每个新token时重新计算整个序列的自注意力。然而,这种计算方式非常低效,因为大量重复的计算被浪费在了已经生成过的token上。

      为了缩短inference time, KV-Cache机制正是为了解决这一问题而提出的。它的工作原理是在生成过程中,将已经计算过的键和值向量存储在缓存中,这样在生成后续token时,可以直接从缓存中获取之前token的键和值,而不需要重新计算。具体来说,当生成一个新的token时,模型只需要计算这个新token的查询向量,并与缓存中的键向量计算注意力得分,然后使用这些得分和缓存中的值向量来计算新token的输出表示.

KV-Cache 的大小取决于以下参数:

  • n_h:  注意力头数,每层的注意力头数量。

  • d_h: 每个注意力头的维度,每个注意力头的 Key 和 Value 的维度。

  • l: 输入的层数模

      则每个token 对应的 KV-cache 为  2n_hd_h l

   不同注意力机制对应的kv-cache

  


二   MLA(Multi-Layer Adaptation)

       多头潜在注意力 (MLA) 是一种新的注意力机制,它通过将键和值压缩为一个较小的共享表示(称为潜在向量)来实现这一点。这可以减小 KV 缓存的大小,同时保持甚至提高性能。

MLA 引入了两项关键创新:

  1. Low-Rank Key-Value Compression
  2. Decoupled Rotary Position Embedding (RoPE)

2.1 MLA 架构

  2.2 计算流程

参考:

MLA reduces the KV cache size by compressing the keys and values into a smaller latent vector and decoupling the position information (RoPE). Here’s how the cache size is calculated.


三 Decoupled Rotary Position Embedding (RoPE)

       旋转位置编码(Rotary Position Embedding, RoPE)是一种用于编码序列中标记位置的技术。然而,RoPE是位置敏感的,这意味着它依赖于每个标记的具体位置。这在使用低秩压缩时会产生问题,因为位置信息会被混入压缩后的键(keys)和值(values)中,导致在推理过程中难以高效地重用它们。为了解决ROPE问题,使用了下面架构

   参考:

KV-cache 的大小(包括了ROPE 部分)


四  PyTorch  代码

 常用超参数

# -*- coding: utf-8 -*-
"""
Created on Sat Mar 15 18:24:47 2025@author: cxf
"""# -*- coding: utf-8 -*-
"""
Created on Thu Mar 13 13:51:48 2025@author: chengxf2
"""import torch
import torch.nn  as nn
import torch.nn.functional as F
import mathclass Config:def __init__(self):self.vocab_size = 32000#词向量的维度self.d_model = 1024#number of attention heads self.n_heads = 8#dDmension of per head =64self.d_head = self.d_model//self.n_heads#ROPE dimension, typically 128self.d_rope =  self.d_head//2#compression dimension KV_cache <<n_head*d_hself.d_kv_cache = 4*self.d_head self.seq_len = 10self.batch_size = 1#256class RotaryEmbedding(nn.Module):def __init__(self, dim):super().__init__()#Dimension must be even for Rotary Embeddingassert dim % 2 == 0, "Dimension must be even for rotary embeddings"self.dim = dim//2inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim, 2).float() / self.dim))self.register_buffer("inv_freq", inv_freq)def forward(self, seq_len):t = torch.arange(seq_len)freqs = torch.einsum("i,j->ij",t, self.inv_freq)output = torch.cat((freqs, freqs), dim=-1)return outputdef rotate_half(x):"""Apply rotary embeddings to the first half of x."""x1 ,x2 = x.chunk(2,dim=-1)output = torch.cat((-x2,x1),dim=-1)return outputdef apply_rotary(x, cos, sin):"""Apply rotary embeddings to the first half of x."""#x.shape batch_size, seq_len, head, d_h# Split x into two parts: one for rotary embeddings and the other untouched    x_rot, x_base = x.split(cos.shape[-1],dim=-1)print("\n apply _rotary ",x.shape)print("\n cos x ",cos.shape, x.shape)x_rot, x_base = x.split(cos.shape[-1],dim=-1)x_rot =(x_rot*cos)+(rotate_half(x_rot)*sin)output = torch.cat([x_rot,x_base],dim=-1)return outputconfig = Config()
class MemoryOptimizedMLA(nn.Module):def __init__(self):super().__init__()self.d_head = config.d_headself.d_split = config.d_model-config.d_rope#down-projectionself.W_DQ =  nn.Linear(config.d_model,  config.d_kv_cache)self.W_DKV = nn.Linear(config.d_model,  config.d_kv_cache)print("\n kv cache size ",config.d_kv_cache)# RoPEself.W_q_rope = nn.Linear(config.d_kv_cache,  config.d_rope)self.W_k_rope = nn.Linear(config.d_model,     config.d_rope)#step2:  Up Projectionsself.W_UQ = nn.Linear(config.d_kv_cache, self.d_split)self.W_UK = nn.Linear(config.d_kv_cache, self.d_split)self.W_UV = nn.Linear(config.d_kv_cache, config.d_model)  #rotary Embeddingself.rotary = RotaryEmbedding(config.d_rope//config.n_heads)#step3 outputself.output = nn.Linear(config.d_model, config.d_model)def forward(self, x):batch_size, seq_len, d_model = x.shapeprint("\n bat_size %d seq_len: %d d_model: %d "%(batch_size, seq_len, d_model))#step1: down-projection Compressionprint("\n step1 : down projection")#query compressionq_c      =  self.W_DQ(x)kv_cache =  self.W_DKV(x)#print("\n kv-cache",kv_cache.shape,"\t q_c",q_c.shape)#Apply RoPEprint("\n step2 : apply ROPE ")rotary_emb = self.rotary(seq_len)cos = torch.cos(rotary_emb).view(1, seq_len, 1, -1)  sin = torch.sin(rotary_emb).view(1, seq_len, 1, -1)q_rot =  self.W_q_rope(q_c)q_rot = q_rot.view(batch_size, seq_len, config.n_heads, -1)q_rot = apply_rotary(q_rot, cos, sin)k_rot_cache =   self.W_k_rope(x)k_rot_cache =   k_rot_cache.view(batch_size, seq_len, config.n_heads,-1)k_rot_cache =   apply_rotary(k_rot_cache,cos, sin)#up-projectionprint("\n step3 : up projection ")q_base = self.W_UQ(q_c).view(batch_size, seq_len, config.n_heads, -1)k = self.W_UK(kv_cache).view(batch_size, seq_len, config.n_heads, -1)v = self.W_UV(kv_cache).view(batch_size, seq_len, config.n_heads, -1)# concateq = torch.cat([q_base, q_rot], dim=-1)k = torch.cat([k, k_rot_cache], dim=-1)# Attention computationscores = torch.einsum("bqhd,bkhd->bhqk", q, k) / math.sqrt(self.d_head)attn = F.softmax(scores, dim=-1)out = torch.einsum("bhqk,bkhd->bqhd", attn, v)out = self.output(out.contiguous().view(batch_size, seq_len, -1))output =  out, (kv_cache, k_rot_cache)print("\n output ",out.shape)return outputnet= MemoryOptimizedMLA()
x  = torch.randn((config.batch_size, config.seq_len, config.d_model))
out = net(x)

     

https://medium.com/@shaiknagurshareef/multi-head-latent-attention-mla-secret-behind-the-success-of-deepseek-large-language-models-66612071d756

DeepSeek's Multi-Head Latent Attention - Lior Sinai

https://www.youtube.com/watch?v=s9R5s4U1WH8

https://medium.com/@atulit23/implementing-multi-head-latent-attention-from-scratch-in-python-1e14d03fbc91

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/35060.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Cloud 中的服务注册与发现: Eureka详解

1. 背景 1.1 问题描述 我们如果通过 RestTamplate 进行远程调用时&#xff0c;URL 是写死的&#xff0c;例如&#xff1a; String url "http://127.0.0.1:9090/product/" orderInfo.getProductId(); 当机器更换或者新增机器时&#xff0c;这个 URL 就需要相应地变…

微服务存在的问题及解决方案

微服务存在的问题及解决方案 1. 存在问题 1.1 接口拖慢 因为一个接口在并发时&#xff0c;正好执行时长又比较长&#xff0c;那么当前这个接口占用过多的 Tomcat 连接&#xff0c;导致其他接口无法即时获取到 Tomcat 连接来完成请求&#xff0c;导致接口拖慢&#xff0c;甚至…

centos 安装pip时报错 Cannot find a valid baseurl for repo: centos-sclo-rh/x86_64

centos 安装pip时报错 [rootindex-es app-ai]# yum update Loaded plugins: fastestmirror Repository centos-sclo-rh is listed more than once in the configuration Determining fastest mirrors Could not retrieve mirrorlist http://mirrorlist.centos.org?archx86_64…

解决图片转 ICO 图标难题,支持批量处理

还在为图片转 ICO 图标发愁吗&#xff1f;别担心&#xff0c;今天为大家带来一款超实用的工具 ——Any to Icon。它功能强大&#xff0c;可实现批量图片转 ICO 图标&#xff0c;轻松解决格式转换难题。更棒的是&#xff0c;这款工具极为小巧&#xff0c;无需安装&#xff0c;即…

MultiPost--多平台博客发布工具

网站介绍 一键发布内容到多个社交平台的浏览器插件&#xff0c;支持知乎、微博、小红书、抖音等主流平台&#xff0c;支持文字、图片、视频等内容形式. 地址 GitHub &#xff1a; https://github.com/leaper-one/MultiPost-Extension Chorme: https://chromewebstore.google.…

Linux进程状态详解:僵尸进程与孤儿进程的深度探索与实践

文章目录 前言一、进程状态概述1.1 运行状态1.2 阻塞状态1.3 挂起状态 二、具体的Linux操作系统中的进程状态2.1 Linux内核源代码2.2 查看进程状态2.3 D磁盘休眠状态(Disk sleep)D状态的定义&#xff1a; 2.4 T停止状态(stopped)停止状态的概述&#xff1a;停止状态的触发条件&…

【Linux】深入理解进程和文件及内存管理

个人主页~ 深入理解进程和文件及内存管理 一、重谈Linux下一切皆文件二、操作系统对物理内存的管理1、物理内存与磁盘的数据交互2、操作系统对物理内存的管理 三、文件页缓冲区向文件写入数据的过程 四、动态库是如何被加载的关于动态库中的全局变量 五、深入理解地址1、程序地…

★9.4.2 context2D 绘图

返回目录&#xff1a; Qt QML专栏目录结构_qml 项目 目录-CSDN博客 ★9.4.2 context2D 绘图 Object <- context 属性 canvas : QtQuick::Canvas fillRule : enumeration fillStyle : variant fillStyle: 设置或获取当前填充颜色或样式。 font : string g…

汇编基础知识

CPU&#xff1a;一种可以执行机器指令进行运算的芯片&#xff08;微处理器&#xff09;。 存储器&#xff08;内存&#xff09;&#xff1a;存放CPU可以工作的指令和数据&#xff08;指令和数据都是二进制信息&#xff09;。 磁盘不同于内存&#xff0c;磁盘中的数据要读到内…

1536数字三角形

1536数字三角形 ⭐️难度&#xff1a;中等 &#x1f31f;考点&#xff1a;动态规划 &#x1f4d6; &#x1f4da; import java.util.Arrays; import java.util.LinkedList; import java.util.Queue; import java.util.Scanner;public class Main {public static void main(…

基于VMware的虚拟机集群搭建

本文作者&#xff1a; slience_me 文章目录 基于VMware的虚拟机集群搭建1. 安装Vmware2. 构建虚拟机3. 安装Linux4. 网络配置5. 开始克隆6. 初始化系统6.1 开放root账户6.2 SSH服务6.3 设置静态IP6.4 镜像源 host 主机名 基于VMware的虚拟机集群搭建 该集群采用镜像ubuntu-20.0…

windows平台搭建python环境

python语言 Python 是一种高级、解释型、跨平台的编程语言&#xff0c;由Guido van Rossum于1991年设计&#xff0c;并发展成为全球最受欢迎的编程语言之一。它以简单易读的语法、灵活的特性和丰富的标准库闻名&#xff0c;适合初学者和经验丰富的开发者。 Python 支持多种编…

【系统架构设计师】操作系统 - 文件管理 ② ( 位示图 | 空闲区域 管理 | 位号 | 字号 )

文章目录 一、空闲区域 管理1、空闲区域分配2、空闲区域 管理方式 简介 二、位示图 简介1、位示图 表示2、位示图 字号3、位示图 位号4、位示图 中 比特位 分组管理 三、位示图 考点1、计算磁盘 位示图 的大小2、位示图 位置计算 一、空闲区域 管理 1、空闲区域分配 在 索引文件…

SpringData Redis:RedisTemplate配置与数据操作

文章目录 引言一、Redis概述与环境准备二、RedisTemplate基础配置三、连接属性配置四、操作String类型数据五、操作Hash类型数据六、操作List类型数据七、操作Set类型数据八、操作ZSet类型数据九、事务与管道操作总结 引言 Redis作为高性能的NoSQL数据库&#xff0c;在分布式系…

串口烧录出现频繁回复乱码 频繁回复一个数字且烧录失败 字节混乱

这是因为你的芯片没有处于系统存储区启动一直未进入bootloader 解决办法是检查boot引脚接正确没&#xff0c;要在系统存储器启动

共享经济再中介化进程中的技术创新与模式重构研究——以“开源AI智能名片链动2+1模式S2B2C商城小程序“为例

摘要 本文基于共享经济中介化演进的双重逻辑&#xff0c;通过案例研究与技术解构&#xff0c;探讨"开源AI智能名片链动21分销机制S2B2C商城小程序"集成系统如何重构数字经济时代的价值网络。研究发现&#xff0c;该技术生态通过三维需求匹配、动态价值分配与智能风险…

【linux】虚拟机执行sudo yum isntall perl报错 could not retrieve mirrorlist htt:

项目场景&#xff1a; 提示&#xff1a;虚拟机安装拓展包&#xff0c;sudo yum install perl Virtualbox 在不安装增强功能扩展的情况下, 无法自适应分辨率和共享剪切板等操作 问题描述 原因分析&#xff1a; 提示&#xff1a;这里填写问题的分析&#xff1a; 出现这个错误是因…

网络编程知识预备阶段

1. OSI七层模型 OSI&#xff08;Open System Interconnect&#xff09;七层模型是一种将计算机网络通信协议划分为七个不同层次的标准化框架。每一层都负责不同的功能&#xff0c;从物理连接到应用程序的处理。这种模型有助于不同的系统之间进行通信时&#xff0c;更好地理解和…

我的Gitee

算法与数据结构: 浙海大小趴菜的一些记录 后续也会更新一些项目&#xff0c;小趴菜以后也会变得很厉害

Collection合集(单列集合)

Collection代表单列集合&#xff0c;每个元素&#xff08;数据&#xff09;只包含一个值。Collection实际上是一个泛型接口 Collection集合常用API&#xff1a; 代码实现&#xff1a; Collection集合遍历 遍历方式一&#xff1a;迭代器 迭代器是用来遍历集合的专用方式&#…