PyTorch 分布式训练(Distributed Data Parallel, DDP)简介

PyTorch 分布式训练(Distributed Data Parallel, DDP)

一、DDP 核心概念

torch.nn.parallel.DistributedDataParallel

1. DDP 是什么?

Distributed Data Parallel (DDP) 是 PyTorch 提供的分布式训练接口,DistributedDataParallel相比 DataParallel 具有以下优势:

  • 多进程而非多线程:避免 Python GIL 限制
  • 更高的效率:每个 GPU 有独立的进程,减少通信开销
  • 更好的扩展性:支持多机多卡训练
  • 更均衡的负载:无主 GPU 瓶颈问题

2. 核心组件

  • 进程组 (Process Group):管理进程间通信
  • NCCL 后端:NVIDIA 优化的 GPU 通信库
  • Ring-AllReduce:高效的梯度同步算法

在这里插入图片描述

二、完整 DDP 训练 Demo

  • 官方DDP Dem参考

1. 基础训练脚本 (ddp_demo.py)

import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
from torch.cuda.amp import GradScalerdef setup(rank, world_size):"""初始化分布式环境"""os.environ['MASTER_ADDR'] = 'localhost'os.environ['MASTER_PORT'] = '12355'dist.init_process_group("nccl", rank=rank, world_size=world_size)def cleanup():"""清理分布式环境"""dist.destroy_process_group()class SimpleModel(nn.Module):"""简单的CNN模型"""def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.fc = nn.Linear(9216, 10)def forward(self, x):x = torch.relu(self.conv1(x))x = torch.max_pool2d(x, 2)x = torch.relu(self.conv2(x))x = torch.max_pool2d(x, 2)x = torch.flatten(x, 1)return self.fc(x)def prepare_dataloader(rank, world_size, batch_size=32):"""准备分布式数据加载器"""transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)return loaderdef train(rank, world_size, epochs=2):"""训练函数"""setup(rank, world_size)# 设置当前设备torch.cuda.set_device(rank)# 初始化模型、优化器等model = SimpleModel().to(rank)ddp_model = DDP(model, device_ids=[rank])optimizer = optim.Adam(ddp_model.parameters())scaler = GradScaler()  # 混合精度训练criterion = nn.CrossEntropyLoss()train_loader = prepare_dataloader(rank, world_size)for epoch in range(epochs):ddp_model.train()train_loader.sampler.set_epoch(epoch)  # 确保每个epoch有不同的shufflefor batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(rank), target.to(rank)optimizer.zero_grad()# 混合精度训练with torch.autocast(device_type='cuda', dtype=torch.float16):output = ddp_model(data)loss = criterion(output, target)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()if batch_idx % 100 == 0:print(f"Rank {rank}, Epoch {epoch}, Batch {batch_idx}, Loss {loss.item():.4f}")cleanup()if __name__ == "__main__":# 单机多卡启动时,torchrun会自动设置这些环境变量rank = int(os.environ['LOCAL_RANK'])world_size = int(os.environ['WORLD_SIZE'])train(rank, world_size)

2. 启动训练

使用 torchrun 启动分布式训练(推荐 PyTorch 1.9+):

# 单机4卡训练
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=12355 ddp_demo.py

3. 关键组件解析

3.1 分布式数据采样 (DistributedSampler)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
  • 确保每个 GPU 处理不同的数据子集
  • 自动处理数据分片和 epoch 间的 shuffle
3.2 模型包装 (DDP)
ddp_model = DDP(model, device_ids=[rank])
  • 自动处理梯度同步
  • 透明地包装模型,使用方式与普通模型一致
3.3 混合精度训练 (AMP)
scaler = GradScaler()
with torch.autocast(device_type='cuda', dtype=torch.float16):# 前向计算
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
  • 减少显存占用,加速训练
  • 自动管理 float16/float32 转换

