LLM各层参数详细分析(以LLaMA为例)

网上大多分析LLM参数的文章都比较粗粒度,对于LLM的精确部署不太友好,在这里记录一下分析LLM参数的过程。


首先看QKV。先上transformer原文
在这里插入图片描述
也就是说,当h(heads) = 1时,在默认情况下, W i Q W_i^Q WiQ W i K W_i^K WiK W i V W_i^V WiV都是2维方阵,方阵维度是 d m o d e l × d m o d e l d_{model} \times d_{model} dmodel×dmodel.

结合llama源码 (https://github.com/facebookresearch/llama/blob/main/llama/model.py)

class ModelArgs:dim: int = 4096n_layers: int = 32n_heads: int = 32n_kv_heads: Optional[int] = Nonevocab_size: int = -1  # defined later by tokenizermultiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2ffn_dim_multiplier: Optional[float] = Nonenorm_eps: float = 1e-5max_batch_size: int = 32max_seq_len: int = 2048
# ...class Attention(nn.Module):"""Multi-head attention module."""def __init__(self, args: ModelArgs):"""Initialize the Attention module.Args:args (ModelArgs): Model configuration parameters.Attributes:n_kv_heads (int): Number of key and value heads.n_local_heads (int): Number of local query heads.n_local_kv_heads (int): Number of local key and value heads.n_rep (int): Number of repetitions for local heads.head_dim (int): Dimension size of each attention head.wq (ColumnParallelLinear): Linear transformation for queries.wk (ColumnParallelLinear): Linear transformation for keys.wv (ColumnParallelLinear): Linear transformation for values.wo (RowParallelLinear): Linear transformation for output.cache_k (torch.Tensor): Cached keys for attention.cache_v (torch.Tensor): Cached values for attention."""super().__init__()self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_headsmodel_parallel_size = fs_init.get_model_parallel_world_size()self.n_local_heads = args.n_heads // model_parallel_sizeself.n_local_kv_heads = self.n_kv_heads // model_parallel_sizeself.n_rep = self.n_local_heads // self.n_local_kv_headsself.head_dim = args.dim // args.n_heads

计算出
self.n_kv_heads = h = 32
self.head_dim = 4096/32=128
所以 W i Q W_i^Q WiQ W i K W_i^K WiK W i V W_i^V WiV 大小都为(4096, 128).(在未拆分前 W Q W^Q WQ W K W^K WK W V W^V WV都是 ( d i m , d i m ) = ( 4096 , 4096 ) (dim, dim) = (4096,4096) (dim,dim)=(4096,4096)大小)

Q , K , V Q,K,V Q,K,V的大小都是 ( n c t x , d i m ) = ( 2048 , 4096 ) (n_{ctx}, dim) = (2048,4096) (nctx,dim)=(2048,4096)在多头公式里。在self-attention里,其实他们都是同一个值:输入X),所以 Q × W i Q Q×W_i^Q Q×WiQ K × W i K K×W_i^K K×WiK Q × W i Q Q×W_i^Q Q×WiQ 都是 ( n c t x , d k ) = ( 2048 , 128 ) (n_{ctx}, d_k)=(2048,128) (nctx,dk)=(2048,128)。带入原文attention公式后,大小为(2048, 128)不变。Attention不改变大小(在默认 d k = d v d_k=d_v dk=dv情况下)。
在这里插入图片描述

经过Cancat,分开的头又合并,大小变为(2048, 4096)矩阵,经过 W O W^O WO大小是(4096,4096))全连接,还是(2048, 4096)矩阵。


然后看Feed forward.根据源码,

class FeedForward(nn.Module):def __init__(self,dim: int,hidden_dim: int,multiple_of: int,ffn_dim_multiplier: Optional[float],):"""Initialize the FeedForward module.Args:dim (int): Input dimension.hidden_dim (int): Hidden dimension of the feedforward layer.multiple_of (int): Value to ensure hidden dimension is a multiple of this value.ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.Attributes:w1 (ColumnParallelLinear): Linear transformation for the first layer.w2 (RowParallelLinear): Linear transformation for the second layer.w3 (ColumnParallelLinear): Linear transformation for the third layer."""super().__init__()hidden_dim = int(2 * hidden_dim / 3)# custom dim factor multiplierif ffn_dim_multiplier is not None:hidden_dim = int(ffn_dim_multiplier * hidden_dim)hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)def forward(self, x):return self.w2(F.silu(self.w1(x)) * self.w3(x))

