大模型训练(2):内存开销

模型训练中的存储消耗

1 存储分类

首先,在大模型训练的过程中,GPU都需要存什么内容:存储主要分为两大块:Model States和Residual States

  • Model State:指和模型本身息息相关的,必须存储的内容,具体包括:
    • Optimizer States:Adam优化算法中的momentum(动量)和variance(方差)
    • Gradients:模型梯度 G G G
    • Parameters:模型参数 W W W
  • Residual States:指并非模型必须的,但在训练过程中会额外产生的内容,具体包括:
    • Activation:激活值。在流水线并行中曾详细介绍过。在backward过程中使用链式法则计算梯度时会用到。有了它算梯度会更快,但它不是必须存储的,因为可以通过重新做Forward来算它。
    • Temporary Buffers:临时存储。例如把梯度发送到某块GPU上做加总聚合时产生的存储。
    • Unusable Fragment Memory:碎片化的存储空间。虽然总存储空间是够的,但是如果取不到连续的存储空间,相关的请求也会被fail掉。对这类空间浪费可以通过内存整理来解决。

2 精度混合训练

知道了存储分类,进一步,我们想知道,假设模型的参数 W W W大小是 Φ \Phi Φ(为了方便后面绝对存储空间大小的说明,这里的大小表示数据的个数,即为有 Φ \Phi Φ个参数),那么每一类存储具体占了多大的空间呢?

在分析这个问题前,我们需要来了解精度混合训练。

对于模型,我们肯定希望其参数越精准越好,也即我们用fp32(单精度浮点数,存储占4byte)来表示参数 W W W。但是在forward和backward的过程中,fp32的计算开销也是庞大的。那么能否在计算的过程中,引入fp16或bf16(半精度浮点数,存储占2byte),来减轻计算压力呢?于是,混合精度训练就产生了,它的步骤如下图:

image-20250104022637095

  • 存储一份fp32的parameter,momentum和variance(统称model states)
  • (1)在forward开始之前,额外开辟一块存储空间,将fp32 parameter减半到fp16 parameter
  • (2)正常做forward和backward,在此之间产生的activation和gradients,都用fp16进行存储
  • (3)用fp16 gradients去更新fp32下的model states
  • 当模型收敛后,fp32的parameter就是最终的参数输出

通过这种方式,混合精度训练在计算开销和模型精度上做了权衡。如果不了解fp32,fp16和bf16的细节也没关系,不影响下文的阅读。只要记住它们所占的存储空间和精度表达上的差异即可。

3 存储大小估计

现在,我们可以来计算模型在训练时需要的存储大小了,假设模型的参数 W W W大小是 Φ \Phi Φ (此处可以理解为参数数量),以byte为单位,存储如下:

  • 必存(共计 12 Φ 12\Phi 12Φ):
    • Parameters(FP32占4个字节,共 Φ \Phi Φ个)= 4 Φ 4\Phi
    • momentum(FP32占4个字节,共 Φ \Phi Φ个)= 4 Φ 4\Phi
    • variance(FP32占4个字节,共 Φ \Phi Φ个) = 4 Φ 4\Phi
  • 中间值(共计 4 Φ 4\Phi ):
    • Parameters(FP16) = 2 Φ 2\Phi
    • Gradients(FP16) = 2 Φ 2\Phi

因为采用了Adam优化,所以才会出现momentum和variance,当然你也可以选择别的优化办法。因此这里为了更通用些,记模型必存的数据大小为 K Φ K\Phi KΦ 。因此最终内存开销为: 2 Φ + 2 Φ + K Φ 2\Phi+2\Phi+K\Phi ++KΦ

另外,这里暂不将activation纳入统计范围,原因是:

  • activation不仅与模型参数相关,还与batch size相关
  • activation的存储不是必须的。存储activation只是为了在用链式法则做backward的过程中,计算梯度更快一些。但你永远可以通过只保留最初的输入 X X X,重新做forward来得到每一层的activation(虽然实际中并不会这么极端)。
  • 因为activation的这种灵活性,纳入它后不方便衡量系统性能随模型增大的真实变动情况。因此在这里不考虑它,在后面会单开一块说明对activation的优化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/337.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据结构高阶】B-树

目录 一、常见的搜索结构 二、B树 2.1 B树的概念 2.2 B树插入数据的分析 2.3 B树的性能分析 2.4 模拟实现B树 2.4.1 B树节点的定义 2.4.2 B树数据的查找 2.4.3 B树节点的数据插入 2.4.4 B树的遍历 2.4.5 模拟实现B树实现的完整代码 三、B树 3.1 B树的概念 3.2 B树…

java项目之房屋租赁系统源码(springboot+mysql+vue)

项目简介 房屋租赁系统实现了以下功能: 房屋租赁系统的主要使用者分为: 系统管理:个人中心、房屋信息管理、预约看房管理、合同信息管理、房屋报修管理、维修处理管理、房屋评价管理等模块的查看及相应操作; 房屋信息管理&#…

maven的简单介绍

目录 1、maven简介2、maven 的主要特点3、maven的下载与安装4、修改配置文件5、私服(拓展) 1、maven简介 Maven 是一个广泛使用的项目管理和构建工具,主要应用于 Java 项目。Maven 由 Apache 软件基金会开发和维护,它提供了一种简洁且一致的方法来构建、…

搭建prometheus+grafana监控系统抓取Linux主机系统资源数据

Prometheus 和 Grafana 是两个非常流行的开源工具,通常结合使用来实现监控、可视化和告警功能。它们在现代 DevOps 和云原生环境中被广泛使用。 1. Prometheus 定义:Prometheus 是一个开源的系统监控和告警工具包,最初由 SoundCloud 开发&am…

