大模型推理——MLA实现方案

1.整体流程

先上一张图来整体理解下MLA的计算过程

2.实现代码

import math
import torch
import torch.nn as nn# rms归一化
class RMSNorm(nn.Module):""""""def __init__(self, hidden_size, eps=1e-6):super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = epsdef forward(self, hidden_states):hidden_states = hidden_states.float()variance = hidden_states.pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return self.weight * hidden_states.float()def rotate_half(x):x1, x2 = x.chunk(2, dim=-1)return torch.cat((-x2, x1), dim=-1)def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):cos = cos.unsqueeze(unsqueeze_dim)sin = sin.unsqueeze(unsqueeze_dim)q_embed = (q * cos) + (rotate_half(q) * sin)k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embed# 旋转位置编码
class RotaryEmbedding(nn.Module):def __init__(self, dim, max_seq_len=1024):super(RotaryEmbedding, self).__init__()self.dim = dimself.max_seq_len = max_seq_leninv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))t = torch.arange(max_seq_len).float().unsqueeze(1)freqs = t @ inv_freq.unsqueeze(0)freqs = torch.cat((freqs, freqs), dim=-1)self.register_buffer("cos_cached", freqs.cos())self.register_buffer("sin_cached", freqs.sin())def forward(self, q, k):cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)return apply_rotate_pos_emb(q, k, cos, sin)class MLA(nn.Module):def __init__(self,dim,n_heads,q_lora_rank,kv_lora_rank,qk_nope_head_dim,qk_rope_head_dim,v_head_dim,max_seq_len,max_batch_size,mode):super().__init__()self.dim = dim  # 隐藏层维度self.n_heads = n_heads  # 总头数self.q_lora_rank = q_lora_rank  # q低秩压缩到的维度self.kv_lora_rank = kv_lora_rank  # k/v低秩压缩到的维度self.qk_nope_head_dim = qk_nope_head_dim    # q/k不带旋转位置编码的维度self.qk_rope_head_dim = qk_rope_head_dim    # q/k带旋转位置编码的维度self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim  # q/k的总维度,不带旋转位置编码的维度加上带旋转位置编码的维度self.v_head_dim = v_head_dim  # value的维度,等于不带旋转位置编码的k维度self.mode = modeself.max_seq_len = max_seq_lenself.max_batch_size = max_batch_sizeself.wq_a = nn.Linear(self.dim, self.q_lora_rank)  # q的降维矩阵self.q_norm = RMSNorm(self.q_lora_rank)self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim)  # q的升维矩阵# 4096*128+128*4864 = 524,288 + 622592 = 1146880    4096*4864 = 19,922,944self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)  # k/v的降维矩阵# nn.Linear(self.dim, self.kv_lora_rank)# nn.Linear(self.dim, self.qk_rope_head_dim)self.kv_norm = RMSNorm(self.kv_lora_rank)self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))  # k/v的升维矩阵self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim)self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim)  # 旋转位置编码# 没有矩阵融合if self.mode == 'naive':self.register_buffer('k_cache',torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.qk_head_dim),persistent=False)self.register_buffer('v_cache',torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.v_head_dim),persistent=False)# 有矩阵融合else:self.register_buffer('kv_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.kv_lora_rank),persistent=False)self.register_buffer('pe_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.qk_rope_head_dim),persistent=False)def forward(self, x, mask=None):bs, seq_len, _ = x.shapeq = self.wq_a(x)  # [bs, seq_len, q_lora_rank]q = self.q_norm(q)  # [bs, seq_len, q_lora_rank]q = self.wq_b(q)  # [bs, seq_len, n_heads * qk_head_dim]q = q.view(bs, seq_len, self.n_heads, self.qk_head_dim)  # [bs, seq_len, n_heads, qk_head_dim]q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim],dim=-1)  # q_nope shape:[bs, seq_len, n_heads, qk_nope_head_dim] q_pe shape:[bs, seq_len, n_heads, qk_rope_head_dim]kv = self.wkv_a(x)  # [bs, seq_len, kv_lora_rank + qk_rope_head_dim]kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim],dim=-1)  # kv shape:[bs, seq_len, kv_lora_rank] k_pe shape:[bs, seq_len, qk_rope_head_dim]k_pe = k_pe.unsqueeze(2)  # k_pe shape:[bs, seq_len, 1, qk_rope_head_dim]   一层共享一个keyq_pe, k_pe = self.rotary_emb(q_pe, k_pe)if self.mode == 'naive':q = torch.cat([q_nope, q_pe], dim=-1)  # * [bs, seq_len, n_heads, qk_head_dim]kv = self.kv_norm(kv)  # [bs, seq_len, kv_lora_rank)]kv = self.wkv_b(kv)  # [bs, seq_len, n_heads * (qk_nope_head_dim + v_head_dim)]kv = kv.view(bs, seq_len, self.n_heads, self.qk_nope_head_dim + self.v_head_dim)k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1)# k shape:[bs, seq_len, n_heads, qk_head_dim]self.k_cache[:bs, :seq_len, :, :] = kself.v_cache[:bs, :seq_len, :, :] = v# scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bs, :seq_len]) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)scores = torch.matmul(q.transpose(1, 2),self.k_cache[:bs, :seq_len, :, :].transpose(1, 2).transpose(2, 3) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim))scores = scores.transpose(1, 2)else:k_pe = k_pe.squeeze(2)wkv_b = self.wkv_b.weight  # [n_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]wkv_b = wkv_b.view(self.n_heads, -1,self.kv_lora_rank)  # [n_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]q_nope = torch.einsum("bshd,hdc->bshc", q_nope,wkv_b[:, :self.qk_nope_head_dim])  # q_nope shape:[bs, seq_len, n_heads, kv_lora_rank]# q*k(T) = x*wq*(c*wkv_b[:, :self.qk_nope_head_dim])(T) = x*wq*wkv_b[:, :self.qk_nope_head_dim](T)*c(T)    c为压缩后的k/v# wq*wkv_b[:, :self.qk_nope_head_dim](T)作为q的投影矩阵  c可以替代原先的k,这样就可以直接使用压缩后的k/v计算注意力了,kv_cache时也只需存储压缩后的k/vkv = self.kv_norm(kv)self.kv_cache[:bs, :seq_len, :] = kv  # kv shape:[bs, seq_len, kv_lora_rank]self.pe_cache[:bs, :seq_len, :] = k_pe  # k_pe shape:[bs, seq_len, qk_rope_head_dim]scores_nope = torch.einsum("bshc,btc->bsht", q_nope,self.kv_cache[:bs, :seq_len, :])  # bshc btc -> bshc bct -> bshtscores_pe = torch.einsum("bshr,btr->bsht", q_pe,self.pe_cache[:bs, :seq_len, :])  # bshr btr -> bshr bt1r -> bshr bthr -> bshtscores = (scores_nope + scores_pe) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)  # [bs, seq_len, n_heads, seq_len]if mask is not None:# mask shape:[bs, seq_len, seq_len]scores += mask.unsqueeze(2)scores = scores.softmax(dim=-1)if self.mode == 'naive':x = torch.einsum("bsht,bthd->bshd", scores,self.v_cache[:bs, :seq_len])  # bsht,bthd -> bhst, bhtd -> bhsd -> bshdelse:# scores * v = scores * c * wkv_b[:, -self.v_head_dim:]x = torch.einsum("bsht,btc->bshc", scores,self.kv_cache[:bs, :seq_len])  # x shape:[bs, seq_len, n_heads, kv_lora_rank]x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])  # bshc, hdc -> bshc,dch -> bsdh -> bshdx = x.contiguous().view(bs, seq_len, -1)x = self.wo(x) return xif __name__ == '__main__':torch.manual_seed(0)torch.set_printoptions(precision=3, sci_mode=False)x = torch.randn(1, 4, 16)dim = 16n_heads = 2q_lora_rank = 10kv_lora_rank = 6qk_nope_head_dim = 8qk_rope_head_dim = 4v_head_dim = 8max_seq_len = 10max_batch_size = 4mode = 'none'mla = MLA(dim=dim,n_heads=n_heads,q_lora_rank=q_lora_rank,kv_lora_rank=kv_lora_rank,qk_nope_head_dim=qk_nope_head_dim,qk_rope_head_dim=qk_rope_head_dim,v_head_dim=v_head_dim,max_seq_len=max_seq_len,max_batch_size=max_batch_size,mode=mode)print(mla(x))print(mla.kv_cache)

