Meta-Llama-3-8B-Instruct 模型的混合精度训练显存需求:AdamW优化器(中英双语)

深入分析 Meta-Llama-3-8B-Instruct 模型的混合精度训练显存需求

Meta-Llama-3-8B-Instruct 是一个 8B(80亿)参数的大型语言模型,适用于指令微调任务。与之前的 7B 模型相比,它在计算和存储方面会有更高的需求。为了提高训练效率并减少显存占用,混合精度训练(使用 BF16 格式存储权重,同时使用 FP32 格式存储更新副本和优化器参数)成为一种重要的技术手段。本博客将详细计算 Meta-Llama-3-8B-Instruct 在混合精度训练中的显存需求,并分析为什么权重更新副本和优化器参数使用 FP32 存储。

1. Meta-Llama-3-8B-Instruct 模型的基本参数

Meta-Llama-3-8B-Instruct 模型包含 80 亿参数。根据参数存储格式的不同,显存需求有所不同。我们将计算在混合精度训练下,模型权重、梯度、优化器参数等的显存占用。

存储格式:

  • BF16:每个参数占用 2 字节。
  • FP32:每个参数占用 4 字节。

2. 显存需求的详细计算

在混合精度训练中,显存需求主要由以下几个部分组成:

1)模型权重

  • 前向传播和反向传播
    模型权重以 BF16 存储。
    模型权重(BF16) = 8 × 1 0 9 × 2 bytes = 16 GB \text{模型权重(BF16)} = 8 \times 10^9 \times 2 \, \text{bytes} = 16 \, \text{GB} 模型权重(BF16)=8×109×2bytes=16GB

  • 权重更新副本
    模型权重的 FP32 副本用于更新。
    模型权重(FP32) = 8 × 1 0 9 × 4 bytes = 32 GB \text{模型权重(FP32)} = 8 \times 10^9 \times 4 \, \text{bytes} = 32 \, \text{GB} 模型权重(FP32)=8×109×4bytes=32GB

2)梯度

梯度与模型权重的规模相同,通常使用 BF16 格式进行存储:
梯度(BF16) = 8 × 1 0 9 × 2 bytes = 16 GB \text{梯度(BF16)} = 8 \times 10^9 \times 2 \, \text{bytes} = 16 \, \text{GB} 梯度(BF16)=8×109×2bytes=16GB

3)优化器动量参数

AdamW 优化器需要存储一阶动量和二阶动量,通常以 FP32 格式存储:

  • 一阶动量(( m \mathbf{m} m))
    一阶动量(FP32) = 8 × 1 0 9 × 4 bytes = 32 GB \text{一阶动量(FP32)} = 8 \times 10^9 \times 4 \, \text{bytes} = 32 \, \text{GB} 一阶动量(FP32)=8×109×4bytes=32GB

  • 二阶动量(( v \mathbf{v} v))
    二阶动量(FP32) = 8 × 1 0 9 × 4 bytes = 32 GB \text{二阶动量(FP32)} = 8 \times 10^9 \times 4 \, \text{bytes} = 32 \, \text{GB} 二阶动量(FP32)=8×109×4bytes=32GB

4)激活值

反向传播过程中需要保留前向传播的激活值,通常以 BF16 存储。假设激活值占总权重的 30%:
激活值(BF16) ≈ 0.3 × 16 GB = 4.8 GB \text{激活值(BF16)} \approx 0.3 \times 16 \, \text{GB} = 4.8 \, \text{GB} 激活值(BF16)0.3×16GB=4.8GB


3. 显存需求总结

组件存储格式显存需求
模型权重(BF16)BF1616 GB
权重更新副本(FP32)FP3232 GB
梯度(BF16)BF1616 GB
一阶动量(FP32)FP3232 GB
二阶动量(FP32)FP3232 GB
激活值(BF16)BF164.8 GB
总计132.8 GB

4. 深入分析

为什么优化器参数使用 FP32 存储?

尽管 BF16 格式在计算中可以显著减少显存需求,但其精度不足以保证优化器参数的稳定性,特别是在大规模模型的训练中。由于 AdamW 优化器需要对梯度和动量进行多次更新操作,使用 FP32 可以确保数值精度和更新的稳定性,避免因精度问题导致训练失败。

动量参数显存需求为何如此高?

