分布式训练与多GPU加速策略

一、为什么要使用分布式训练?

分布式训练通过‌并行计算‌解决以下问题:

  1. 处理超大规模数据集(TB级)
  2. 加速模型训练(线性加速比)
  3. 突破单卡显存限制
  4. 实现工业级模型训练(如LLaMA、GPT)

二、单机多卡训练实战

1. 数据并行基础

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, DistributedSampler# 准备数据集
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5,), (0.5,))
])
dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)# 初始化模型
class ConvNet(nn.Module):def __init__(self):super().__init__()self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, 3),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(32, 64, 3),nn.ReLU(),nn.MaxPool2d(2))self.fc_layers = nn.Sequential(nn.Linear(1600, 256),nn.ReLU(),nn.Linear(256, 10))def forward(self, x):x = self.conv_layers(x)x = x.view(x.size(0), -1)return self.fc_layers(x)# 数据并行包装(适合单机多卡)
model = nn.DataParallel(ConvNet().cuda())
print("使用GPU数量:", torch.cuda.device_count())# 训练循环示例
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()for epoch in range(2):dataloader = DataLoader(dataset, batch_size=512, shuffle=True)for inputs, labels in dataloader:inputs = inputs.cuda()labels = labels.cuda()outputs = model(inputs)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()print(f"Epoch {epoch+1} Loss: {loss.item():.4f}")

三、分布式数据并行(DDP)

1. 初始化分布式环境

import os
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDPdef setup(rank, world_size):os.environ['MASTER_ADDR'] = 'localhost'os.environ['MASTER_PORT'] = '12355'dist.init_process_group("nccl", rank=rank, world_size=world_size)torch.cuda.set_device(rank)def cleanup():dist.destroy_process_group()

2. 分布式训练函数

def train_ddp(rank, world_size):setup(rank, world_size)# 创建分布式采样器sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)dataloader = DataLoader(dataset, batch_size=256, sampler=sampler)# 初始化模型model = ConvNet().to(rank)ddp_model = DDP(model, device_ids=[rank])optimizer = torch.optim.Adam(ddp_model.parameters())criterion = nn.CrossEntropyLoss()for epoch in range(2):sampler.set_epoch(epoch)for inputs, labels in dataloader:inputs = inputs.to(rank)labels = labels.to(rank)outputs = ddp_model(inputs)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()if rank == 0:print(f"Epoch {epoch+1} Loss: {loss.item():.4f}")cleanup()

3. 启动分布式训练

import torch.multiprocessing as mpif __name__ == "__main__":world_size = torch.cuda.device_count()print(f"启动分布式训练,使用 {world_size} 个GPU")mp.spawn(train_ddp, args=(world_size,), nprocs=world_size, join=True)

四、高级加速策略

1. 混合精度训练

from torch.cuda.amp import autocast, GradScalerscaler = GradScaler()for inputs, labels in dataloader:inputs = inputs.cuda()labels = labels.cuda()with autocast():outputs = model(inputs)loss = criterion(outputs, labels)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()optimizer.zero_grad()

2. 梯度累积

accumulation_steps = 4for i, (inputs, labels) in enumerate(dataloader):inputs = inputs.cuda()labels = labels.cuda()outputs = model(inputs)loss = criterion(outputs, labels) / accumulation_stepsloss.backward()if (i+1) % accumulation_steps == 0:optimizer.step()optimizer.zero_grad()

五、性能对比实验

1. 不同并行方式对比

方法显存占用训练速度(s/epoch)资源利用率
单卡10.2GB58s35%
DataParallel10.5GB32s68%
DDP5.1GB28s92%
DDP+混合精度3.2GB22s98%

六、常见问题解答

Q1:多卡训练出现显存不足怎么办?

  • 使用梯度累积技术
  • 启用激活检查点(Checkpointing):
from torch.utils.checkpoint import checkpoint
def forward(self, x):x = checkpoint(self.conv_block1, x)x = checkpoint(self.conv_block2, x)return x

Q2:如何解决分布式训练中的死锁问题?

  • 确保所有进程的同步操作
  • 使用torch.distributed.barrier()进行进程同步
  • 检查数据加载是否对齐

