【教程】DGL单机多卡分布式GCN训练

转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn]

如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~

        PyTorch中的DDP会将模型复制到每个GPU中。

        梯度同步默认使用Ring-AllReduce进行,重叠了通信和计算。

        示例代码:

视频:https://youtu.be/Cvdhwx-OBBo

代码:multigpu.py

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoaderimport torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import osimport dgl
from dgl.data import RedditDataset
from dgl.nn.pytorch import GraphConvdef ddp_setup(rank, world_size):"""DDP初始化设置。参数:rank (int): 当前进程的唯一标识符。world_size (int): 总进程数。"""os.environ["MASTER_ADDR"] = "localhost"  # 设置主节点地址os.environ["MASTER_PORT"] = "12355"      # 设置主节点端口init_process_group(backend="nccl", rank=rank, world_size=world_size)  # 初始化进程组torch.cuda.set_device(rank)  # 设置当前进程使用的GPU设备class GCN(torch.nn.Module):def __init__(self, in_feats, h_feats, num_classes):"""初始化图卷积网络(GCN)。参数:in_feats (int): 输入特征的维度。h_feats (int): 隐藏层特征的维度。num_classes (int): 输出类别的数量。"""super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)  # 第一层图卷积self.conv2 = GraphConv(h_feats, num_classes)  # 第二层图卷积def forward(self, g, in_feat):"""前向传播。参数:g (DGLGraph): 输入的图。in_feat (Tensor): 输入特征。返回:Tensor: 输出的logits。"""h = self.conv1(g, in_feat)  # 进行第一层图卷积h = F.relu(h)  # ReLU激活h = self.conv2(g, h)  # 进行第二层图卷积return hclass Trainer:def __init__(self,model: torch.nn.Module,train_data: DataLoader,optimizer: torch.optim.Optimizer,gpu_id: int,save_every: int,) -> None:"""初始化训练器。参数:model (torch.nn.Module): 要训练的模型。train_data (DataLoader): 训练数据的DataLoader。optimizer (torch.optim.Optimizer): 优化器。gpu_id (int): GPU ID。save_every (int): 每隔多少个epoch保存一次检查点。"""self.gpu_id = gpu_idself.model = model.to(gpu_id)  # 将模型移动到指定GPUself.train_data = train_dataself.optimizer = optimizerself.save_every = save_everyself.model = DDP(model, device_ids=[gpu_id])  # 使用DDP包装模型def _run_batch(self, batch):"""运行单个批次。参数:batch: 单个批次的数据。"""self.optimizer.zero_grad()  # 梯度清零graph, features, labels = batchgraph = graph.to(self.gpu_id)  # 将图移动到GPUfeatures = features.to(self.gpu_id)  # 将特征移动到GPUlabels = labels.to(self.gpu_id)  # 将标签移动到GPUoutput = self.model(graph, features)  # 前向传播loss = F.cross_entropy(output, labels)  # 计算交叉熵损失loss.backward()  # 反向传播self.optimizer.step()  # 更新模型参数def _run_epoch(self, epoch):"""运行单个epoch。参数:epoch (int): 当前epoch号。"""print(f"[GPU{self.gpu_id}] Epoch {epoch} | Steps: {len(self.train_data)}")for batch in self.train_data:self._run_batch(batch)  # 运行每个批次def _save_checkpoint(self, epoch):"""保存训练检查点。参数:epoch (int): 当前epoch号。"""ckp = self.model.module.state_dict()  # 获取模型的状态字典PATH = "checkpoint.pt"  # 定义检查点路径torch.save(ckp, PATH)  # 保存检查点print(f"Epoch {epoch} | Training checkpoint saved at {PATH}")def train(self, max_epochs: int):"""训练模型。参数:max_epochs (int): 总训练epoch数。"""for epoch in range(max_epochs):self._run_epoch(epoch)  # 运行当前epochif self.gpu_id == 0 and epoch % self.save_every == 0:self._save_checkpoint(epoch)  # 保存检查点def load_train_objs():"""加载训练所需的对象:数据集、模型和优化器。返回:tuple: 数据集、模型和优化器。"""data = RedditDataset(self_loop=True)  # 加载Reddit数据集,并添加自环graph = data[0]  # 获取图train_mask = graph.ndata['train_mask']  # 获取训练掩码features = graph.ndata['feat']  # 获取特征labels = graph.ndata['label']  # 获取标签model = GCN(features.shape[1], 128, data.num_classes)  # 初始化GCN模型optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)  # 初始化优化器train_data = [(graph, features, labels)]  # 准备训练数据return train_data, model, optimizerdef prepare_dataloader(dataset, batch_size: int):"""准备DataLoader。参数:dataset: 数据集。batch_size (int): 批次大小。返回:DataLoader: DataLoader对象。"""return DataLoader(dataset,batch_size=batch_size,pin_memory=True,shuffle=True,collate_fn=lambda x: x[0]  # 自定义collate函数,解包数据集中的单个元素)def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_size: int):"""主训练函数。参数:rank (int): 当前进程的唯一标识符。world_size (int): 总进程数。save_every (int): 每隔多少个epoch保存一次检查点。total_epochs (int): 总训练epoch数。batch_size (int): 批次大小。"""ddp_setup(rank, world_size)  # DDP初始化设置dataset, model, optimizer = load_train_objs()  # 加载训练对象train_data = prepare_dataloader(dataset, batch_size)  # 准备DataLoadertrainer = Trainer(model, train_data, optimizer, rank, save_every)  # 初始化训练器trainer.train(total_epochs)  # 开始训练destroy_process_group()  # 销毁进程组if __name__ == "__main__":import argparseparser = argparse.ArgumentParser(description='Simple distributed training job')parser.add_argument('--total_epochs', default=50, type=int, help='Total epochs to train the model')parser.add_argument('--save_every', default=10, type=int, help='How often to save a snapshot')parser.add_argument('--batch_size', default=8, type=int, help='Input batch size on each device (default: 32)')args = parser.parse_args()world_size = torch.cuda.device_count()  # 获取可用GPU的数量mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)  # 启动多个进程进行分布式训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/352381.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

