浅谈PyTorch中的DP和DDP

目录

  • 1. 引言
  • 2. PyTorch 数据并行(Data Parallel, DP)
    • 2.1 DP 的优缺点
    • 2.2 DP 实现代码示例
  • 3. PyTorch 分布式数据并行(Distributed Data Parallel, DDP)
    • 3.1 DDP 的优缺点
    • 3.2 分布式基本概念
    • 3.3 DDP 的应用流程
    • 3.5 DDP 实现代码示例
  • 4. DP和DDP的对比

1. 引言

在现代深度学习中,随着模型规模的不断增大以及数据量的快速增长,模型训练所需的计算资源也变得愈加庞大。尤其是在大型深度学习模型的训练过程中,单张 GPU 显存往往难以满足需求,因此,如何高效利用多 GPU 进行并行训练,成为了加速模型训练的关键手段。PyTorch 作为目前最受欢迎的深度学习框架之一,提供了多种并行训练的方式,其中最常用的是 数据并行(Data Parallel, DP)分布式数据并行(Distributed Data Parallel, DDP)

⚠️ 无论是DP还是DDP都只支持数据并行。

2. PyTorch 数据并行(Data Parallel, DP)

数据并行(Data Parallel, DP) 是 PyTorch 中一种简单的并行训练方式,它的主要思想是将数据拆分为多个子集,然后将这些子集分别分配给不同的 GPU 进行计算。DP 的工作原理如下:

  1. 在前向传播时,首先将模型的参数复制到每个 GPU 上。
  2. 每个 GPU 独立计算一部分数据的前向传播和损失值,并将计算结果返回到主 GPU。
  3. 主 GPU 汇总每个 GPU 计算的损失,并计算出梯度。
  4. 通过反向传播,将计算得到的梯度更新主 GPU 的模型参数,然后再将更新后的参数广播到其他 GPU 上。

2.1 DP 的优缺点

优点

  • 实现简单,使用 PyTorch 提供的 torch.nn.DataParallel 接口即可轻松实现。
  • 对于小规模的模型和数据集,DP 能够在单机多卡的场景下提供良好的加速效果。

缺点

  • DP 在每个 batch 中需要在 GPU 之间传递模型参数和数据,参数更新时也需要将梯度传递回主 GPU,这会造成大量的通信开销。
  • 由于梯度的计算和模型参数的更新都是在主 GPU 上完成的,主 GPU 的负载会显著增加,导致 GPU 资源无法得到充分利用。

2.2 DP 实现代码示例

使用 torch.nn.DataParallel 实现数据并行非常简单。我们只需要将模型封装到 DataParallel 中,然后传入多个 GPU 即可。下面我们通过代码示例展示如何使用 DP 进行并行训练。

