探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(四)分组多查询注意力

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(四)分组多查询注意力

Grouped-query Attention,简称GQA

分组查询注意力(Grouped-query Attention,简称GQA)是多查询和多头注意力的插值。它在保持与多查询注意力相当的处理速度的同时,实现了与多头注意力相似的质量。

在这里插入图片描述

自回归解码的标准做法是缓存序列中先前标记的键和值,以加快注意力计算的速度。

  • 然而,随着上下文窗口或批量大小的增加,多头注意力(Multi-Head Attention,简称MHA)模型中键值缓存(Key-Value Cache,简称KV Cache)的大小所关联的内存成本显著增加。

  • 多查询注意力(Multi-Query Attention,简称MQA)是一种机制,它对多个查询仅使用单个键值头,这可以节省内存并大幅加快解码器的推理速度。

  • Llama(一种模型)整合了GQA,以解决在Transformer模型自回归解码期间的内存带宽挑战。主要问题源于GPU进行计算的速度比它们将数据移入内存的速度快。在每个阶段都需要加载解码器权重和注意力键,这消耗了大量的内存。

在这里插入图片描述
在这里插入图片描述

class SelfAttention(nn.Module): def  __init__(self, args: ModelArgs):super().__init__()self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads# Indicates the number of heads for the queriesself.n_heads_q = args.n_heads# Indiates how many times the heads of keys and value should be repeated to match the head of the Queryself.n_rep = self.n_heads_q // self.n_kv_heads# Indicates the dimentiona of each headself.head_dim = args.dim // args.n_headsself.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))def forward(self, x: torch.Tensor, start_pos: int, freq_complex: torch.Tensor):batch_size, seq_len, _ = x.shape #(B, 1, dim)# Apply the wq, wk, wv matrices to query, key and value# (B, 1, dim) -> (B, 1, H_q * head_dim)xq = self.wq(x)# (B, 1, dim) -> (B, 1, H_kv * head_dim)xk = self.wk(x)xv = self.wv(x)# (B, 1, H_q * head_dim) -> (B, 1, H_q, head_dim)xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)# (B, 1, H_kv * head_dim) -> (B, 1, H_kv, head_dim)xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)# Apply the rotary embeddings to the keys and values# Does not chnage the shape of the tensor# (B, 1, H_kv, head_dim) -> (B, 1, H_kv, head_dim)xq = apply_rotary_embeddings(xq, freq_complex, device=x.device)xk = apply_rotary_embeddings(xk, freq_complex, device=x.device)# Replace the enty in the cache for this tokenself.cache_k[:batch_size, start_pos:start_pos + seq_len] = xkself.cache_v[:batch_size, start_pos:start_pos + seq_len] = xv# Retrive all the cached keys and values so far# (B, seq_len_kv, H_kv, head_dim)keys = self.cache_k[:batch_size, 0:start_pos + seq_len]values = self.cache_v[:batch_size, 0:start_pos+seq_len] # Repeat the heads of the K and V to reach the number of heads of the querieskeys = repeat_kv(keys, self.n_rep)values = repeat_kv(values, self.n_rep)# (B, 1, h_q, head_dim) --> (b, h_q, 1, head_dim)xq = xq.transpose(1, 2)keys = keys.transpose(1, 2)values = values.transpose(1, 2)# (B, h_q, 1, head_dim) @ (B, h_kv, seq_len-kv, head_dim) -> (B, h_q, 1, seq_len-kv)scores = torch.matmul(xq, keys.transpose(2,3)) / math.sqrt(self.head_dim)scores = F.softmax(scores.float(), dim=-1).type_as(xq)# (B, h_q, 1, seq_len) @ (B, h_q, seq_len-kv, head_dim) --> (b, h-q, q, head_dim)output = torch.matmul(scores, values)# (B, h_q, 1, head_dim) -> (B, 1, h_q, head_dim) -> ()output = (output.transpose(1,2).contiguous().view(batch_size, seq_len, -1))return self.wo(output) # (B, 1, dim) -> (B, 1, dim)

系列博客

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(一)
https://duanzhihua.blog.csdn.net/article/details/138208650
探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(二)
https://duanzhihua.blog.csdn.net/article/details/138212328

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(三)KV缓存
https://duanzhihua.blog.csdn.net/article/details/138213306
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/314090.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

grafana报错This panel requires Angular (deprecated)

1.原因 报错解释: Grafana在更新到7.0版本后,弃用了AngularJS(一种用于构建大型Web应用的JavaScript框架)。在早期的Grafana版本中,某些面板可能依赖于AngularJS,但这种依赖已经逐步被新的React或Vue面板所…

【软考经验分享】软考-中级-嵌入式备考

这里写目录标题 教辅用书嵌入式系统设计师考试大纲嵌入式系统设计师教程嵌入式系统设计师5天修炼嵌入式系统设计师考前冲刺100题 刷题软件希赛网软考真题 视频教程希赛网王道-计组计网 教辅用书 嵌入式系统设计师考试大纲 50页左右,内容为罗列一些考点&#xff0c…

Java基础之JVM基础调优与常见问题