AdamW 优化器的两种动量参数(( m \mathbf{m} m) 和 ( v \mathbf{v} v))的显存占用与权重相同,因此显存需求非常高。每个动量参数都以 FP32 存储,和模型的参数数量成正比。对 LLaMA 3 8B 模型而言,动量参数的显存需求非常可观。

如何优化显存使用?

  • ZeRO 优化:使用 DeepSpeed 的 ZeRO 技术来分片优化器参数和梯度,能显著降低单张 GPU 的显存占用。
  • 混合精度优化:继续使用 BF16 存储权重和梯度,减少内存消耗,同时通过 FP32 存储动量和权重更新副本,保证数值稳定。

5. 代码示例:DeepSpeed 混合精度训练

下面是如何使用 DeepSpeed 训练 Meta-Llama-3-8B-Instruct 模型的代码示例:

import deepspeed
from transformers import AutoModelForCausalLM, AutoTokenizer# 加载模型和分词器
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)# DeepSpeed 配置
ds_config = {"fp16": {"enabled": True  # 启用混合精度训练 (BF16)},"optimizer": {"type": "AdamW",  # 使用 AdamW 优化器"params": {"lr": 1e-5,"betas": [0.9, 0.999],"eps": 1e-8,"weight_decay": 0.01}},"zero_optimization": {"stage": 2,  # ZeRO Stage 2,分片优化器参数"contiguous_gradients": True,"overlap_comm": True},"gradient_accumulation_steps": 4,"train_micro_batch_size_per_gpu": 1,"gradient_clipping": 1.0
}# 启动 DeepSpeed
model_engine, optimizer, _, _ = deepspeed.initialize(model=model,config_params=ds_config
)# 示例数据
inputs = tokenizer("Hello, DeepSpeed!", return_tensors="pt")
outputs = model_engine(**inputs, labels=inputs["input_ids"])
loss = outputs.loss# 反向传播和优化
model_engine.backward(loss)
model_engine.step()

6. 总结

  1. 混合精度训练通过使用 BF16 存储模型权重和梯度,在保证计算精度的同时显著降低了显存占用。然而,由于 AdamW 优化器需要更高精度的更新操作,动量参数和权重更新副本仍然以 FP32 格式存储。
  2. 对于 Meta-Llama-3-8B-Instruct 模型,混合精度训练的显存需求约为 132.8 GB,其中动量参数和优化器副本占据了显存的大部分。
  3. 通过 DeepSpeed 等技术,如 ZeRO 优化,可以进一步优化显存占用,为大规模模型的训练提供更高效的解决方案。

In-Depth Analysis of Memory Requirements for Mixed Precision Training of Meta-Llama-3-8B-Instruct Model

Meta-Llama-3-8B-Instruct is an 8 billion parameter language model, fine-tuned for instruction-based tasks. Compared to smaller models, it requires more computational resources and memory. Mixed precision training, which stores weights in BF16 format while keeping the update copies of weights and optimizer parameters in FP32 format, helps reduce memory usage and speeds up training. In this blog, we will calculate the memory requirements for the Meta-Llama-3-8B-Instruct model under mixed precision training and explain the rationale behind using FP32 for optimizer parameters despite using BF16 for weights and gradients.

1. Meta-Llama-3-8B-Instruct Model Parameters

The Meta-Llama-3-8B-Instruct model consists of 8 billion parameters. The memory requirements vary depending on the precision format used for different components (weights, gradients, optimizer parameters). Let’s calculate the memory needed for various components of the model under mixed precision training.

Precision Formats:

  • BF16: 2 bytes per parameter.
  • FP32: 4 bytes per parameter.

2. Detailed Calculation of Memory Requirements

Mixed precision training involves storing model weights and gradients in BF16, while optimizer parameters (momentum) and the weight update copies are stored in FP32. Let’s break down the memory requirements for the different components:

1) Model Weights

  • Forward and Backward Propagation:
    Weights are stored in BF16.
    Model Weights (BF16) = 8 × 1 0 9 × 2 bytes = 16 GB \text{Model Weights (BF16)} = 8 \times 10^9 \times 2 \, \text{bytes} = 16 \, \text{GB} Model Weights (BF16)=8×109×2bytes=16GB

  • Weight Update Copies:
    The FP32 copy of the weights is used for updates.
    Model Weights (FP32) = 8 × 1 0 9 × 4 bytes = 32 GB \text{Model Weights (FP32)} = 8 \times 10^9 \times 4 \, \text{bytes} = 32 \, \text{GB} Model Weights (FP32)=8×109×4bytes=32GB

