模型训练中的存储消耗
1 存储分类
首先,在大模型训练的过程中,GPU都需要存什么内容:存储主要分为两大块:Model States和Residual States
- Model State:指和模型本身息息相关的,必须存储的内容,具体包括:
- Optimizer States:Adam优化算法中的momentum(动量)和variance(方差)
- Gradients:模型梯度 G G G
- Parameters:模型参数 W W W
- Residual States:指并非模型必须的,但在训练过程中会额外产生的内容,具体包括:
- Activation:激活值。在流水线并行中曾详细介绍过。在backward过程中使用链式法则计算梯度时会用到。有了它算梯度会更快,但它不是必须存储的,因为可以通过重新做Forward来算它。
- Temporary Buffers:临时存储。例如把梯度发送到某块GPU上做加总聚合时产生的存储。
- Unusable Fragment Memory:碎片化的存储空间。虽然总存储空间是够的,但是如果取不到连续的存储空间,相关的请求也会被fail掉。对这类空间浪费可以通过内存整理来解决。
2 精度混合训练
知道了存储分类,进一步,我们想知道,假设模型的参数 W W W大小是 Φ \Phi Φ(为了方便后面绝对存储空间大小的说明,这里的大小表示数据的个数,即为有 Φ \Phi Φ个参数),那么每一类存储具体占了多大的空间呢?
在分析这个问题前,我们需要来了解精度混合训练。
对于模型,我们肯定希望其参数越精准越好,也即我们用fp32(单精度浮点数,存储占4byte)来表示参数 W W W。但是在forward和backward的过程中,fp32的计算开销也是庞大的。那么能否在计算的过程中,引入fp16或bf16(半精度浮点数,存储占2byte),来减轻计算压力呢?于是,混合精度训练就产生了,它的步骤如下图:
- 存储一份fp32的parameter,momentum和variance(统称model states)
- (1)在forward开始之前,额外开辟一块存储空间,将fp32 parameter减半到fp16 parameter
- (2)正常做forward和backward,在此之间产生的activation和gradients,都用fp16进行存储
- (3)用fp16 gradients去更新fp32下的model states
- 当模型收敛后,fp32的parameter就是最终的参数输出
通过这种方式,混合精度训练在计算开销和模型精度上做了权衡。如果不了解fp32,fp16和bf16的细节也没关系,不影响下文的阅读。只要记住它们所占的存储空间和精度表达上的差异即可。
3 存储大小估计
现在,我们可以来计算模型在训练时需要的存储大小了,假设模型的参数 W W W大小是 Φ \Phi Φ (此处可以理解为参数数量),以byte为单位,存储如下:
- 必存(共计 12 Φ 12\Phi 12Φ):
- Parameters(FP32占4个字节,共 Φ \Phi Φ个)= 4 Φ 4\Phi 4Φ
- momentum(FP32占4个字节,共 Φ \Phi Φ个)= 4 Φ 4\Phi 4Φ
- variance(FP32占4个字节,共 Φ \Phi Φ个) = 4 Φ 4\Phi 4Φ
- 中间值(共计 4 Φ 4\Phi 4Φ):
- Parameters(FP16) = 2 Φ 2\Phi 2Φ
- Gradients(FP16) = 2 Φ 2\Phi 2Φ
因为采用了Adam优化,所以才会出现momentum和variance,当然你也可以选择别的优化办法。因此这里为了更通用些,记模型必存的数据大小为 K Φ K\Phi KΦ 。因此最终内存开销为: 2 Φ + 2 Φ + K Φ 2\Phi+2\Phi+K\Phi 2Φ+2Φ+KΦ
另外,这里暂不将activation纳入统计范围,原因是:
- activation不仅与模型参数相关,还与batch size相关
- activation的存储不是必须的。存储activation只是为了在用链式法则做backward的过程中,计算梯度更快一些。但你永远可以通过只保留最初的输入 X X X,重新做forward来得到每一层的activation(虽然实际中并不会这么极端)。
- 因为activation的这种灵活性,纳入它后不方便衡量系统性能随模型增大的真实变动情况。因此在这里不考虑它,在后面会单开一块说明对activation的优化。