「Pytorch」BF16 Mixed Precision Training

在深度学习领域,神经网络的训练性能瓶颈常常出现在 GPU显存的使用上。主要表现为两方面:

  1. 单卡上可容纳的模型和数据量有限;
  2. 显存与计算单元之间的带宽和延迟限制了运算速度;

为了解决显卡瓶颈的问题,涌现了不同的解决方法。


模型参数量估计

为了更好地估算模型所要占用的显存,首先需要分析模型训练过程中有哪些部分需要消耗存储空间。在 “ZeRO: Memory Optimizations Toward Training Trillion Parameter Model” 中提出,模型在训练时,主要有两大部分的空间占用。

  1. 对于大模型来说,主要的空间占用是模型状态,包括优化器状态(eg:Adam优化器的动量和方差)、模型参数和模型参数的梯度;
  2. 剩余的空间主要被模型训练中间激活值、临时缓冲区和不可用的内存碎片占用,统称为剩余状态。

以 3.5B 大模型为例。①35亿个参数,如果使用 FP16进行存储的话,即 70亿个字节,约7GB左右;②前向传播的激活值和反向传播的梯度大小跟模型参数保持一致,约 7GB;③以 Adam优化器为例,包括三部分,分别为 FP32格式模型参数的备份,FP32的动量和方差,加起来约28GB;因此,从理论上,要微调此模型的话,至少需要 49GB的空间。


1、BF16 半精度浮点数

双精度浮点 Float64;单浮点精度 Float32;半浮点精度 Float16,被广泛应用于model推理
为了进一步降低计算量和显存占用,可以考虑整数int4和int8量化推理

之前深度学习模型的训练通常都采用 Float32(FP32)的精度,而作者发现,使用较低的精度来进行模型的训练也是可行的,并且能够显著提升速度。通过采用混合精度训练,一般可以实现 2~3 倍的速度提升,极大的优化了模型的训练流程。

  • FP32 整体长度为4个字节,即32位,其中有8位的指数位宽,23位的尾数精度和1位的符号位,能够表示的数值范围是 1 × 2 − 126 ∼ ( 2 − ϵ ) × 2 127 1\times 2^{-126} \sim (2-\epsilon)\times 2^{127} 1×2126(2ϵ)×2127
  • 在一些不太需要高精度计算的应用中,eg:图像处理和神经网络中,32位的空间其实有一些浪费,因此又出现了新的数据类型,半精度浮点数 FP16,使用16位(2个字节)来存储浮点值,有5位的指数位宽,10位的尾数精度和1位的符号位,能够表示的数值范围是 1 × 2 − 14 ∼ ( 2 − ϵ ) × 2 15 1\times 2^{-14} \sim (2-\epsilon)\times 2^{15} 1×214(2ϵ)×215
格式位数/位指数位宽/位尾数精度/位符号位/位数值范围
FP32328231 1 × 2 − 126 ∼ ( 2 − ϵ ) × 2 127 1\times 2^{-126} \sim (2-\epsilon)\times 2^{127} 1×2126(2ϵ)×2127
FP16165101 1 × 2 − 14 ∼ ( 2 − ϵ ) × 2 15 1\times 2^{-14} \sim (2-\epsilon)\times 2^{15} 1×214(2ϵ)×215
BP3216871 1 × 2 − 126 ∼ ( 2 − ϵ ) × 2 127 1\times 2^{-126} \sim (2-\epsilon)\times 2^{127} 1×2126(2ϵ)×2127

