DeepSeek V3 源码:从入门到放弃!

从入门到放弃

花了几天时间,看懂了DeepSeek V3 源码的逻辑。源码的逻辑是不难的,但为什么模型结构需要这样设计,为什么参数需要这样设置呢?知其然,但不知其所以然。除了模型结构以外,模型的训练数据、训练脚本和训练经验,也是DeepSeek V3能够训练出来的关键,但这些是DeepSeek母公司的核心机密,我们无从得知。
因此,看懂了源码,算是入门了DeepSeek V3,因为没有条件知道更多重要细节,因此不得不放弃重现整个模型的训练。

Paper 和源码

Paper URL: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf
Code URL: https://github.com/deepseek-ai/DeepSeek-V3

模型逻辑

下面这张图,代表了DeepSeek的核心逻辑。左边是Transformer的逻辑结构,可以认为有N个左边这样的Block结构不断重复,组成Transformer模型。每个Block中,分成两个部分,Attention 和 Feed-Forward Network。对这两个部分使用不同的网络结构,我们就得到了不同的模型。
DeepSeek V3 的 Attention 用的是 Multi-Head Latent Attention(MLA) ,Feed-Forward Network 用的是DeepSeekMoE。
在这里插入图片描述

MLA

Multi-Head Latent Attention(MLA)即多头潜在注意力,是DeepSeek模型中引入的一种创新注意力机制,旨在优化传统多头注意力(Multi-Head Attention,MHA)的计算效率和内存占用。具体介绍如下:

核心创新点

  • 低秩键值压缩
    • KV的低秩压缩:不直接存储原始的Key和Value,而是先将隐藏状态投影到一个更小的压缩潜在向量。在推理时,只需缓存该压缩潜在向量,而不是完整的Key和Value,从而大大降低了KV缓存的存储需求。
    • Query的低秩压缩:对Query也进行低秩压缩,虽然不会减少KV缓存的大小,但可以减少训练时的激活存储需求,进而降低计算成本。
  • 解耦旋转位置嵌入(RoPE)
    • 额外引入“解耦查询”:将查询拆分为两个部分,一部分不经过RoPE变换,代表非位置敏感的特征信息;另一部分专门用于嵌入RoPE位置编码信息。
    • 共享RoPE变换的Key:所有注意力头共用一个旋转变换后的Key,减少了计算开销,也减小了KV缓存大小,降低了GPU内存占用,提高了推理速度,特别适用于长序列任务和大规模Transformer。

推理过程中的优化

将上投影矩阵吸收到里面,简化查询计算,并优化注意力分数的计算,减少了计算步骤,提升了计算效率。避免了先计算Value向量,减少了矩阵运算的开销,使推理更快。

整体优势

  • 降低内存占用:通过对键值进行低秩联合压缩以及解耦RoPE等策略,显著减少了KV缓存的存储需求,降低了GPU内存占用。
  • 提高计算效率:减少了训练和推理过程中的计算量,加快了模型的推理速度,在保持甚至提高模型性能的同时,提升了模型的运行效率。
  • 增强模型适应性:特别适用于长序列任务和大规模Transformer模型,能够更好地处理长序列输入,提高模型在各种自然语言处理任务中的表现。

MLA 有物理意义吗?

Multi-Head Latent Attention(MLA)能够起作用主要源于其独特的技术设计,在数学和信息处理层面有清晰的逻辑,不过它是一种抽象的算法概念,并不直接对应具体的物理意义,以下是对其作用原理的分析:

起作用的原因

  • 低秩压缩的有效性
    • 信息浓缩与降噪:通过低秩键值压缩,MLA将高维的Key和Value信息投影到低维的潜在向量空间,这一过程类似于对原始信息进行浓缩,提取出最关键、最具代表性的特征,去除了一些可能的噪声和冗余信息,使得模型能够更聚焦于重要信息,从而提高信息处理的效率和准确性。
    • 减少计算量和存储需求:低秩压缩大大降低了数据的维度,减少了模型训练和推理过程中的计算量和存储需求,使得模型能够更高效地运行,尤其是在处理大规模数据和长序列数据时,这种优势更为明显。
  • 解耦旋转位置嵌入的优势
    • 位置信息与内容信息的分离:传统的位置编码方式将位置信息和内容信息混合在一起进行处理,而MLA的解耦旋转位置嵌入将查询拆分为位置敏感和非位置敏感两部分,使模型能够更清晰地分离和处理位置信息与内容信息,更好地捕捉文本中的长距离依赖关系。
    • 共享RoPE变换的Key:所有注意力头共用一个旋转变换后的Key,不仅减少了计算开销,还使得模型能够从更宏观的角度利用位置信息,增强了模型对序列数据整体结构的理解和把握能力。
  • 多头机制的协同作用
    • 捕捉多维度信息:MLA中的多头机制允许模型同时从多个不同的角度和维度去捕捉输入数据中的信息,每个头可以关注到输入序列的不同方面,通过多个头的并行计算和协同工作,模型能够更全面、更深入地理解输入数据,提高模型的表示能力和泛化能力。

