【大模型LLM面试合集】大语言模型架构_MHA_MQA_GQA

MHA_MQA_GQA

1.总结

  • MHA(Multi Head Attention) 中,每个头有自己单独的 key-value 对;标准的多头注意力机制,h个Query、Key 和 Value 矩阵。
  • MQA(Multi Query Attention) 中只会有一组 key-value 对;多查询注意力的一种变体,也是用于自回归解码的一种注意力机制。与MHA不同的是,MQA 让所有的头之间共享同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量
  • GQA(Grouped Query Attention)中,会对 attention 进行分组操作,query 被分为 N 组,每个组共享一个 Key 和 Value 矩阵GQA将查询头分成G组,每个组共享一个Key 和 Value 矩阵。GQA-G是指具有G组的grouped-query attention。GQA-1具有单个组,因此具有单个Key 和 Value,等效于MQA。而GQA-H具有与头数相等的组,等效于MHA。

在这里插入图片描述

GQA-N 是指具有 N 组的 Grouped Query Attention。GQA-1具有单个组,因此具有单个Key 和 Value,等效于MQA。而GQA-H具有与头数相等的组,等效于MHA。

GQA介于MHA和MQA之间。GQA 综合 MHA 和 MQA ,既不损失太多性能,又能利用 MQA 的推理加速。不是所有 Q 头共享一组 KV,而是分组一定头数 Q 共享一组 KV,比如上图中就是两组 Q 共享一组 KV。

2.代码实现

2.1 MHA

多头注意力机制是Transformer模型中的核心组件。在其设计中,"多头"意味着该机制并不只计算一种注意力权重,而是并行计算多种权重,每种权重都从不同的“视角”捕获输入的不同信息。

  1. 为输入序列中的每个元素计算q, k, v,这是通过将输入此向量与三个权重矩阵相乘实现的:
    q = x W q k = x W k v = x W v \begin{aligned} q & =x W_{q} \\ k & =x W_{k} \\ v & =x W_{v}\end{aligned} qkv=xWq=xWk=xWv
    其中, x x x是输入词向量, W q W_q Wq, W k W_k Wk W v W_v Wv是q, k, v的权重矩阵
  2. 计算q, k 注意力得分: score ⁡ ( q , k ) = q ⋅ k T d k \operatorname{score}(q, k)=\frac{q \cdot k^{T}}{\sqrt{d_{k}}} score(q,k)=dk qkT,其中, d k d_k dk是k的维度
  3. 使用softmax得到注意力权重: Attention ⁡ ( q , K ) = softmax ⁡ ( score ⁡ ( q , k ) ) \operatorname{Attention}(q, K)=\operatorname{softmax}(\operatorname{score}(q, k)) Attention(q,K)=softmax(score(q,k))
  4. 使用注意力权重和v,计算输出: O u t p u t = Attention ⁡ ( q , K ) ⋅ V Output =\operatorname{Attention}(q, K) \cdot V Output=Attention(q,K)V
  5. 拼接多头输出,乘以 W O W_O WO,得到最终输出: M u l t i H e a d O u t p u t = C o n c a t ( O u t p u t 1 , O u t p u t 2 , … , O u t p u t H ) W O MultiHeadOutput = Concat \left(\right. Output ^{1}, Output ^{2}, \ldots, Output \left.^{H}\right) W_{O} MultiHeadOutput=Concat(Output1,Output2,,OutputH)WO

代码实现