import torch
import torch.nn as nn
import torchvisionBATCH_SIZE = 256
EPOCHS = 5
NUM_CLASSES = 10
INPUT_SHAPE = (3, 224, 224)  # ResNet-18 的输入尺寸# 1. 创建模型
net = torchvision.models.resnet18(pretrained=False, num_classes=NUM_CLASSES)
net = nn.DataParallel(net)
net = net.cuda()# 2. 生成随机数据
total_steps = 100  # 假设每个 epoch 有 100 个步骤
inputs = torch.randn(BATCH_SIZE, *INPUT_SHAPE).cuda()
targets = torch.randint(0, NUM_CLASSES, (BATCH_SIZE,)).cuda()# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.02, momentum=0.9, weight_decay=0.0001, nesterov=True
)# 4. 开始训练
net.train()
for ep in range(1, EPOCHS + 1):train_loss = correct = total = 0for idx in range(total_steps):outputs = net(inputs)loss = criterion(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()total += targets.size(0)correct += torch.eq(outputs.argmax(dim=1), targets).sum().item()if (idx + 1) % 25 == 0 or (idx + 1) == total_steps:print(f"Epoch [{ep}/{EPOCHS}], Step [{idx + 1}/{total_steps}], Loss: {train_loss / (idx + 1):.3f}, Acc: {correct / total:.3%}")

在这个代码示例中,我们使用了随机生成的输入和标签数据,以简化代码并专注于并行训练的实现。通过将模型封装在 DataParallel 中,我们可以在多个 GPU 上进行并行计算。然而,由于 DP 存在较大的通信开销以及主 GPU 的计算瓶颈,因此在更大规模的训练中,我们更推荐使用分布式数据并行(DDP)来加速训练。

3. PyTorch 分布式数据并行(Distributed Data Parallel, DDP)

分布式数据并行(Distributed Data Parallel, DDP) 是 PyTorch 中推荐使用的多 GPU 并行训练方式,特别适合大规模训练任务。与 DP 不同,DDP 是一种多进程并行方式,避免了 Python 全局解释器锁(GIL)的限制,可以在单机或多机多卡环境中实现更高效的并行计算。DDP的工作原理如下:

  1. 在每个 GPU 上运行一个独立的进程,每个进程都有自己的一份模型副本和数据。
  2. 各个进程独立执行前向传播、计算损失和反向传播,得到各自的梯度。
  3. 在反向传播阶段,各个 GPU 的进程通过通信将梯度汇总,平均后更新每个进程中的模型参数。
  4. 每个进程的模型参数在整个训练过程中保持一致,避免了 DP 中由于参数广播导致的通信开销。

3.1 DDP 的优缺点

优点

  • 由于各个 GPU 上的进程独立计算梯度,更新模型参数时只需要同步梯度而非整个模型,通信开销较小,性能大幅提升。
  • DDP 可以在多机多卡环境下使用,支持大规模的分布式训练,适合深度学习模型的高效扩展。

缺点

  • 代码实现相对 DP 较为复杂,需要手动管理进程的初始化和同步。

3.2 分布式基本概念

在使用 DDP 进行分布式训练时,我们需要理解以下几个基本概念:

  1. node(节点):物理节点,一台机器即为一个节点。
  2. nnodes(节点数量):表示参与训练的物理节点数量。
  3. node rank(节点序号):节点的编号,用于区分不同的物理节点。
  4. nproc per node(每节点的进程数量):表示每个物理节点上启动的进程数量,通常等于 GPU 的数量。
  5. world size(全局进程数量):表示全局并行的进程总数,等于 nnodes * nproc_per_node
  6. rank(进程序号):表示每个进程的唯一编号,用于进程间通信,rank=0 的进程为主进程。
  7. local rank(本地进程序号):在某个节点上的进程的序号,local_rank=0 表示该节点的主进程。

3.3 DDP 的应用流程

使用 DDP 进行分布式训练的步骤如下:

  1. 初始化分布式训练环境:通过 torch.distributed.init_process_group 初始化进程组,指定通信后端和相关配置。
  2. 创建分布式模型:将模型封装到 torch.nn.parallel.DistributedDataParallel 中,进行并行训练。
  3. 生成或加载数据:在每个进程中加载数据,并确保数据在不同进程间的分布,如使用 DistributedSampler
  4. 执行训练脚本:在每个节点的每个进程上启动训练脚本,进行模型训练。

3.5 DDP 实现代码示例

import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torchvision
from torch.nn.parallel import DistributedDataParallel as DDPBATCH_SIZE = 256
EPOCHS = 5
NUM_CLASSES = 10
INPUT_SHAPE = (3, 224, 224)  # ResNet-18 的输入尺寸if __name__ == "__main__":# 1. 设置分布式变量,初始化进程组rank = int(os.environ["RANK"])local_rank = int(os.environ["LOCAL_RANK"])torch.cuda.set_device(local_rank)dist.init_process_group(backend="nccl")device = torch.device("cuda", local_rank)print(f"[init] == local rank: {local_rank}, global rank: {rank} ==")# 2. 创建模型net = torchvision.models.resnet18(pretrained=False, num_classes=NUM_CLASSES)net = net.to(device)net = DDP(net, device_ids=[local_rank], output_device=local_rank)# 3. 生成随机数据total_steps = 100  # 假设每个 epoch 有 100 个步骤inputs = torch.randn(BATCH_SIZE, *INPUT_SHAPE).to(device)targets = torch.randint(0, NUM_CLASSES, (BATCH_SIZE,)).to(device)# 4. 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = torch.optim.SGD(net.parameters(), lr=0.02, momentum=0.9, weight_decay=0.0001, nesterov=True)# 5. 开始训练net.train()for ep in range(1, EPOCHS + 1):train_loss = correct = total = 0for idx in range(total_steps):outputs = net(inputs)loss = criterion(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()total += targets.size(0)correct += torch.eq(outputs.argmax(dim=1), targets).sum().item()if rank == 0 and ((idx + 1) % 25 == 0 or (idx + 1) == total_steps):print("   == step: [{:3}/{}] [{}/{}] | loss: {:.3f} | acc: {:6.3f}%".format(idx + 1,total_steps,ep,EPOCHS,train_loss / (idx + 1),100.0 * correct / total,))if rank == 0:print("\n            =======  Training Finished  ======= \n")

在以上代码中,我们使用了随机生成的输入和标签数据,以简化代码并专注于 DDP 的实现细节。通过在每个进程中初始化分布式环境,并将模型封装在 DistributedDataParallel 中,我们可以在多个 GPU 上高效地进行并行训练。

需要注意的是,DDP 的实现需要在每个进程中正确设置设备和初始化过程,这样才能确保模型和数据在对应的 GPU 上进行计算。

4. DP和DDP的对比

DP 是单进程多线程的分布式方法,主要用于单机多卡的场景。它的工作方式是在每个批处理期间,将模型参数分发到所有 GPU,各 GPU 计算各自的梯度后将结果汇总到 GPU0,再由 GPU0 完成参数更新,然后将更新后的模型参数广播回其他 GPU。由于 DP 只广播模型的参数,速度较慢,尤其是在多个 GPU 协同工作时,GPU 利用率低,通常效率不如 DDP。

相比之下,DDP 使用多进程架构,既支持单机多卡,也支持多机多卡,并避免了 GIL(全局解释器锁)带来的性能损失。每个进程独立计算梯度,计算完成后各进程汇总并平均梯度,更新参数时各进程均独立完成。这种方式减少了通信开销,只在初始化时广播一次模型参数,并且在每次更新后只传递梯度。由于各进程独立更新参数,且更新过程中模型参数保持一致,DDP 在效率和速度上大大优于 DP。

数据并行(DP)分布式数据并行(DDP)
实现复杂度使用 nn.DataParallel,实现简单,代码改动较少。需要设置分布式环境,使用 torch.distributed,代码实现相对复杂,需要手动管理进程和同步。
通信开销通信开销较大,参数和梯度需要在主 GPU 和其他 GPU 之间频繁传递。通信开销较小,只在反向传播时同步梯度,各 GPU 之间直接通信,无需通过主 GPU。
扩展性扩展性有限,适用于单机多卡,不支持多机训练。扩展性强,支持单机多卡和多机多卡,适合大规模分布式训练。
性能主 GPU 负载重,可能成为瓶颈,GPU 资源利用率较低。各 GPU 负载均衡,资源利用率高,训练速度更快。
适用场景适合小规模模型和数据集的单机多卡训练。适合大规模模型和数据集的单机或多机多卡训练。
梯度同步方式梯度在主 GPU 上汇总和更新,需要从其他 GPU 收集梯度。梯度在各 GPU 间直接同步,通常使用 All-Reduce 操作,效率更高。
模型参数广播每次前向传播都需要将模型参数从主 GPU 复制到其他 GPU。初始化时各进程各自持有一份模型副本,参数更新后自动同步,无需频繁复制。
对 Python GIL 的影响受限于 Python 全局解释器锁(GIL),因为是单进程多线程,无法充分利用多核 CPU。采用多进程方式,不受 GIL 影响,能够充分利用多核 CPU 和多 GPU 进行并行计算。
容错性主 GPU 故障会导致整个训练中断,容错性较差。各进程相对独立,某个进程出错不会影响其他进程,容错性较好。
调试难度由于是单进程,调试相对容易。多进程调试较为复杂,需要注意进程间的通信和同步问题。
代码修改量只需在模型外层加上 nn.DataParallel 封装,代码改动少。需要在代码中添加进程初始化、模型封装、设备设置等步骤,修改量较大。
数据加载方式使用常规的数据加载方式,无需特殊处理。需要使用 DistributedSampler 等工具,确保各进程加载不同的数据子集,避免数据重复。
资源占用主 GPU 内存和计算资源占用较高,其他 GPU 资源可能未被充分利用。各 GPU 资源均衡占用,能够最大化利用多 GPU 的计算能力。
训练结果一致性由于参数更新在主 GPU 上进行,可能存在精度损失或不一致的情况。各进程的模型参数同步更新,训练结果一致性更好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/443904.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FreeRTOS学习总结

背景:在裸机开发上,有时候我们需要等待某个信号或者需要延迟时,CPU的运算是白白浪费掉了的,CPU的利用率并不高,我们希望当一个函数在等待的时候,可以去执行其他内容,提高CPU的效率,同…

视频格式不支持播放怎么办?几招教你转换成mp4格式

视频已成为我们生活中不可或缺的一部分,无论是学习、娱乐还是工作交流,视频都扮演着重要角色。然而,在享受视频带来的便利时,我们时常会遇到一个令人头疼的问题——视频格式不支持播放。不同设备、平台和软件对视频格式的支持各不…

什么是组态软件?Web组态软件又是什么?

从事相关工作的对“组态软件”应该都不陌生,那Web组态软件又是什么呢?本文将对Web组态可视化软件(下称“Web组态软件”)做简单介绍,可视化编辑器是Web组态软件中的一个重要功能模块。除了编辑器,还有哪些功能模块?又…

leetcode---素数,最小质因子,最大公约数

1 判断一个数是不是质数(素数) 方法1&#xff1a;依次判断能否被n整除即可&#xff0c;能够整除则不是质数&#xff0c;否则是质数 方法2&#xff1a;假如n是合数&#xff0c;必然存在非1的两个约数p1和p2&#xff0c;其中p1<sqrt(n)&#xff0c;p2>sqrt(n)。 方法3&…

医院管理新思维:Spring Boot技术应用

5系统详细实现 5.1 医生模块的实现 5.1.1 病床信息管理 医院管理系统的医生可以管理病床信息&#xff0c;可以对病床信息添加修改删除操作。具体界面的展示如图5.1所示。 图5.1 病床信息管理界面 5.1.2 药房信息管理 医生可以对药房信息进行添加&#xff0c;修改&#xff0c;…

Java中System类和RunTime类的Api

目录 System 类 1)out 2)err 3)in 4)currentTimeMillis() 5)nanoTime() 6)arraycopy(Object 要从里面复制东西的数组, int 要从里面复制东西数组的索引起始位置, Object 获得复制元素的数组, int 获得复制元素数组的起始索引, int 要复制东西的个数) 7)gc() 8)exit(int status)…