难以直接赋予物理意义的原因及近似理解

  • 抽象的算法概念:MLA是一种基于数学和计算机科学的算法概念,主要用于处理和分析数据中的模式和关系,它不像物理概念那样具有直接可观测的物理实体或现象与之对应,更多地是在数据空间和计算逻辑中发挥作用。
  • 类比物理现象理解:可以进行一些类比来帮助理解。比如低秩压缩类似于物理中的能量聚集,将分散的能量(信息)聚集到关键的“点”上;解耦旋转位置嵌入有点像物理中对不同性质力的分解,将位置信息和内容信息这两种“力”分开处理;多头机制如同多个物理传感器从不同方向和角度对环境进行感知,然后综合这些感知信息来对整个系统进行理解和判断。

DeepSeekMoE

DeepSeekMoE是由深度求索(DeepSeek)研发的基于混合专家系统(Mixture of Experts,MoE)的技术架构,以下是具体介绍:

架构原理

  • 混合专家系统核心:采用MoE架构,核心在于通过动态路由机制,把输入数据分配给最相关的专家处理。比如在自然语言处理中,有的专家专门处理情感分析,有的处理主题建模。
  • 结合多头潜在注意力机制:与MLA相结合,MLA通过引入潜在向量,减少键值缓存(KV cache)需求,提升推理效率。
  • Transformer架构基础:以Transformer架构为基础,每个Transformer块由一个注意力模块和一个前馈网络(FFN)组成,在注意力机制和FFN方面采用创新架构。

技术优势

  • 降低算力需求:MoE的动态分配机制和MLA减少KV缓存需求等特点,使模型在训练和推理时对算力的要求降低。
  • 保持高性能:在参数量减少的情况下仍能保持高性能,例如DeepSeek-V2以236B总参数、21B激活,大致可以达到70B-110B Dense的模型能力。
  • 减少计算量:自研Sparse结构DeepSeekMoE进一步降低了计算量。
  • 长上下文理解能力强:支持超100万token的上下文窗口,显著优于行业平均水平,适用于长文档分析、代码开发等复杂场景的连贯交互。

DeepSeekMoE的物理意义是什么?

DeepSeekMoE作为一种人工智能技术架构,没有严格意义上的物理意义,但可以从一些角度进行类比和理解:

从系统资源分配角度

  • 资源按需分配类比:可以将DeepSeekMoE的专家网络和动态路由机制类比为一个智能电力分配系统。在这个系统中,不同的电器设备(任务)需要不同的电量(计算资源)来运行。专家网络就像不同功率的发电机,而动态路由机制则像是智能电表和分配器,它会根据每个电器设备的实际需求,将电力(计算资源)精准地分配给需要的设备,避免了资源的浪费,提高了整个系统的能源利用效率。
  • 负载均衡类比:类似于在一个大型物流中心,不同的仓库区域(专家)负责存储和处理不同类型的货物(数据)。当有货物运输任务时,调度系统(动态路由)会根据货物的特点和仓库的负载情况,合理地安排货物存储到哪个仓库,确保每个仓库都能在其承载能力范围内高效运作,不会出现某个仓库过度拥挤而其他仓库闲置的情况,实现了负载均衡,提高了物流中心的整体运营效率。

从信息处理角度

  • 多维度信息处理类比:可以把DeepSeekMoE处理信息的过程想象成一个由多个不同专业的侦探(专家)组成的侦探团队在调查一个复杂案件。每个侦探都有自己独特的专业技能和视角,比如有的擅长调查线索,有的擅长分析人物关系,有的擅长破解密码等。当面对案件(输入数据)时,队长(路由器)会根据案件的具体情况,分配合适的侦探去处理相应的部分,最后将各个侦探的调查结果综合起来,形成对整个案件的全面了解和判断,从而更高效地解决复杂问题。
  • 特征提取与融合类比:如同在一个化学实验中,不同的化学试剂(专家)可以与不同的物质发生反应,提取出特定的化学特征。DeepSeekMoE中的专家网络就像这些化学试剂,它们各自对输入数据进行处理,提取出不同的特征。然后通过融合机制,将这些特征像混合化学物质一样进行整合,得到更全面、更有价值的信息,用于后续的分析和决策。

从模型架构角度

  • 积木搭建类比:把DeepSeekMoE的架构比作搭建积木。每个专家网络就像不同形状和功能的积木块,有的积木块负责搭建基础结构,有的负责构建上层建筑,有的负责添加装饰等。路由器则像是搭建者的手,根据要搭建的目标模型的需求,选择合适的积木块进行组合,最终搭建出一个复杂而功能强大的模型结构,实现对各种自然语言处理任务的高效处理。
  • 人体神经系统类比:可以将DeepSeekMoE类比为人体的神经系统。专家网络类似于人体的不同神经细胞或神经中枢,它们各自负责处理特定类型的信息,如视觉神经细胞负责处理视觉信息,听觉神经细胞负责处理听觉信息等。路由器就像神经系统中的神经递质或信号传导机制,它负责将外界的刺激信号(输入数据)准确地传递给相应的神经细胞,并将各个神经细胞处理后的信号进行整合和传递,使人体能够做出协调的反应和决策,实现对外部世界的感知和交互。

