PyTorch 分布式训练(Distributed Data Parallel, DDP)
一、DDP 核心概念
torch.nn.parallel.DistributedDataParallel
1. DDP 是什么?
Distributed Data Parallel (DDP) 是 PyTorch 提供的分布式训练接口,DistributedDataParallel
相比 DataParallel
具有以下优势:
- 多进程而非多线程:避免 Python GIL 限制
- 更高的效率:每个 GPU 有独立的进程,减少通信开销
- 更好的扩展性:支持多机多卡训练
- 更均衡的负载:无主 GPU 瓶颈问题
2. 核心组件
- 进程组 (Process Group):管理进程间通信
- NCCL 后端:NVIDIA 优化的 GPU 通信库
- Ring-AllReduce:高效的梯度同步算法
二、完整 DDP 训练 Demo
- 官方DDP Dem参考
1. 基础训练脚本 (ddp_demo.py
)
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
from torch.cuda.amp import GradScalerdef setup(rank, world_size):"""初始化分布式环境"""os.environ['MASTER_ADDR'] = 'localhost'os.environ['MASTER_PORT'] = '12355'dist.init_process_group("nccl", rank=rank, world_size=world_size)def cleanup():"""清理分布式环境"""dist.destroy_process_group()class SimpleModel(nn.Module):"""简单的CNN模型"""def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.fc = nn.Linear(9216, 10)def forward(self, x):x = torch.relu(self.conv1(x))x = torch.max_pool2d(x, 2)x = torch.relu(self.conv2(x))x = torch.max_pool2d(x, 2)x = torch.flatten(x, 1)return self.fc(x)def prepare_dataloader(rank, world_size, batch_size=32):"""准备分布式数据加载器"""transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)return loaderdef train(rank, world_size, epochs=2):"""训练函数"""setup(rank, world_size)# 设置当前设备torch.cuda.set_device(rank)# 初始化模型、优化器等model = SimpleModel().to(rank)ddp_model = DDP(model, device_ids=[rank])optimizer = optim.Adam(ddp_model.parameters())scaler = GradScaler() # 混合精度训练criterion = nn.CrossEntropyLoss()train_loader = prepare_dataloader(rank, world_size)for epoch in range(epochs):ddp_model.train()train_loader.sampler.set_epoch(epoch) # 确保每个epoch有不同的shufflefor batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(rank), target.to(rank)optimizer.zero_grad()# 混合精度训练with torch.autocast(device_type='cuda', dtype=torch.float16):output = ddp_model(data)loss = criterion(output, target)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()if batch_idx % 100 == 0:print(f"Rank {rank}, Epoch {epoch}, Batch {batch_idx}, Loss {loss.item():.4f}")cleanup()if __name__ == "__main__":# 单机多卡启动时,torchrun会自动设置这些环境变量rank = int(os.environ['LOCAL_RANK'])world_size = int(os.environ['WORLD_SIZE'])train(rank, world_size)
2. 启动训练
使用 torchrun
启动分布式训练(推荐 PyTorch 1.9+):
# 单机4卡训练
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=12355 ddp_demo.py
3. 关键组件解析
3.1 分布式数据采样 (DistributedSampler)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
- 确保每个 GPU 处理不同的数据子集
- 自动处理数据分片和 epoch 间的 shuffle
3.2 模型包装 (DDP)
ddp_model = DDP(model, device_ids=[rank])
- 自动处理梯度同步
- 透明地包装模型,使用方式与普通模型一致
3.3 混合精度训练 (AMP)
scaler = GradScaler()
with torch.autocast(device_type='cuda', dtype=torch.float16):# 前向计算
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
- 减少显存占用,加速训练
- 自动管理 float16/float32 转换
三、DDP 最佳实践
-
数据加载
- 必须使用
DistributedSampler
- 每个 epoch 前调用
sampler.set_epoch(epoch)
保证 shuffle 正确性
- 必须使用
-
模型保存
if rank == 0: # 只在主进程保存torch.save(model.state_dict(), "model.pth")
-
多机训练
# 机器1 (主节点) torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr=IP1 --master_port=12355 ddp_demo.py# 机器2 torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr=IP1 --master_port=12355 ddp_demo.py
-
性能调优
- 调整
batch_size
使各 GPU 负载均衡 - 使用
pin_memory=True
加速数据加载 - 考虑梯度累积减少通信频率
- 调整
四、常见问题解决
-
CUDA 内存不足
- 减少
batch_size
- 使用梯度累积
for i, (data, target) in enumerate(train_loader):if i % 2 == 0:optimizer.zero_grad()# 前向和反向...if i % 2 == 1:optimizer.step()
- 减少
-
进程同步失败
- 检查所有节点的
MASTER_ADDR
和MASTER_PORT
一致 - 确保防火墙开放相应端口
- 检查所有节点的
-
精度问题
- 混合精度训练时出现 NaN:调整
GradScaler
参数
scaler = GradScaler(init_scale=1024, growth_factor=2.0)
- 混合精度训练时出现 NaN:调整