multiattention layer过后,经过加法和normlayer(RMS norm),进入feed_forward前馈网络。注意这里的前馈网络其中一个维度会有8/3≈2.7的放缩,然后multiple_of又保证必须是256的倍数,所以这里算出来hidden_dim是256的倍数中与8/3*4096最接近的,是11008。以这里的w1,w3大小为(4096,11008),w2大小为(11008,4096). 输出结果大小

整个decode layer计算如图所示,
在这里插入图片描述

来源:https://github.com/microsoft/Llama-2-Onnx/blob/main/Images/DecoderLayer.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/140509.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2023/09/20 day4 qt

做一个动态指针钟表 头文件 #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QPainter> //绘制事件类 #include <QPaintEvent> //画家类 #include <QTime> #include <QTimer> #include <QTimerEvent> QT_BEGIN…

用于设计 CNN 的 7 种不同卷积

一 说明 最近对CNN架构的研究包括许多不同的卷积变体&#xff0c;这让我在阅读这些论文时感到困惑。我认为通过一些更流行的卷积变体的精确定义&#xff0c;效果和用例&#xff08;在计算机视觉和深度学习中&#xff09;是值得的。这些变体旨在保存参数计数、增强推理并利用目标…

rtsp转webrtc的其他几个项目

1&#xff09; mpromonet/webrtc-streamer &#xff08;c开发&#xff09; 把rtsp转webrtc&#xff0c; 通过 load urls from JSON config file ./webrtc-streamer -C config.json 通过exe文件和docker项目实际测试可以显示&#xff0c;但不太稳定加载慢,有时候出错后很难…

7.15 SpringBoot项目实战 【学生入驻】(上):从API接口定义 到 Mybatis查询 串讲

文章目录 前言一、service层 和 dal层方式一、Example方式方式二、Mybatis XML方式方式三、Mybatis 注解方式 二、web层 StudentController最后 前言 接下来我们实战【学生入驻】&#xff0c;对于C端学生端&#xff0c;一切交互开始于知道 当前学生是否入驻、是否有借阅资格&a…

LeetCode 753. 破解保险箱【欧拉回路,DFS】困难

本文属于「征服LeetCode」系列文章之一&#xff0c;这一系列正式开始于2021/08/12。由于LeetCode上部分题目有锁&#xff0c;本系列将至少持续到刷完所有无锁题之日为止&#xff1b;由于LeetCode还在不断地创建新题&#xff0c;本系列的终止日期可能是永远。在这一系列刷题文章…

java字符串压缩和字符串解压

java字符串压缩和字符串解压 运行效果 java工具类 CompressUtil.java import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.serializer.SerializerFeature; import org.apache.commons.codec.binary.Base64;import java.io.BufferedReader; import java.io.Byte…

el-form自定义规则后表单验证validate不生效的问题

1.首先放出结论&#xff0c;自定义验证规则必须降所有的可能全部都return callback出去&#xff0c;不然不会走validate 错误示范 // template <el-formref"ruleFormRef":model"ruleForm":rules"rules"label-width"120px"class&qu…

力扣:105. 从前序与中序遍历序列构造二叉树(Python3)

题目&#xff1a; 给定两个整数数组 preorder 和 inorder &#xff0c;其中 preorder 是二叉树的先序遍历&#xff0c; inorder 是同一棵树的中序遍历&#xff0c;请构造二叉树并返回其根节点。 来源&#xff1a;力扣&#xff08;LeetCode&#xff09; 链接&#xff1a;力扣&am…

Visio——绘制倾斜线段

一、形状 -> 图表和数学图形 -> 多行 二、放置多行线&#xff0c;可以发现存在两个折点 三、选择多行线&#xff0c;右键选择删除点&#xff0c;即可得到倾斜线段

C进阶-数据的存储

数据类型介绍 内置类型&#xff1a; //数据类型中的内置类型 // char //字符数据类型 // short //短整型 // int //整型 // long //长整型 // long long //更长的整型 // float //单精度浮点数 // double //双精度浮点数 //数据类型中的内置类型 单位是字节 // char //字…

