深入解析大语言模型显存占用:训练与推理

深入解析大语言模型显存占用:训练与推理

  • 文章脉络
  • 估算模型保存大小
  • 估算模型在训练时占用显存的大小
    • 全量参数训练
    • PEFT训练
  • 估算模型在推理时占用显存的大小
  • 总结

对于NLP领域的从业者和研究人员来说,有没有遇到过这样一个场景,你的领导(或者导师)突然冷不丁来一句:“最近马斯克又新出了个Grok模型,小王你看看怎么放到我们的业务里来?”
——然而尴尬的是你只知道Grok是个3000亿参数的模型,很大!但是具体要用多少资源你也不知道,这个时候你想拒绝你的领导,但是又怕他追问一些你答不上来的问题,于是只好沉默。

  本篇文章将帮助你优雅又快速地拒绝老板。看完本篇《深入解析大语言模型显存占用:训练与推理》,你将对模型占用显存的问题有个透彻的理解。主要介绍:

  如何有效估算一个模型加载后的显存值?


文章脉络

模型显存占用

图1 模型显存占用分类情况图

  如上面的图1所示,本文将从三个方面介绍大语言模型显存占用估算方法。

  第一部分是模型保存大小。以BERT模型举例,BERT的预训练参数保存为.bin文件后的大小是可以用公式估算出来的。

  第二部分是模型在训练时占用显存的估算。估算模型在训练时占用的显存是每个NLPer必备的能力。这部分会介绍全量参数微调时占用的显存,而且由于PEFT技术(LoRA、P-Tuning)最近比较火,还会介绍使用PEFT方法训练时的显存估算。

  第三部分是模型在推理时占用显存的估算

估算模型保存大小

  话不多说先上公式:

模型大小 ( G B ) = 模型参数量 ∗ 参数类型所占字节数 102 4 3 模型大小(GB) = \frac{模型参数量 * 参数类型所占字节数}{1024^3} 模型大小(GB)=10243模型参数量参数类型所占字节数

  在上面的公式中,模型的参数量我们一般事先会知道(参数量大概成为现在各大AI公司的吹点了),如果实在不知道参数量,在代码里面加载一下模型,然后打印一下也能很快知道,打印模型参数量的代码示例如下:

from transformers import AutoModelForSeq2SeqLM# 导入模型
model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_path)
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')"""
会得到类似于这样的输出:The model has 582,401,280 trainable parameters
"""

  在上面的公式中,不同的参数类型所占的字节有个对比表,如下:

类型所占字节
FP324
FP162
INT81

  不同参数类型往往代表着资源的消耗,理想情况下,我们用FP16训练模型,可以比FP32少用整整一半的显存,确实很香。但是FP16会损失一定的精度,降低模型的表现与效果。而INT8就只能在推理的时候用用了,而且效果也很不理想。

【注意】关于FP32、FP16、INT4、INT8这些类型,其实只是对数值计数的比特位的不同而导致能表达的数值范围不同。和《计算机组成原理》最开始学的浮点数那边的知识差不多。感兴趣可以去了解一下。

  所以说,模型精度和模型速度是个取乎平衡的问题。虽然用FP16训练模型会更快,但是所有组织都会开源FP32的模型参数,甚至我们用混合精度训练模型最终保存的参数也是FP32格式的,这是因为用FP32的模型效果更好。所以一般我们在计算模型保存的大小时,就默认每个参数所占字节为4就好了。

  同样以BERT模型举例,它大概是1亿1千万个参数,然后是以FP32形式保存的,每个参数占4个字节,那么BERT模型的参数保存下来占的大小为:

0.40978 ( G B ) = 419.62 ( M B ) = 11000 , 0000 ∗ 4 102 4 3 0.40978(GB) =419.62(MB)= \frac{11000,0000 * 4}{1024^3} 0.40978(GB)=419.62(MB)=1024311000,00004

  可以看到图2中BERT模型的bin文件的大小为420MB,和我们预测的419.62MB是差不多的。

在这里插入图片描述

图2 BERT模型参数大小

  同理,mT5模型的参数量大概有58200,0000,所以它的大小为:

2.168 ( G B ) = 58200 , 0000 ∗ 4 102 4 3 2.168(GB) = \frac{58200,0000 * 4}{1024^3} 2.168(GB)=1024358200,00004

在这里插入图片描述

图3 mT5模型参数大小

  图3中mT5模型的bin文件的大小为2.16GB,和我们预测的2.168GB是差不多的。

估算模型在训练时占用显存的大小

  这个部分我们分全量参数训练和PEFT训练两个部分来讲,先讲最基本的全量参数训练。

全量参数训练

在这里插入图片描述

图4 模型训练时显存开销类型(引自知乎kaiyuan)

  如图4所示,通常来说,模型在训练阶段,显存开销主要有四个部分:模型本身参数(图中绿色)、优化器状态(图中黄色)、梯度(图中蓝色)、正向传播的中间计算结果(图中红色、黑色)。

  模型本身参数:静态值,训练时保持不变。模型参数需要加载到显存中参与计算,此部分开销为1倍模型参数量。根据模型的参数类型,最后在算字节的时候,FP32要乘4,FP16要乘2。

  优化器状态:静态值,训练时保持不变。主要看优化器的种类,通常Adamw再带一阶动量的话则需要1倍模型参数量,带二阶动量则需要2倍模型参数量。默认用的Adam都是带2阶动量的。根据优化器状态的参数类型,最后在算字节的时候,FP32要乘4,FP16要乘2。

  梯度:动态值,训练时呈现波状。一个参数对应一个梯度值,所以是1倍。根据梯度的参数类型,最后在算字节的时候,FP32要乘4,FP16要乘2。

【注意】一般情况下是不会只用FP16训练模型的,会使用混合精度。在混合精度下,模型参数是FP32,而梯度是FP16,并且不同的方法对不同变量的类型设置也略有不同,所以采用混合精度时需要大家自己看情况估算。

  正向传播的中间计算结果:动态值,训练时呈现波状。反向传播中需要对中间层的计算图求导,所以中间层的输出不会被释放。这里的占用资源显然与batch大小、seq长度有关,但难以计算,常用的方法是改batch、seq看显存差值然后去估算。

  还有其他零零散散的部分就不纳入考虑了。

【注意】对于正向传播的中间计算结果,其实也是可以估算的,论文Reducing Activation Recomputation in Large Transformer Models有给出Transformer-based模型正向传播的中间计算结果的估算公式:
1、无重计算的激活内存:(s * b * h * l) * (10+24/t+5 * a * s/ h / t) (Byte)
2、选择性重计算的激活内存:(s * b * h * l) * (10+24/t) (Byte)
3、全部重计算的激活内存:2 * (s * b * h * l) (Byte)
————————
s 是token 长度
b 是 每个GPU的batch size
h 是 每个hidden layer的维度
l 是 模型的隐层数
a 是 transformer 模型中注意力头 (attention heads) 的个数
t 是张量并行度 (如果无张量并行,则为 1)
————————
但是既然都估算了,为什么还要去套这么复杂的公式呢,所以我比较喜欢直接改batch_size的大小来估算。。。

  于是我们可以粗略的估算出,在FP32的情况下,mT5有58000,0000个参数,那么光是模型本身+优化器+梯度,就要4倍参数的显存大小,也就是8.672GB。那么正常的60系显卡(8GB显存)就玩不动了。

8.672 ( G B ) = 58200 , 0000 ∗ 4 102 4 3 ∗ 4 8.672(GB) = \frac{58200,0000 * 4}{1024^3} * 4 8.672(GB)=1024358200,000044

  然后再来看下正向传播的中间计算结果大概占用的显存,参考图5,我设置batch=32,seq=10,在翻译任务训练时显存保持在12.2GB,那么减去8.672可以得到正向传播的中间计算结果占用3.528GB的显存。

在这里插入图片描述

图5 mT5模型全量参数训练时显存消耗截图(batch=32,seq=10)

PEFT训练

  GPT-2、T5以及后续的语言模型都已经证明了,在模型的输入中加入前缀,能够让模型去适应不同的任务,完成NLP界的“大一统”。

  具体来说,像T5模型,加上前缀“translate English to German”,会自动输出英德翻译的结果,加上前缀“Answer the following yes/no question.”即可完成二分类任务,还有其他的前缀此处不再举例。感兴趣的可以在T5模型页右侧接口API那里自己玩一下https://huggingface.co/google/flan-t5-base。

  PEFT 是 “Prompt Engineering for Few-shot Tuning” 的缩写,是一种做few-shot微调的技术,比较火热有Prefix Tuning、P-Tuning、LoRA等。

【注意】PEFT的相关技术此处不详细介绍了,现在已经有相应的peft库了,完美兼容transformers库,接口简单,十分好用。

  举个形象生动的例子:拿BERT来举例,常规的全量微调是在BERT最后接入一个fc层并且更新所有的参数来做文本分类;但是PEFT是在BERT内部插入一些fc层,再在BERT最后接入一个fc层,同时冻结BERT的参数,训练这些额外增加的fc层参数。

请添加图片描述

图6 传统全量微调和PEFT微调方法的简单示例

  如图6所示,图中红色的部分表示参与梯度更新,白色表示冻结参数,各种PEFT的方法都证明了图6中的方法可以达到很好的效果,而且语言模型本身的参数被冻结了,训练成本将极低。当然PEFT技术并不是简单的在LLM内部插入一些fc层,这里只是举个例子,技术细节还是推荐去看对应的论文。

  OK!回归正题,在这种情况下,模型在训练时占用显存的大小如何呢?

  同样的,对于一次FP32全量参数微调,假设使用Adam二阶动量,使用了PEFT之后,会增加N个可以训练的参数,原始模型的K个参数会冻结住,那么:

  模型本身参数:为K+N。虽然原始模型参数已经被冻结,但是还是需要加载到显存中的。
  优化器状态:2N。只有新增的参数可以梯度更新,并采用二阶动量。
  梯度:N。只有新增的参数有梯度。

  因此,与全量参数训练相比,PEFT节省了3K-4N个可训练参数。

  我自己实验了一下,对于mT5模型来说,K=58200,0000,使用了Lora后N=884,736,所以两者显存差值(3K-4N >> 6.49GB),所以PEFT对于资源的节省还是非常非常可观的。

  对于正向传播的中间计算结果大概占用的显存,仍然是把实验跑起来,人工来计算,采用与之前相同的配置,用PEFT的LoRA方法实验结果如下图7所示。在翻译任务训练时显存保持在5.3GB,那么减去PEFT方法占用的显存(K+4N >> 2.18GB)可以得到正向传播的中间计算结果占用3.12GB的显存。

在这里插入图片描述

图7 mT5模型PEFT训练时显存消耗截图(batch=32,seq=10)

  综合来看,PEFT能显著减少优化器状态和梯度这两个方面的显存开销,在正向传播的中间计算结果显存开销略微减少。

估算模型在推理时占用显存的大小

  在推理时,占用显存的只有模型本身参数正向传播的中间计算结果

  我用图7中的mT5模型PEFT方法训练好的模型进行了batch=1的推理,观测到显存最大占用为3.3GB。

  说明,推理时正向传播的中间计算结果的显存=3.3-2.18=1.12GB。所以推理时正向传播的中间计算结果的显存也不能简单的拿训练时的占用情况来除以batch_size。可能还是有缓存(Flash Attention之类的)、波束搜索等等其他原因,这要看transformers库的具体实现代码了。

  总之,推理时显存占用是很少的。

总结

  🏆在这篇博客中,我们深入探讨了大型语言模型在训练和推理过程中对显存的占用问题。

  ⭐介绍了如何估算模型保存后的大小。
  ⭐讨论了全量参数训练和PEFT训练两种情况下模型显存占用的估算方法。并且以BERT和mT5模型为例子,解释了如何计算模型参数量和不同参数类型所占的字节数。
  ⭐最后分析了模型在推理时占用显存的大小。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/291990.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言例1-11:语句 while(!a); 中的表达式 !a 可以替换为

A. a!1 B. a!0 C. a0 D. a1 答案&#xff1a;C while()成真才执行&#xff0c;所以!a1 &#xff0c;也就是 a0 原代码如下&#xff1a; #include<stdio.h> int main(void) {int a0;while(!a){a;printf("a\n");} return 0; } 结果如…

平台介绍-搭建赛事运营平台(8)

平台介绍-搭建赛事运营平台&#xff08;5&#xff09;提到了字典是分级的&#xff0c;本篇具体介绍实现。 平台级别的代码是存储在核心库中&#xff0c;品牌级别的代码是存储在品牌库中&#xff08;注意代码类是一样的&#xff09;。这部分底层功能封装为jar包&#xff0c;然后…

算法打卡day21(开始回溯)

今日任务&#xff1a; 1&#xff09;77.组合 77.组合 题目链接&#xff1a;77. 组合 - 力扣&#xff08;LeetCode&#xff09; 文章讲解&#xff1a;代码随想录 (programmercarl.com) 视频讲解&#xff1a;带你学透回溯算法-组合问题&#xff08;对应力扣题目&#xff1a;77…

Stable Diffusion之核心基础知识和网络结构解析

Stable Diffusion核心基础知识和网络结构解析 一. Stable Diffusion核心基础知识1.1 Stable Diffusion模型工作流程1. 文生图(txt2img)2. 图生图3. 图像优化模块 1.2 Stable Diffusion模型核心基础原理1. 扩散模型的基本原理2. 前向扩散过程详解3. 反向扩散过程详解4. 引入Late…

axios+springboot上传图片到本地(vue)

结果&#xff1a; 前端文件&#xff1a; <template> <div> <input type"file" id"file" ref"file" v-on:change"handleFileUpload()"/> <button click"submitFile">上传</button> </div&g…

3D汽车模型线上三维互动展示提供视觉盛宴

VR全景虚拟看车软件正在引领汽车展览行业迈向一个全新的时代&#xff0c;它不仅颠覆了传统展览的局限&#xff0c;还为参展者提供了前所未有的高效、便捷和互动体验。借助于尖端的vr虚拟现实技术、逼真的web3d开发、先进的云计算能力以及强大的大数据处理&#xff0c;这一在线展…

某东推荐的十大3C热榜第一名!2024随身wifi靠谱品牌推荐!2024随身wifi怎么选?

一、鼠标金榜&#xff1a;戴尔 商务办公有线鼠标 售价:19.9&#xffe5; 50万人好评 二、平板电脑金榜&#xff1a;Apple iPod 10.2英寸 售价:2939&#xffe5; 200万人好评 三、随身WiFi金榜&#xff1a;格行随身WiFi 售价:69&#xffe5; 15万人好评 四、游戏本金榜&#xff…

Android 自定义坐标曲线图(二)

Android 自定义坐标曲线图_android 自定义曲线图-CSDN博客 继上一篇文章&#xff0c;点击折线图上的点&#xff0c;显示提示信息进行修改&#xff0c;之前通过回调&#xff0c;调用外部方法&#xff0c;使用popupwindow或dialog来显示&#xff0c;但是这种方法对于弹框显示的位…

Mysql or与in的区别

创建一个表格 内涵一千万条数据 这张表中&#xff0c;只有id有建立索引&#xff0c;且其余都没有 测试1&#xff1a;使用or的情况下&#xff0c;根据主键进行查询 可以看到根据主键id进行or查询 花费了30-114毫秒&#xff0c;后面30多毫秒可能是因为Mysql的Buffer Pool缓冲池的…

Java并查集详解(附Leetcode 547.省份数量讲解)

一、并查集概念 并查集是一种树型的数据结构&#xff0c;用于处理一些不相交集合的合并及查询问题。 并查集的思想是用一个数组表示了整片森林&#xff08;parent&#xff09;&#xff0c;树的根节点唯一标识了一个集合&#xff0c;我们只要找到了某个元素的的树根&#xff0c;…

Unreal的Quixel Bridge下载速度过慢、下载失败

从Quixel Bridge下载MetaHuman模型&#xff0c;速度非常慢&#xff0c;而且经常下载失败&#xff0c;从头下载。 可以从Quixel Bridge的右上角我的图标->Support->Show Logs打开日志目录 downloaded-assets目录下为下载的资源 bridge-plugin.log文件记录了下载URL和下载…

Webpack生成企业站静态页面 - 项目搭建

现在Web前端流行的三大框架有Angular、React、Vue&#xff0c;很多项目经过这几年的洗礼&#xff0c;已经都 转型使用这三大框架进行开发&#xff0c;那为什么还要写纯静态页面呢&#xff1f;比如Vue中除了SPA单页面开发&#xff0c;也可以使用nuxt.js实现SSR服务端渲染&#x…

RabbitMQ基础笔记

视频链接&#xff1a;【黑马程序员RabbitMQ入门到实战教程】 文章目录 1.初识MQ1.1.同步调用1.2.异步调用1.3.技术选型 2.RabbitMQ2.1.安装2.1.1 Docker2.1.1 Linux2.1.1 Windows 2.2.收发消息2.2.1.交换机2.2.2.队列2.2.3.绑定关系2.2.4.发送消息 2.3.数据隔离2.3.1.用户管理2…

Go 之 Gin 框架

Gin 是一个 Go (Golang) 编写的轻量级 web 框架&#xff0c;运行速度非常快&#xff0c;擅长 Api 接口的高并发&#xff0c;如果项目的规模不大&#xff0c;业务相对简单&#xff0c;这个时候我们也推荐您使用 Gin&#xff0c;特别适合微服务框架。 简单路由配置 package mai…

STM32CubeIDE基础学习-USART串口通信实验(中断方式)

STM32CubeIDE基础学习-USART串口通信实验&#xff08;中断方式&#xff09; 文章目录 STM32CubeIDE基础学习-USART串口通信实验&#xff08;中断方式&#xff09;前言第1章 硬件介绍第2章 工程配置2.1 工程外设配置部分2.2 生成工程代码部分 第3章 代码编写第4章 实验现象总结 …

AMD hipcc 生成各个gpu 微架构汇编语言代码的方法示例

1&#xff0c;gpu vectorAdd 示例 为了简化逻辑&#xff0c;故假设 vector 的 size 与运行配置的thread个熟正好一样多&#xff0c;比如都是512之类的. 1.1 源码 vectorAdd.hip #include <stdio.h> #include <hip/hip_runtime.h>__global__ void vectorAdd(con…

Linux shell编程学习笔记43:cut命令

0 前言 在 Linux shell编程学习笔记42&#xff1a;md5sum 中&#xff0c;md5sum命令计算md5校验值后返回信息的格式是&#xff1a; md5校验值 文件名 包括两项内容&#xff0c;前一项是md5校验值 &#xff0c;后一项是文件名。 如果我们只想要前面的md5 校验值&#xff0c…

Golang生成UUID

安装依赖 go get -u github.com/google/uuid文档 谷歌UUID文档 示例 函数签名func NewV7() ( UUID ,错误) func (receiver *basicUtils) GenerateUUID() uuid.UUID {return uuid.Must(uuid.NewV7()) } uid : GenerateUUID()

ssm009毕业生就业信息统计系统+vue

毕业生就业信息统计系统 摘 要 随着移动应用技术的发展&#xff0c;越来越多的学生借助于移动手机、电脑完成生活中的事务&#xff0c;许多的行业也更加重视与互联网的结合&#xff0c;以提高快捷、高效、安全&#xff0c;可以帮助更多有需求的人。针对传统毕业生就业信息统计…

echarts 图表/SVG 图片指定位置截取

echarts 图表/SVG 图片指定位置截取 1.前期准备2.图片截取3.关于drawImage参数 需求&#xff1a;如下图所示&#xff0c;需要固定头部legend信息 1.前期准备 echarts dom渲染容器 <div :id"barchart id" class"charts" ref"barchart">&…