DeepSpeed: 大模型训练框架 | 京东云技术团队

背景:

目前,大模型的发展已经非常火热,关于大模型的训练、微调也是各个公司重点关注方向。但是大模型训练的痛点是模型参数过大,动辄上百亿,如果单靠单个GPU来完成训练基本不可能。所以需要多卡或者分布式训练来完成这项工作。

一、分布式训练

1.1 目前主流的大模型分布式训练主要包括两种:

  • 数据并行训练
  • 模型并行训练

二、DeepSpeed

DeepSpeed是由Microsoft提供的分布式训练工具,旨在支持更大规模的模型和提供更多的优化策略和工具。对于更大模型的训练来说,DeepSpeed提供了更多策略,例如:Zero、Offload等。

2.1 基础组件

分布式训练需要掌握分布式环境中的基础配置,包括节点变化、全局进程编号、局部进程编号、全局总进程数、主节点等。这些组件都跟分布式训练紧密相关,同时组件之间也有非常大的联系,例如通信联系等。

2.2 通信策略

既然是分布式训练,那机器之间必须要保持通信,这样才可以传输模型参数,梯度参数等信息。

DeepSpeed提供了mpi、gioo、nccl等通信策略

通信策略通信作用
mpi它是一种跨界点的通信库,经常用于CPU集群的分布式训练
gloo它是一种高性能的分布式训练框架,可以支持CPU或者GPU的分布式训练
nccl它是nvidia提供的GPU专用通信库,广泛用于GPU上的分布式训练

我们在使用DeepSpeed进行分布式训练的时候,可以根据自身的情况选择合适的通信库,通常情况下,如果是GPU进行分布式训练,可以选择nccl。

2.3 Zero(零冗余优化器)

Microsoft开发的Zero可以解决分布式训练过程中数据并行和模型并行的限制。比如: Zero通过在数据并行过程中划分模型状态(优化器、梯度、参数),来解决数据并行成可能出现内存冗余的情况(正常数据并行训练,模型全部参数是复制在各个机器上的);同时可以在训练期间使用动态通信计划,在分布式设备之间共享重要的状态变量,这样保持计算粒度和数据并行的通信量。

Zero是用于大规模模型训练优化的技术,它的主要目的是减少模型的内存占用,让模型可以在显卡上训练,内存占用主要分为Model StatesActivation两个部分,Zero主要解决的是Model States的内存占用问题。

Zero将模型参数分成三个部分:

状态作用
Optimizer States优化器在进行梯度更新的时候需要用到的数据
Gradient在反向转播过程中产生的数据,其决定参数的更新方向
Model Parameter模型参数,在模型训练过程中通过数据“学习”的信息

Zero的级别如下:

级别作用
Zero-0不使用所有类型的分片,仅使用DeepSpeed作为DDP
Zero-1分割Optimizer States, 减少4倍内存,通信容量和数据并行性相同
Zero-2分割Optimizer States和Gradients,减少8倍内存,通信容量和数据并行性相同
Zero-3分割Optimizer States、gradients、Parametes,内存减少与数据并行度呈线性关系。例如,在64个GPU(Nd=64)之间进行拆分将产生64倍的内存缩减。通信量有50%的适度增长
Zero-InfinityZero-Infinity是Zero-3的扩展,它允许通过使用 NVMe 固态硬盘扩展 GPU 和 CPU 内存来训练大型模型

2.4 Zero-Offload:

相比GPU,CPU就相对比较廉价,所以Zero-Offload思想是将训练阶段的某些模型状态放(offload)到内存以及CPU计算。

Zero-Offload不希望为了最小化显存占用而让系统计算效率下降,但如果使用CPU也需要考虑通信和计算的问题(通信:GPU和CPU的通信;计算:CPU占用过多计算就会导致效率降低)。

Zero-Offload想做的是把计算节点和数据节点分布在GPU和CPU上,计算节点落到哪个设备上,哪个设备就执行计算,数据节点落到哪个设备上,哪个设备就负责存储。

Zero-Offload切分思路:

下图中有四个计算类节点:FWD、BWD、Param update和float2half,前两个计算复杂度大致是 O(MB), B是batch size,后两个计算复杂度是 O(M)。为了不降低计算效率,将前两个节点放在GPU,后两个节点不但计算量小还需要和Adam状态打交道,所以放在CPU上,Adam状态自然也放在内存中,为了简化数据图,将前两个节点融合成一个节点FWD-BWD Super Node,将后两个节点融合成一个节点Update Super Node。如下图右边所示,沿着gradient 16和parameter 16两条边切分。

