Pytorch单机多卡分布式训练

Pytorch单机多卡分布式训练

数据并行:

DP和DDP

这两个都是pytorch下实现多GPU训练的库,DP是pytorch以前实现的库,现在官方更推荐使用DDP,即使是单机训练也比DP快。

  1. DataParallel(DP)

    • 只支持单进程多线程,单一机器上进行训练。
    • 模型训练开始的时候,先把模型复制到四个GPU上面,然后把数据分配给四个GPU进行前向传播,前向传播之后再汇总到卡0上面,然后在卡0上进行反向传播,参数更新,再将更新好的模型复制到其他几张卡上。

    在这里插入图片描述

  2. DistributedDataParallel(DDP)

    • 支持多线程多进程,单一或者多个机器上进行训练。通常DDP比DP要快。

    • 先把模型载入到四张卡上,每个GPU上都分配一些小批量的数据,再进行前向传播,反向传播,计算完梯度之后再把所有卡上的梯度汇聚到卡0上面,卡0算完梯度的平均值之后广播给所有的卡,所有的卡更新自己的模型,这样传输的数据量会少很多。

      在这里插入图片描述

DDP代码写法

  1. 初始化

    import torch.distributed as dist
    import torch.utils.data.distributed# 进行初始化,backend表示通信方式,可选择的有nccl(英伟达的GPU2GPU的通信库,适用于具有英伟达GPU的分布式训练)、gloo(基于tcp/ip的后端,可在不同机器之间进行通信,通常适用于不具备英伟达GPU的环境)、mpi(适用于支持mpi集群的环境)
    # init_method: 告知每个进程如何发现彼此,默认使用env://
    dist.init_process_group(backend='nccl', init_method="env://")
    
  2. 设置device

    device = torch.device(f'cuda:{args.local_rank}')	# 设置device,local_rank表示当前机器的进程号,该方式为每个显卡一个进程
    torch.cuda.set_device(device)	# 设定device
    
  3. 创建dataloader之前要加一个sampler

    trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    data_set = torchvision.datasets.MNIST("./", train=True, transform=trans, target_transform=None, download=True)
    train_sampler = torch.utils.data.distributed.DistributedSampler(data_set)	# 加一个sampler
    data_loader_train = torch.utils.data.DataLoader(dataset=data_set, batch_size=256, sampler=train_sampler)
    
  4. torch.nn.parallel.DistributedDataParallel包裹模型(先to(device)再包裹模型)

    net = torchvision.models.resnet101(num_classes=10)
    net.conv1 = torch.nn.Conv2d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
    net = net.to(device)
    net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], output_device=[device])	# 包裹模型
    
  5. 真正训练之前要set_epoch(),否则将不会shuffer数据

    for epoch in range(10):train_sampler.set_epoch(epoch)		# set_epochfor step, data in enumerate(data_loader_train):images, labels = dataimages, labels = images.to(device), labels.to(device)opt.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()opt.step()if step % 10 == 0:print("loss: {}".format(loss.item()))
    
  6. 模型保存

    if args.local_rank == 0:		# local_rank为0表示master进程torch.save(net, "my_net.pth")
    
  7. 运行

    if __name__ == "__main__":parser = argparse.ArgumentParser()# local_rank参数是必须的,运行的时候不必自己指定,DDP会自行提供parser.add_argument("--local_rank", type=int, default=0)args = parser.parse_args()main(args)
    
  8. 运行命令

    python -m torch.distributed.launch --nproc_per_node=2 多卡训练.py	# --nproc_per_node=2表示当前机器上有两个GPU可以使用
    

完整代码

