大模型训练显存优化推理加速方案

当前的深度学习框架大都采用的都是fp32来进行权重参数的存储,比如Python float的类型为双精度浮点数fp64,pytorch Tensor的默认类型为单精度浮点数fp32。随着模型越来越大,加速训练模型的需求就产生了。在深度学习模型中使用fp32主要存在几个问题,第一模型尺寸大,训练的时候对显卡的显存要求高;第二模型训练速度慢;第三模型推理速度慢。其解决方案就是使用低精度计算对模型进行优化。本文主要讲解几种优化显存存储的方法。

1. fp32、fp16、bf16混合精度训练

  • FP32 是单精度浮点数,1位符号位,8位指数,23位表示小数,总共32位
  • BF16 是对FP32单精度浮点数截断数据,即用8bit 表示指数,7bit 表示小数
  • FP16 半精度浮点数,用5bit 表示指数,10bit 表示小数;
    请添加图片描述
    与32位相比,采用BF16/FP16吞吐量可以翻倍,内存需求可以减半。但是这两者精度上差异不一样,BF16 可表示的整数范围更广泛,但是尾数精度较小;FP16 表示整数范围较小,但是尾数精度较高。

1.1 混合精度训练

直接使用半精度进行计算会导致的两个问题的处理:舍入误差(Rounding Error)和溢出错误(Grad Overflow / Underflow)

  • 舍入误差
    float16 的最大舍入误差约为 2 − 10 ~2 ^{-10}  210,比 float32 的最大舍入误差 2 − 23 ~2 ^{-23}  223 要大不少。 对足够小的浮点数执行的任何操作都会将该值四舍五入到零。在反向传播中很多梯度更新值都非常小,但不为零,在反向传播中舍入误差累积可以把这些数字变成0或者nan, 这会导致不准确的梯度更新,影响网络的收敛

  • 溢出错误
    由于 float16 的有效的动态范围(正数部分,负数部分与正数对应)约为 5.96 × 1 0 − 8 ∼ 6.55 × 10 4 5.96\times10^{-8} \sim 6.55\times10{^4} 5.96×1086.55×104,比单精度的 float32 的动态范围 1.4 × 1 0 − 45 ∼ 1.7 × 1 0 38 1.4\times10^{-45} \sim 1.7 \times10^{38} 1.4×10451.7×1038要狭窄很多,精度下降会导致得到的值大于或者小于fp16的有效动态范围,也就是上溢出或者下溢出。在深度学习中,由于激活函数的的梯度往往要比权重梯度小,更易出现下溢出的情况

针对以上两种情况的解决方法是混合精度训练(Mixed Precision)和损失缩放(Loss Scaling)

  • 混合精度训练
    混合精度训练是一种通过在FP16上执行尽可能多的操作来大幅度减少神经网络训练时间的技术,在像线性层或是卷积操作上,FP16运算较快,但像Reduction运算又需要 FP32的动态范围。通过混合精度训练的方式,便可以在部分运算操作使用FP16,另一部分则使用 FP32,混合精度功能会尝试为每个运算使用相匹配的数据类型,在内存中用FP16做储存和乘法从而加速计算,用FP32做累加避免舍入误差。这样在权重更新的时候就不会出现舍入误差导致更新失败,混合精度训练的策略有效地缓解了舍入误差的问题

  • 损失缩放
    尽管使用了混合精度训练,还是会存在无法收敛的情况,原因是激活梯度的值太小,造成了下溢出。损失缩放是指在执行反向传播之前,将损失函数的输出乘以某个标量数(论文建议从8开始)。 乘性增加的损失值产生乘性增加的梯度更新值,提升许多梯度更新值到超过FP16的安全阈值2^-24。 只要确保在应用梯度更新之前撤消缩放,并且不要选择一个太大的缩放以至于产生inf权重更新(上溢出) ,从而导致网络向相反的方向发散

bf16/fp32 混合训练因为两种格式在范围上对齐了,并且 bf16 比 fp16 的范围更大,所以要比 fp16/fp32 混合训练稳定性更高

2. gradient checkpointing