Zero-Offload计算思路:

在GPU上面进行前向和后向计算,将梯度传给CPU,进行参数更新,再将更新后的参数传给GPU。为了提高效率,可以将计算和通信并行起来,GPU在反向传播阶段,可以待梯度值填满bucket后,一遍计算新的梯度一遍将bucket传输给CPU,当反向传播结束,CPU基本上已经有最新的梯度值了,同样的,CPU在参数更新时也同步将已经计算好的参数传给GPU,如下图所示。

2.5 混合精度:

混合精度训练是指在训练过程中同时使用FP16(半精度浮点数)和FP32(单精度浮点数)两种精度的技术。使用FP16可以大大减少内存占用,从而可以训练更大规模的模型。但是,由于FP16的精度较低,训练过程中可能会出现梯度消失和模型坍塌等问题。

DeepSpeed支持混合精度的训练,可以在config.json配置文件中设置来启动混合精度(“fp16.enabled”:true)。在训练的过程中,DeepSpeed会自动将一部分操作转化为FP16格式,并根据需要动态调整精度缩放因子,来保证训练的稳定性和精度。

在使用混合精度训练时,需要注意一些问题,例如梯度裁剪(Gradient Clipping)和学习率调整(Learning Rate Schedule)等。梯度裁剪可以防止梯度爆炸,学习率调整可以帮助模型更好地收敛。

三、总结

DeepSpeed方便了我们在机器有限的情况下来训练、微调大模型,同时它也有很多优秀的性能来使用,后期可以继续挖掘。

目前主流的达模型训练方式: GPU + PyTorch + Megatron-LM + DeepSpeed

优势

  1. 存储效率:DeepSpeed提供了一种Zero的新型解决方案来减少训练显存的占用,它与传统的数据并行不同,它将模型状态和梯度进行分区来节省大量的显存;
  2. 可扩展性:DeepSpeed支持高效的数据并行、模型并行、pipeline并行以及它们的组合,这里也称3D并行;
  3. 易用性: 在训练阶段,只需要修改几行代码就可以使pytorch模型使用DeepSpeed和Zero。

参考:

1. http://wed.xjx100.cn/news/204072.html?action=onClick

2. https://zhuanlan.zhihu.com/p/513571706

作者:京东物流 郑少强

来源:京东云开发者社区 转载请注明来源

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/170612.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux--进程等待

1.什么是进程等待 1.通过系统调用wait/waitid,来对子进程进行进行检测和回收的功能。 2.为什么有进程等待 1.对于每个进程来说,如果子进程终止,父进程没有停止,就会形成僵尸进程,导致内存泄露,为了防止僵尸进程的形成…

【JAVA基础】多线程与线程池

多线程与线程池 文章目录 多线程与线程池1. 相关概念1.1 线程调度1.2 守护线程 2. 生命周期3. 同步机制/同步锁3.1 synchronized3.2 lock3.3 synchronized 与 Lock 的对比 4. 死锁5. 线程通信5.1 线程间的通信5.2 等待唤醒机制5.3 举例5.4 调用 wait 和 notify 需注意的细节5.5…

进阶课4——随机森林

1.定义 随机森林是一种集成学习方法,它利用多棵树对样本进行训练并预测。 随机森林指的是利用多棵树对样本进行训练并预测的一种分类器,每棵树都由随机选择的一部分特征进行训练和构建。通过多棵树的集成,可以增加模型的多样性和泛化能力。…

c++视觉检测------Shi-Tomasi 角点检测

Shi-Tomasi 角点检测 :goodFeaturesToTrack() goodFeaturesToTrack() 函数是 OpenCV 中用于角点检测的功能函数。它的主要作用是检测图像中的良好特征点,通常用于计算机视觉任务中的光流估算、目标跟踪等。 函数签名: void goodFeaturesTo…

pytorch实战---IMDB情感分析

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢…

Pytorch公共数据集、tensorboard、DataLoader使用

本文将主要介绍torchvision.datasets的使用,并以CIFAR-10为例进行介绍,对可视化工具tensorboard进行介绍,包括安装,使用,可视化过程等,最后介绍DataLoader的使用。希望对你有帮助 Pytorch公共数据集 torc…

vscode安装可以打开docx文件的插件

文章目录 vscode安装可以打开docx文件的插件去插件商城搜索并安装安装后打开一个word vscode安装可以打开docx文件的插件 去插件商城搜索并安装 安装后 打开一个word