【YOLOv5】源码(train.py)

train.py是YOLOv5中用于模型训练的脚本文件,其主要功能是读取配置文件、设置训练参数、构建模型结构、加载数据、训练/验证模型、保存模型权重文件、输出日志等 参考笔记: 【YOLOv3】源码(train.py)_yolo原始代码-CSDN博客 【y…

MySQL存储引擎、索引、索引失效

MySQL Docker 安装 MySQL8.0,安装见docker-compose.yaml 操作类型 SQL 程序语言有四种类型,对数据库的基本操作都属于这四种类,分为 DDL、DML、DQL、DCL DDL(Dara Definition Language 数据定义语言),是负责数据结构定义与数据…

如何看待Akamai 退出中国市场进行转型?

Akamai宣布退出中国市场并进行战略转型,这一举措引发了广泛的关注和讨论。从多个角度来看,这一决策既反映了Akamai自身的业务调整需求,也与中国市场环境的变化密切相关。 Akamai的退出是其全球战略调整的一部分。Akamai近年来一直在推进业务…

【Linux】网络层

目录 IP协议 协议头格式 网段划分 2中网段划分的方式 为什么要进行网段划分 特殊的IP地址 IP地址的数量限制 私有IP地址和公有IP地址 路由 IP协议 在通信时,主机B要把数据要给主机C,一定要经过一条路径选择,为什么经过路由器G后&…

多线程面试相关

线程基础知识 线程与进程的区别 并行和并发的区别 创建线程的方式 Runnable和Callable有什么区别 run()方法和start()方法的区别 小结 线程包含哪些状态,各个状态之间如何变化 线程按顺序执行 notify()和notifyAll()的区别 Java中的wait方法和sleep方法的不同 如何…

Unity + Firebase + GoogleSignIn 导入问题

我目前使用 Unity版本:2021.3.33f1 JDK版本为:1.8 Gradle 版本为:6.1.1 Firebase 版本: 9.6.0 Google Sign In 版本为: 1.0.1 问题1 :手机点击登录报错 apk转化成zip,解压,看到/lib/armeabi-v…

【Rust自学】11.10. 集成测试

喜欢的话别忘了点赞、收藏加关注哦(加关注即可阅读全文),对接下来的教程有兴趣的可以关注专栏。谢谢喵!(・ω・) 11.10.1. 什么是集成测试 在Rust里,集成测试完全位于被测试库的外部。集成测试…

【Word_笔记】Word的修订模式内容改为颜色标记

需求如下:请把修改后的部分直接在原文标出来,不要采用修订模式 步骤1:打开需要转换的word后,同时按住alt和F11 进入(Microsoft Visual Basic for Appliations) 步骤2:插入 ---- 模块 步骤3&…

深入Android架构(从线程到AIDL)_21 IPC的Proxy-Stub设计模式03

目录 3、包裝IBinder接口 -- 使用Proxy-Stub设计模式 EIT造型的双层组合 4、 谁来写Proxy及Stub类呢? -- 地头蛇(App开发者)自己写 范例 定义一个新接口: IPlayer 撰写一个Stub类: PlayerStub 撰写mp3Binder类 撰写mp3RemoteService类 3、…

数据在内存的存储

数据类型介绍 前面我们已经学习了基本的内置类型: char //字符数据类型 1字节 打印%c short //短整型 2字节 打印%hd int //整形 4字节 打印%d long long int //长整型 4/8字节 打印%ld l…

springboot 默认的 mysql 驱动版本

本案例以 springboot 3.1.12 版本为例 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>3.1.12</version><relativePath/> </parent> 点击 spring-…

直流无刷电机控制(FOC):电流模式

目录 概述 1 系统框架结构 1.1 硬件模块介绍 1.2 硬件实物图 1.3 引脚接口定义 2 代码实现 2.1 软件架构 2.2 电流检测函数 3 电流环功能实现 3.1 代码实现 3.2 测试代码实现 4 测试 概述 本文主要介绍基于DengFOC的库函数&#xff0c;实现直流无刷电机控制&#x…

在 Flownex 网络中创建传热元件

本文探讨了在 Flownex 仿真中集成新传热元件的步骤&#xff0c;以增强热系统管理能力。 了解 Flownex 中的传热元件 Flownex 中的传热元件复制传导、对流和辐射&#xff0c;从而深入了解系统的热流体行为。准确建模复杂系统需要了解这些单元的特性和功能。Flownex 的高级功能…

《OpenCV计算机视觉实战项目》——银行卡号识别

文章目录 项目任务及要求项目实现思路项目实现及代码导入模块设置参数对模版图像中数字的定位处理银行卡的图像处理读取输入图像&#xff0c;预处理找到数字边框使用模版匹配&#xff0c;计算匹配得分 画出并打印结果 项目任务及要求 任务书&#xff1a; 要为某家银行设计一套…

PyCharm文档管理

背景&#xff1a;使用PyCharmgit做文档管理 需求&#xff1a;需要PyCharm自动识别docx/xslx/vsdx等文件类型&#xff0c;并在PyCharm内点击文档时唤起系统内关联应用(如word、excel、visio) 设置步骤&#xff1a; 1、file -》 settings -》file types 2、在Files opened i…

景联文科技提供高质量多模态数据处理服务,驱动AI新时代

在当今快速发展的AI时代&#xff0c;多模态数据标注成为推动人工智能技术进步的关键环节。景联文科技作为行业领先的AI数据服务提供商&#xff0c;专注于为客户提供高质量、高精度的多模态数据标注服务&#xff0c;涵盖图像、语音、文本、视频及3D点云等多种类型的数据。通过专…