PyCharm 手动下载插件

插件模块一直加载失败&#xff0c;报错信息&#xff1a; Marketplace plugins are not loaded. Check the internet connection and refresh. 尝试了以下方法&#xff0c;均告失败&#xff1a; pip 换源Manage Plugin Repositories...HTTP 代理设置...关闭三个防火墙 最后选…

《动手学深度学习 Pytorch版》 7.4 含并行连接的网络(GoogLeNet)

import torch from torch import nn from torch.nn import functional as F from d2l import torch as d2l7.4.1 Inception块 GoogLNet 中的基本卷积块叫做 Inception 块&#xff08;大概率得名于盗梦空间&#xff09;&#xff0c;由 4 条并行路径组成。 前 3 条路径使用窗口…

华为云云耀云服务器L实例评测|使用云耀云服务器L实例,自行搭建NextCloud

背景 本人是搞分布式存储的&#xff0c;一心想搭个网盘&#xff0c;发现L实例里面就有个网盘实例最低配置是2核4G&#xff0c;为了节省成本&#xff0c;选个【2核2G】的配置&#xff0c;一样可以搭起来。 简介 Nextcloud是一款开源免费的私有云存储网盘项目&#xff0c;可以让…

clickhouse学习之路----clickhouse的特点及安装

clickhouse学习笔记 反正都有学不完的技术&#xff0c;不如就学一学clickhouse吧 文章目录 clickhouse学习笔记clickhouse的特点1.列式存储2. DBMS 的功能3.多样化引擎4.高吞吐写入能力5.数据分区与线程级并行 clickhouse安装1.关闭防火墙2.CentOS 取消打开文件数限制3.安装依…

DATE和LocalDateTime在Java中有什么区别

在Java中&#xff0c;Date和LocalDateTime是两个表示日期和时间的类&#xff0c;它们有以下区别&#xff1a; 类型&#xff1a;Date是Java旧版提供的日期和时间类&#xff0c;而LocalDateTime是Java 8引入的新日期和时间API中的类。 不可变性&#xff1a;Date是可变类&#x…

深入理解C#中委托的使用及不同类型委托的应用示例

在C#中&#xff0c;委托是一种强大而灵活的机制&#xff0c;可以引用一个或多个方法&#xff0c;并允许以类似函数指针的方式进行调用。委托在事件处理、回调函数和多线程编程等场景中非常有用。本文将深入探讨C#中委托的使用&#xff0c;并介绍不同类型委托的应用示例。 目录…

Fiddler抓包工具配置+Jmeter基本使用

一、Fiddler抓包工具的配置和使用 在编写网关自动化脚本之前&#xff0c;得先学会如何抓包&#xff0c;这里以Fiddler为例。会抓包的同学可以跳过这一步&#xff0c;当然看看也是没坏处的…… 局域网络配置 将要进行抓包的手机与电脑连入同一局域网&#xff0c;电脑才能够抓到…

【深度学习实验】前馈神经网络(八):模型评价(自定义支持分批进行评价的Accuracy类)

目录 一、实验介绍 二、实验环境 1. 配置虚拟环境 2. 库版本介绍 三、实验内容 0. 导入必要的工具包 1. __init__(构造函数) 2. update函数(更新评价指标) 5. accumulate(计算准确率) 4. reset(重置评价指标) 5. 构造数据进行测试 6. 代码整合 一、实验介绍 本文将实…

react多条件查询

1、声明一个filter常量 2.filter接受&#xff08;condition,data&#xff09;两个参数 3、调用data里面的filter进行筛选 4、任意一个item当筛选条件 5、使用object.key获取对象所有key 6、对每个key使用Array.prototype.every&#xff08;&#xff09;方法判断是否满足条…

华为数通方向HCIP-DataCom H12-821题库(单选题:361-380)

第361题 如图所示是一台路由器的BGP输出信息。那么以下关于这段信息的描述,错误的是哪一项? <Huawei>display bgp error Error Type: Peer Error Peer Address:10.1.1.2 VRFName:Public Error Info: Router-ID conflictA、该路由器邻居地址是10.1.1.2 B、Error Type显…