代码逻辑

整体 - Transformer

下面这段代码是典型的 Transformer 实现,核心可以看 forward 函数逻辑:

  1. 进行 Embeding;
  2. 经过各个 Block;
  3. 归一化并输出。
    对应的代码:
# 通过嵌入层将输入标记转换为向量表示
h = self.embed(tokens)
# 依次通过每个Transformer块进行处理
for layer in self.layers:h = layer(h, start_pos, freqs_cis, mask)
# 对输出进行层归一化,并取最后一个时间步的输出
h = self.norm(h)[:, -1]
# 通过输出投影层得到对数概率
logits = self.head(h)

完整代码:

# 定义Transformer类,继承自PyTorch的nn.Module类
class Transformer(Module):"""Transformer模型,包含位置嵌入、多个层以及输出投影。属性:max_seq_len (int): Transformer允许的最大序列长度。embed (nn.Module): 用于输入标记的嵌入层,将输入的标记转换为向量表示。layers (torch.nn.ModuleList): 存储多个Transformer块的列表,每个块包含多头注意力和前馈网络。norm (nn.Module): 层归一化层,在所有Transformer块之后应用,用于稳定训练。head (nn.Module): 输出投影层,将模型的输出映射到词汇表大小,用于预测下一个标记。freqs_cis (torch.Tensor): 预计算的复指数值,用于旋转位置嵌入,帮助模型捕捉序列中的位置信息。"""def __init__(self, args):"""初始化Transformer模型。参数:args: 模型参数对象,包含Transformer的各种参数,如词汇表大小、维度、层数等。"""# 获取全局变量world_size和rank,分别表示分布式训练中的进程总数和当前进程的编号global world_size, rank# 如果分布式训练已初始化,则获取进程总数,否则默认为1world_size = dist.get_world_size() if dist.is_initialized() else 1# 如果分布式训练已初始化,则获取当前进程编号,否则默认为0rank = dist.get_rank() if dist.is_initialized() else 0# 根据参数设置线性层的数据类型Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16# 调用父类的初始化方法super().__init__()# 保存最大序列长度self.max_seq_len = args.max_seq_len# 初始化嵌入层,将输入标记转换为向量表示self.embed = ParallelEmbedding(args.vocab_size, args.dim)# 初始化一个空的ModuleList,用于存储Transformer块self.layers = torch.nn.ModuleList()# 循环创建指定数量的Transformer块,并添加到layers列表中for layer_id in range(args.n_layers):self.layers.append(Block(layer_id, args))# 初始化层归一化层self.norm = RMSNorm(args.dim)# 初始化输出投影层,将模型的输出映射到词汇表大小self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())# 预计算旋转位置嵌入所需的复指数值,并将其注册为缓冲区,不参与模型参数的更新self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)@torch.inference_mode()def forward(self, tokens, start_pos=0):"""Transformer模型的前向传播过程。参数:tokens (torch.Tensor): 输入的标记ID张量,形状为 (batch_size, seq_len)。start_pos (int, 可选): 旋转位置嵌入的起始位置,默认为0。返回:torch.Tensor: 对数概率张量,形状为 (batch_size, vocab_size),表示每个标记的预测概率。"""# 获取输入序列的长度seqlen = tokens.size(1)# 通过嵌入层将输入标记转换为向量表示h = self.embed(tokens)# 从预计算的复指数值中截取当前序列所需的部分freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]# 初始化掩码为Nonemask = None# 如果序列长度大于1,则创建一个上三角掩码,用于屏蔽未来的标记if seqlen > 1:mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)# 依次通过每个Transformer块进行处理for layer in self.layers:h = layer(h, start_pos, freqs_cis, mask)# 对输出进行层归一化,并取最后一个时间步的输出h = self.norm(h)[:, -1]# 通过输出投影层得到对数概率logits = self.head(h)# 如果使用分布式训练,则收集所有进程的对数概率if world_size > 1:# 创建一个列表,用于存储所有进程的对数概率all_logits = [torch.empty_like(logits) for _ in range(world_size)]# 收集所有进程的对数概率dist.all_gather(all_logits, logits)# 将所有进程的对数概率拼接在一起logits = torch.cat(all_logits, dim=-1)return logits

单个 - Block

在这里插入图片描述
核心代码非常简单MLA(attention) + MOE(Feed-Forward Network):

# 首先对输入进行层归一化,然后通过注意力层进行计算,最后将结果与输入进行残差连接
x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
# 接着对上述结果进行层归一化,再通过前馈网络层进行计算,最后将结果与之前的结果进行残差连接
x = x + self.ffn(self.ffn_norm(x))

全部代码:

# 定义一个Transformer块类,继承自PyTorch的nn.Module类
class Block(Module):"""Transformer块,结合了注意力层和前馈网络层。属性:attn (nn.Module): 注意力层(采用多头潜在注意力机制,即MLA),用于捕捉输入序列中不同位置之间的依赖关系。ffn (nn.Module): 前馈网络层(可以是多层感知机MLP或者混合专家模型MoE),对注意力层的输出进行非线性变换。attn_norm (nn.Module): 用于注意力层的层归一化层,对输入到注意力层的数据进行归一化处理,稳定训练过程。ffn_norm (nn.Module): 用于前馈网络层的层归一化层,对输入到前馈网络层的数据进行归一化处理。"""def __init__(self, layer_id, args):"""初始化Transformer块。参数:layer_id (int): 当前块在Transformer模型中的层索引,用于确定使用哪种前馈网络结构。args: 模型参数对象,包含了块的各种参数,如维度、层数等。"""# 调用父类的初始化方法super().__init__()# 初始化注意力层,使用多头潜在注意力机制(MLA)self.attn = MLA(args)# 根据当前层的索引来决定使用MLP还是MoE作为前馈网络# 如果当前层索引小于密集层的数量,则使用MLP# 否则使用混合专家模型(MoE)self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)# 初始化用于注意力层的层归一化层self.attn_norm = RMSNorm(args.dim)# 初始化用于前馈网络层的层归一化层self.ffn_norm = RMSNorm(args.dim)def forward(self, x, start_pos, freqs_cis, mask=None):"""Transformer块的前向传播过程。参数:x (torch.Tensor): 输入张量,包含了序列的特征信息。start_pos (int): 序列中的起始位置,用于旋转位置嵌入。freqs_cis (torch.Tensor): 预计算的复指数值,用于旋转位置嵌入,帮助模型捕捉序列中的位置信息。mask (Optional[torch.Tensor]): 掩码张量,用于在注意力计算中排除某些位置,避免模型关注到不应该关注的信息。返回:torch.Tensor: 经过当前Transformer块计算后的输出张量。"""# 首先对输入进行层归一化,然后通过注意力层进行计算,最后将结果与输入进行残差连接x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)# 接着对上述结果进行层归一化,再通过前馈网络层进行计算,最后将结果与之前的结果进行残差连接x = x + self.ffn(self.ffn_norm(x))return x

Attention 模块

经典的QKV计算公式。解释可以自行搜索,或者参考:Transformer结构和注意力机制
在这里插入图片描述
和传统的QKV相比,可以认为是做了压缩,主要是为了减小 KV Cache。
在这里插入图片描述
代码,就是做了一堆下采样上采样和矩阵的组合变换。最终目的是减少计算量和显存使用量。