gradient checkpointing(梯度检查点)的工作原理是在反向传播时重新计算深度神经网络的中间值(通常情况是在前向传播时存储的)。这个策略是用时间(重新计算这些值两次的时间成本)来换空间(提前存储这些值的内存成本)

3. Xformers

Xformers 应该是目前社区知名度最高的优化加速方案了,所谓 Xformers 指的是该库将各种transformer 架构的模型囊括其中。注意,该库仅适用于N卡,特点是加速图片生成并降低显存占用,代价是输出图像不稳定,有可能比不开Xformers略差。各种transformer变体可以参考 A Survey of Transformers.

参考

  • 彻底搞懂float16与float32的计算方式
  • pytorch模型训练之fp16、apm、多GPU模型、梯度检查点(gradient checkpointing)显存优化等
  • facebookresearch/xformers

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/138961.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

el-upload 上传附件(拆解步骤)

目录 1. 看elementui /element-plus 官网案例 2. html部分: 把官网上的搬下来,最好加一个按钮,上传到服务器(后端) 3. js 部分: 3.1 首先,先定义一个变量,files 3.2 当上传图片…

解析数据库的“四世同堂”,畅聊数据前沿技术!

引言 数据库与大数据一直是技术圈的两个常青领域。PC 时代诞生了最早的关系型数据库,之后数据类型越来越多,出现了各种非关系型数据库。云时代拉开序幕的同时,“大数据”一词也被广泛使用,涵盖海量数据的采集、处理、存储、分析和…

华为云云耀云服务器L实例评测|云耀云服务器L实例部署Dashdot服务器仪表盘

华为云云耀云服务器L实例评测|云耀云服务器L实例部署Dashdot服务器仪表盘 一、云耀云服务器L实例介绍二、Dashdot介绍2.1 Dashdot简介2.2 开发环境要求2.3 Yarn介绍 三、本次实践介绍3.1 本次实践简介3.2 本次环境规划 四、检查服务器环境4.1 购买云耀云服务器L实例…

DA5 网站用户没有补全的信息

目录 1.题目描述 2.输入描述 3.输出描述 4.题目分析 5.通过代码 1.题目描述 现有一个Nowcoder.csv文件,它记录了牛客网的部分用户数据,包含如下字段(字段与字段之间以逗号间隔): Nowcoder_ID:用户ID …

SQL模板-用户留存率计算

在这段实习中,我遇到了用户留存率计算的需求,这里做个总结。 首先来讲下,什么是用户留存? 在互联网行业中,用户在某段时间内开始使用应用,经过一段时间后,仍然继续使用该应用的用户。用户留存一…

netty 客户端 实现断开重连