参考资料:

https://zhuanlan.zhihu.com/p/16730036197

https://github.com/wyf3/llm_related/tree/main/deepseek_learn

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/15364.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python截图轻量化工具

一、兼容局限性 这是用Python做的截图工具,不过由于使用了ctypes调用了Windows的API, 同时访问了Windows中"C:/Windows/Cursors/"中的.cur光标样式文件, 这个工具只适用于Windows环境; 如果要提升其跨平台性的话,需要考虑替换cty…

链表(LinkedList) 1

上期内容我们讲述了顺序表,知道了顺序表的底层是一段连续的空间进行存储(数组),在插入元素或者删除元素需要将顺序表中的元素整体移动,时间复杂度是O(n),效率比较低。因此,在Java的集合结构中又引入了链表来解决这一问…

SpringAI系列 - 使用LangGPT编写高质量的Prompt

目录 一、LangGPT —— 人人都可编写高质量 Prompt二、快速上手2.1 诗人 三、Role 模板3.1 Role 模板3.2 Role 模板使用步骤3.3 更多例子 四、高级用法4.1 变量4.2 命令4.3 Reminder4.4 条件语句4.5 Json or Yaml 方便程序开发 一、LangGPT —— 人人都可编写高质量 Prompt La…

jupyterLab插件开发