# 定义多头注意力层类,继承自PyTorch的nn.Module类
class MLA(Module):"""多头注意力层(MLA)。属性:dim (int): 输入特征的维度。n_heads (int): 注意力头的数量。n_local_heads (int): 分布式系统中本地注意力头的数量。q_lora_rank (int): 查询(query)的低秩投影的秩。kv_lora_rank (int): 键(key)和值(value)的低秩投影的秩。qk_nope_head_dim (int): 非位置相关的查询/键投影的维度。qk_rope_head_dim (int): 旋转位置编码的查询/键投影的维度。qk_head_dim (int): 查询/键投影的总维度。v_head_dim (int): 值投影的维度。softmax_scale (float): 注意力计算中softmax函数的缩放因子。"""def __init__(self, args):# 调用父类的初始化方法super().__init__()# 保存输入特征的维度self.dim = args.dim# 保存注意力头的数量self.n_heads = args.n_heads# 计算分布式系统中本地注意力头的数量self.n_local_heads = args.n_heads // world_size# 保存查询的低秩投影的秩self.q_lora_rank = args.q_lora_rank# 保存键和值的低秩投影的秩self.kv_lora_rank = args.kv_lora_rank# 保存非位置相关的查询/键投影的维度self.qk_nope_head_dim = args.qk_nope_head_dim# 保存旋转位置编码的查询/键投影的维度self.qk_rope_head_dim = args.qk_rope_head_dim# 计算查询/键投影的总维度self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim# 保存值投影的维度self.v_head_dim = args.v_head_dim# 如果查询的低秩投影的秩为0,直接使用列并行线性层进行查询投影if self.q_lora_rank == 0:self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)# 否则,使用低秩分解的方式进行查询投影else:self.wq_a = Linear(self.dim, self.q_lora_rank)self.q_norm = RMSNorm(self.q_lora_rank)self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)# 对输入进行线性变换得到键和值的低秩表示self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)# 对键和值的低秩表示进行归一化self.kv_norm = RMSNorm(self.kv_lora_rank)# 对归一化后的键和值进行线性变换self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))# 对多头注意力的输出进行行并行线性变换self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)# 计算softmax函数的缩放因子self.softmax_scale = self.qk_head_dim ** -0.5# 如果最大序列长度大于原始序列长度,对缩放因子进行调整if args.max_seq_len > args.original_seq_len:mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0self.softmax_scale = self.softmax_scale * mscale * mscale# 如果注意力实现方式为朴素方式if attn_impl == "naive":# 注册键缓存self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)# 注册值缓存self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)# 否则else:# 注册键值缓存self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)# 注册位置编码缓存self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)def forward(self, x, start_pos, freqs_cis, mask=None):"""多头注意力层(MLA)的前向传播过程。参数:x (torch.Tensor): 输入张量,形状为 (batch_size, seq_len, dim)。start_pos (int): 序列中用于缓存的起始位置。freqs_cis (torch.Tensor): 预计算的复指数值,用于旋转位置编码。mask (Optional[torch.Tensor]): 掩码张量,用于在注意力计算中排除某些位置。返回:torch.Tensor: 输出张量,形状与输入相同。"""# 获取输入张量的批次大小、序列长度bsz, seqlen, _ = x.size()# 计算序列的结束位置end_pos = start_pos + seqlen# 如果查询的低秩投影的秩为0,直接通过线性层得到查询if self.q_lora_rank == 0:q = self.wq(x)# 否则,通过低秩分解的方式得到查询else:q = self.wq_b(self.q_norm(self.wq_a(x)))# 调整查询的形状,将其划分为多个头q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)# 将查询划分为非位置相关部分和位置相关部分q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)# 对位置相关部分应用旋转位置编码q_pe = apply_rotary_emb(q_pe, freqs_cis)# 通过线性层得到键和值的低秩表示kv = self.wkv_a(x)# 将键和值的低秩表示划分为低秩部分和位置编码部分kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)# 对位置编码部分应用旋转位置编码k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)# 如果注意力实现方式为朴素方式if attn_impl == "naive":# 将非位置相关部分和位置相关部分拼接得到完整的查询q = torch.cat([q_nope, q_pe], dim=-1)# 对键和值的低秩表示进行归一化和线性变换kv = self.wkv_b(self.kv_norm(kv))# 调整键和值的形状,将其划分为多个头kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)# 将键和值划分为非位置相关部分和值部分k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)# 将非位置相关部分和位置编码部分拼接得到完整的键k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)# 将键存入缓存self.k_cache[:bsz, start_pos:end_pos] = k# 将值存入缓存self.v_cache[:bsz, start_pos:end_pos] = v# 计算注意力分数scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale# 否则else:# 获取键和值的线性变换层的权重wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) # 调整权重的形状wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)# 计算非位置相关部分的注意力分数q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])# 将键和值的低秩表示归一化后存入缓存self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)# 将位置编码部分存入缓存self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)# 计算注意力分数scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale# 如果存在掩码,将掩码加到注意力分数上if mask is not None:scores += mask.unsqueeze(1)# 对注意力分数应用softmax函数scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)# 如果注意力实现方式为朴素方式if attn_impl == "naive":# 通过注意力分数和值缓存计算输出x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])# 否则else:# 通过注意力分数和键值缓存计算中间结果x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])# 通过中间结果和权重计算输出x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])# 对输出进行线性变换x = self.wo(x.flatten(2))return x

代码解释总结

这段代码定义了一个多头注意力层(MLA)类。在初始化时,根据传入的参数设置各种维度、低秩投影的秩等,并初始化相应的线性层和归一化层,同时根据注意力实现方式注册不同的缓存。在前向传播过程中,对输入进行处理得到查询、键和值,应用旋转位置编码,根据不同的注意力实现方式计算注意力分数,最后通过注意力分数和缓存得到输出并进行线性变换。

Feed-Forward Network

在这里插入图片描述
这个MoE分成两个部分,左边是一些可以分享的专家,就是每次都需要去计算的,右边的是根据分数来选择的。如何选择,是通过一个门控机制来选择。这个门控是如何设计的,代码里有实现,但论文和代码都没有对它物理意义的解释。门控机制,简单来说,就是设计了一个网络,来选出K个候选。

核心逻辑:

  1. 通过门控确定本轮要用到的本地专家:weights, indices = self.gate(x)
  2. 用选择的每个本地专家进行计算:y[idx] += expert(x[idx]) * weights[idx, top, None]
  3. 用共享专家进行计算:z = self.shared_experts(x)
  4. 将本地专家的输出和共享专家的输出相加,并恢复到原始形状: return (y + z).view(shape)

全部代码:

# 定义混合专家(Mixture-of-Experts, MoE)模块类,继承自PyTorch的nn.Module类
class MoE(nn.Module):"""混合专家(Mixture-of-Experts, MoE)模块。属性:dim (int): 输入特征的维度。n_routed_experts (int): 模型中专家的总数。n_local_experts (int): 在分布式系统中本地处理的专家数量。n_activated_experts (int): 每个输入激活的专家数量。gate (nn.Module): 门控机制,用于将输入路由到不同的专家。experts (nn.ModuleList): 专家模块列表,包含多个专家网络。shared_experts (nn.Module): 共享专家模块,应用于所有输入。"""def __init__(self, args):"""初始化MoE模块。参数:args: 模型参数对象,包含MoE模块的相关参数。"""# 调用父类的初始化方法super().__init__()# 保存输入特征的维度self.dim = args.dim# 确保专家总数能被分布式系统中的进程数整除assert args.n_routed_experts % world_size == 0, f"专家数量必须能被进程数整除 (进程数={world_size})"# 保存模型中专家的总数self.n_routed_experts = args.n_routed_experts# 计算本地处理的专家数量self.n_local_experts = args.n_routed_experts // world_size# 保存每个输入激活的专家数量self.n_activated_experts = args.n_activated_experts# 计算本地专家在所有专家中的起始索引self.experts_start_idx = rank * self.n_local_experts# 计算本地专家在所有专家中的结束索引self.experts_end_idx = self.experts_start_idx + self.n_local_experts# 初始化门控机制self.gate = Gate(args)# 初始化专家模块列表,本地负责的专家使用Expert模块,其他位置置为Noneself.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else Nonefor i in range(self.n_routed_experts)])# 初始化共享专家模块self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)def forward(self, x):"""MoE模块的前向传播过程。参数:x (torch.Tensor): 输入张量。返回:torch.Tensor: 经过专家路由和计算后的输出张量。"""# 保存输入张量的原始形状shape = x.size()# 将输入张量展平为二维张量,方便后续处理x = x.view(-1, self.dim)# 通过门控机制得到每个输入分配到各个专家的权重和对应的专家索引weights, indices = self.gate(x)# 初始化输出张量,形状与输入相同,初始值全为0y = torch.zeros_like(x)# 统计每个专家被分配到的输入数量counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()# 遍历本地负责的专家for i in range(self.experts_start_idx, self.experts_end_idx):# 如果该专家没有被分配到输入,则跳过if counts[i] == 0:continue# 获取当前专家模块expert = self.experts[i]# 找出分配到当前专家的输入的索引idx, top = torch.where(indices == i)# 将这些输入通过当前专家模块进行计算,并乘以对应的权重,累加到输出张量中y[idx] += expert(x[idx]) * weights[idx, top, None]# 将输入通过共享专家模块进行计算z = self.shared_experts(x)# 如果使用分布式训练,对本地专家的输出进行全局归约操作if world_size > 1:dist.all_reduce(y)# 将本地专家的输出和共享专家的输出相加,并恢复到原始形状return (y + z).view(shape)

代码解释总结

这段代码定义了一个混合专家(MoE)模块。在初始化时,根据传入的参数设置专家的数量、门控机制、专家模块列表和共享专家模块。在前向传播过程中,首先通过门控机制将输入路由到不同的专家,然后对本地负责的专家进行计算并累加结果,同时将输入通过共享专家模块进行计算,最后将两部分结果相加并恢复原始形状。如果使用分布式训练,还会对本地专家的输出进行全局归约操作。

Gate

不考虑分组路由来看看它的核心逻辑,实际上就是线下变换,然后激活,选择K个极值(如果用了分组,就是选择K的方式发生了一些变化):

# 通过线性变换计算每个输入对应各个专家的分数
scores = linear(x, self.weight)
# 根据评分函数类型对分数进行处理
if self.score_func == "softmax":scores = scores.softmax(dim=-1, dtype=torch.float32)
else:scores = scores.sigmoid()
# 选择分数最高的若干专家
indices = torch.topk(scores, self.topk, dim=-1)[1]
# 根据选择的专家索引,从原始分数中获取对应的权重
weights = scores.gather(1, indices)
# 如果评分函数是sigmoid,对权重进行归一化
if self.score_func == "sigmoid":weights /= weights.sum(dim=-1, keepdim=True)
# 对权重进行缩放
weights *= self.route_scale
return weights.type_as(x), indices
# 定义门控机制类,用于在混合专家(MoE)模型中对输入进行路由
class Gate(nn.Module):"""混合专家(MoE)模型中用于输入路由的门控机制。属性:dim (int): 输入特征的维度。topk (int): 每个输入激活的顶级专家数量。n_groups (int): 用于路由的分组数量。topk_groups (int): 输入将被路由到的分组数量。score_func (str): 评分函数,取值为 'softmax' 或 'sigmoid'。route_scale (float): 路由权重的缩放因子。weight (torch.nn.Parameter): 门控机制的可学习权重。bias (Optional[torch.nn.Parameter]): 门控机制的可选偏置项。"""def __init__(self, args):"""初始化门控机制模块。参数:args: 模型参数对象,包含门控机制的相关参数。"""# 调用父类的初始化方法super().__init__()# 保存输入特征的维度self.dim = args.dim# 保存每个输入激活的顶级专家数量self.topk = args.n_activated_experts# 保存用于路由的分组数量self.n_groups = args.n_expert_groups# 保存输入将被路由到的分组数量self.topk_groups = args.n_limited_groups# 保存评分函数类型self.score_func = args.score_func# 保存路由权重的缩放因子self.route_scale = args.route_scale# 初始化可学习权重self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))# 根据输入特征维度决定是否初始化偏置项self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else Nonedef forward(self, x):"""门控机制的前向传播过程。参数:x (torch.Tensor): 输入张量。返回:Tuple[torch.Tensor, torch.Tensor]: 路由权重和选择的专家索引。"""# 通过线性变换计算每个输入对应各个专家的分数scores = linear(x, self.weight)# 根据评分函数类型对分数进行处理if self.score_func == "softmax":scores = scores.softmax(dim=-1, dtype=torch.float32)else:scores = scores.sigmoid()# 保存原始分数,后续计算权重时使用original_scores = scores# 如果存在偏置项,将其加到分数上if self.bias is not None:scores = scores + self.bias# 如果分组数量大于1,进行分组路由操作if self.n_groups > 1:# 调整分数的形状,以便按组处理scores = scores.view(x.size(0), self.n_groups, -1)# 根据是否有偏置项,计算每个组的分数表示if self.bias is None:group_scores = scores.amax(dim=-1)else:group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)# 选择分数最高的若干组indices = group_scores.topk(self.topk_groups, dim=-1)[1]# 创建掩码,用于屏蔽未选择的组mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)# 将屏蔽组的分数设为负无穷scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)# 选择分数最高的若干专家indices = torch.topk(scores, self.topk, dim=-1)[1]# 根据选择的专家索引,从原始分数中获取对应的权重weights = original_scores.gather(1, indices)# 如果评分函数是sigmoid,对权重进行归一化if self.score_func == "sigmoid":weights /= weights.sum(dim=-1, keepdim=True)# 对权重进行缩放weights *= self.route_scalereturn weights.type_as(x), indices

