PyTorch数据并行(DP/DDP)浅析

一直以来都是用的单机单卡训练模型,虽然很多情况下已经足够了,但总有一些情况得上分布式训练:

  • 模型大到一张卡放不下;
  • 单张卡batch size不敢设太大,训练速度慢;
  • 当你有好几张卡,不想浪费;
  • 展示一下技术

由于还没遇到过一张显卡放不下整个模型的情况,本文的分布式训练仅限数据并行。主要从数据并行的原理和一些简单的实践例子进行说明。

文章目录

    • 原理介绍
    • DataParallel
      • 小样
    • DistributedDataParallel
      • 小样
      • 一些概念
    • DDP与DP的区别
    • 参考

原理介绍

与每个step一个batch数据相比,数据并行是指每个step用更多的数据(多个batch)进行计算——即多个batch的数据并行进行前向计算。既然是并行,那么就涉及到多张卡一起计算。单卡和多卡训练过程如下图1所示,主要有三个过程:

  • 各卡分别计算损失和梯度,即图中红线部分;
  • 所以梯度整合到主device,即图中蓝线部分;
  • 主device进行参数更新,并将新模型拷贝到其他device上,即图中绿线部分。
../_images/ps.svg
左图是单GPU训练;右图是多GPU训练的一个变体:(1)计算损失和梯度,(2)所有梯度聚合在一个GPU上,(3)发生参数更新,并将参数重新广播给所有GPU

如果不使用数据并行,在显存足够的情况下,我们可以将batch_size设大,这和数据并行的区别在哪呢?如果只将batch_size设大,计算还是在一张卡上完成,速度相对来说是不如将数据均分后放在不同卡上并行计算的。当然,考虑到卡之间的通信问题,要发挥多卡并行的力量需要进行一定权衡。

torch中主要有两种数据并行方式:DP和DDP。

DataParallel

DP是较简单的一种数据并行方式,直接将模型复制到多个GPU上并行计算,每个GPU计算batch中的一部分数据,各自完成前向和反向后,将梯度汇总到主GPU上。其基本流程:

  1. 加载模型、数据至内存;
  2. 创建DP模型;
  3. DP模型的forward过程:
    1. 一个batch的数据均分到不同device上;
    2. 为每个device复制一份模型;
    3. 至此,每个device上有模型和一份数据,并行进行前向传播;
    4. 收集各个device上的输出;
  4. 每个device上的模型反向传播后,收集梯度到主device上,更新主device上的模型,将模型广播到其他device上;
  5. 3-4循环。

在DP中,只有一个主进程,主进程下有多个线程,每个线程管理一个device的训练。因此,DP中内存中只存在一份数据,各个线程间是共享这份数据的。DP和Parameter Server的方式很像。

小样

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset# 假设我们有一个简单的数据集类
class SimpleDataset(Dataset):def __init__(self, data, target):self.data = dataself.target = targetdef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.target[idx]# 假设我们有一个简单的神经网络模型
class SimpleModel(nn.Module):def __init__(self, input_dim):super(SimpleModel, self).__init__()self.fc = nn.Linear(input_dim, 1)def forward(self, x):return torch.sigmoid(self.fc(x))# 假设我们有一些数据
n_sample = 100
n_dim = 10
batch_size = 10
X = torch.randn(n_sample, n_dim)
Y = torch.randint(0, 2, (n_sample, )).float()dataset = SimpleDataset(X, Y)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# ===== 注意:刚创建的模型是在 cpu 上的 ===== #
device_ids = [0, 1, 2]
model = SimpleModel(n_dim).to(device_ids[0])
model = nn.DataParallel(model, device_ids=device_ids)optimizer = optim.SGD(model.parameters(), lr=0.01)for epoch in range(10):for batch_idx, (inputs, targets) in enumerate(data_loader):inputs, targets = inputs.to('cuda'), targets.to('cuda')outputs = model(inputs)loss = nn.BCELoss()(outputs, targets.unsqueeze(1))optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}')

其中最重要的一行便是:

model = nn.DataParallel(model, device_ids=device_ids)

注意,模型的参数和缓冲区都要放在device_ids[0]上。在执行forward函数时,模型会被复制到各个GPU上,对模型的属性进行更新并不会产生效果,因为前向完后各个卡上的模型就被销毁了。只有在device_ids[0]上对模型的参数或者buffer进行的更新才会生效!2

DistributedDataParallel