jupyter lab安装、配置: jupyter lab安装、配置教程_容器里装jupyterlab-CSDN博客 『Linux笔记』服务器搭建神器JupyterLab_linux_布衣小张-腾讯云开发者社区 Jupyter Lab | 安装、配置、插件推荐、多用户使用教程-腾讯云开发者社区-腾讯云 jupyterLab插件开发教…

使用LLaMA Factory踩坑记录

前置条件:电脑显卡RTX 4080 问题:LLaMA-Factory在运行的时候,弹出未检测到CUDA的报错信息 结论:出现了以上的报错,主要可以归结于以下两个方面: 1、没有安装GPU版本的pytorch,下载的是CPU版本…

『Apisix进阶篇』结合Consul作服务发现实战演练

文章目录 一、引言二、APISIX与Consul集成2.1 环境准备2.2 配置Consul服务发现2.2.1 修改APISIX配置文件2.2.2 重启APISIX 2.3 在路由中使用Consul服务发现2.3.1 创建路由2.3.2 验证路由 2.4 高级配置2.4.1 服务过滤2.4.2 多数据中心支持 三、总结 📣读完这篇文章里…

SpringBoot速成(八)登录实战:未登录不能访问 P5-P8

1.登录 package com.itheima.springbootconfigfile.controller;import com.itheima.springbootconfigfile.pojo.Result; import com.itheima.springbootconfigfile.pojo.User; import com.itheima.springbootconfigfile.service.UserService;import com.itheima.springbootco…

对接DeepSeek

其实,整个对接过程很简单,就四步,获取key,找到接口文档,接口测试,代码对接。 获取 KEY https://platform.deepseek.com/transactions 直接付款就是了(现在官网暂停充值2025年2月7日&#xff0…

ASP.NET Core JWT

目录 Session的缺点 JWT(Json Web Token) 优点: 登录流程 JWT的基本使用 生成JWT 解码JWT 用JwtSecurityTokenHandler对JWT解码 注意 Session的缺点 对于分布式集群环境,Session数据保存在服务器内存中就不合适了&#…

【MySQL】深度学习数据库开发技术:使用CC++语言访问数据库