import torch
from torch import nn
class MutiHeadAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(MutiHeadAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_heads## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, hidden_size)self.v_linear = nn.Linear(hidden_size, hidden_size)## 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key)value = self.split_head(value)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)## 对注意力输出进行拼接output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x):batch_size = x.size()[0]return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)

2.2 MQA

上图最右侧,直观上就是在计算多头注意力的时候,query仍然进行分头,和多头注意力机制相同,而key和value只有一个头。

正常情况在计算多头注意力分数的时候,query、key的维度是相同的,所以可以直接进行矩阵乘法,但是在多查询注意力(MQA)中,query的维度为 [batch_size, num_heads, seq_len, head_dim],key和value的维度为 [batch_size, 1, seq_len, head_dim]。这样就无法直接进行矩阵的乘法,为了完成这一乘法,可以采用torch的广播乘法

## 多查询注意力
import torch
from torch import nn
class MutiQueryAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(MutiQueryAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_heads## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, self.head_dim) ###self.v_linear = nn.Linear(hidden_size, self.head_dim) ##### 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key, 1)value = self.split_head(value, 1)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x, head_num=None):batch_size = x.size()[0]if head_num == None:return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)else:return x.view(batch_size, -1, head_num, self.head_dim).transpose(1,2)

相比于多头注意力,多查询注意力在W_k和W_v的维度映射上有所不同,还有就是计算注意力分数采用的是广播机制,计算最后的output也是广播机制,其他的与多头注意力完全相同。

2.3 GQA

GQA将MAQ中的key、value的注意力头数设置为一个能够被原本的注意力头数整除的一个数字,也就是group数。

不同的模型使用GQA有着不同的实现方式,但是总体的思路就是这么实现的,注意,设置的组一定要能够被注意力头数整除。

## 分组注意力查询
import torch
from torch import nn
class GroupQueryAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads, group_num):super(MutiQueryAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_headsself.group_num = group_num## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)self.v_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)## 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key, self.group_num)value = self.split_head(value, self.group_num)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x, group_num=None):batch_size,seq_len = x.size()[:2]if group_num == None:return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)else:x = x.view(batch_size, -1, group_num, self.head_dim).transpose(1,2)x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len, self.head_dim).reshape(batch_size, self.num_heads // group_num * group_num, seq_len, self.head_dim)return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/11198.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Transformer】手撕Attention

import torch from torch import nn import torch.functional as F import mathX torch.randn(16,64,512) # B,T,Dd_model 512 # 模型的维度 n_head 8 # 注意力头的数量多头注意力机制 class multi_head_attention(nn.Module): def __init__(self, d_model, n_hea…

【Linux】 冯诺依曼体系与计算机系统架构全解

Linux相关知识点可以通过点击以下链接进行学习一起加油!初识指令指令进阶权限管理yum包管理与vim编辑器GCC/G编译器make与Makefile自动化构建GDB调试器与Git版本控制工具Linux下进度条 冯诺依曼体系是现代计算机设计的基石,其统一存储和顺序执行理念推动…

冯·诺依曼体系结构

目录 冯诺依曼体系结构推导 内存提高冯诺依曼体系结构效率的方法 你使用QQ和朋友聊天时,整个数据流是怎么流动的(不考虑网络情况) 与冯诺依曼体系结构相关的一些知识 冯诺依曼体系结构推导 计算机的存在就是为了解决问题,而解…

全面认识了解DeepSeek+利用ollama在本地部署、使用和体验deepseek-r1大模型

文章目录 一、DeepSeek简介二、技术特点三、架构设计3.1、DeepSeek-V33.2、DeepSeek-V23.3、DeepSeek-R1 四、DeepSeek算法4.1、DeepSeek LLM 算法4.2、DeepSeek-V2 算法4.3、DeepSeek-R1 算法4.4、DeepSeek 在算力优化上的算法 五、DeepSeek的使用六、本地部署DeepSeek R1模型…

Python 梯度下降法(七):Summary

文章目录 Python 梯度下降法(七):Summary一、核心思想1.1 核心思想1.2 优化方法概述1.3 第三方库的使用 二、 BGD2.1 介绍2.2 torch 库算法2.2 代码示例2.3 SGD2.4 SGD代码示例2.5 MBGD2.6 MBGD 代码示例 三、 Adagrad3.1 介绍3.2 torch 库算…

SpringBoot Web开发(SpringMVC)

SpringBoot Web开发(SpringMVC) MVC 核心组件和调用流程 Spring MVC与许多其他Web框架一样,是围绕前端控制器模式设计的,其中中央 Servlet DispatcherServlet 做整体请求处理调度! . 除了DispatcherServletSpringMVC还会提供其他…

Web_php_unserialize

代码审计 <?php class Demo { private $file index.php;public function __construct($file) { $this->file $file; }、 //接收一个参数 $file 并赋值给私有属性 $filefunction __destruct() { echo highlight_file($this->file, true); } //在对象销毁时调用&…

Spring Web MVC基础第一篇

目录 1.什么是Spring Web MVC&#xff1f; 2.创建Spring Web MVC项目 3.注解使用 3.1RequestMapping&#xff08;路由映射&#xff09; 3.2一般参数传递 3.3RequestParam&#xff08;参数重命名&#xff09; 3.4RequestBody&#xff08;传递JSON数据&#xff09; 3.5Pa…

安装anaconda3 后 电脑如何单独运行python,python还需要独立安装吗?

安装anaconda3 后 电脑如何单独运行python&#xff0c;python还需要独立安装吗? 电脑第一此安装anaconda用于jupyter notebook使用。 但是在运行cmd的时候&#xff0c;输入python --version 显示未安装或跳转商店提示安装。 明明我可以运行python但是为什么cmd却说我没安装呢…

分布式事务组件Seata简介与使用,搭配Nacos统一管理服务端和客户端配置

文章目录 一. Seata简介二. 官方文档三. Seata分布式事务代码实现0. 环境简介1. 添加undo_log表2. 添加依赖3. 添加配置4. 开启Seata事务管理5. 启动演示 四. Seata Server配置Nacos1. 修改配置类型2. 创建Nacos配置 五. Seata Client配置Nacos1. 增加Seata关联Nacos的配置2. 在…

使用真实 Elasticsearch 进行高级集成测试

作者&#xff1a;来自 Elastic Piotr Przybyl 掌握高级 Elasticsearch 集成测试&#xff1a;更快、更智能、更优化。 在上一篇关于集成测试的文章中&#xff0c;我们介绍了如何通过改变数据初始化策略来缩短依赖于真实 Elasticsearch 的集成测试的执行时间。在本期中&#xff0…

OpenEuler学习笔记(十四):在OpenEuler上搭建.NET运行环境

一、在OpenEuler上搭建.NET运行环境 基于包管理器安装 添加Microsoft软件源&#xff1a;运行命令sudo rpm -Uvh https://packages.microsoft.com/config/centos/8/packages-microsoft-prod.rpm&#xff0c;将Microsoft软件源添加到系统中&#xff0c;以便后续能够从该源安装.…

基于Python的简单企业维修管理系统的设计与实现

以下是一个基于Python的简单企业维修管理系统的设计与实现&#xff0c;这里我们会使用Flask作为Web框架&#xff0c;SQLite作为数据库来存储相关信息。 1. 需求分析 企业维修管理系统主要功能包括&#xff1a; 维修工单的创建、查询、更新和删除。设备信息的管理。维修人员…

Van-Nav:新年,将自己学习的项目地址统一整理搭建自己的私人导航站,供自己后续查阅使用,做技术的同学应该都有一个自己网站的梦想

嗨&#xff0c;大家好&#xff0c;我是小华同学&#xff0c;关注我们获得“最新、最全、最优质”开源项目和高效工作学习方法 Van-Nav是一个基于Vue.js开发的导航组件库&#xff0c;它提供了多种预设的样式和灵活的配置选项&#xff0c;使得开发者可以轻松地定制出符合项目需求…

Android 音视频编解码 -- MediaCodec

引言 如果我们只是简单玩一下音频、视频播放&#xff0c;那么使用 MediaPlayer SurfaceView 播放就可以了&#xff0c;但如果想加个水印&#xff0c;加点其他特效什么的&#xff0c;那就不行了&#xff1b; 学习 Android 自带的硬件码类 – MediaCodec。 MediaCodec 介绍 在A…

UE 5.3 C++ 对垃圾回收的初步认识

一.UObject的创建 UObject 不支持构造参数。 所有的C UObject都会在引擎启动的时候初始化&#xff0c;然后引擎会调用其默认构造器。如果没有默认的构造器&#xff0c;那么 UObject 将不会编译。 有修改父类参数的需求&#xff0c;就使用指定带参构造 // Sets default value…

使用LLaMA-Factory对AI进行认知的微调

使用LLaMA-Factory对AI进行认知的微调 引言1. 安装LLaMA-Factory1.1. 克隆仓库1.2. 创建虚拟环境1.3. 安装LLaMA-Factory1.4. 验证 2. 准备数据2.1. 创建数据集2.2. 更新数据集信息 3. 启动LLaMA-Factory4. 进行微调4.1. 设置模型4.2. 预览数据集4.3. 设置学习率等参数4.4. 预览…

2025最新源支付V7全套开源版+Mac云端+五合一云端

2025最新源支付V7全套开源版Mac云端五合一云端 官方1999元&#xff0c; 最新非网上那种功能不全带BUG开源版&#xff0c;可以自己增加授权或二开 拥有卓越的性能和丰富的功能。它采用全新轻量化的界面UI&#xff0c;让您能更方便快捷地解决知识付费和运营赞助的难题 它基于…

Linux02——Linux的基本命令

目录 ls 常用选项及功能 综合示例 注意事项 cd和pwd命令 cd命令 pwd命令 相对路径、绝对路径和特殊路径符 特殊路径符号 mkdir命令 1. 功能与基本用法 2. 示例 3. 语法与参数 4. -p选项 touch-cat-more命令 1. touch命令 2. cat命令 3. more命令 cp-mv-rm命…

[EAI-023] FAST,机器人动作专用的Tokenizer,提高VLA模型的能力和训练效率

Paper Card 论文标题&#xff1a;FAST: Efficient Action Tokenization for Vision-Language-Action Models 论文作者&#xff1a;Karl Pertsch, Kyle Stachowicz, Brian Ichter, Danny Driess, Suraj Nair, Quan Vuong, Oier Mees, Chelsea Finn, Sergey Levine 论文链接&…