分布式训练:(Pytorch)

分布式训练是将机器学习模型的训练过程分散到多个计算节点或设备上,以提高训练速度和效率,尤其是在处理大规模数据和模型时。分布式训练主要分为数据并行模型并行两种主要策略:

1. 数据并行 (Data Parallelism)

数据并行是最常见的分布式训练方式。在这种方法中,模型副本会被复制到多个计算设备上,每个设备处理不同的批次(batch)数据。

工作流程:
  • 每个设备上都有一个完整的模型副本。
  • 数据集被分割成多个部分(mini-batches),每个设备处理其中一部分。
  • 每个设备独立计算模型的前向传播和反向传播,计算出梯度。
  • 通过某种方式(如梯度聚合),将所有设备的梯度平均化,并更新全局模型参数。
  • 同步方式可分为同步训练和异步训练:
    • 同步训练:所有设备都在同一个时刻更新模型参数。
    • 异步训练:各设备独立更新参数,可能导致一些参数不一致。
# Replicate module to devices in device_ids
replicas = nn.parallel.replicate(module, device_ids)
# Distribute input to devices in device_ids
inputs = nn.parallel.scatter(input, device_ids)
# Apply the models to corresponding inputs
outputs = nn.parallel.parallel_apply(replicas, inputs)
# Gather result from all devices to output_device
result = nn.parallel.gather(outputs, output_device)
优点:
  • 易于实现,特别是在GPU集群或云端平台中。
  • 可以在大规模数据集上显著加快训练过程。
缺点:
  • 通信开销较大,特别是在梯度同步阶段,可能会成为训练速度的瓶颈。
  • 对大模型的扩展性有限,因为每个设备都需要存储完整的模型。

2. 模型并行 (Model Parallelism)

模型并行将一个大型模型拆分到多个设备上,以便更好地利用计算资源,尤其适用于内存消耗较大的模型。

工作流程:
  • 模型被拆分成多个部分,每个设备负责模型的一个子集。
  • 输入数据在各设备间传递,完成前向传播和反向传播。
  • 各设备独立计算梯度并更新自己负责的模型参数。
优点:
  • 适合超大规模模型,尤其是单个设备无法存储整个模型的情况。
  • 内存使用效率较高。
缺点:
  • 由于模型的不同部分在不同设备上进行计算,存在大量的通信开销,尤其是在前向传播和反向传播时需要设备间频繁交互。
  • 难以实现模型的负载均衡,部分设备可能成为性能瓶颈。

常用的分布式训练框架

  • TensorFlow:支持多设备、多机器的分布式训练,通过 tf.distribute.Strategy 轻松实现。
  • PyTorch:通过 torch.distributed 提供原生支持,还支持基于 Horovod 等第三方工具的分布式训练。
  • Horovod:Uber 开源的分布式深度学习库,支持 TensorFlow、Keras、PyTorch 等。

关键挑战

  • 同步和通信开销:在数据并行训练中,梯度的同步可能成为瓶颈。
  • 负载均衡:在模型并行训练中,确保各设备之间的负载均衡非常重要,以避免性能瓶颈。
  • 容错性:分布式训练中节点故障可能导致训练过程中断,需要具备一定的容错机制。

常用的 API 有两个:

  • torch.nn.DataParallel(DP)
  • torch.nn.DistributedDataParallel(DDP)

torch.nn.DataParallel(简称 DP)是 PyTorch 提供的一个简单的并行化工具,主要用于在多个 GPU 上进行数据并行训练。DataParallel 通过将输入数据批次(batch)切分成多个小批次,并将其分发到多个 GPU 上,进行并行处理。它会自动处理梯度的同步和模型参数的更新。