常见命令 以下命令的介绍,全部在jdk8环境下运行的; jps ☆☆☆☆☆ 查看当前运行的进程号; jmap ☆☆☆ jmap命令可以查看jvm的内存信息,class对应的实例个数以及占用的内存大小 jmap -histo 查看当前java进程 [rdVM-8-12-c…

ARM学习(26)链接库的依赖查看

笔者今天来聊一下查看链接库的依赖。 通常情况下,运行一个可执行文件的时候,可能会出现找不到依赖库的情况,比如图下这种情况,可以看到是缺少了license.dll或者libtest.so,所以无法运行。怎么知道它到底缺少什么dll呢&…

多输入多输出 | Matlab实现WOA-LSSVM鲸鱼算法优化最小二乘支持向量机多输入多输出预测

多输入多输出 | Matlab实现WOA-LSSVM鲸鱼算法优化最小二乘支持向量机多输入多输出预测 目录 多输入多输出 | Matlab实现WOA-LSSVM鲸鱼算法优化最小二乘支持向量机多输入多输出预测预测效果基本介绍程序设计往期精彩参考资料 预测效果 基本介绍 Matlab实现WOA-LSSVM鲸鱼算法优化…

SpringCloud系列(5)--SpringCloud微服务工程公共部分提取

前言:在上一章节中我们创建了两个个SpringCloud工程,但在两个工程中分别存在着一些重复的部分,例如重复的实体类(如图所示),这样会造成系统的冗余,所以我们需要把公共的类提取到一个工程里&…

AOC vs. DAC:哪个更适合您的网络需求?

在现代网络通信中,选择合适的连接线缆对于数据传输的稳定性和速度至关重要。两种常见的线缆类型是 AOC(Active Optical Cable) 和 DAC(Direct Attach Cable)。本文将详细介绍这两种线缆的特点、优势和适用场景&#xf…

03_Scala变量和数据类型

文章目录 [toc] **变量和数据类型****1.注释****2.变量和常量****3. 标识符的命名规范****4.scala的字符串****5.键盘输入****5.1 StdIn.readLine()****5.2 从文件中读取数据****5.3 Scala向外写数据** 变量和数据类型 1.注释 和Java完全一样 ** ** 2.变量和常量 var name…

LiveNVR监控流媒体Onvif/RTSP常见问题-如何对比监控摄像头延时视频流延时支持webrtc视频流播放超低延时播放

LiveNVR如何对比监控摄像头延时视频流延时支持webrtc视频流播放超低延时播放 1、问题场景2、如何对比延时?3、WEBRTC延时对比4、LiveNVR支持WEBRTC输出5、RTSP/HLS/FLV/RTMP拉流Onvif流媒体服务 1、问题场景 需要低延时的视频流监控播放,之前可以用rtmp…

OpenHarmony实战开发-媒体查询 (@ohos.mediaquery)

概述 媒体查询作为响应式设计的核心,在移动设备上应用十分广泛。媒体查询可根据不同设备类型或同设备不同状态修改应用的样式。媒体查询常用于下面两种场景: 针对设备和应用的属性信息(比如显示区域、深浅色、分辨率)&#xff0…

新手Pytorch入门笔记-transforms.Compose()

我使用的图片是上图,直接下载即可 transforms.Compose 是PyTorch中的一个实用工具,用于创建一个包含多个数据变换操作的变换对象。这些变换操作通常用于数据预处理,例如图像数据的缩放、裁剪、旋转等。使用transforms.Compose 可以将多个数据…

如何在 Flutter 中制作多种颜色的 TextField

TextField widget 本身并不施加任何样式。相反,它会要求 TextEditingController 生成一个样式化的 TextSpan 对象,即一段带有样式的文本。 TextField 将其样式传递给 TextEditingController ,默认实现只是将其放入 TextSpan 对象中&#xff0…

基于SSM+Vue的护工预约服务小程序和后台管理系统

1、系统演示视频(演示视频) 2、需要请联系

【GIS教程】ArcGIS做日照分析(附练习数据下载)

我国对住宅日照标准的规定是:冬至日住宅底层日照不少于1小时或大寒日住宅层日照不少于2小时(通常以当地冬至日正午12时的太阳高度角作为依据)。因冬至日太阳高度角最低,照射范围最小,如果冬至日12:00建筑物底层能够接收到阳光,那么…

kubernetes部署控制器Deployment

一、概念 在学习rc和rs控制器资源时,这两个资源都是控制pod的副本数量的,但是,他们两个有个缺点,就是在部署新版本pod或者回滚代码的时候,需要先apply资源清单,然后再删除现有pod,通过资源控制&…

[Flutter3] 记录Dio的简单封装(一)

文章目录 效果使用ResponseEntity类DioManager封装_onResponse / _onDioException 的设计Response的处理catch处理 效果 请求成功/失败/异常的日志输出效果 成功: 失败:500 失败:404 网络异常: 使用 举个使用的例子, 在调用 DioManager的时候, 直接通过返回值的状态, 来…

森林消防隔膜泵的应用与前景——恒峰智慧科技

随着全球气候变暖,森林火灾频发,给生态环境和人类安全带来严重威胁。为有效应对这一挑战,森林消防领域不断引入新技术、新装备。其中,隔膜泵作为一种高效、可靠的消防设备,正逐渐受到广泛关注。本文将探讨森林消防隔膜…

Python Flask Web框架快速入门

Flask 入门Demo Flask 开发环境搭建,执行如下指令: pip install flask # 第一节: Flask 快速入门from flask import Flask app Flask(__name__)app.route(/flask) def hello_flask():return Hello Flaskapp.run() 核心代码剖析: 从 fla…

五种主流数据库:集合运算

关系型数据库中的表与集合理论中的集合类似,表是由行(记录)组成的集合。因此,SQL 支持基于数据行的各种集合运算,包括并集运算(Union)、交集运算(Intersect)和差集运算&a…

springcloud Ribbon的详解

1、Ribbon是什么 Ribbon是Netflix发布的开源项目,Spring Cloud Ribbon是基于Netflix Ribbon实现的一套客户端负载均衡的框架。 2、Ribbon能干什么 LB负载均衡(Load Balance)是什么?简单的说就是将用户的请求平摊的分配到多个服务上,从而达…