三、DDP 最佳实践

  1. 数据加载

    • 必须使用 DistributedSampler
    • 每个 epoch 前调用 sampler.set_epoch(epoch) 保证 shuffle 正确性
  2. 模型保存

    if rank == 0:  # 只在主进程保存torch.save(model.state_dict(), "model.pth")
    
  3. 多机训练

    # 机器1 (主节点)
    torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr=IP1 --master_port=12355 ddp_demo.py# 机器2
    torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr=IP1 --master_port=12355 ddp_demo.py
    
  4. 性能调优

    • 调整 batch_size 使各 GPU 负载均衡
    • 使用 pin_memory=True 加速数据加载
    • 考虑梯度累积减少通信频率

四、常见问题解决

  1. CUDA 内存不足

    • 减少 batch_size
    • 使用梯度累积
    for i, (data, target) in enumerate(train_loader):if i % 2 == 0:optimizer.zero_grad()# 前向和反向...if i % 2 == 1:optimizer.step()
    
  2. 进程同步失败

    • 检查所有节点的 MASTER_ADDRMASTER_PORT 一致
    • 确保防火墙开放相应端口
  3. 精度问题

    • 混合精度训练时出现 NaN:调整 GradScaler 参数
    scaler = GradScaler(init_scale=1024, growth_factor=2.0)
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/42372.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

蓝桥杯[每日一题] 真题:连连看

题目描述 小蓝正在和朋友们玩一种新的连连看游戏。在一个 n m 的矩形网格中,每个格子中都有一个整数,第 i 行第 j 列上的整数为 Ai, j 。玩家需要在这个网格中寻找一对格子 (a, b) − (c, d) 使得这两个格子中的整数 Aa,b 和 Ac,d 相等,且它…

Linux环境下安装部署Docker

windows下连接Linux: 打开终端: //ssh远程连接 ssh root192.168.xx.xx//输入账号密码 root192.168.xx.xxs password: ssh连接成功! 安装Docker: //安装Docker yum install -y yum-utils device-mapper-persistent-data lvm2 …

k近邻算法K-Nearest Neighbors(KNN)

算法核心 KNN算法的核心思想是“近朱者赤,近墨者黑”。对于一个待分类或预测的样本点,它会查找训练集中与其距离最近的K个样本点(即“最近邻”)。然后根据这K个最近邻的标签信息来对当前样本进行分类或回归。 在分类任务中&#…

Appium中元素定位之一个元素定位API

应用场景 想要对按钮进行点击,想要对输入框进行输入,想要获取文本框的内容,定位元素是自动化操作必须要使用的方法。只有获取元素之后,才能对这个元素进行操作。 在 Java 中使用 Appium 定位元素时,可以通过多种方式…

Dify 服务器部署指南

1. 系统要求 在开始部署之前,请确保你的服务器满足以下要求: 操作系统:Linux(推荐使用 Ubuntu 20.04 或更高版本)内存:至少 4GB RAM存储:至少 20GB 可用空间网络:稳定的互联网连接…

Sa-Token

简介 Sa-Token 是一个轻量级 Java 权限认证框架,主要解决:登录认证、权限认证、单点登录、OAuth2.0、分布式Session会话、微服务网关鉴权 等一系列权限相关问题。 官方文档 常见功能 登录认证 本框架 用户提交 name password 参数,调用登…

ADZS-ICE-2000和AD-ICE2000仿真器在线升级固件

作者的话 近期发现有些兄弟的ICE-2000仿真器链接DSP报错,然后test第四步不通过,我就拿我的仿真器也试了一下,发现ADI悄咪咪的在线升级仿真器固件,有些兄弟不会操作,就会导致仿真器升级失败,连不上目标板&a…

C++概述

1 什么是面向对象】 概念上来说:就是以对象(具体的变量)为导向的编程思路 专注于:一个对象具体能实现哪些过程(哪些功能) 面向对象 n * 面向过程 结论:面向对象需要做的事情 1:我们要想清楚,我们现在需要编写一个…

Java 大视界 -- 基于 Java 的大数据隐私计算在医疗影像数据共享中的实践探索(158)

💖亲爱的朋友们,热烈欢迎来到 青云交的博客!能与诸位在此相逢,我倍感荣幸。在这飞速更迭的时代,我们都渴望一方心灵净土,而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识,也…