torch.nn.DataParallel 的工作机制

  1. 模型复制DataParallel 会将模型复制到多个 GPU 上,每个 GPU 上有一个模型副本。
  2. 数据分割:输入数据会被划分成多个小批次(mini-batches),并分别分发给各个 GPU。
  3. 并行执行:每个 GPU 独立进行前向传播和反向传播,计算梯度。
  4. 梯度汇总:主设备(默认是 cuda:0)会收集所有 GPU 计算出的梯度,并将它们平均化,更新模型的全局参数。

使用 torch.nn.DataParallel

使用 DataParallel 非常简单,通常只需要将模型用 DataParallel 包裹,然后像普通模型一样使用即可。

import torch
import torch.nn as nn
import torch.optim as optim# 定义模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 5)def forward(self, x):return self.fc(x)# 初始化模型和数据
model = SimpleModel()# 将模型并行化
if torch.cuda.device_count() > 1:print("Using", torch.cuda.device_count(), "GPUs")model = nn.DataParallel(model)model = model.cuda()# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 模拟输入数据
inputs = torch.randn(32, 10).cuda()  # 一个 32 样本的 batch,每个样本 10 个特征
targets = torch.randn(32, 5).cuda()  # 对应的目标输出# 前向传播
outputs = model(inputs)# 计算损失
loss = criterion(outputs, targets)# 反向传播
optimizer.zero_grad()
loss.backward()# 更新模型参数
optimizer.step()

DistributedDataParallel (简称 DDP) 是 PyTorch 用于分布式训练的高级并行化工具,它的效率和灵活性比 DataParallel 更高,特别适合在多个 GPU 甚至跨多个节点(机器)上进行分布式训练。与 DataParallel 不同,DDP 在每个设备(GPU)上独立处理模型的前向传播和反向传播,并且避免了主设备的瓶颈问题。

DistributedDataParallel 的工作原理

  1. 模型的分发:与 DataParallel 类似,DDP 会在每个 GPU 上保留一份模型副本。但与 DataParallel 不同的是,DDP 不需要将数据集中在主设备上,而是让每个 GPU 独立完成自己的工作。
  2. 前向和反向传播:每个 GPU 上的模型执行前向传播和反向传播,并计算梯度。
  3. 梯度同步:每个设备上计算的梯度通过 all-reduce 操作在所有设备之间同步,确保所有模型副本的梯度相同。这个过程是并行进行的,不会像 DataParallel 那样集中在主设备上,因此通信效率更高。
  4. 参数更新:每个设备独立地应用梯度更新全局模型参数。

DistributedDataParallel 的优点

  • 高效的通信和同步:梯度的同步是在所有设备之间并行进行的,避免了主设备成为通信瓶颈的问题,因此在多 GPU 或跨节点时表现更加优异。
  • 可扩展性强DDP 支持跨多台机器的训练,适合超大规模模型或需要跨节点的分布式训练。
  • 无锁设计DDP 实现了无锁的梯度同步,不会因锁机制造成性能损失。

DistributedDataParallel 的使用

DataParallel 类似,DDP 也需要对模型进行包装,但它需要更多的设置,特别是在多机环境下,还需要配置通信后端。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP# 初始化分布式环境
def setup(rank, world_size):dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)# 销毁分布式环境
def cleanup():dist.destroy_process_group()# 定义模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 5)def forward(self, x):return self.fc(x)# 初始化模型、优化器和数据
def main(rank, world_size):setup(rank, world_size)model = SimpleModel().cuda(rank)ddp_model = DDP(model, device_ids=[rank])criterion = nn.MSELoss()optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)# 模拟输入数据inputs = torch.randn(32, 10).cuda(rank)targets = torch.randn(32, 5).cuda(rank)# 前向传播outputs = ddp_model(inputs)loss = criterion(outputs, targets)# 反向传播optimizer.zero_grad()loss.backward()# 更新模型参数optimizer.step()cleanup()# 假设有两个GPU,可以这样启动分布式训练
if __name__ == "__main__":world_size = 2  # GPU数torch.multiprocessing.spawn(main, args=(world_size,), nprocs=world_size, join=True)
特性DataParallel (DP) DistributedDataParallel (DDP)
通信模式主设备负责梯度同步所有设备间并行同步梯度
性能通信开销大,主设备瓶颈通信开销小,性能更高
可扩展性适用于单机多 GPU适用于单机或多机多 GPU
使用场景小规模并行大规模或跨节点分布式训练