运维工具之ansible

Ansible 1.什么是ansible? ​ ansible是基于ssh架构的自动化运维工具&#xff0c;由python语言实现&#xff0c;通过ansible可以远程批量部署等。 2.部署前提 ​ 控制端需要安装ansible,被控制端要开启ssh服务&#xff0c;并允许远程登录&#xff0c;被管理主机需要安装py…

探讨Facebook在全球社交网络中的技术优势

Facebook作为全球最大的社交网络之一&#xff0c;其技术优势在于多个方面&#xff0c;这些优势不仅塑造了用户体验&#xff0c;也影响了整个社交媒体生态。 个性化用户体验 Facebook通过分析用户的行为和兴趣&#xff0c;提供个性化的内容推荐。利用机器学习算法&#xff0c;平…

仅用一分钟,AI如何帮你构建完整的论文初稿?揭秘背后科技!

大家好&#xff01;在今天的分享中&#xff0c;我们将深入探讨一项令人兴奋的技术进展&#xff1a;仅用一分钟&#xff0c;AI如何帮助你构建一篇完整的论文初稿。这项技术不仅节省了研究人员和学生的宝贵时间&#xff0c;还改变了我们对学术写作的传统认知。 首先&#xff0c;…

【读书笔记·VLSI电路设计方法解密】问题10:从概念到硅片开发SoC芯片的主要任务