DDP,顾名思义,即分布式的数据并行,每个进程独立进行训练,每个进程会加载完整的数据,但是读取不重叠的数据。DDP执行流程3

  • 准备阶段

    • 环境初始化
      • 在各张卡上初始化进程并建立进程间通信,对应代码:init_process_group
    • 模型广播
      • 将模型parameter、buffer广播到各节点,对应代码:model = DDP(model).to(local_rank)
    • 创建管理器reducer,给每个参数注册梯度平均hook。
  • 准备数据

    • 加载数据集,创建适用于分布式场景的数据采样器,以防不同节点使用的数据不重叠。
  • 训练阶段

    • 前向传播
      • 同步各进程状态(parameter和buffer);
      • 当DDP参数find_unused_parametertrue时,其会在forward结束时,启动一个回溯,标记未用到的参数,提前将这些设置为ready
    • 计算梯度
      • reducer外面:各进程各自开始反向计算梯度;
      • reducer外面:当某个参数的梯度计算好了,其之前注册的grad hook就会触发,在reducer里把这个参数的状态标记为ready
      • reducer里面:当某个bucket的所有参数都是ready时,reducer开始对这个bucket的所有参数开始一个异步的all-reduce梯度平均操作;
      • reducer里面:当所有bucket的梯度平均都结束后,reducer把得到的平均梯度正式写入到parameter.grad里。
    • 优化器应用梯度更新参数。

小样

import argparse
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Datasetimport torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP# 1. 基础模块 ### 
class SimpleModel(nn.Module):def __init__(self, input_dim):super(SimpleModel, self).__init__()self.fc = nn.Linear(input_dim, 1)cnt = torch.tensor(0)self.register_buffer('cnt', cnt)def forward(self, x):self.cnt += 1# print("In forward: ", self.cnt, "Rank: ", self.fc.weight.device)return torch.sigmoid(self.fc(x))class SimpleDataset(Dataset):def __init__(self, data, target):self.data = dataself.target = targetdef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.target[idx]# 2. 初始化我们的模型、数据、各种配置  ####
## DDP:从外部得到local_rank参数。从外面得到local_rank参数,在调用DDP的时候,其会自动给出这个参数
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=-1, type=int)
FLAGS = parser.parse_args()
local_rank = FLAGS.local_rank## DDP:DDP backend初始化
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')## 假设我们有一些数据
n_sample = 100
n_dim = 10
batch_size = 25
X = torch.randn(n_sample, n_dim)  # 100个样本,每个样本有10个特征
Y = torch.randint(0, 2, (n_sample, )).float()dataset = SimpleDataset(X, Y)
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
data_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)## 构造模型
model = SimpleModel(n_dim).to(local_rank)
## DDP: Load模型要在构造DDP模型之前,且只需要在master上加载就行了。
ckpt_path = None
if dist.get_rank() == 0 and ckpt_path is not None:model.load_state_dict(torch.load(ckpt_path))## DDP: 构造DDP model —————— 必须在 init_process_group 之后才可以调用 DDP
model = DDP(model, device_ids=[local_rank], output_device=local_rank)## DDP: 要在构造DDP model之后,才能用model初始化optimizer。
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
loss_func = nn.BCELoss().to(local_rank)# 3. 网络训练  ###
model.train()
num_epoch = 100
iterator = tqdm(range(100))
for epoch in iterator:# DDP:设置sampler的epoch,# DistributedSampler需要这个来指定shuffle方式,# 通过维持各个进程之间的相同随机数种子使不同进程能获得同样的shuffle效果。data_loader.sampler.set_epoch(epoch)# 后面这部分,则与原来完全一致了。for data, label in data_loader:data, label = data.to(local_rank), label.to(local_rank)optimizer.zero_grad()prediction = model(data)loss = loss_func(prediction, label.unsqueeze(1))loss.backward()iterator.desc = "loss = %0.3f" % lossoptimizer.step()# DDP:# 1. save模型的时候,和DP模式一样,有一个需要注意的点:保存的是model.module而不是model。#    因为model其实是DDP model,参数是被`model=DDP(model)`包起来的。# 2. 只需要在进程0上保存一次就行了,避免多次保存重复的东西。if dist.get_rank() == 0 and epoch == num_epoch - 1:torch.save(model.module.state_dict(), "%d.ckpt" % epoch)

结合上面的代码,一个简化版的DDP流程:

  1. 读取DDP相关的配置,其中最关键的就是:local_rank
  2. DDP后端初始化:dist.init_process_group
  3. 创建DDP模型,以及数据加载器。注意要为加载器创建分布式采样器(DistributedSampler);
  4. 训练。

DDP的通常启动方式:

CUDA_VISIBLE_DEVICES="0,1" python -m torch.distributed.launch --nproc_per_node 2 ddp.py

一些概念