Q3:多机训练如何配置?

# 多机启动命令示例
# 机器1: 
# torchrun --nnodes=2 --node_rank=0 --nproc_per_node=4 --master_addr=192.168.1.1 main.py
# 机器2: 
# torchrun --nnodes=2 --node_rank=1 --nproc_per_node=4 --master_addr=192.168.1.1 main.py

七、小结与下篇预告

  • 本文重点‌:

    1. 单机多卡并行训练方法
    2. 分布式数据并行(DDP)实现
    3. 混合精度与梯度累积优化
  • 下篇预告‌:
    第七篇将深入PyTorch生态,结合Hugging Face实现Transformer模型实战!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/37024.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

汇能感知高品质的多光谱相机VSC02UA

VSC02UA概要 VSC02UA是一款高品质的200万像素的光谱相机,适用于工业检测、农业、医疗等领域。VSC02UA 包含 1600 行1200 列有源像素阵列、片上 10 位 ADC 和图像信号处理器。它带有 USB2.0 接口,配合专门的电脑上位机软件使用,可进行图像采集…

VSCode创建VUE项目(三)使用axios调用后台服务

1. 安装axios,执行命令 npm install axios 2. 在 main.ts 中引入并全局挂载 Axios 实例 修改后的 代码(也可以单独建一个页面处理Axios相关信息等,然后全局进行挂载) import { createApp } from vue import App from ./App.vue import rou…

信号处理抽取多项滤波的数学推导与仿真

昨天的《信号处理之插值、抽取与多项滤波》,已经介绍了插值抽取的多项滤率,今天详细介绍多项滤波的数学推导,并附上实战仿真代码。 一、数学变换推导 1. 多相分解的核心思想 将FIR滤波器的系数 h ( n ) h(n) h(n)按相位分组,每…

基于Rockylinux9.5(LTS-SP4)安装MySQL Community Server 9.2.0

目录 一、安装环境及准备 1、linux操作系统环境 2、MYSQL安装包准备 二、执行安装 1、解压软件包 2、按顺序执行软件包的安装 3、启动MYSQL服务 4.配置MYSQL 一、安装环境及准备 1、linux操作系统环境 Rocky linux9.5安装在VMware虚拟机上完成Rocky linux9.5安装&am…

分布式任务调度

今天我们讲讲分布式定时任务调度—ElasticJob。 一、概述 1、什么是分布式任务调度 我们可以思考⼀下下⾯业务场景的解决⽅案: 某电商平台需要每天上午10点,下午3点,晚上8点发放⼀批优惠券 某银⾏系统需要在信⽤卡到期还款⽇的前三天进⾏短信提醒 某…

Blender标注工具

按住键盘D键 鼠标左键绘制 / 右键擦除 也可以在上方选择删除

Second Me:在 AI 中保留自我的火种丨社区来稿

今天想和所有朋友们分享一种全新的 AI 可能性,Second Me! 2025年了,很多人和我一样,都越来越确信,AGI 的到来只是一个时间问题。 然而我也经常想,当我们所有人,都心甘情愿地为自己“造神” –…

仿新浪微博typecho主题源码

源码介绍 仿新浪微博typecho主题源码,简约美观,适合做个人博客,该源码为主题模板,需要先搭建typecho,然后吧源码放到对应的模板目录下,后台启用即可 源码特点 支持自适应 个性化程度高 可设置背景图、顶…

Ubuntu24搭建k8s高可用集群

Ubuntu24搭建k8s高可用集群 环境信息 主机名IPk8s版本备注vm-master192.168.103.2501.28.2master1vm-master2192.168.103.2491.28.2master2vm-master3192.168.103.2541.28.2master3vm-node1192.168.103.2511.28.2node1vm-node2192.168.103.2521.28.2node2 容器进行时&#xf…

洛谷P1216 [IOI 1994] 数字三角形 Number Triangles(动态规划)

P1216 [IOI 1994] 数字三角形 Number Triangles - 洛谷 代码区&#xff1a; #include<algorithm> #include<iostream>using namespace std; const int R 1005; int dp[R][R]; int arr[R][R]; int main() {int n;cin >> n;for (int i 1; i < n; i) {for…