从概念到硅片的SoC芯片开发过程可分为以下四个任务&#xff1a;设计、验证、实现和软件开发。 设计&#xff1a;通常从市场调研和产品定义开始&#xff0c;然后进行系统设计&#xff0c;最后以RTL编码结束。验证&#xff1a;确保芯片按照设计规格能够准确执行功能&#xff0c;…

深度学习500问——Chapter17:模型压缩及移动端部署(4)

文章目录 17.9 常用的轻量级网络有哪些 17.9.1 SequeezeNet 17.9.2 MobileNet 17.9.3 MobileNet-v2 17.9.4 Xception 17.9 常用的轻量级网络有哪些 17.9.1 SequeezeNet SqueezeNet出自 F.N.landola, S.Han等人发表的论文《SqueezeNet&#xff1a;ALexNet-level accuracy with…

目标检测中的损失函数

损失函数是用来衡量模型与数据的匹配程度的&#xff0c;也是模型权重更新的基础。计算损失产生模型权重的梯度&#xff0c;随后通过反向传播算法&#xff0c;模型权重得以更新进而更好地适应数据。一般情况下&#xff0c;目标损失函数包含两部分损失&#xff0c;一个是目标框分…

RandLA-Net PB 模型 测试

tensorflow ckpt 模型 转换 pb 模型, 测试模型是否正确, 后续实现 c++ 部署。 Code: https://github.com/QingyongHu/RandLA-Net 测试PB 模型 RandLANetConvert.py import tensorflow.compat.v1 as tf tf.disable_v2_behavior