1、首先引入依赖 <dependency><groupId>io.netty</groupId><artifactId>netty-all</artifactId><version>4.1.6.Final</version> </dependency>2、创建server层代码 2.1、编写服务端代码 public static void main(String[]…

十四、MySql的用户管理

文章目录 一、用户管理二、用户&#xff08;一&#xff09;用户信息&#xff08;二&#xff09;创建用户1.语法&#xff1a;2.案例&#xff1a; &#xff08;三&#xff09; 删除用户1.语法&#xff1a;2.示例&#xff1a; &#xff08;四&#xff09;修改用户密码1.语法&#…

【问题记录】解决“命令行终端”和“Git Bash”操作本地Git仓库时出现 中文乱码 的问题!

环境 Windows 11 家庭中文版git version 2.41.0.windows.1 问题情况 在使用 “命令行终端” 和 “Git Bash” 在本地Git仓库敲击命令时&#xff0c;对中文名称文件显示一连串的数字&#xff0c;如下所示&#xff1a;这种情况通常是由于字符编码设置不正确所引起的 解决办法 设置…

ffmpeg抠图

1.不用png&#xff0c;用AVFrame 2.合流 3.图片抠图透明 (1.)mp4扣yuv图&#xff0c;(2)用1.把一张yuv标记为透明然后av_hwframe_transfer_data到GPU (3)用抠图算法函数对yuv进行处理 (4) qsv的h264_qsv只支持nv12和qsv&#xff0c;但qsv本身并不限制像素格式&#xff0c;比如在…

SpringMVC学习笔记——1

SpringMVC学习笔记——1 一、SpringMVC简介1.1、SpringMVC概述1.2、SpringMVC快速入门1.3、Controller中访问容器中的Bean1.4、SpringMVC关键组件的浅析 二、SpringMVC的请求处理2.1、请求映射路径配置2.2、请求数据的接收2.2.1、键值对方式接收数据2.2.2、封装JavaBean数据2.2…

tomcat架构概览

https://blog.csdn.net/ldw201510803006/article/details/119880100 前言 Tomcat 要实现 2 个核心功能&#xff1a; 处理 Socket 连接&#xff0c;负责网络字节流与 Request 和 Response 对象的转化。加载和管理 Servlet&#xff0c;以及具体处理 Request 请求。 因此 Tomc…

C# 实现数独游戏

1.数独单元 public struct SudokuCell{public SudokuCell() : this(0, 0, 0){}public SudokuCell(int x, int y, int number){X x; Y y; Number number;}public int X { get; set; }public int Y { get; set; }public int Number { get; set; }} 2.数独创建 public class …

elementUI elfrom表单验证无效、不起作用常见原因

今天遇到一个变态的问题&#xff0c;因页面比较复杂&#xff0c;出现几组条件判断&#xff0c;每个template内部又包含很多表单&#xff01;&#xff01; <template v-if"transformTypeValue 1"></template><template v-else-if"transformTypeV…

LeetCode 接雨水 木桶理论、dp预处理

原题链接&#xff1a; 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 题面&#xff1a; 给定 n 个非负整数表示每个宽度为 1 的柱子的高度图&#xff0c;计算按此排列的柱子&#xff0c;下雨之后能接多少雨水。 示例 1&#xff1a; 输入&#xff1a…

如何使用微信文件传输助手?看这里!

微信文件传输助手在哪里&#xff1f;为什么我找不到&#xff1f;有哪位朋友能够告诉我吗&#xff1f; 微信文件传输助手是微信官方推出的一款辅助工具&#xff0c;为用户提供了便捷的文件传输方式。用户在使用微信的过程中&#xff0c;可以随时随地通过该功能在手机和电脑之间任…

【TCP】三次握手 与 四次挥手 详解

三次握手 与 四次挥手 1. 三次握手2. 四次挥手三次握手和四次挥手的区别 在正常情况下&#xff0c;TCP 要经过三次握手建立连接&#xff0c;四次挥手断开连接 1. 三次握手 服务端状态转化&#xff1a; [CLOSED -> LISTEN] 服务器端调用 listen 后进入 LISTEN 状态&#xff…

Flink--4、DateStream API(执行环境、源算子、基本转换算子)

星光下的赶路人star的个人主页 注意力的集中&#xff0c;意象的孤立绝缘&#xff0c;便是美感的态度的最大特点 文章目录 1、DataStream API1.1 执行环境&#xff08;Execution Environment&#xff09;1.1.1 创建执行环境 1.2 执行模式&#xff08;Execution Mode&#xff09;…

0基础学three.js环境搭建(2)

这是0基础学three.js系列中的第二篇&#xff0c;在这篇里面我会带着大家把开发环境搭建起来&#xff0c;关于开发环境&#xff0c;方式很多&#xff0c;如果你没有基础&#xff0c;就跟着我的步骤一步一步来&#xff0c;保你不出错。 首先安装node环境&#xff0c;关于node是干…

【MySQL】 MySQL的增删改查(进阶)--贰

文章目录 &#x1f6eb;新增&#x1f6ec;查询&#x1f334;聚合查询&#x1f6a9;聚合函数&#x1f388;GROUP BY子句&#x1f4cc;HAVING &#x1f38b;联合查询⚾内连接⚽外连接&#x1f9ed;自连接&#x1f3c0;子查询&#x1f3a1;合并查询 &#x1f3a8;MySQL的增删改查(…

C语言的文件操作(炒详解)

⭐回顾回顾文件操作的相关细节⭐ 欢迎大家指正错误 &#x1f4dd;在之前的学习中&#xff0c;不管增加数据&#xff0c;减少数据&#xff0c;当程序退出时&#xff0c;所有的数据都会销毁&#xff0c;等下次运行程序时&#xff0c;又要重新输入相关数据&#xff0c;如果一直像这…