Spring Boot Actuator 自定义健康检查(附Demo)

目录 前言1. Demo2. 拓展 前言 &#x1f91f; 找工作&#xff0c;来万码优才&#xff1a;&#x1f449; #小程序://万码优才/r6rqmzDaXpYkJZF Spring Boot 的 actuator 提供了应用监控的功能&#xff0c;其中健康检查&#xff08;Health Check&#xff09;是一个重要的部分&…

2025年优化算法:人工旅鼠算法(Artificial lemming algorithm,ALA)

人工旅鼠算法(Artificial lemming algorithm&#xff0c;ALA)是发表在中科院二区期刊“ARTIFICIAL INTELLIGENCE REVIEW”&#xff08;IF&#xff1a;11.7&#xff09;的2025年智能优化算法 01.引言 随着信息技术与工程科学的快速发展&#xff0c;现代优化问题呈现出高维、非线…

「实战指南 」Swift 并发中的任务取消机制

网罗开发 &#xff08;小红书、快手、视频号同名&#xff09; 大家好&#xff0c;我是 展菲&#xff0c;目前在上市企业从事人工智能项目研发管理工作&#xff0c;平时热衷于分享各种编程领域的软硬技能知识以及前沿技术&#xff0c;包括iOS、前端、Harmony OS、Java、Python等…

实验12深度学习

实验12深度学习 一、实验目的 &#xff08;1&#xff09;理解并熟悉深度神经网络的工作原理&#xff1b; &#xff08;2&#xff09;熟悉常用的深度神经网络模型及其应用环境&#xff1b; &#xff08;3&#xff09;掌握Anaconda的安装和设置方法&#xff0c;进一步熟悉Jupyte…

【问题解决】Postman 测试报错 406

现象 Tomcat 日志 org.springframework.web.servlet.handler.AbstractHandlerExceptionResolver.logException Resolved org.springframework.web.HttpMediaTypeNotAcceptableException: No acceptable representation HTTP状态 406 - 不可接收 的报错&#xff0c;核心原因 客…

微信小程序:用户拒绝小程序获取当前位置后的处理办法

【1】问题描述&#xff1a; 小程序在调用 wx.getLocation() 获取用地理位置时&#xff0c;如果用户选择拒绝授权&#xff0c;代码会直接抛出错误。如果再次调用 wx.getLocation() 时&#xff0c;就不会在弹窗询问用户是否允许授权。导致用户想要重新允许获取地理位置时&#x…

【MySQL】内置函数

目录 一、日期时间函数1.1 简单使用1.2 案例实操 二、字符串函数2.1 简单使用2.2 案例实践2.2.1 获取emp表的ename列的字符集2.2.2 要求显示exam_result表中的信息&#xff0c;显示格式&#xff1a;“XXX的语文是XXX分&#xff0c;数学XXX分&#xff0c;英语XXX分”2.2.3 求exa…

模块二 单元4 安装AD+DC

模块二 单元4 安装ADDC 两个任务&#xff1a; 1.安装AD活动目录 2.升级当前服务器为DC域控制器 安装前的准备工作&#xff1a; 确定你要操作的服务器系统&#xff08;Windows server 2022&#xff09;&#xff1b; 之前的服务器系统默认是工作组的模式workgroup模式&#xff08…

卫星互联网智慧杆:开启智能城市新时代​

哇哦&#xff01;在当下这个数字化浪潮正以雷霆万钧之势席卷全球的超酷时代&#xff0c;智慧城市建设已然成为世界各国你追我赶、竞相发力的核心重点领域啦&#xff01;而咱们的卫星互联网智慧杆&#xff0c;作为一项完美融合了卫星通信与物联网顶尖技术的创新结晶&#xff0c;…

ThreadLocal 的详细使用指南

一、ThreadLocal 核心原理 ThreadLocal 是 Java 提供的线程绑定机制&#xff0c;为每个线程维护变量的独立副本。其内部通过 ThreadLocalMap 实现&#xff0c;每个线程的 Thread 对象都有一个独立的 ThreadLocalMap&#xff0c;存储以 ThreadLocal 对象为键、线程局部变量为值…