R语言中的plumber介绍

R语言中的plumber介绍 基本用法常用 API 方法1. GET 方法2. POST 方法3. 带路径参数的 GET 方法 使用 R 对数据进行操作处理 JSON 输入和输出运行 API 的其他选项其他功能 plumber 是个强大的 R 包&#xff0c;用于将 R 代码转换为 Web API&#xff0c;通过使用 plumber&#x…

PowerJob做定时任务调度

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、区别对比二、使用步骤1. 定时任务类型2.PowerJob搭建与部署 前言 提示&#xff1a;这里可以添加本文要记录的大概内容&#xff1a; PowerJob是基于java开…

如何优化抖音直播间数据?

在数字驱动的时代&#xff0c;缺乏精准的数据支撑&#xff0c;任何线上活动都难以形成有效的流量循环。特别是在抖音直播这一领域&#xff0c;深入理解并优化核心数据&#xff0c;是提升直播效果、吸引并留住观众的关键。那么&#xff0c;抖音直播平台在评估一场直播时&#xf…

【重学 MySQL】四十六、创建表的方式

【重学 MySQL】四十六、创建表的方式 使用CREATE TABLE语句创建表使用CREATE TABLE LIKE语句创建表使用CREATE TABLE AS SELECT语句创建表使用CREATE TABLE SELECT语句创建表并从另一个表中选取数据&#xff08;与CREATE TABLE AS SELECT类似&#xff09;使用CREATE TEMPORARY …

安装最新 MySQL 8.0 数据库(教学用)

安装 MySQL 8.0 数据库&#xff08;教学用&#xff09; 文章目录 安装 MySQL 8.0 数据库&#xff08;教学用&#xff09;前言MySQL历史一、第一步二、下载三、安装四、使用五、语法总结 前言 根据 DB-Engines 网站的数据库流行度排名&#xff08;2024年&#xff09;&#xff0…

【Redis】持久化(上)---RDB

文章目录 持久化的概念RDB手动触发自动触发bgsave命令的运行流程RDB文件的处理RDB的优缺点RDB效果展示 持久化的概念 Redis支持AOF和RDB两种持久化机制,持久化功能能有效的避免因进程退出而导致的数据丢失的问题,当下次重启的时候利用之前持久化的文件即可实现数据恢复. 所以此…

一键生成PPT的AI工具-Kimi!

一键生成PPT的AI工具-Kimi&#xff01; 前言介绍Kimi为什么选择Kimi如何使用Kimi在线编辑PPT下载生成的PPT自己编辑 结语 &#x1f600;大家好&#xff01;我是向阳&#x1f31e;&#xff0c;一个想成为优秀全栈开发工程师的有志青年&#xff01; &#x1f4d4;今天不来讨论前后…