import os
import argparse
import torch
import torchvision
import torch.distributed as dist
import torch.utils.data.distributedfrom torchvision import transforms
from torch.multiprocessing import Processdef main(args):# nccl: 后端基于NVIDIA的GPU-to-GPU通信库,适用于具有NVIDIA GPU的分布式训练# gloo: 后端是一个基于TCP/IP的后端,可在不同机器之间进行通信,通常适用于不具备NVIDIA GPU的环境。# mpi: 后端使用MPI实现,适用于具备MPI支持的集群环境。# init_method: 告知每个进程如何发现彼此,如何使用通信后端初始化和验证进程组。 默认情况下,如果未指定 init_method,PyTorch 将使用环境变量初始化方法 (env://)。dist.init_process_group(backend='nccl', init_method="env://") # nccl比较推荐device = torch.device(f'cuda:{args.local_rank}')torch.cuda.set_device(device)trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])data_set = torchvision.datasets.MNIST("./", train=True, transform=trans, target_transform=None, download=True)train_sampler = torch.utils.data.distributed.DistributedSampler(data_set)data_loader_train = torch.utils.data.DataLoader(dataset=data_set, batch_size=256, sampler=train_sampler)net = torchvision.models.resnet101(num_classes=10)net.conv1 = torch.nn.Conv2d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)net = net.to(device)net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], output_device=[device])criterion = torch.nn.CrossEntropyLoss()opt = torch.optim.Adam(params=net.parameters(), lr=0.001)for epoch in range(10):train_sampler.set_epoch(epoch)for step, data in enumerate(data_loader_train):images, labels = dataimages, labels = images.to(device), labels.to(device)opt.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()opt.step()if step % 10 == 0:print("loss: {}".format(loss.item()))if args.local_rank == 0:torch.save(net, "my_net.pth")if __name__ == "__main__":parser = argparse.ArgumentParser()# must parse the command-line argument: ``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by DDPparser.add_argument("--local_rank", type=int, default=0)args = parser.parse_args()main(args)

参考:

https://zhuanlan.zhihu.com/p/594046884
https://zhuanlan.zhihu.com/p/358974461

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/142698.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

3、嵌入式系统的启动过程(BoodLoader)

1、系统启动过程 通电 - > 执行BootLoader - > 加载内核 - > 挂在根文件系统 - > 执行应用程序 Windows的启动过程: 通电 - > 执行BIOS - > 加载WinNT内核 - > 挂在文件系统 - > 执行应用程序 二、嵌入式系统的结构 BootLoader 1、BootL…

ElasticSearch - 基于 DSL 、JavaRestClient 实现数据聚合

目录 一、数据聚合 1.1、基本概念 1.1.1、聚合分类 1.1.2、特点 1.2、DSL 实现 Bucket 聚合 1.2.1、Bucket 聚合基础语法 1.2.2、Bucket 聚合结果排序 1.2.3、Bucket 聚合限定范围 1.3、DSL 实现 Metrics 聚合 1.4、基于 JavaRestClient 实现聚合 1.4.1、组装请求 1…

Tomcat多实例、负载均衡、动静分离

Tomcat多实例部署 安装jdk [rootlocalhost ~]#systemctl stop firewalld.service [rootlocalhost ~]#setenforce 0 [rootlocalhost ~]#cd /opt [rootlocalhost opt]#ls apache-tomcat-8.5.16.tar.gz jdk-8u91-linux-x64.tar.gz rh [rootlocalhost opt]#tar xf jdk-8u91-linu…

春招秋招,大学生求职容易遇到哪些问题?

每到毕业季就有大批大学生从校园出来,他们怀抱梦想,希望能做出一番成绩。但现实总归是残酷的,有些人找不到工作,有一些人频繁跳槽,也有一些人最终找到的工作与自己的专业没有一点关系,迷茫好几年才找到方向…

钡铼BL302与PLC:提升酿酒业效率与品质的利器

啤酒是人类非常古老的酒精饮料,是水和茶之后世界上消耗量排名第三的饮料。 啤酒在生产过程中主要有制造麦芽、粉碎原料、糖化、发酵、贮酒後熟、过滤、灌装包装等工序流程。需要用到风选机、筛分机、糖化锅、发酵设备、过滤机、灌装机、包装机等食品机械设备。这些食…

安全远程访问工具

什么是安全远程访问 安全远程访问是指一种 IT 安全策略,允许对企业网络、任务关键型系统或任何机密数据进行授权、受控访问。它使 IT 团队能够根据员工和第三方的角色和工作职责为其提供不同级别的访问权限,安全的远程访问方法可保护系统和应用程序&…

前缀和实例5(连续数组)

