Pytorch使用教程(12)-如何进行并行训练?

在使用GPU训练大模型时,往往会面临单卡显存不足的情况。这时,通过多卡并行的形式来扩大显存是一个有效的解决方案。PyTorch主要提供了两个类来实现多卡并行:数据并行torch.nn.DataParallel(DP)和模型并行torch.nn.DistributedDataParallel(DDP)。本文将详细介绍这两种方法。

一、数据并行(torch.nn.DataParallel)

  1. 基本原理
    数据并行是一种简单的多GPU并行训练方式。它通过多线程的方式,将输入数据分割成多个部分,每个部分在不同的GPU上并行处理,最后将所有GPU的输出结果汇总,计算损失和梯度,更新模型参数。
    在这里插入图片描述

  2. 使用方法
    使用torch.nn.DataParallel非常简单,只需要一行代码就可以实现。以下是一个示例:

import torch
import torch.nn as nn# 检查是否有多个GPU可用
if torch.cuda.device_count() > 1:print("Let's use", torch.cuda.device_count(), "GPUs!")# 将模型转换为DataParallel对象model = nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
  1. 优缺点
    ‌优点‌:代码简单,易于使用,对小白比较友好。
    ‌缺点‌:GPU会出现负载不均衡的问题,一个GPU可能占用了大部分负载,而其他GPU却负载较轻,导致显存使用不平衡。

二、模型并行(torch.nn.DistributedDataParallel)

  1. 基本原理
    torch.nn.DistributedDataParallel(DDP)是一种真正的多进程并行训练方式。每个进程对应一个独立的训练过程,且只对梯度等少量数据进行信息交换。每个进程包含独立的解释器和GIL(全局解释器锁),因此可以充分利用多GPU的优势,实现更高效的并行训练。
    在这里插入图片描述

  2. 使用方法

    使用torch.nn.DistributedDataParallel需要进行一些额外的配置,包括初始化GPU通信方式、设置随机种子点、使用DistributedSampler分配数据等。以下是一个详细的示例:

初始化环境

import torch
import torch.distributed as dist
import argparsedef parse():parser = argparse.ArgumentParser()parser.add_argument('--local_rank', type=int, default=0)args = parser.parse_args()return argsdef main():args = parse()torch.cuda.set_device(args.local_rank)dist.init_process_group('nccl', init_method='env://')device = torch.device(f'cuda:{args.local_rank}')

设置随机种子点

import numpy as np# 固定随机种子点
seed = np.random.randint(1, 10000)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)使用DistributedSampler分配数据
python
Copy Code
from torch.utils.data.distributed import DistributedSamplertrain_dataset = ...  # 你的数据集
train_sampler = DistributedSampler(train_dataset, shuffle=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opts.batch_size, sampler=train_sampler
)

初始化模型

model = mymodel().to(device)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])训练循环
python
Copy Code
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()for ep in range(total_epoch):train_sampler.set_epoch(ep)for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()
  1. 优缺点
  • 优点‌:每个进程对应一个独立的训练过程,显存使用更均衡,性能更优。
  • 缺点‌:代码相对复杂,需要进行一些额外的配置。

三、对比与选择

  1. 对比
特点torch.nn.DataParalleltorch.nn.DistributedDataParallel
并行方式多线程多进程
显存使用可能不均衡更均衡
性能一般更优
代码复杂度简单复杂
  1. 选择建议
  • 对于初学者或快速实验,可以选择torch.nn.DataParallel,因为它代码简单,易于使用。
  • 对于需要高效并行训练的场景,建议选择torch.nn.DistributedDataParallel,因为它可以充分利用多GPU的优势,实现更高效的训练。

四、小结

通过本文的介绍,相信读者已经对PyTorch的多GPU并行训练有了更深入的了解。在实际应用中,可以根据模型的复杂性和数据的大小选择合适的并行训练方式,并调整batch size和学习率等参数以优化模型的性能。希望这篇文章能帮助你掌握PyTorch的多GPU并行训练技术。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/4488.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电脑换固态硬盘

参考: https://baijiahao.baidu.com/s?id1724377623311611247 一、根据尺寸和缺口可以分为以下几种: 1、M.2 NVME协议的固态 大部分笔记本是22x42MM和22x80MM nvme固态。 在京东直接搜: M.2 2242 M.2 2280 2、msata接口固态 3、NGFF M.…

利用免费GIS工具箱实现高斯泼溅切片,将 PLY 格式转换为 3dtiles

在地理信息系统(GIS)和三维数据处理领域,不同数据格式有其独特应用场景与优势。PLY(Polygon File Format)格式常用于存储多边形网格数据,而 3DTiles 格式在 Web 端三维场景展示等方面表现出色。将 PLY 格式…

【华为路由/交换机的ftp文件操作】

华为路由/交换机的ftp文件操作 PC:10.0.1.1 R1:10.0.1.254 / 10.0.2.254 FTP:10.0.2.1 S1:无配置 在桌面创建FTP-Huawei文件夹,里面创建config/test.txt。 点击上图中的“启动”按钮。 然后ftp到server,…

基于微信小程序的安心陪诊管理系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏:…

利用rsync备份全网服务器数据

一、项目描述 某公司里有一台Web服务器,里面的数据很重要,但是如果硬盘坏了数据就会丢失,现在领导要求把数据做备份,这样Web服务器数据丢失在可以进行恢复,要求如下: 1、备份要求 每天晚上00点整在Web服…