混合精度训练,即在模型训练时同时采用 FP32 和 FP16 两种精度。在实践过程中,研究人员发现在大语言模型的训练中直接使用 FP16会有一些问题,在训练过程中 loss 会非常不稳定,因此使用 FP16 训练大模型非常困难。问题在于 FP16的指数位宽只有 5位,能表示的最大整数为 65504,一旦权重超过这个值就会发生溢出,因此只能进行较小数的乘法,eg:可以计算 250 × 250 = 62500 250\times250=62500 250×250=62500,但如果计算 255 × 255 = 65025 255\times 255=65025 255×255=65025 就会溢出,这是导致训练出现问题的主要原因。这也意味着模型权重必须保持很小。一种成为损失缩放的技术可以缓解这个问题,但是当模型变得非常大时,FP16 较小的数值范围依旧是一个问题。

  • 为了更好地解决 FP16的问题,谷歌开发了一种新的浮点数格式 BF16(Brain Floating Point, 2个字节),用于降低存储需求,提高机器学习算法的计算速度。BF16 的指数位宽为8位,于 FP32相同,尾数精度采用7位,因此当使用 BF16时,精度非常差。然而,在训练模型时一般采用随机梯度下降法及其变体,其过程像蹒跚而行,即使某一步没有找到最优方向也没关系,模型会在后续调整纠正。

将模型参数类型从 FP16换为 BF16,训练的大模型 loss值的下降也会变得更加稳定。

这种低精度和 混合精度训练的方法逐渐被广泛接受和应用,深度学习框架、GPU以及 神经网络加速器的设计也因此受到了深渊的影响。可以说,混合精度训练的提出,对深度学习领域起到了关键的推动作用,有效地解决了 GPU显存的使用问题,提升了模型训练的效率。


2、混合精度训练

paper:Mixed Precision Training

  1. 维护一个权重的单精度副本,在每个优化器步骤后累计梯度(对于前向和反向传播,此副本四舍五入到半精度);
  2. 提出了损失缩放来保持小幅度的梯度值;
  3. 使用半精度算法,该算法累积为单精度输出,在存储到内存之前将其转化为半精度;

在这里插入图片描述

FP32 为主副本权重,

在混合精度训练时,权重、激活函数、梯度被保存为 FP16,为了与 FP32模型的精度相匹配,在optimizer step时,维持 FP32权重为主线,并使用权重梯度进行更新。在每次迭代时,主权重的 FP16副本用于前向和反向传播,将 FP32训练所需的存储和带宽减半,如上图所示。

虽然对 FP32住权重的需求并不普遍,但许多模型还需要 FP32的两个可能原因是:

  1. 权重更新变得太小,无法在 FP16中表示,任何大小小于 2 − 24 2^{-24} 224 的值在 FP16中都将变为 零,当与学习率相乘时,这些小值梯度在优化器中都会变为零,并对模型的准确性产生不利影响。使用单精度进行更新可以解决这一问题;
  2. 权重值 与 权重更新的比例非常大。在这种情况下,即使权重更新可以在 FP16中表示,当加法操作将其右移以使二进制点与权重对齐时,它仍然可能变为零。当归一化权重值的幅度比权重更新的新幅度大 至少2048倍,就会发生这种情况。由于 FP16有10位尾数,隐式位必须右移11位或更多位置,才能潜在地创建一个零。在比例大于2048时,隐式比特将右移12位或更多位。这将导致权重更新变得无法恢复的零。更大的比例将导致非标注化数字的效果。同样,可以通过计算 FP32中的更新来抵消这种影响。

图2-a所示,在 FP16前后传播更新 FP32权重主线时,匹配FP32训练结果,而更新FP16权重会导致 80%的相对精度损失。

由于更大的 batch-size 和每层的激活被保存以在反向传播过程中重复使用,因此训练内存消耗主要由激活决定。由于激活也以半精度格式存储,因此训练深度神经网络的整体内存消耗大约减半。

2.1 损失缩放

FP16指数偏差将归一化指数的范围集中到 [-14,15],而实践中的梯度值往往由小幅度(负指数)主导,如图3所示,显示了Multibox SSD模型的 FP32训练期间在所有层上的急活梯度值的直方图,FP16 可表示范围的大部分未必使用,而许多值低于最小可表示范围 变为指数为0。放大梯度将使它们占据更多的可表示范围,并保留否则会丢失为0的值。当梯度未被缩放时,这个特定的网络会发散,但将其压缩8倍(指数为3)就足以匹配 FP32训练所达到的精度。这说明激活

在这里插入图片描述

在这里插入图片描述

2.2运算精度