docker-compose启动oracle11、并使用navicat进行连接

一、docker-compose.yml version: 3.9 services:oracle:image: registry.cn-hangzhou.aliyuncs.com/helowin/oracle_11grestart: alwaysprivileged: truecontainer_name: oracle11gvolumes:- ./data:/u01/app/oracleports:- 1521:1521network_mode: "host"logging:d…

Mac下载了docker,在终端使用docker命令时用不了

问题:在mac使用docker的时候,拉取docker镜像失败 原因:docker是需要用app使用的 ,所以在使用的时候必须打开这个桌面端软件才可以在终端上使用docker命令!!!

腾讯云EdgeOne对比普通CDN的分别

EdgeOne架构图 普通CDN架构图 ​​​​​​​ 腾讯云EdgeOne对比普通CDN的不同点 服务范围和集成度 腾讯云EdgeOne是一体化的综合平台,不仅提供内容分发功能,还包括安全防护、性能优化和边缘计算等服务。EdgeOne提供了DDoS防护、WAF(Web应…

通过语言大模型来学习tensorflow框架训练模型(三)

一、模型训练5步骤走 1.数据获取,2,数据处理,3.模型创建与训练,4 模型测试与评估,5.模型预测 二、tensorflow数据获取 在TensorFlow中,数据获取和预处理是构建深度学习模型的重要步骤。TensorFlow提供了多…

C语言王国——数组的旋转(轮转数组)三种解法

目录 一、题目 二、分析 2.1 暴力求解法 2.2 找规律 2.3 追求时间效率,以空间换时间 三、结论 一、题目 给定一个整数数组 nums,将数组中的元素向右轮转 k 个位置,其中 k 是非负数。 示例 1: 输入: nums [1,2,3,4,5,6,7], k 3 输出…

树莓派4B_OpenCv学习笔记6:OpenCv识别已知颜色_运用掩膜

今日继续学习树莓派4B 4G:(Raspberry Pi,简称RPi或RasPi) 本人所用树莓派4B 装载的系统与版本如下: 版本可用命令 (lsb_release -a) 查询: Opencv 版本是4.5.1: 学了这些OpenCv的理论性知识,不进行实践实在…

数据库管理-第205期 换个角度看23ai(20240617)

数据库管理205期 2024-06-17 数据库管理-第205期 换个角度看23ai(20240617)1 规范应用开发2 融合总结 数据库管理-第205期 换个角度看23ai(20240617) 作者:胖头鱼的鱼缸(尹海文) Oracle ACE Pro…

11.5.k8s中pod的调度-cordon,drain,delete

目录 一、概念 二、使用 1.cordon 停止调度 1.1.停止调度 1.2.解除恢复 2.drain 驱逐节点 2.1.驱逐节点 2.2.参数介绍 2.3.解除恢复 3.delete 删除节点 一、概念 cordon节点,drain驱逐节点,delete 节点,在对k8s集群节点执行维护&am…

vivado NODE、PACKAGE_PIN

节点是Xilinx部件上用于路由连接或网络的设备对象。它是一个 WIRE集合,跨越多个瓦片,物理和电气 连接在一起。节点可以连接到单个SITE_, 而是简单地将NETs携带进、携带出或携带穿过站点。节点可以连接到 任何数量的PIP,并且也可以…

Science | 稀土开采威胁马来西亚的生物多样性

马来西亚是一个生物多样性热点地区,拥有超过17万种物种,其中1600多种处于濒临灭绝的风险。马来西亚的热带雨林蕴藏了大部分的生物多样性,并为全球提供重要的生态系统效益,同时为土著社区带来经济和文化价值。同时马来西亚具有可观…

04 远程访问及控制

1、SSH远程管理 SSH是一种安全通道协议,主要用来实现字符界面的远程登录、远程复制等功能。 SSH协议对通信双方的数据传输进行了加密处理(包括用户登陆时输入得用户口令)。 终端:接收用户的指令 TTY终端不能远程,它…

Python界面编辑器Tkinter布局助手 使用体验

一、发现 我今天在网上搜关于Python Tkinter方面的信息时,发现了Python界面编辑器 Tkinter布局助手 的使用说明。 https://blog.csdn.net/weixin_52777652/article/details/135291731?spm1001.2014.3001.5506 这个编辑器是个开源的项目,个人用户可以…

大模型KV Cache节省神器MLA学习笔记(包含推理时的矩阵吸收分析)

首先,本文回顾了MHA的计算方式以及KV Cache的原理,然后深入到了DeepSeek V2的MLA的原理介绍,同时对MLA节省的KV Cache比例做了详细的计算解读。接着,带着对原理的理解理清了HuggingFace MLA的全部实现,每行代码都去对应…

从中概回购潮,看互联网的未来

王兴的饭否语录里有这样一句话:“对未来越有信心,对现在越有耐心。” 而如今的美团,已经不再掩饰对未来的坚定信心。6月11日,美团在港交所公告,计划回购不超过20亿美元的B类普通股股份。 而自从港股一季度财报季结束…

【吉林大学Java程序设计】第9章:并发控制

第9章:并发控制 1.线程的基本概念2.线程的创建与启动3.线程的调度与优先级线程的状态线程的生命周期线程控制的基本方法线程优先级 4.线程的协作多线程存在的问题同步区域(临界区)生产者与消费者问题(互斥与同步问题)哲…

线程池吞掉异常的case:源码阅读与解决方法

1. 问题背景 有一天给同事CR,看到一段这样的代码 try {for (param : params) {//并发处理,func无返回值ThreadPool.submit(func(param));} } catch (Exception e) {log.info("func抛异常啦,参数是:{}", param) } 我:你这段代码是…

【数据结构与算法 刷题系列】求带环链表的入环节点(图文详解)

💓 博客主页:倔强的石头的CSDN主页 📝Gitee主页:倔强的石头的gitee主页 ⏩ 文章专栏:《数据结构与算法 经典例题》C语言 期待您的关注 ​ 目录 一、问题描述 二、解题思路 方法一:数学公式推导法 方法…

苏州辰安塑业携塑料托盘、塑料物流箱解决方案亮相2024杭州快递物流展

苏州辰安塑业携塑料托盘、吹塑托盘、塑料卡板箱、塑料周转箱、塑料物流箱、塑料垃圾桶解决方案盛装亮相2024杭州快递物流展! 展位号:3C馆A51 苏州辰安塑业有限公司,是一家专业从事塑料托盘、吹塑托盘、塑料卡板箱、塑料周转箱、塑料物流箱、…

【前端】Nesj 学习笔记

1、前置知识 1.1 装饰器 装饰器的类型 declare type ClassDecorator <TFunction extends Function>(target: TFunction) > TFunction | void; declare type PropertyDecorator (target: Object, propertyKey: string | symbol) > void; declare type MethodDe…

大模型应用开发技术:Multi-Agent框架流程、源码及案例实战(二)

LlaMA 3 系列博客 基于 LlaMA 3 LangGraph 在windows本地部署大模型 &#xff08;一&#xff09; 基于 LlaMA 3 LangGraph 在windows本地部署大模型 &#xff08;二&#xff09; 基于 LlaMA 3 LangGraph 在windows本地部署大模型 &#xff08;三&#xff09; 基于 LlaMA…