题目: 给定一个二进制数组 nums , 找到含有相同数量的 0 和 1 的最长连续子数组,并返回该子数组的长度。 示例 1: 输入: nums [0,1] 输出: 2 说明: [0, 1] 是具有相同数量 0 和 1 的最长连续子数组。 示例 2: 输入: nums [0,1,0] 输出: 2 说明: [0…

el-upload实现复制粘贴图片

前言: 在之前的项目中,利用el-upload实现了上传图片视频的预览。项目上线后,经使用人员反馈,上传图片、视频每次要先保存到本地然后再上传,很是浪费时间,公司客服人员时间又很紧迫(因为要响应下…

DAMO-YOLO训练KITTI数据集

1.KITTI数据集准备 DAMO-YOLO支持COCO格式的数据集,在训练KITTI之前,需要将KITTI的标注转换为KITTI格式。KITTI是采取逐个文件标注的方式确定的,即一张图片对应一个label文件。下面是KITTI 3D目标检测训练集的第一个标注文件:000…

JavaScript位运算的妙用

位运算的妙用: 奇偶数, 色值换算,换值, 编码等 位运算的基础知识: 操作数是32位整数自动转化为整数在二进制下进行运算 一.按位与& 判断奇偶数: 奇数: num & 1 1偶数: num & 1 0 基本知识: 用法:操作数1 & 操作数2规则:有 0 则为…

机柜PDU产品采购与安装指南——TOWE精选

机柜PDU指的是Power Distribution Unit,即电源分配单元。它是一种电子设备,通常用于为数据中心、服务器机房等设施中的计算机和其他设备提供电力,是各行业数据中心“标配”构成部分,以确保服务器等用电设备的安全和稳定运行。 数据…

查看Linux系统信息的常用命令

文章目录 1. 机器配置查看2. 常用分析工具3. 常用指令解读3.1 lscpu 4. 定位僵尸进程5. 参考 1. 机器配置查看 # 总核数物理CPU个数x每颗物理CPU的核数 # 总逻辑CPU数物理CPU个数x每颗物理CPU的核数x超线程数 cat /proc/cpuinfo| grep "physical id"| sort| uniq| w…

[Linux]多线程编程

[Linux]多线程编程 文章目录 [Linux]多线程编程pthread_create函数pthread_join函数pthread_exit函数pthread_cancel函数pthread_self函数pthread_detach函数理解线程库和线程id Linux操作系统下,并没有真正意义上的线程,而是由进程中的轻量级进程&#…

在多台服务器上运行相同命令(二)、clush

介绍安装配置互信认证参数含义基本使用节点组拷贝文件 介绍 Clush(Cluster Shell)是一个用于管理和执行集群操作的工具,它允许你在多台远程主机上同时执行命令,以便批量管理服务器。Clush 提供了一种简单而强大的方式来管理大规模…

“押宝高手”乐视视频再出手,看中商业传奇剧《大盛魁》

作为最早开始版权采购的长视频平台,乐视视频一向擅长“押宝”优质内容。从《甄嬛传》到《白鹿原》等,乐视拿下了众多经典古装剧、年代剧的版权。 9月,乐视视频再次出手拿下的历史传奇剧《大盛魁》开始热播。该剧由王新民导演执导&#xff0c…

全渠道客服体验:Rocket.Chat 的无缝互动 | 开源日报 No.41

RocketChat/Rocket.Chat Stars: 36.9k License: NOASSERTION Rocket.Chat 是一个完全可定制的开源通信平台,适用于具有高标准数据保护要求的组织。我们是团队沟通场景下的最终免费开源解决方案,可以实现同事之间、公司之间或客户之间的实时对话。提高生…

SSM - Springboot - MyBatis-Plus 全栈体系(十三)

第三章 MyBatis 一、MyBatis 简介 1. 简介 MyBatis 最初是 Apache 的一个开源项目 iBatis, 2010 年 6 月这个项目由 Apache Software Foundation 迁移到了 Google Code。随着开发团队转投 Google Code 旗下, iBatis3.x 正式更名为 MyBatis。代码于 2013 年 11 月迁…

TS中class类的基本使用

想要创建对象,必须要先定义类,所谓的类可以理解为对象的模型,程序中可以根据类创建所指定类型的对象。 一、使用class关键字定义类 class 类名 { } // 使用class关键字来定义一个类 class Person{}// 使用new关键字创建一个对象 const per …

Pikachu靶场——SSRF 服务端请求伪造

文章目录 1 SSRF 服务端请求伪造1.1 SSRF(curl)1.1.1 漏洞防御 1.2 SSRF(file_get_content)1.2.1 漏洞防御1.2.3 SSRF 防御 1 SSRF 服务端请求伪造 SSRF(Server-Side Request Forgery:服务器端请求伪造) 其形成的原因大都是由于服务端提供了从其他服务器应用获取数据的功能&a…

【红外与可见光图像融合】离散平稳小波变换域中基于离散余弦变换和局部空间频率的红外与视觉图像融合方法(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…