数字化如何赋能食品抽检全流程升级,助力食品安全监管现代化

食品安全是关乎民众健康和社会稳定的重要问题。食品抽检作为保障食品安全的核心监管手段,通过对食品生产、加工、销售等环节的随机抽样检测,及时发现潜在的食品安全问题,防止不合格产品流入市场,同时为政府监管、企业自查和消费者…

HBase入门教程

HBase入门教程 HBase是一个开源的、分布式的、版本化的非关系型数据库,是Apache Hadoop生态系统的重要组成部分。本文将全面介绍HBase的基础知识,帮助你快速入门。 文章目录 HBase入门教程1. HBase简介1.1 什么是HBase?1.2 HBase核心特点 2.…

vscode连接服务器失败问题解决

文章目录 问题描述原因分析解决方法彻底删除VS Code重新安装较老的版本 问题描述 vscode链接服务器时提示了下面问题: 原因分析 这是说明VScode版本太高了。 https://code.visualstudio.com/docs/remote/faq#_can-i-run-vs-code-server-on-older-linux-distribu…

redis常用部署架构之redis分片集群。

redis 3.x版本后开始支持 作用: 1.提升数据读写速度 2..提升可用性 分片集群就是将业务服务器产生的数据储存在不同的机器上。 redis分片集群的架构 如上图所示,会将数据分散存储到不同的服务器上,相比于之前来说,redis要处…

Modbus主站EtherNet/IP转ModbusRTU/ASCII工业EIP网关串口服务器

型号 2路总线EIP网关 MS-A1-2021 4路总线EIP网关 MS-A1-2041 4路总线EIP网关(双网口) MS-A2-2041 8路总线EIP网关 MS-A1-2081 8路总线EIP网关(双网口) MS-A2-2081 EtherNet/IP 串口网关 EtherNet/IP 转 RS485 …

Centos7 安装 TDengine

Centos7 安装 TDengine 1、简介 官网: https://www.taosdata.com TDengine 是一款开源、高性能、云原生的时序数据库(Time Series Database, TSDB), 它专为物联网、车联网、工业互联网、金融、IT 运维等场景优化设计。同时它还带有内建的缓…

基于社交裂变的S2B2C电商模式创新研究——以“颜值PK+礼品卡+AI智能名片“融合生态为例

摘要 本文构建了融合开源AI技术、社交裂变机制与S2B2C商业模式的创新模型。通过开发具备AI智能名片功能的商城小程序,实现用户日均停留时长提升171%、社交转化效率提高2.8倍的实证效果。研究发现:基于GAN的虚拟形象生成技术可降低用户决策成本32%&…

王者荣耀服务器突然崩了

就在刚刚王者荣耀服务器突然崩了 #王者荣耀崩了#的话题毫无预兆地冲上热搜,许多玩家发现游戏登录界面反复弹出异常提示,匹配成功后卡在加载界面,甚至出现对局数据丢失的情况。根据官方公告,目前技术团队已在全力抢修服务器 #王者…

LabVIEW医疗设备备用电源实时监控系统

开发了一个基于LabVIEW的医疗设备备用电源实时监控系统。系统提高医疗设备备用电源的管理效能与使用安全,通过实时监测与数据分析,确保医疗设施在电力供应中断时的可靠运行。 ​ 项目背景 医院中的医疗设备对电源的连续供应有着极高的要求,…

04-SpringBoot3入门-配置文件(多环境配置)

1、简介 在 SpringBoot 中,不同的环境(如开发、测试、生产)可以编写对应的配置文件,例如数据库连接信息、日志级别、缓存配置等。在不同的环境中使用对应的配置文件。 2、配置环境 # 开发环境 zbj:user:username: root # 测试环…

C++链表详解:从基础概念到高级应用

C++链表详解:从基础概念到高级应用 链表是计算机科学中最基础也是最重要的数据结构之一,它在内存管理、算法实现和实际应用中扮演着关键角色。本文将详细介绍链表的概念、类型、C++实现以及实际应用场景,帮助读者全面理解这一重要的数据结构。 文章目录 C++链表详解:从基础…