神经网络模型分为三类:vector dot-products, reductions, point-wise operations。当涉及到降低精度的算法时,这些类别受益于不同的处理,为了保持模型的准确性,发现一些模型要求 FP16矢量点积将部分累加成 FP32,在写入内存之前将其转换为 FP16。如果 FP32中没有这种积累,一些 FP16模型与基线模型的精度不匹配。

之前的GPU只支持 FP16 乘/加法运算,而 NVIDIA Volta GPU引入了 Tensor Core,可以将 FP16输入矩阵相乘,并将乘积累加到 FP16或 FP32输出中。

FP32中应进行大幅缩减(向量各元素之和)。在累积统计数据和 softmax层时,这种建好主要出现在 batch-normalization 层中。在两种层类型种,仍然从内存中读取和写入FP16张量,在 FP32中执行算术运算,这并没有减缓训练过程,因为这些层的内存带宽有限,对算术速度不敏感。

逐点操作,如非线性和逐像素矩阵运算,是内存带宽有限,由于算术精度不影响这些运算的速度,因此可以使用FP16 或 FP32.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/397615.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Arduino控制带编码器的直流电机速度

Arduino DC Motor Speed Control with Encoder, Arduino DC Motor Encoder 作者 How to control dc motor with encoder:DC Motor with Encoder Arduino, Circuit Diagram:Driving the Motor with Encoder and Arduino:Control DC motor using Encoder feedback loop: How …

深度学习碎碎念——碎片知识1

1、什么叫模型收敛?什么叫模型欠拟合和过拟合? 什么叫模型收敛?——模型收敛是指在训练过程中,模型的损失函数逐渐减小并且趋于稳定的状态。简而言之,当模型的训练过程达到一个稳定的点,使得进一步的训练不…

CV党福音:YOLOv8实现语义分割(一)

前面我们得知YOLOv8不但可以实现目标检测任务,还包揽了分类、分割、姿态估计等计算机视觉任务。在上一篇博文中,博主已经介绍了YOLOv8如何实现分类,在这篇博文里,博主将介绍其如何将语义分割给收入囊中。 YOLOv8语义分割架构图 …

【C++】特殊类的设计与类型转换

文章目录 1. 特殊类的设计1.1 不能被拷贝的类1.2 只能在堆上创建对象的类1.3 只能在栈上创建对象的类1.4 不能被继承的类1.5 只能创建一个对象的类(单列模式) 2. 类型转换2.1 C/C的类型转换2.2 C规定的四种类型转换2.2.1 static_cast2.2.2 reinterpret_c…

【吊打面试官系列-Elasticsearch面试题】对于 GC 方面,在使用 Elasticsearch 时要注意什么?

大家好,我是锋哥。今天分享关于 【对于 GC 方面,在使用 Elasticsearch 时要注意什么?】面试题,希望对大家有帮助; 对于 GC 方面,在使用 Elasticsearch 时要注意什么? 1、SEE 2、倒排词典的索引需…

vue3使用pnpm运行项目但是运行不起来

运行项目的时候发现根本运行不起来了 尝试过创建.npmr文件 删除node_modules重新下 但是都出现问题了 创建.npmr:不管用 删除node_modules重新下:文字编译乱码,utf-8可能解析处理问题 最后解决方法: 重新创建项目&#xff0…

网络科技公司官网电商软件开发小程序网站pbootcms模板带手机端

免费授权可商用网站模板 PC端移动端后台测试数据 所有页面均都能完全自定义标题/关键词/描述,PHP程序,安全、稳定、快速,响应式同一个后台,数据即时同步,简单适用,附带测试数据!!

物流仓库安全视频智能管理方案:构建全方位、高效能的防护体系

一、背景分析 随着物流行业的快速发展和仓储需求的日益增长,仓库安全成为企业运营中不可忽视的重要环节。传统的人工监控方式不仅效率低下,且难以做到全天候、无死角覆盖,给仓库资产和人员安全带来潜在风险。因此,引入仓库安全视…

了解细胞外基质:它是啥?有啥作用?

了解细胞外基质:它是啥?有啥作用? 大家好,今天我们来阅读这篇Biofabrication methods for reconstructing extracellular matrix mimetics发表于《Bioactive Materials》上的文章。细胞外基质在人体中起着至关重要的作用&#xff…

