当前的深度学习框架大都采用的都是fp32来进行权重参数的存储,比如Python float的类型为双精度浮点数fp64,pytorch Tensor的默认类型为单精度浮点数fp32。随着模型越来越大,加速训练模型的需求就产生了。在深度学习模型中使用fp32主要存在几个问题,第一模型尺寸大,训练的时候对显卡的显存要求高;第二模型训练速度慢;第三模型推理速度慢。其解决方案就是使用低精度计算对模型进行优化。本文主要讲解几种优化显存存储的方法。
1. fp32、fp16、bf16混合精度训练
- FP32 是单精度浮点数,1位符号位,8位指数,23位表示小数,总共32位
- BF16 是对FP32单精度浮点数截断数据,即用8bit 表示指数,7bit 表示小数
- FP16 半精度浮点数,用5bit 表示指数,10bit 表示小数;
与32位相比,采用BF16/FP16吞吐量可以翻倍,内存需求可以减半。但是这两者精度上差异不一样,BF16 可表示的整数范围更广泛,但是尾数精度较小;FP16 表示整数范围较小,但是尾数精度较高。
1.1 混合精度训练
直接使用半精度进行计算会导致的两个问题的处理:舍入误差(Rounding Error)和溢出错误(Grad Overflow / Underflow)
-
舍入误差
float16 的最大舍入误差约为 2 − 10 ~2 ^{-10} 2−10,比 float32 的最大舍入误差 2 − 23 ~2 ^{-23} 2−23 要大不少。 对足够小的浮点数执行的任何操作都会将该值四舍五入到零。在反向传播中很多梯度更新值都非常小,但不为零,在反向传播中舍入误差累积可以把这些数字变成0或者nan, 这会导致不准确的梯度更新,影响网络的收敛 -
溢出错误
由于 float16 的有效的动态范围(正数部分,负数部分与正数对应)约为 5.96 × 1 0 − 8 ∼ 6.55 × 10 4 5.96\times10^{-8} \sim 6.55\times10{^4} 5.96×10−8∼6.55×104,比单精度的 float32 的动态范围 1.4 × 1 0 − 45 ∼ 1.7 × 1 0 38 1.4\times10^{-45} \sim 1.7 \times10^{38} 1.4×10−45∼1.7×1038要狭窄很多,精度下降会导致得到的值大于或者小于fp16的有效动态范围,也就是上溢出或者下溢出。在深度学习中,由于激活函数的的梯度往往要比权重梯度小,更易出现下溢出的情况
针对以上两种情况的解决方法是混合精度训练(Mixed Precision)和损失缩放(Loss Scaling)
-
混合精度训练
混合精度训练是一种通过在FP16上执行尽可能多的操作来大幅度减少神经网络训练时间的技术,在像线性层或是卷积操作上,FP16运算较快,但像Reduction运算又需要 FP32的动态范围。通过混合精度训练的方式,便可以在部分运算操作使用FP16,另一部分则使用 FP32,混合精度功能会尝试为每个运算使用相匹配的数据类型,在内存中用FP16做储存和乘法从而加速计算,用FP32做累加避免舍入误差。这样在权重更新的时候就不会出现舍入误差导致更新失败,混合精度训练的策略有效地缓解了舍入误差的问题 -
损失缩放
尽管使用了混合精度训练,还是会存在无法收敛的情况,原因是激活梯度的值太小,造成了下溢出。损失缩放是指在执行反向传播之前,将损失函数的输出乘以某个标量数(论文建议从8开始)。 乘性增加的损失值产生乘性增加的梯度更新值,提升许多梯度更新值到超过FP16的安全阈值2^-24。 只要确保在应用梯度更新之前撤消缩放,并且不要选择一个太大的缩放以至于产生inf权重更新(上溢出) ,从而导致网络向相反的方向发散
bf16/fp32 混合训练因为两种格式在范围上对齐了,并且 bf16 比 fp16 的范围更大,所以要比 fp16/fp32 混合训练稳定性更高
2. gradient checkpointing
gradient checkpointing(梯度检查点)的工作原理是在反向传播时重新计算深度神经网络的中间值(通常情况是在前向传播时存储的)。这个策略是用时间(重新计算这些值两次的时间成本)来换空间(提前存储这些值的内存成本)
3. Xformers
Xformers 应该是目前社区知名度最高的优化加速方案了,所谓 Xformers 指的是该库将各种transformer 架构的模型囊括其中。注意,该库仅适用于N卡,特点是加速图片生成并降低显存占用,代价是输出图像不稳定,有可能比不开Xformers略差。各种transformer变体可以参考 A Survey of Transformers.
参考
- 彻底搞懂float16与float32的计算方式
- pytorch模型训练之fp16、apm、多GPU模型、梯度检查点(gradient checkpointing)显存优化等
- facebookresearch/xformers