Mysql 主从复制原理及其工作过程,配置一主两从实验

主从原理:MySQL 主从同步是一种数据库复制技术,它通过将主服务器上的数据更改复制到一个或多个从服务器,实现数据的自动同步。 主从同步的核心原理是将主服务器上的二进制日志复制到从服务器,并在从服务器上执行这些日志中的操作…

Ubuntu 24.04 LTS 空闲硬盘挂载到 文件管理器的 other locations

Ubuntu 24.04 LTS 确认硬盘是否被识别 使用 lsblk 查看信息,其中sda这个盘是我找不到的,途中是挂在好的。 分区和格式化硬盘 如果新硬盘没有分区,你需要先分区并格式化它。假设新硬盘为 /dev/sdb,使用 fdisk 或 parted 对硬盘…

调试Hadoop源代码

个人博客地址:调试Hadoop源代码 | 一张假钞的真实世界 Hadoop版本 Hadoop 2.7.3 调试模式下启动Hadoop NameNode 在${HADOOP_HOME}/etc/hadoop/hadoop-env.sh中设置NameNode启动的JVM参数,如下: export HADOOP_NAMENODE_OPTS"-Xdeb…

JSON-stringify和parse

目录 JSON序列化 JSON反序列化 序列化和反序列化转换 深拷贝 JSON.parse接受参数类型错误导致抛出异常 当有子元素的时候,设置父元素样式的方式 防抖问题 JSON序列化 const obj {name: "John",age: 30,city: "New York",};// 基本用法&…

3 前端(中):JavaScript

文章目录 前言:JavaScript简介一、ECMAscript(JavaScript基本语法)1 JavaScript与html结合方式(快速入门)2 基本知识(1)JavaScript注释(和Java注释一样)(2&am…

服务器一次性部署One API + ChatGPT-Next-Web

服务器一次性部署One API ChatGPT-Next-Web One API ChatGPT-Next-Web 介绍One APIChatGPT-Next-Web docker-compose 部署One API ChatGPT-Next-WebOpen API docker-compose 配置ChatGPT-Next-Web docker-compose 配置docker-compose 启动容器 后续配置 同步发布在个人笔记服…

OSI七层协议——分层网络协议

OSI七层协议,顾名思义,分为七层,实际上七层是不存在的,是人为的进行划分,让人更好的理解 七层协议包括,物理层(我),数据链路层(据),网络层(网),传输层(传输),会话层(会),表示层(表),应用层(用)(记忆口诀->我会用表…

【AI论文】生成式视频模型是否通过观看视频学习物理原理?

摘要:AI视频生成领域正经历一场革命,其质量和真实感在迅速提升。这些进步引发了一场激烈的科学辩论:视频模型是否学习了能够发现物理定律的“世界模型”,或者,它们仅仅是复杂的像素预测器,能够在不理解现实…

【TCP】rfc文档

tcp协议相关rfc有哪些 TCP(传输控制协议)是一个复杂的协议,其设计和实现涉及多个RFC文档。以下是一些与TCP协议密切相关的RFC文档列表,按照时间顺序排列,涵盖了从基础定义到高级特性和优化的各个方面: 基…

VLAN基础理论

VLAN V:Virtual(虚拟) LAN ——局域网 VLAN ——虚拟局域网(虚拟广播域:交换机和路由器协同工作后,将原来的一个广播域,逻辑上切分为多个。) VLAN的配置我们基于以下拓扑进行: PC1-4的IP地址依次为192.168.1.1-192.168…

RabbitMQ实现延迟消息发送——实战篇

在项目中,我们经常需要使用消息队列来实现延迟任务,本篇文章就向各位介绍使用RabbitMQ如何实现延迟消息发送,由于是实战篇,所以不会讲太多理论的知识,还不太理解的可以先看看MQ的延迟消息的一个实现原理再来看这篇文章…

IoTDB 常见问题 QA 第四期

关于 IoTDB 的 Q & A IoTDB Q&A 第四期来啦!我们将定期汇总我们将定期汇总社区讨论频繁的问题,并展开进行详细回答,通过积累常见问题“小百科”,方便大家使用 IoTDB。 Q1:Java 中如何使用 SSL 连接 IoTDB 问题…

【STM32-学习笔记-14-】FLASH闪存

文章目录 FALSH闪存一、FLASH简介二、FLASH基本结构三、FLASH解锁四、使用指针访问存储器五、FLASH擦除以及编程流程Ⅰ、程序存储器全擦除1. 读取FLASH_CR的LOCK位2. 检查LOCK位是否为13. 设置FLASH_CR的MER 1和STRT 1(如果LOCK位0)4. 检查FLASH_SR的B…

CamemBERT:一款出色的法语语言模型

摘要 预训练语言模型在自然语言处理中已无处不在。尽管这些模型取得了成功,但大多数可用模型要么是在英语数据上训练的,要么是在多种语言数据拼接的基础上训练的。这使得这些模型在除英语以外的所有语言中的实际应用非常有限。本文探讨了为其他语言训练…

线性代数概述

矩阵与线性代数的关系 矩阵是线性代数的研究对象之一: 矩阵(Matrix)是一个按照长方阵列排列的复数或实数集合,是线性代数中的核心概念之一。矩阵的定义和性质构成了线性代数中矩阵理论的基础,而矩阵运算则简洁地表示和…