同城门户同城分类信息网站源码discuz插件+pc端+小程序端+49款插件

同城分类信息 同城好店 同城合伙人 同城招聘 同城卡 同城活动 同城优惠抢购 同城商城 同城头条 同城抽奖 同城拼团 同城砍价 同城电话本 同城认证 同城签到 同城拼车 同城红包 同城子站点 同城相亲 同城交友 同城小程序 比较流行的同城信息门户网站源码,基于dz&…

【计算机网络】网络基础概念

目录 计算机网络发展 协议 协议分层 OSI 七层模型 TCP/IP 五层(四层)模型 究竟什么是协议? 网络与操作系统的关系 网络传输基本流程 局域网网络传输流程 认识 MAC 地址 局域网(以太网为例)通信原理 数据包…

【前端设计方案】H5 图片懒加载 SDK

实现思路 定义<img srcloading.png data-srcxxx.png/>页面滚动&#xff0c;图片露出时&#xff0c;将 data-src 赋值给 src 注意事项&#xff1a;滚动要节流 技术要点 获取图片的位置 elem.getBoundingClientRect() 图片 top < window.innerHeight 时&#xff0c;图片…

Install pytorch 使用 torch 的例子

如果不知道怎么开始和安装软件 从这里开始 如果需要GPU版本&#xff0c;请选择CUDA&#xff0c;而不是CPU PyTorchhttps://pytorch.org/ Python 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:10) [GCC 10.3.0] on linux Type "help", &quo…

opencv 深度图视差图可视化案例

参考:https://www.cnblogs.com/zyly/p/9373991.html(图片这里面下载的) https://blog.csdn.net/He3he3he/article/details/101053457 原理 双目摄像头 视差公式: 三角形对应推算 深度距离转换: 这里d是视差Disparity 代码 下面两种计算视差方法: import os impor…

计算机毕业设计Hadoop+Hive居民用电量分析 居民用电量可视化 电量爬虫 机器学习 深度学习 大数据毕业设计 Spark

《Hadoop居民用电量分析》开题报告 一、研究背景与意义 能源问题在全球范围内一直是热点议题&#xff0c;尤其是随着居民生活水平的提高和城市化进程的加快&#xff0c;居民用电量急剧增长&#xff0c;对电力系统的稳定运行和能源管理提出了更高要求。如何科学合理地管理和分…

T9打卡学习笔记

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 import tensorflow as tfgpus tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True) #设置…

红黑树的插入

文章目录 3.红黑树3.1概念3.2 性质3.3 RBTree的实现3.3.1 insert的框架3.3.2 insert的处理3.3.3 中序遍历3.3.4检查是否平衡和获取树的高度 3.4完整代码 3.红黑树 3.1概念 红黑树&#xff0c;是一种二叉搜索树&#xff0c;但在每个结点上增加一个存储位表示结点的颜色&#xf…

07一阶电路和二阶电路的时域分析

一阶电路和二阶电路的时域分析 时域分析、频域分析、复频域分析本应该在信号与系统&#xff0c;或者数字信号处理这一章节里面进行处理的。 但在电路理论中也有这些知识&#xff0c;那就要好好掌握一下&#xff0c;打个底。详细细致的部分放到信号与系统里面去掌握

【单片机开发软件】使用VSCode开发STM32环境搭建

&#x1f48c; 所属专栏&#xff1a;【单片机开发软件技巧】 &#x1f600; 作  者&#xff1a; 于晓超 &#x1f680; 个人简介&#xff1a;嵌入式工程师&#xff0c;专注嵌入式领域基础和实战分享 &#xff0c;欢迎咨询&#xff01; &#x1f496; 欢迎大家&#xff1…

Java Web —— 第四天(HTTP协议,Tomcat)

HTTP-概述 概念:Hyper Text Transfer Protocol&#xff0c;超文本传输协议&#xff0c;规定了浏览器和服务器之间数据传输的规则 特点: 1. 基于TCP协议:面向连接&#xff0c;安全 2.基于请求-响应模型的:一次请求对应一次响应 3. HTTP协议是无状态的协议: 对于事务处理没有…