Group Query Attention (GQA) 机制详解以及手动实现计算

Group Query Attention (GQA) 机制详解

1. GQA的定义

Grouped-Query Attention (GQA) 是对 Multi-Head Attention (MHA) 和 Multi-Query Attention (MQA) 的扩展。通过提供计算效率和模型表达能力之间的灵活权衡,实现了查询头的分组。GQA将查询头分成了G个组,每个组共享一个公共的键(K)和值(V)投影。

2. GQA的变体

GQA有三种变体:

  • GQA-1:一个单独的组,等同于 Multi-Query Attention (MQA)。
  • GQA-H:组数等于头数,基本上与 Multi-Head Attention (MHA) 相同。
  • GQA-G:一个中间配置,具有G个组,平衡了效率和表达能力。
3. GQA的优势

使用G个组可以减少存储每个头的键和值所需的内存开销,特别是在具有大的上下文窗口或批次大小的情况下。GQA提供了对模型质量和效率的细致控制。

4. GQA的实现

GQA的最简形式可以通过实现 GroupedQueryAttention 类来实现。GroupedQueryAttention 类继承自 Attention 类,重写了 forward 方法,其中使用了 MultiQueryAttention 类的实例来处理每个组的查询。通过将每个组的结果拼接起来,然后与投影矩阵进行矩阵乘法运算,最终得到 GQA 的输出。[1]

pytorch 示例实现:

假设我们有以下初始化的query, key, value:

# shapes: (batch_size, seq_len, num_heads, head_dim)
query = torch.randn(1, 256, 8, 64)
key = torch.randn(1, 256, 2, 64)
value = torch.randn(1, 256, 2, 64)
1. 确定分组数量

首先,我们需要确定将查询头分为多少组。在这个例子中,我们有8个查询头,而键和值的头数为2,所以我们可以将查询头分为4组,每组有2个查询头。

2. 对查询进行分组

然后,我们将查询头分组。我们可以使用 torch.chunk 函数将查询张量沿着头维度分割成4个组,每个组有2个头。

query_groups = torch.chunk(query, 4, dim=2)  # shape of each group: (1, 256, 2, 64)
3. 计算注意力分数

对于每一个查询组,我们计算它与键的注意力分数。我们首先计算查询组和键的点积,然后通过 torch.softmax 函数得到注意力分数。

attention_scores = []
for query_group in query_groups:score = torch.matmul(query_group, key.transpose(-2, -1))  # shape: (1, 256, 2, 256)score = torch.softmax(score, dim=-1)attention_scores.append(score)
4. 计算注意力输出

接下来,我们使用注意力分数对值进行加权求和,得到每一个查询组的注意力输出。

attention_scores = []
for query_group in query_groups:score = torch.matmul(query_group, key.transpose(-2, -1))  # shape: (1, 256, 2, 256)score = torch.softmax(score, dim=-1)attention_scores.append(score)
5. 拼接输出

最后,我们将所有查询组的注意力输出拼接起来,得到最终的 Grouped Query Attention 的输出。

attention_outputs = []
for score in attention_scores:output = torch.matmul(score, value)  # shape: (1, 256, 2, 64)attention_outputs.append(output)

这就是 Grouped Query Attention 的实现过程。在这个过程中,我们将查询头分组,然后对每一个查询组分别计算注意力分数和输出,最后将所有查询组的输出拼接起来。这样可以减少存储每个头的键和值所需的内存开销,特别是在具有大的上下文窗口或批次大小的情况下。


  1. Grouped-Query Attention (GQA) - The Large Language Model Playbook

  2. 安全验证 - 知乎
  3. 安全验证 - 知乎
  4. 安全验证 - 知乎
  5. Grouped-Query Attention (GQA) - The Large Language Model Playbook

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/313852.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Jammy@Jetson Orin - Tensorflow Keras Get Started: 000 setup for tutorial

JammyJetson Orin - Tensorflow & Keras Get Started: 000 setup for tutorial 1. 源由2. 搭建环境2.1 安装IDE环境2.2 安装numpy2.3 安装keras2.4 安装JAX2.5 安装tensorflow2.6 安装PyTorch2.7 安装nbdiff 3. 测试DEMO3.1 numpy版本兼容问题3.2 karas API - model.compil…

Docker命令总结

目录 一.Docker常用命令总结 1.镜像命令管理 2.容器命令管理 二.Docker镜像命令 1.docker search:搜索镜像 2.docker pull:下载镜像 3.docker push:上传镜像 4.docker images:查看本地镜像 5.docker inspect:…

李彦宏:程序员将不复存在! 周鸿祎回怼!网友:先把百度程序员都开除了!

近日,百度创始人、董事长兼首席执行官李彦宏在央视《对话》•开年说的访谈中指出:“基本上说以后其实不会存在“程序员”这种职业了,因为只要会说话,人人都会具备程序员的能力”。 “未来的编程语言只会剩下两种,一种…

Python 网络与并发编程(四)

文章目录 协程Coroutines协程的核心(控制流的让出和恢复)协程和多线程比较协程的优点协程的缺点 asyncio实现协程(重点) 协程Coroutines 协程,全称是“协同程序”,用来实现任务协作。是一种在线程中,比线程更加轻量级的存在,由程…

wstunnel (websocket模式ssh)

接上一篇 修改客户端运行参数 ssh -o ProxyCommand"./wstunnel client -L stdio://%h:%p ws://192.168.254.131:8080" 127.0.0.1 其中127.0.0.1为服务端的本地ssh访问,可以修改为通过服务端访问其他设备的ssh服务。例如: ssh -o ProxyComma…