代码解释总结

这段代码定义了一个门控机制(Gate)类,用于在混合专家(MoE)模型中对输入进行路由。在初始化时,根据传入的参数设置各种属性,如输入维度、激活专家数量、分组数量等,并初始化可学习的权重和偏置项。在前向传播过程中,首先计算每个输入对应各个专家的分数,然后根据评分函数类型进行处理,接着根据分组情况进行分组路由操作,选择激活的专家并计算对应的权重,最后返回路由权重和选择的专家索引。

B站大牛详解视频链接

B站大牛有详细视频讲解:
https://www.bilibili.com/video/BV1RtNLeqEeu/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/29420.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

thinkphp5.1 在fetch模版就超时

场景 当被渲染模版不存在&#xff0c;请求不响应任何内容&#xff0c;过一会就timeout 排查过程 使用xdebug,追踪代码&#xff0c;发现走到D:\temporary_files\m40285_mini\40285_mini\thinkphp\library\think\exception\Handle.php&#xff0c;进入死循环&#xff0c;一直…

【Vue CLI脚手架开发】——6.scoped样式

文章目录 一、scoped是什么二、应用案例1.使用代码2.原理3父组件App未添加scoped影响 一、scoped是什么 我们知道vue为了防止css样式污染&#xff0c;在每个组件中提供了 scoped属性进行限定css作用域&#xff1b;当<style>标签有 scoped 属性时&#xff0c;它的 CSS 只…

微服务,服务治理nacos,负载均衡LOadBalancer,OpenFeign

1.微服务 简单来说&#xff0c;微服务架构风格[1]是一种将一个单一应用程序开发为一组小型服务的方法&#xff0c;每个服务运行在 自己的进程中&#xff0c;服务间通信采用轻量级通信机制(通常用HTTP资源API)。这些服务围绕业务能力构建并 且可通过全自动部署机制独立部署。这…

STM32-USART串口数据包

一&#xff1a;HEX数据包发送 1.为了收发数据包&#xff0c;先定义两个缓存区的数组 &#xff0c;这4个数据只存储发送或者接收的载荷数据&#xff0c;包头和包尾不存 uint8_t Serial_TxPacket[4]; uint8_t Serial_RxPacket[4]; uint8_t Serial_RxFlag;//接收一个数据包就置F…

[项目]基于FreeRTOS的STM32四轴飞行器: 四.LED控制

基于FreeRTOS的STM32四轴飞行器: 四.LED控制 一.配置Com层二.编写驱动 一.配置Com层 先在Com_Config.h中定义灯位置的枚举类型&#xff1a; 之后定义Led的结构体&#xff1a; 定义飞行器状态&#xff1a; 在Com_Config.c中初始化四个灯&#xff1a; 在Com_Config.h外部声明…

RT-DETR融合YOLOv12中的R-ELAN结构

RT-DETR使用教程&#xff1a; RT-DETR使用教程 RT-DETR改进汇总贴&#xff1a;RT-DETR更新汇总贴 《YOLOv12: Attention-Centric Real-Time Object Detectors》 一、 模块介绍 论文链接&#xff1a;https://arxiv.org/abs/2502.12524 代码链接&#xff1a;https://gitcode.com…

“深入浅出”系列之Linux篇:(13)socket编程实战+TCP粘包解决方案

从日常使用的APP&#xff0c;到背后支撑的各类服务器&#xff0c;网络通信无处不在&#xff0c;而socket作为实现网络通信的关键技术&#xff0c;更是开发者们必须掌握的核心知识。但在socket编程的道路上&#xff0c;TCP粘包问题宛如一只拦路虎&#xff0c;让无数开发者头疼不…

【计算机操作系统】操作系统的功能和目标

1、操作系统的功能和目标---作为系统资源的管理者 作为系统资源的管理者提供的功能&#xff1a; &#xff08;1&#xff09;处理机管理 &#xff08;2&#xff09;存储器管理 &#xff08;3&#xff09;文件管理 &#xff08;4&#xff09;设备管理 作为系统资源的管理者…