2. 并行数据加载

在深度学习任务中,数据加载通常是训练过程中的一个瓶颈,特别是当数据量很大时。使用多个进程来并行加载数据,并将数据从可分页内存(虚拟内存)转移到固定内存(GPU 内存)可以显著提高训练效率。

工作流程

  1. 数据加载

    • 使用多个进程并行从磁盘读取数据。每个进程负责加载不同的数据批次,减少了磁盘 I/O 操作的等待时间。
  2. 生产者-消费者模式

    • 数据加载进程(生产者)将读取的数据批次放入队列中,而主线程(消费者)从队列中取出数据批次进行训练。这样可以在数据加载和模型训练过程中实现并行化,减少数据加载对训练速度的影响。
  3. 固定内存的使用

    • 将数据从主机的可分页内存转移到固定内存。数据被加载到固定内存中后,转移到 GPU 的速度会更快,因为固定内存中的数据可以快速传输。

参数解释

  1. num_workers

    • 这个参数指定了数据加载的进程数量。将 num_workers 设置为大于 0 的值可以让 DataLoader 使用多个子进程来并行加载数据。
    • 例如,num_workers=4 表示使用 4 个进程来加载数据。这可以显著提高数据加载速度,因为多个进程可以同时从磁盘读取不同的数据批次。
  2. pin_memory

    • 这个参数用于将数据从主机内存(CPU 内存)固定到页面锁定内存(pinned memory)。固定内存可以让数据传输到 GPU 更加高效。
    • pin_memory=True 时,DataLoader 会将数据从可分页的内存(虚拟内存)传输到固定内存中,这样在将数据转移到 GPU 时,数据传输速度会更快,因为固定内存可以避免页面交换的开销。

总结

  • 数据加载:使用多个进程来并行加载和预处理数据,通过流水线处理减少数据加载的延迟。
  • 数据传输:利用 CUDA 流优化从固定内存到 GPU 的数据传输。
  • 数据并行性:使用数据并行和 NCCL 等通信库实现高效的梯度同步和模型参数更新,优化训练过程。

这种方法结合了数据加载、数据传输和数据并行处理的优化,能够显著提升深度学习模型的训练效率和速度。

import torch
from torch.utils.data import DataLoader, Dataset
import numpy as npclass CustomDataset(Dataset):def __init__(self, size):self.data = np.random.rand(size, 3, 224, 224).astype(np.float32)self.labels = np.random.randint(0, 2, size).astype(np.int64)def __len__(self):return len(self.data)def __getitem__(self, idx):return torch.tensor(self.data[idx]), torch.tensor(self.labels[idx])dataset = CustomDataset(size=10000)
dataloader = DataLoader(dataset,batch_size=64,shuffle=True,num_workers=4,      # 使用 4 个子进程加载数据pin_memory=True     # 将数据转移到固定内存
)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")for inputs, labels in dataloader:inputs, labels = inputs.to(device), labels.to(device)# 模型训练代码# ...

 参考文章:

Pytorch 分布式训练(DP/DDP)_pytorch分布式训练-CSDN博客icon-default.png?t=O83Ahttps://blog.csdn.net/ytusdc/article/details/122091284?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522CC589E02-BBE1-4F15-BDC0-CA76EBF6C160%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=CC589E02-BBE1-4F15-BDC0-CA76EBF6C160&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~top_positive~default-1-122091284-null-null.142^v100^control&utm_term=%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83&spm=1018.2226.3001.4187

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/425808.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【网络安全】逻辑漏洞之购买商品