2) Gradients

The gradients have the same scale as the model weights and are usually stored in BF16 format:
Gradients (BF16) = 8 × 1 0 9 × 2 bytes = 16 GB \text{Gradients (BF16)} = 8 \times 10^9 \times 2 \, \text{bytes} = 16 \, \text{GB} Gradients (BF16)=8×109×2bytes=16GB

3) Optimizer Momentum Parameters

AdamW optimizer requires storing first and second moments, which are stored in FP32:

  • First Moment (( m \mathbf{m} m)):
    First Moment (FP32) = 8 × 1 0 9 × 4 bytes = 32 GB \text{First Moment (FP32)} = 8 \times 10^9 \times 4 \, \text{bytes} = 32 \, \text{GB} First Moment (FP32)=8×109×4bytes=32GB

  • Second Moment (( v \mathbf{v} v)):
    Second Moment (FP32) = 8 × 1 0 9 × 4 bytes = 32 GB \text{Second Moment (FP32)} = 8 \times 10^9 \times 4 \, \text{bytes} = 32 \, \text{GB} Second Moment (FP32)=8×109×4bytes=32GB

4) Activations

During backpropagation, activations need to be stored for the forward pass. Activations are typically stored in BF16 format. Assuming activations take up around 30% of the weight size:
Activations (BF16) ≈ 0.3 × 16 GB = 4.8 GB \text{Activations (BF16)} \approx 0.3 \times 16 \, \text{GB} = 4.8 \, \text{GB} Activations (BF16)0.3×16GB=4.8GB


3. Memory Requirements Summary

ComponentPrecision FormatMemory Requirement
Model Weights (BF16)BF1616 GB
Weight Update Copies (FP32)FP3232 GB
Gradients (BF16)BF1616 GB
First Moment (FP32)FP3232 GB
Second Moment (FP32)FP3232 GB
Activations (BF16)BF164.8 GB
Total132.8 GB

4. Detailed Analysis

Why are Optimizer Parameters Stored in FP32?

Even though BF16 offers significant memory savings during calculations, it lacks the precision needed for reliable updates in optimizer parameters. AdamW optimizer requires high precision for momentum updates, as the accuracy of parameter adjustments directly affects model convergence. Storing the momentum parameters and weight update copies in FP32 ensures that numerical stability is maintained during optimization.

Why is the Memory Requirement for Momentum Parameters So High?

Both first and second moment parameters in AdamW are stored in FP32, which results in a significant memory footprint. The memory needed for these parameters scales linearly with the number of model parameters. In the case of LLaMA-3 8B, the optimizer’s memory requirements are substantial, occupying 64 GB just for these momentum terms.

How Can Memory Usage Be Optimized?

  • ZeRO Optimization: DeepSpeed’s ZeRO optimization can help reduce the memory footprint by partitioning the optimizer states and gradients across multiple GPUs, making it feasible to train large models with less memory.
  • Mixed Precision Optimization: Using BF16 for weights, gradients, and activations reduces memory usage while still maintaining enough precision for forward and backward passes, ensuring efficient training.

5. Code Example: DeepSpeed Mixed Precision Training

Here’s how to use DeepSpeed to train the Meta-Llama-3-8B-Instruct model with mixed precision:

import deepspeed
from transformers import AutoModelForCausalLM, AutoTokenizer# Load model and tokenizer
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)# DeepSpeed configuration
ds_config = {"fp16": {"enabled": True  # Enable mixed precision training (BF16)},"optimizer": {"type": "AdamW",  # Use AdamW optimizer"params": {"lr": 1e-5,"betas": [0.9, 0.999],"eps": 1e-8,"weight_decay": 0.01}},"zero_optimization": {"stage": 2,  # ZeRO Stage 2: Partition optimizer states"contiguous_gradients": True,"overlap_comm": True},"gradient_accumulation_steps": 4,"train_micro_batch_size_per_gpu": 1,"gradient_clipping": 1.0
}# Initialize DeepSpeed
model_engine, optimizer, _, _ = deepspeed.initialize(model=model,config_params=ds_config
)# Example input
inputs = tokenizer("Hello, DeepSpeed!", return_tensors="pt")
outputs = model_engine(**inputs, labels=inputs["input_ids"])
loss = outputs.loss# Backward pass and optimizer step
model_engine.backward(loss)
model_engine.step()

6. Conclusion

  1. Mixed precision training reduces memory usage by storing weights and gradients in BF16, while using FP32 for the momentum terms and weight update copies to ensure numerical stability.
  2. For the Meta-Llama-3-8B-Instruct model, the memory requirements under mixed precision training are about 132.8 GB, with momentum parameters and optimizer copies occupying a significant portion of this memory.
  3. By using techniques like ZeRO optimization, DeepSpeed helps further optimize memory usage, making training large models more feasible and efficient.

后记

2024年12月1日14点46分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/483594.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

单片机学习笔记 12. 定时/计数器_定时

更多单片机学习笔记:单片机学习笔记 1. 点亮一个LED灯单片机学习笔记 2. LED灯闪烁单片机学习笔记 3. LED灯流水灯单片机学习笔记 4. 蜂鸣器滴~滴~滴~单片机学习笔记 5. 数码管静态显示单片机学习笔记 6. 数码管动态显示单片机学习笔记 7. 独立键盘单片机学习笔记 8…

6.824/6.5840(2024)环境配置wsl2+vscode

本文是经过笔者实践得出的最速の环境配置 首先,安装wsl2和vscode 具体步骤参见Mit6.s081环境配置踩坑之旅WSL2VScode_mit6s081-CSDN博客 接下来开始为Ubuntu(笔者使用的版本依然是20.04)配置go的相关环境 1、更新Ubuntu的软件包 sudo apt-get install build-es…

【排序用法】.NET开源 ORM 框架 SqlSugar 系列

💥 .NET开源 ORM 框架 SqlSugar 系列 🎉🎉🎉 【开篇】.NET开源 ORM 框架 SqlSugar 系列【入门必看】.NET开源 ORM 框架 SqlSugar 系列【实体配置】.NET开源 ORM 框架 SqlSugar 系列【Db First】.NET开源 ORM 框架 SqlSugar 系列…

「Mac畅玩鸿蒙与硬件38」UI互动应用篇15 - 猜数字增强版

本篇将带你实现一个升级版的数字猜谜游戏。相比基础版,新增了计分和历史记录功能,用户可以在每次猜测后查看自己的得分和猜测历史。此功能展示了状态管理的进阶用法以及如何保存和显示历史数据。 关键词 UI互动应用数字猜谜状态管理历史记录用户交互 一…

040集——CAD中放烟花(CAD—C#二次开发入门)

效果如下: 单一颜色的烟花: 渐变色的火花: namespace AcTools {public class HH{public static TransientManager tm TransientManager.CurrentTransientManager;public static Random rand new Random();public static Vector3D G new V…

【机器学习】分类任务: 二分类与多分类

二分类与多分类:概念与区别 二分类和多分类是分类任务的两种类型,区分的核心在于目标变量(label)的类别数: 二分类:目标变量 y 只有两个类别,通常记为 y∈{0,1} 或 y∈{−1,1}。 示例&#xff…

Python实现网站资源批量下载【可转成exe程序运行】

Python实现网站资源批量下载【可转成exe程序运行】 背景介绍解决方案转为exe可执行程序简单点说详细了解下 声明 背景介绍 发现 宣讲家网 的PPT很好,作为学习资料使用很有价值,所以想下载网站的PPT课件到本地,但是由于网站限制,一…

基于Matlab卡尔曼滤波的GPS/INS集成导航系统研究与实现

随着智能交通和无人驾驶技术的迅猛发展,精确可靠的导航系统已成为提升车辆定位精度与安全性的重要技术。全球定位系统(GPS)和惯性导航系统(INS)在导航应用中各具优势:GPS提供全球定位信息,而INS…

【计算机网络】实验6:IPV4地址的构造超网及IP数据报

实验 6:IPV4地址的构造超网及IP数据报 一、 实验目的 加深对IPV4地址的构造超网(无分类编制)的了解。 加深对IP数据包的发送和转发流程的了解。 二、 实验环境 • Cisco Packet Tracer 模拟器 三、 实验内容 1、了解IPV4地址的构造超网…

使用ESP32通过Arduino IDE点亮1.8寸TFT显示屏

开发板选择 本次使用开发板模块丝印为ESP32-WROOM-32E 开发板库选择 Arduino IDE上型号选择为ESP32-WROOM-DA Module 显示屏选择 使用显示屏为8针SPI接口显示屏 驱动IC为ST7735S 使用库 使用三个Arduino平台库 分别是 Adafruit_GFXAdafruit_ST7735SPI 代码详解 首…

[C++设计模式] 为什么需要设计模式?

文章目录 什么是设计模式?为什么需要设计模式?GOF 设计模式再次理解面向对象软件设计固有的复杂性软件设计复杂性的根本原因如何解决复杂性?分解抽象 结构化 VS 面向对象(封装)结构化设计代码示例:面向对象设计代码示例&#xff1…

机器学习:精确率与召回率的权衡

高精度意味着如果诊断得了那种罕见病的病人,可能病人确实有,这是一个准确的诊断,高召回率意味着如果有一个还有这种罕见疾病的病人,也许算法会正确的识别他们确实患有这种疾病,事实中,在精确与召回之间往往…

03-13、SpringCloud Alibaba第十三章,升级篇,服务降级、熔断和限流Sentinel

SpringCloud Alibaba第十三章,升级篇,服务降级、熔断和限流Sentinel 一、Sentinel概述 1、Sentinel是什么 随着微服务的流行,服务和服务之间的稳定性变得越来越重要。Sentinel 以流量为切入点,从流量控制、熔断降级、系统负载保…

基于vite6+ vue3 + electron@33 实现的 局域网内互传文件的桌面软件

目录 项目介绍项目部分截图介绍下基础项目搭建先搭建一个vite 前端项目 再安装 electron 相关依赖依赖安装失败解决方案修改 vite配置文件和 ts 配置文件修改packjsonts相关配置项目结构介绍 项目介绍 前端 基于 vue3 ts windicss 后端 就是node 层 项目地址: h…

安装MySQL 5.7 亲测有效

前言:本文是笔者在安装MySQL5.7时根据另一位博主大大的安装教程基础上做了一些修改而成 首先在这里表示对博主大大的感谢 下面附博主大大地址 下面的步骤言简意赅 跟着做就不会出错 希望各位读者耐下心来 慢慢解决安装中出现的问题~MySQL 5.7 安装教程(全…

CSS函数

目录 一、背景 二、函数的概念 1. var()函数 2、calc()函数 三、总结 一、背景 今天我们就来说一说,常用的两个css自定义属性,也称为css函数。本文中就成为css函数。先来看一下官方对其的定义。 自定义属性(有时候也被称作CSS 变量或者级…

6.824/6.5840 Lab 1: MapReduce

宁静的夏天 天空中繁星点点 心里头有些思念 思念着你的脸 ——宁夏 完整代码见: https://github.com/SnowLegend-star/6.824 由于这个lab整体难度实在不小,故考虑再三还是决定留下代码仅供参考 6.824的强度早有耳闻,我终于也是到了挑战这座高…

MongoDB集群分片安装部署手册

文章目录 一、集群规划1.1 集群安装规划1.2 端口规划1.3 目录创建 二、mongodb安装(三台均需要操作)2.1 下载、解压2.2 配置环境变量 三、mongodb组件配置3.1 配置config server的副本集3.1.1 config配置文件3.1.2 config server启动3.1.3 初始化config …

一种多功能调试工具设计方案开源

一种多功能调试工具设计方案开源 设计初衷设计方案具体实现HUB芯片采用沁恒微CH339W。TF卡功能网口功能SPI功能IIC功能JTAG功能下行USB接口 安路FPGA烧录器功能Xilinx FPGA烧录器功能Jlink OB功能串口功能RS232串口RS485和RS422串口自适应接口 CAN功能烧录器功能 目前进度后续计…

三维测量与建模笔记 - 5.3 光束法平差(Bundle Adjustment)

此篇笔记尚未理解,先做笔记。 如上图,在不同位姿下对同一个物体采集到了一系列图像, 例子中有四张图片。物体上某点M,在四幅图像上都能找到其观测点。 上式中的f函数是对使用做投影得到的估计点位置。求解这个方程有几种方法&…