“深入浅出”系列之Linux篇:(10)基于C++实现分布式网络通信RPC框架

分布式网络通信rpc框架 项目是分布式网络通信rpc框架&#xff0c; 文中提到单机服务器的缺点&#xff1a; 硬件资源的限制影响并发&#xff1a;受限于硬件资源&#xff0c;聊天服务器承受的用户的并发有限 模块的编译部署难&#xff1a;任何模块小的修改&#xff0c;都导致整…

Aws batch task 无法拉取ECR 镜像unable to pull secrets or registry auth 问题排查

AWS batch task使用了自定义镜像&#xff0c;在提作业后出现错误 具体错误是ResourceInitializationError: unable to pull secrets or registry auth: The task cannot pull registry auth from Amazon ECR: There is a connection issue between the task and Amazon ECR. C…

机器学习之无监督学习

无监督学习&#xff08;Unsupervised Learning&#xff09;是机器学习的一个重要分支&#xff0c;其特点是在训练过程中不使用标签数据。与有监督学习不同&#xff0c;无监督学习的目标是从未标记的数据中发现隐藏的结构、模式或关系。无监督学习广泛应用于聚类、降维、异常检测…

自然语言处理:朴素贝叶斯

介绍 大家好&#xff0c;博主又来和大家分享自然语言处理领域的知识了。按照博主的分享规划&#xff0c;本次分享的核心主题本应是自然语言处理中的文本分类。然而&#xff0c;在对分享内容进行细致梳理时&#xff0c;我察觉到其中包含几个至关重要的知识点&#xff0c;即朴素…

HTML label 标签使用

点击 <label> 标签通常会使与之关联的表单控件获得焦点或被激活。 通过正确使用 <label> 标签&#xff0c;可以使表单更加友好和易于使用&#xff0c;同时提高整体的可访问性。 基本用法 <label> 标签通过 for 属性与 id 为 username 的 <input> 元素…

Ubuntu20.04双系统安装及软件安装(五):VSCode

Ubuntu20.04双系统安装及软件安装&#xff08;五&#xff09;&#xff1a;VSCode 打开VScode官网&#xff0c;点击中间左侧的deb文件下载&#xff1a; 系统会弹出下载框&#xff0c;确定即可。 在文件夹的**“下载”目录**&#xff0c;可看到下载的安装包&#xff0c;在该目录下…

SparkSQL全之RDD、DF、DS ,UDF、架构、资源划分、sql执行计划、调优......

1 SparkSQL概述 1.1 sparksql简介 Shark是专门针对于spark的构建大规模数据仓库系统的一个框架Shark与Hive兼容、同时也依赖于Spark版本Hivesql底层把sql解析成了mapreduce程序&#xff0c;Shark是把sql语句解析成了Spark任务随着性能优化的上限&#xff0c;以及集成SQL的一些…

Linux总结

1 用户与用户组管理 1.1 用户与用户组 //linux用户和用户组 Linux系统是一个多用户多任务的分时操作系统 使用系统资源的用户需要账号进入系统 账号是用户在系统上的标识&#xff0c;系统根据该标识分配不同的权限和资源 一个账号包含用户和用户组 //用户分类 超级管理员 UID…

【AI深度学习网络】卷积神经网络(CNN)入门指南:从生物启发的原理到现代架构演进

深度神经网络系列文章 【AI深度学习网络】卷积神经网络&#xff08;CNN&#xff09;入门指南&#xff1a;从生物启发的原理到现代架构演进【AI实践】基于TensorFlow/Keras的CNN&#xff08;卷积神经网络&#xff09;简单实现&#xff1a;手写数字识别的工程实践 引言 在当今…

Qt之QGraphicsView图像操作

QGraphicsView图像操作:旋转、放大、缩小、移动、图层切换 1 摘要 GraphicsView框架结构主要包含三个主要的类QGraphicsScene(场景)、QGraphicsView(视图)、QGraphicsItem(图元)。QGraphicsScene本身不可见,是一个存储图元的容器,必须通过与之相连的QGraphicsView视图来显…

【Azure 架构师学习笔记】- Azure Databricks (14) -- 搭建Medallion Architecture part 2

本文属于【Azure 架构师学习笔记】系列。 本文属于【Azure Databricks】系列。 接上文 【Azure 架构师学习笔记】- Azure Databricks (13) – 搭建Medallion Architecture part 1 前言 上文搭建了ADB 与外部的交互部分&#xff0c;本篇搭建ADB 内部配置来满足medallion 架构。…

AI视频领域的DeepSeek—阿里万相2.1图生视频

让我们一同深入探索万相 2.1 &#xff0c;本文不仅介绍其文生图和文生视频的使用秘籍&#xff0c;还将手把手教你如何利用它实现图生视频。 如下为生成的视频效果&#xff08;我录制的GIF动图&#xff09; 如下为输入的图片 目录 1.阿里巴巴全面开源旗下视频生成模型万相2.1模…