未经授权,不得转载。 文章目录 正文正文 电子商务平台的核心功能,即购买商品功能。因为在这个场景下,任何功能错误都有可能对平台产生重大影响,特别是与商品价格和数量有关的问题。 将商品添加到购物车时拦截请求: 请求包的参数: 解码参数后,并没有发现价格相关的参数,…

Python(TensorFlow和PyTorch)及C++注意力网络导图

🎯要点 谱图神经网络计算注意力分数对比图神经网络、卷积网络和图注意力网络药物靶标建模学习和预测相互作用腹侧和背侧皮质下结构手写字体字符序列文本识别组织病理学图像分析长短期记忆财务模式预测相关性生物医学图像特征学习和迭代纠正 Python注意力机制 对…

AE VM5000 Platform VarioMatch Match Network 手侧

AE VM5000 Platform VarioMatch Match Network 手侧

算法入门-贪心1

第八部分:贪心 409.最长回文串(简单) 给定一个包含大写字母和小写字母的字符串 s ,返回通过这些字母构造成的最长的回文串 的长度。 在构造过程中,请注意 区分大小写 。比如 "Aa" 不能当做一个回文字符串…

Understanding the model of openAI 5 (1024 unit LSTM reinforcement learning)

题意:理解 OpenAI 5(1024 单元 LSTM 强化学习)的模型 问题背景: I recently came across openAI 5. I was curious to see how their model is built and understand it. I read in wikipedia that it "contains a single l…

从0-1 用AI做一个赚钱的小红书账号(不是广告不是广告)

大家好,我是胡广!是不是被标题吸引过来的呢?是不是觉得自己天赋异禀,肯定是那万中无一的赚钱天才。哈哈哈,我告诉你,你我皆是牛马,不要老想着突然就成功了,一夜暴富了,瞬…

【SQL】百题计划:SQL对于空值的比较判断。

[SQL]百题计划 方法&#xff1a; 使用 <> (!) 和 IS NULL [Accepted] 想法 有的人也许会非常直观地想到如下解法。 SELECT name FROM customer WHERE referee_Id <> 2;然而&#xff0c;这个查询只会返回一个结果&#xff1a;Zach&#xff0c;尽管事实上有 4 个…

React js Router 路由 2, (把写过的几个 app 组合起来)

完整的项目&#xff0c;我已经上传了&#xff0c;资源链接. 起因&#xff0c; 目的: 每次都是新建一个 react 项目&#xff0c;有点繁琐。 刚刚学了路由&#xff0c;不如写一个 大一点的 app &#xff0c;把前面写过的几个 app, 都包含进去。 这部分感觉就像是&#xff0c; …

linux网络编程——UDP编程

写在前边 本文是B站up主韦东山的4_8-3.UDP编程示例_哔哩哔哩_bilibili视频的笔记&#xff0c;其中有些部分博主也没有理解&#xff0c;希望各位辩证的看。 UDP协议简介 UDP 是一个简单的面向数据报的运输层协议&#xff0c;在网络中用于处理数据包&#xff0c;是一种无连接的…

借助大模型将文档转换为视频

利用传统手段将文档内容转换为视频&#xff0c;比如根据文档内容录制一个视频&#xff0c;不仅需要投入大量的时间和精力&#xff0c;而且往往需要具备专业的视频编辑技能。使用大模型技术可以更加有效且智能化地解决上述问题。本实践方案旨在依托大语言模型&#xff08;Large …

JDBC导图

思维歹徒 一、使用步骤 二、SQL注入 三、数据库查询&#xff08;查询&#xff09; 四、数据库写入&#xff08;增删改&#xff09; 五、Date日期对象处理 六、连接池使用 创建连接是从连接池拿&#xff0c;释放连接是放回连接池 七、事务和批次插入 八、Apache Commons DBUtil…

Village Exteriors Kit 中世纪乡村房屋场景模型