脏牛提权 liunx

使用方法 Liunx 普通用户 内核版本 在版本里 我直接脏牛提权 有脚本查看内核版本 上传c脚本 编译 直接执行 获取高权限 提权 Liunx https://github.com/InteliSecureLabs/Linux Exploit Suggester 运行这个脚本 上传到客户端 https://github…

Unity Hub报错:No valid Unity Editor license found. Please activate your license.

最近 遇到一个问题,打开高版本时Hub抛出异常:No valid Unity Editor license found. Please activate your license. 首先你必须排除是否登录Unity Hub,并且激活许可证。 方法一:禁用网络(这个可能无效) …

如何通过卖虚拟资料月入10万?看这几个卖资料案例

我微信好友里,有近4000个是做创业博主的同行。 你可能会好奇,其中60%的人都通过卖虚拟资料起家,这到底说明了什么呢? 嗯,事实上,这就意味着这些人选择了网络赚钱的首选项目,那就是销售各种资料…

python selenium如何带cookie访问网站

python selenium如何带cookie访问网站 要使用Python的Selenium库带有cookie访问网站,你可以按照以下步骤进行操作: 一、流程介绍 安装Selenium库(如果尚未安装): pip install selenium导入Selenium库并启动一个浏览…

导致爬虫无法使用的原因有哪些?

随着互联网的普及和发展,爬虫技术也越来越多地被应用到各个领域。然而,在实际使用中,爬虫可能会遇到各种问题导致无法正常工作。本文将探讨导致爬虫无法使用的原因,并给出相应的解决方法。 一、目标网站反爬虫机制 许多网站为了…

11 个最值得推荐的 Windows 数据恢复软件

您可能已经尝试过许多免费的恢复程序,但它们都不起作用,对吧?这就是您正在寻找最好的数据恢复软件的原因。 个人去过那里。根据个人的经验,大多数免费软件并不能解决这个问题。有时,当个人在 PC 上运行恢复程序时&…

【API篇】七、Flink窗口

文章目录 1、窗口2、分类3、窗口API概览4、窗口分配器 在批处理统计中,可以等待一批数据都到齐后,统一处理。但是在无界流的实时处理统计中,是来一条就得处理一条,那么如何统计最近一段时间内的数据呢? ⇒ 窗口的概念&…

选择合适的项目管理系统来支持专业产品研发团队

专业产品研发团队的公司离不开其严谨的管理和高效的研发流程,为了进一步提升研发效率和管理水平,产研团队需要一个全流程的项目管理系统来支持其研发团队的协同合作。 一、系统需求 IT行业的研发工作涵盖了从立项、项目变更到项目的进程计划等多个环节。…

B-tree(PostgreSQL 14 Internals翻译版)

概览 B树(作为B树访问方法实现)是一种数据结构,它使您能够通过从树的根向下查找树的叶节点中所需的元素。为了明确地标识搜索路径,必须对所有树元素进行排序。B树是为有序数据类型设计的,这些数据类型的值可以进行比较和排序。 下面的机场代…

OpenCV视频车流量识别详解与实践

视频车流量识别基本思想是使用背景消去算法将运动物体从图片中提取出来,消除噪声识别运动物体轮廓,最后,在固定区域统计筛选出来符合条件的轮廓。 基于统计背景模型的视频运动目标检测技术: 背景获取:需要在场景存在…

基于PHP的线上购物商城,MySQL数据库,PHPstudy,原生PHP,前台用户+后台管理,完美运行,有一万五千字论文。

目录 演示视频 基本介绍 论文截图 功能结构 系统截图 演示视频 基本介绍 基于PHP的线上购物商城,MySQL数据库,PHPstudy,原生PHP,前台用户后台管理,完美运行,有一万五千字论文。 现如今,购物网站是商业…

【七】SpringBoot为什么可以打成 jar包启动

SpringBoot为什么可以打成 jar包启动 简介:庆幸的是夜跑的习惯一直都在坚持,正如现在坚持写博客一样。最开始刚接触springboot的时候就觉得很神奇,当时也去研究了一番,今晚夜跑又想起来了这茬事,于是想着应该可以记录一…

Python单元测试

import unittest #必须要导入单元测试的包class Student(object):def __init__(self, name, score):self.name nameself.score scoredef get_grade(self):if self.score > 100:#返回错误不能用return,应该用raise raise ValueError("成绩不能大于100"…