**前言:**本节内容介绍使用C/C访问数据库, 包括对数据库的增删查改操作。 主要是学习一些接口的调用, 废话不多说, 开始我们的学习吧! ps:本节内容比较容易, 友友们放心观看哦! 目录 准备mysql…

postgreSQL16.6源码安装

1.获取源码 从PostgreSQL: File Browser获取tar.bz2或者tar.gz源码 2.解压 tar xf postgresql-version.tar.bz2 roothwz-VMware-Virtual-Platform:/usr/local# tar xf postgresql-16.6.tar.bz2 roothwz-VMware-Virtual-Platform:/usr/local# ll 总计 24324 drwxr-xr-x 12 ro…

音频进阶学习十一——离散傅里叶级数DFS

文章目录 前言一、傅里叶级数1.定义2.周期信号序列3.表达式DFSIDFS参数含义 4.DFS公式解析1)右边解析 T T T、 f f f、 ω \omega ω的关系求和公式N的释义求和公式K的释义 e j ( − 2 π k n N ) e^{j(\frac{-2\pi kn}{N})} ej(N−2πkn​)的释义 ∑ n 0 N − 1 e…

【kafka系列】Topic 与 Partition

Kafka 的 Topic(主题) 和 Partition(分区) 是数据组织的核心概念,它们的映射关系及在 Broker 上的分布直接影响 Kafka 的性能、扩展性和容错能力。以下是详细解析: 一、Topic 与 Partition 的映射关系 Top…

卷积神经网络CNN如何处理语音信号

卷积神经网络(CNN)在处理语音数据时通常不直接处理原始的一维波形信号,而是处理经过预处理的二维语音特征图。以下是CNN处理语音数据时的常见数据类型和步骤: 1. 语音信号预处理 语音信号通常是一维的时间序列(波形信…

【MQ】Spring3 中 RabbitMQ 的使用与常见场景

一、初识 MQ 传统的单体架构,分布式架构的同步调用里,无论是方法调用,还是 OpenFeign 难免会有以下问题: 扩展性差(高耦合,需要依赖对应的服务,同样的事件,不断有新需求&#xff0…

GB/T 43698-2024 《网络安全技术 软件供应链安全要求》标准解读

一、43698-2024标准图解 https://mmbiz.qpic.cn/sz_mmbiz_png/rwcfRwCticvgeBPR8TWIPywUP8nGp4IMFwwrxAHMZ9Enfp3wibNxnfichT5zs7rh2FxTZWMxz0je9TZSqQ0lNZ7lQ/640?wx_fmtpng&fromappmsg 标准在线预览: 国家标准|GB/T 43698-2024 相关标准: &a…

Linux系统-centos防火墙firewalld详解

Linux系统-centos7.6 防火墙firewalld详解 1 firewalld了解 CentOS 7.6默认的防火墙管理工具是firewalld,它取代了之前的iptables防火墙。firewalld属于典型的包过滤防火墙或称之为网络层防火墙,与iptables一样,都是用来管理防火墙的工具&a…

Gitlab中如何进行仓库迁移

需求:之前有一个自己维护的新仓库A,现在需要将这个仓库提交并覆盖另一个旧的仓库B,需要保留A中所有的commit信息。 1.方法一:将原有仓库A导出后再导入到新的仓库B中 适用场景:新的仓库B是一个待建仓库,相当…

微信点餐系统小程序ssm+论文源码调试讲解

第4章 系统设计 一个成功设计的系统在内容上必定是丰富的,在系统外观或系统功能上必定是对用户友好的。所以为了提升系统的价值,吸引更多的访问者访问系统,以及让来访用户可以花费更多时间停留在系统上,则表明该系统设计得比较专…

01单片机上电后没有正常运行怎么办

单片机上电后没有运转, 首先要检查什么? 1、单片机供电是否正常? &电路焊接检查 如果连最基本的供电都没有,其它都是空谈啊!检查电路断路了没有?短路了没有?电源合适吗?有没有虚焊? 拿起万用表之前,预想一下测量哪里?供电电压应该是多少?对PCB上电压测量点要…