此模块化工具包就是你一直在寻找的适合建造所有中世纪幻想村庄和城市建筑所需要的工具包。 皇家园区 - 村庄外饰套件的模型和纹理插件资源包 酒馆和客栈、魔法商店、市政大厅、公会大厅、布莱克史密斯锻造厂、百货商店、珠宝商店、药店、草药师、银行、铠甲、弗莱切、马厩、桌…

这个时代唯一“不变“的又是{变}

这个时代唯一不变的就是“变”&#xff0c;所以每个人都得有规划意识&#xff0c;首先要对自己的价值有清晰的认知&#xff0c;你核心卖点是什么。第二&#xff0c;你取得的成绩是通过平台成就的还是通过自身努力取得的&#xff0c;很多人在一家平台待久了之后&#xff0c;身上…

2022高教社杯全国大学生数学建模竞赛C题 问题一(1) Python代码

目录 问题 11.1 对这些玻璃文物的表面风化与其玻璃类型、纹饰和颜色的关系进行分析数据探索 -- 单个分类变量的绘图树形图条形图扇形图雷达图 Cramer’s V 相关分析统计检验列联表分析卡方检验Fisher检验 绘图堆积条形图分组条形图 分类模型Logistic回归随机森林 import matplo…

在STM32工程中使用Mavlink与飞控通信

本文讲述如何在STM32工程中使用Mavlink协议与飞控通信&#xff0c;特别适合自制飞控外设模块的项目。 需求来源&#xff1a; 1、增稳云台里的STM32单片机需要通过串口接收飞控传来的云台俯仰、横滚控制指令和相机拍照控制指令&#xff1b; 2、自制的有害气体采集器需要接收飞…

[Python可视化]数据可视化在医疗领域应用:提高诊断准确性和治疗效果

随着医疗数据的增长&#xff0c;如何从庞大的数据集中快速提取出有用的信息&#xff0c;成为了医疗研究和实践中的一大挑战。数据可视化在这一过程中扮演了至关重要的角色&#xff0c;它能够通过图形的方式直观展现复杂的数据关系&#xff0c;从而帮助医生和研究人员做出更好的…

专题四_位运算( >> , << , , | , ^ )_算法详细总结

目录 位运算 常见位运算总结 1.基础位运算 2.给一个数 n ,确定它的二进制表示中的第 x 位是 0 还是 1 3.运算符的优先级 4.将一个数 n 的二进制表示的第 x 位修改成 1 5.将一个数n的二进制表示的第x位修改成0 6.位图的思想 7.提取一个数&#xff08;n&#xff09;二进…

【嘉立创EDA】画PCB板中为什么要两面铺铜为GND,不能一面GND一面VCC吗?

在新手画板子铺铜时&#xff0c;经常会铺一面GND一面VCC。但一般情况下我们不会这样铺铜。下面将详细分析为什么要两面铺铜为GND&#xff0c;而不是一面GND一面VCC的原因&#xff1a; 提高散热能力 金属导热性&#xff1a;金属具有良好的导热性&#xff0c;铺铜可以有效分散PCB…

引用和指针的区别(面试概念性题型)

个人主页&#xff1a;Jason_from_China-CSDN博客 所属栏目&#xff1a;C系统性学习_Jason_from_China的博客-CSDN博客 所属栏目&#xff1a;C知识点的补充_Jason_from_China的博客-CSDN博客 概念概述 内存占用&#xff1a; 引用&#xff1a;引用一个变量时&#xff0c;实际上并…

2024 年浙江省网络安全行业网络安全运维工程师项目 职业技能竞赛网络安全运维工程师(决赛样题)

2024年浙江省网络安全行业网络安全运维工程师项目 职业技能竞赛网络安全运维工程师&#xff08;决赛样题&#xff09; 应急响应&#xff1a;1 通过流量分析&#xff0c;找到攻击者的 IP 地址2 找到攻击者下载的恶意文件的 32 位小写 md5 值3 找到攻击者登录后台的 URI4 找到攻击…