以上过程中涉及到一些陌生的概念,其实走一遍DDP的过程就会很好理解:每个进程是一个独立的训练流程,不同进程之间共享同一份数据。为了避免不同进程使用重复的数据训练,以及训练后同步梯度,进程间需要同步。因此,其中一个重点就是每个进程序号,或者说使用的GPU的序号。

  • node:节点,可以是物理主机,也可以是容器;
  • ranklocal_rank:都表示进程在整个分布式任务中的编号。rank是进程在全局的编号,local_rank是进程在所在节点上的编号。显然,如果只有一个节点,那么二者是相等的。在启动脚本中的--nproc_per_node即指定一个节点上有多少进程;
  • world_size:即整个分布式任务中进程的数量。

DDP与DP的区别

  • DP是单进程多线程的,只能在单机上工作;DDP是多进程的,可以在多级多卡上工作。DP通常比DDP慢,主要原因有:1)DP是单进程的,受到GIL的限制;2)DP每个step都需要拷贝模型,以及划分数据和收集输出;
  • DDP可以与模型并行相结合;
  • DP的通信成本随着卡数线性增长,DDP支持Ring-AllReduce,通信成本是固定的。

本文利用pytorch进行数据并行训练进行了一个粗浅的介绍。包括DP和DDP的基本原理,以及简单的例子。实际在分布式过程中涉及到的东西还是挺多的,比如DP/DDP中梯度的回收是如何进行的,DDP中数据采样的细节,DDP中的数据同步操作等。更多的还是要基于真实的需求出发才能真的体会得到。

参考


  1. 参数服务器-动手学深度学习2.0. ↩︎

  2. dataparallel ↩︎

  3. Pytorch Distributed Data Parallal. ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/230859.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

YOLOv8改进 | 检测头篇 | ASFF改进YOLOv8检测头(全网首发)

一、本文介绍 本文给大家带来的改进机制是利用ASFF改进YOLOv8的检测头形成新的检测头Detect_ASFF,其主要创新是引入了一种自适应的空间特征融合方式,有效地过滤掉冲突信息,从而增强了尺度不变性。经过我的实验验证,修改后的检测头在所有的检测目标上均有大幅度的涨点效果,…

设计模式 七大原则

1.单一职责原则 单一职责原则(SRP:Single responsibility principle)又称单一功能原则 核心:解耦和增强内聚性(高内聚,低耦合)。 描述: 类被修改的几率很大,因此应该专注…

Vue: 多个el-select不能重复选择相同属性

一、场景 1.需求&#xff1a; 用户可自由选择需要修改的对象并同时修改多个属性&#xff0c;需要校验修改对象不能重复选择&#xff0c;但是可供修改属性是固定的 2.目标效果&#xff1a; 二、实现 1.主要代码&#xff1a; <template><el-selectv-model"se…

开源一套原创文本处理工具:Java+Bat脚本实现自动批量处理对账单工具

原创/朱季谦 这款工具是笔者在2018年初开发完成的&#xff0c;时隔两载&#xff0c;偶然想起这款小工具&#xff0c;于是&#xff0c;决定将其开源&#xff0c;若有人需要做类似Java批处理实现整理文档的工具&#xff0c;可参考该工具逻辑思路来实现。 该工具是运行在windos系统…

vercel部署Gemini pro

一、注册一个vercel账号&#xff08;这个东西类似于第三方的github pages&#xff0c;能部署github中的项目&#xff09; 二、注册结束后&#xff0c;填写github的账号&#xff08;需要事先在该github账号中fork一个gemini的repository&#xff09; 三、babaohuang/GeminiPro…

ssm基于vue框架和elementui组件的手机官网论文

摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本手机官网就是在这样的大环境下诞生&#xff0c;其可以帮助管理者在短时间内处理完毕庞大的数据信息&#x…

用C语言采集游戏平台数据并做行业分析

游戏一直深受90/00后的喜爱&#xff0c;有些人因为对游戏的热爱还专门成立了工作室做游戏赚钱&#xff0c;但是游戏行业赚钱走不好就会被割一波韭菜&#xff0c;那么现在什么游戏挣钱&#xff0c;什么游戏好玩认可度高&#xff1f;带着这样的问题我将利用我毕生所学&#xff0c…

【React系列】Redux(一)管理状态

本文来自#React系列教程&#xff1a;https://mp.weixin.qq.com/mp/appmsgalbum?__bizMzg5MDAzNzkwNA&actiongetalbum&album_id1566025152667107329) 在React的开发过程中&#xff0c;Redux对于我们是非常重要的。 但是对于很多人来说&#xff0c;初次接触redux会感觉r…

2024年HCIE认证有什么用?华为HCIE好考吗?