C# 生成图形验证码

目录 应用场景 开发运行环境 设计 生成内容 生成图片 实现 核心代码 调用示例 小结 应用场景 我们当用户登录系统时经常会用到图形验证码技术,要求用户识别图片中的内容,并正确输入,方可尝试登录。类似的场景还有用户注册或者涉及…

C#带引导窗体的窗体设计方法:创建特殊窗体

目录 1.设计操作流程 2.实例 (1)Resources.Designer.cs (2)Frm_Main.Designer.cs (3)Frm_Main.cs (4)Frm_Start.Designer.cs (5)Frm_Start.cs &#…

Kubernetes:云原生时代的核心引擎

文章目录 一、Kubernetes简介:引领云原生潮流二、K8s的核心特性:自动化与智能化三、K8s的实践应用:打造高效云原生应用架构四、K8s的挑战与应对:安全与性能并重五、K8s的未来展望:无限可能与挑战并存《Kubernetes快速进…

【windows-搭建Ubuntu22LTS】

一、环境要求 1. windows版本要求 至少Windows 10 2020年5月(2004) 版, Windows 10 2019年5月(1903) 版,或者 Windows 10 2019年11月(1909) 版 2. 控制面板开启相关的程序(需要重启) 二、Microsoft store安装unbuntu 下载后直接运行(稍微等会&#…

HTML 中创建 WebSocket服务与接收webSocket发送内容

效果图 服务端 html客户端接受的消息 接下来开始实现服务端 创建server.js const WebSocket require(ws);const wss new WebSocket.Server({ port: 8877 });wss.on(connection, function connection(ws) {console.log(WebSocket connection opened.);// 每隔 5 秒发送一次…

百种提权及手段一览系列第10集

特权升级的危险是显而易见的。通过提升权限,攻击者可以绕过网络安全措施,从而损害数据完整性、机密性和系统可用性。对于组织而言,这可能会导致数据泄露、系统停机以及潜在的法律和声誉后果。识别权限升级的迹象并部署预防性网络安全措施对于…

linux下 Mysql8.0 离线安装

环境:centos7.9 MysqlL8.0.36安装包 链接:https://pan.baidu.com/s/1bKwHr05z8Ye82dT9tntdUA 提取码:3a5z 参考Centos安装MYSQL8(离线可用) 文章目录 1、解压安装2、配置启动2.1 修改配置文件2.2 mysql 启动 3、mysql 测试 1、解压安装 #…

《欢乐钓鱼大师》攻略,钓友入坑必备!

欢迎来到《欢乐钓鱼大师》!在这个游戏里,你可以尽情享受垂钓的乐趣,通过不断更换和升级高阶鱼竿,轻松地钓到各种稀有鱼类。因为许多玩家在挑战关卡时遇到了一些困难,所以今天我给大家带来了《欢乐钓鱼大师攻略指南》&a…

1-内核开发环境ubuntu+virtualbox+mobaXterm搭建

内核开发环境 ubuntuvirtualboxmobaXterm搭建 目录 内核开发环境 ubuntuvirtualboxmobaXterm搭建 1.virtualbox 安装 2.ubuntu 安装 3.网络设置 4.虚拟机安装ssh 服务,更新ubuntu 源安装基本软件 5.mobaXterm 个人免费版本安装 6.总结 本课程教程从0-1开始教…

VS调试技巧

Hi~!这里是奋斗的小羊,很荣幸各位能阅读我的文章,诚请评论指点,关注收藏,欢迎欢迎~~ 💥个人主页:小羊在奋斗 💥所属专栏:C语言 本系列文章为个人学习笔记&#x…

GaN HEMT中短沟道效应的建模

来源:Modeling of Short-Channel Effects in GaN HEMTs(TED 20年) 摘要 在本文中,我们提出了一种用于估算GaN高电子迁移率晶体管(HEMT)器件中短沟道效应(SCEs)的显式和解析的基于电…

node.js egg.js

Egg 是 Node.js 社区广泛使用的框架,简洁且扩展性强,按照固定约定进行开发,低协作成本。 在Egg.js框架中,ctx 是一个非常核心且常用的对象,全称为 Context,它代表了当前 HTTP 请求的上下文。ctx 对象封装了…

Golang内存、指针逃逸、垃圾回收机制概览

最近看到了一篇文章是关于go的内存、指针逃逸和垃圾回收机制的,发现自己并未很细致的了解过这方面的内容,于是在翻阅各种文章的情况下,写出了这篇总结,参考文章放在文末,可自取 内存 Go 语言使用一个自带的垃圾收集器…

YOLOV5 TensorRT部署 BatchedNMS(engine模型推理)(下)

主要是在王新宇代码的基础上改进,引入对BatchedNMS的解码 文章目录 1. 修改yolov5.cpp2.修改yololayer.h1. 修改yolov5.cpp 首先增加全局变量,名字根据转onnx时修改的节点名字来,查看onnx文件可以看到,顺序不要弄错。 const char *INPUT_NAME = “images”; const char …

Aigtek:介电弹性体高压放大器在软体机器人研究中的应用

近年来软体机器人的研究成为目前机器人研究领域的热点,由于软体材料的自由度可以根据需求自由变化,因此软体机器人有着极高的灵活性,而且软体机器人因其材料的柔软性有着很好的人机交互性能和安全性。它的出现成功解决了传统的刚性机器人人机…