随着信息技术的迅速发展&#xff0c;网络工程师的需求越来越高&#xff0c;而HCIE作为华为认证体系中的最高级别认证&#xff0c;备受从业者关注。本文将深入研究2024年HCIE认证的价值、考试难度以及报名费用等方面的信息。 2024年HCIE认证有什么用? 新的一年即将到来&#x…

jmeter关联依赖---三种

1.正则表达式提取器 2.xpath取样器 3.json提取器

听GPT 讲Rust源代码--compiler(11)

File: rust/compiler/rustc_mir_transform/src/simplify.rs 在Rust源代码中&#xff0c;rust/compiler/rustc_mir_transform/src/simplify.rs文件是Rust编译器中一系列进行MIR&#xff08;中间表示&#xff09;简化的转换的实现。MIR是Rust编译器中用于进行优化和代码生成的中间…

QT_02 窗口属性、信号槽机制

QT - 窗口属性、信号槽机制 1. 设置窗口属性 窗口设置 1,标题 2,大小 3,固定大小 4,设置图标在 widget.cpp 文件中&#xff1a; //设置窗口大小,此时窗口是可以拉大拉小的 //1参:宽度 //2参:高度 this->resize(800, 600); //设置窗口标题 this->setWindowTitle("…

2023 IoTDB Summit:清华大学软件学院长聘副教授龙明盛《IoTDB 新组件:内生机器学习》...

12 月 3 日&#xff0c;2023 IoTDB 用户大会在北京成功举行&#xff0c;收获强烈反响。本次峰会汇集了超 20 位大咖嘉宾带来工业互联网行业、技术、应用方向的精彩议题&#xff0c;多位学术泰斗、企业代表、开发者&#xff0c;深度分享了工业物联网时序数据库 IoTDB 的技术创新…

安全狗入选“2023年福建省信息技术应用创新解决方案”名单

近日&#xff0c;福建省数字福建建设领导小组办公室公布了2023年福建省信息技术应用创新解决方案入选项目名单。 作为国内云原生安全领导厂商&#xff0c;安全狗凭借综合且具备突出创新水平的方案入选。 据悉&#xff0c;此次方案征集面向全省信创企业和用户单位&#xff0c;…

HarmonyOS4.0系统性深入开发14AbilityStage组件容器

AbilityStage组件容器 AbilityStage是一个Module级别的组件容器&#xff0c;应用的HAP在首次加载时会创建一个AbilityStage实例&#xff0c;可以对该Module进行初始化等操作。 AbilityStage与Module一一对应&#xff0c;即一个Module拥有一个AbilityStage。 DevEco Studio默…

勒索事件急剧增长,亚信安全发布《勒索家族和勒索事件监控报告》

近期(12.15-12.21)态势快速感知 近期全球共发生了247起攻击和勒索事件&#xff0c;勒索事件数量急剧增长。 近期需要重点关注的除了仍然流行的勒索家族lockbit3以外&#xff0c;还有本周top1勒索组织toufan。toufan是一个新兴勒索组织&#xff0c;本周共发起了108起勒索攻击&a…

电脑视频需要分屏怎么做

在当今数字时代&#xff0c;人们对于视频的需求越来越高。有时候&#xff0c;我们可能想在同一屏幕上同时播放多个视频&#xff0c;进行对比、观看、剪辑或者其他目的。那么&#xff0c;视频分屏应该怎么做呢&#xff1f; 在本篇文章中&#xff0c;我们将会详细的为你介绍视频分…

JS中模块的导入导出

背景 学习js过程中&#xff0c;发现导入导出有的是使用的export 导出&#xff0c;import导入&#xff0c;有的是使用exports或module.exports导出&#xff0c;使用require导入&#xff0c;不清楚使用场景和规则&#xff0c;比较混乱。 经过了解发现&#xff0c;NodeJS 中&…

JAVA基础学习笔记-day13-数据结构与集合源1

JAVA基础学习笔记-day13-数据结构与集合源1 1. 数据结构剖析1.1 研究对象一&#xff1a;数据间逻辑关系1.2 研究对象二&#xff1a;数据的存储结构&#xff08;或物理结构&#xff09;1.3 研究对象三&#xff1a;运算结构1.4 小结 2. 一维数组2.1 数组的特点 3. 链表3.1 链表的…

CSS免费在线字体格式转换器 CSS @font-face 生成器

今天竟意外发现的一款免费的“网页字体生成器”&#xff0c;功能强大又好用~ 工具地址&#xff1a;https://transfonter.org/ 根据你设置生成后的文件预览&#xff1a; 支持TTF、OTF、WOFF、WOFF2 或 SVG字体格式转换生成&#xff0c;每个文件最大15MB。转换完成以后还会生成一…