使用猴子补丁对pytorch的分布式接口进行插桩

训练脚本:

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch import nn
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import torch.nn.functional as F
import os
import distributed_patch# 设置 NCCL 日志环境变量
'''
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["NCCL_DEBUG_SUBSYS"] = "ALL"  # 或者 COLL
os.environ["NCCL_LOG_FILE"] = "nccl_log.txt"# 运行 PyTorch 分布式代码
'''class Net(nn.Module):  # 模型定义def __init__(self):super(Net, self).__init__()self.flatten = nn.Flatten()self.seq = nn.Sequential(nn.Linear(28 * 28, 128),nn.ReLU(),nn.Linear(128, 64),nn.ReLU(),nn.Linear(64, 10))def forward(self, x):x = self.flatten(x)return self.seq(x)def main():dist.init_process_group(backend='nccl')  # 【集合通讯】其他进程连master,大家互认rank = dist.get_rank()world_size = dist.get_world_size()device_name = f'cuda:{rank}'checkpoint = None  # 各自加载checkpointtry:checkpoint = torch.load('checkpoint.pth', map_location='cpu')  # checkpoint是cuda:0保存的,加载默认会读到cuda:0,所以明确指定给cpuexcept:passmodel = Net().to(device_name)if checkpoint and rank == 0:  # rank0恢复模型参数model.load_state_dict(checkpoint['model'])model = DDP(model)  # 【集合通讯】rank0广播参数给其他进程optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # model参数一致,则optim会保证其初始状态一致if checkpoint:optimizer.load_state_dict(checkpoint['optimizer'])  # 各自加载checkpointtrain_dataset = MNIST(root='./data', download=True, transform=ToTensor(), train=True)  # 各自加载datasetsampler = DistributedSampler(train_dataset)  # 指派子集给各进程train_dataloader = DataLoader(train_dataset, batch_size=32, sampler=sampler, persistent_workers=True, num_workers=2)val_dataset = MNIST(root='./data', download=True, transform=ToTensor(), train=False)val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True, persistent_workers=True, num_workers=2)for epoch in range(20):sampler.set_epoch(epoch)  # 【集合通讯】生成随机种子,rank0广播给其他进程model.train()for x, y in train_dataloader:x, y = x.to(device_name), y.to(device_name)pred_y = model(x)  # 【集合通讯】rank0广播model buffer给其他进程loss = F.cross_entropy(pred_y, y)optimizer.zero_grad()loss.backward()  # 【集合通讯】每个参数的梯度做all reduce(每个进程会收到其他进程的梯度,并求平均)optimizer.step()dist.reduce(loss, dst=0)  # 【集合通讯】rank0汇总其他进程的lossif rank == 0:train_avg_loss = loss.item() / world_size# evaluateraw_model = model.moduleval_loss = 0with torch.no_grad():for x, y in val_dataloader:x, y = x.to(device_name), y.to(device_name)pred_y = raw_model(x)loss = F.cross_entropy(pred_y, y)val_loss += loss.item()val_avg_loss = val_loss / len(val_dataloader)print(f'train_loss:{train_avg_loss} val_loss:{val_avg_loss}')# checkpointtorch.save({'model': model.module.state_dict(), 'optimizer': optimizer.state_dict()}, '.checkpoint.pth')os.replace('.checkpoint.pth', 'checkpoint.pth')dist.barrier()  # 【集合通讯】等待rank0跑完evalif __name__ == '__main__':main()# torchrun --nproc_per_node 1 pytorch_dis_gpu.py

插桩脚本:

import torch.distributed as dist# 保存原始函数引用
original_functions = {"init_process_group": dist.init_process_group,"all_reduce": dist.all_reduce,"reduce": dist.reduce,"broadcast": dist.broadcast,"barrier": dist.barrier,"get_rank": dist.get_rank,"get_world_size": dist.get_world_size
}# 插桩函数
def patched_init_process_group(*args, **kwargs):print("[distributed] init_process_group called")return original_functions["init_process_group"](*args, **kwargs)def patched_all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, async_op=False):print("[distributed] all_reduce called")return original_functions["all_reduce"](tensor, op, group, async_op)def patched_reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, async_op=False):print("[distributed] reduce called")return original_functions["reduce"](tensor, dst, op, group, async_op)def patched_broadcast(tensor, src, group=None, async_op=False):print("[distributed] broadcast called")return original_functions["broadcast"](tensor, src, group, async_op)def patched_barrier(*args, **kwargs):print("[distributed] barrier called")return original_functions["barrier"](*args, **kwargs)def patched_get_rank(*args, **kwargs):print("[distributed] get_rank called")return original_functions["get_rank"](*args, **kwargs)def patched_get_world_size(*args, **kwargs):print("[distributed] get_world_size called")return original_functions["get_world_size"](*args, **kwargs)# 替换分布式接口函数为插桩版本
dist.init_process_group = patched_init_process_group
dist.all_reduce = patched_all_reduce
dist.reduce = patched_reduce
dist.broadcast = patched_broadcast
dist.barrier = patched_barrier
dist.get_rank = patched_get_rank
dist.get_world_size = patched_get_world_size

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/477254.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AWS 新加坡EC2 VPS 性能、线路评测及免费注意事项

原文论坛给你更好的阅读讨论体验💐: AWS 新加坡EC2 VPS 性能、线路评测及免费注意事项 - VPS - 波波论坛 引言 对于那些习惯薅“羊毛”的朋友来说, AWS 的 免费套餐 可能已经非常熟悉。这台vps是我用外币卡薅的免费的12个月的机器&#xf…

C++ASCII码表和字符操作

目录 1. 引言 2. ASCII码表 2.1 控制字符 2.2 可显示字符 3. 字符操作 3.1 记住几个字符规律 3.2 打印能够显示的ASCII码 3.3 字母大小写转换 3.4 数字转数字字符 1. 引言 在电子计算机中,只能识别由 0 和 1 组成的一串串的二进制数字,为了将人类…

git使用(二)

git使用(二) git常用基本操作命令git clonegit loggit remotegit statusgit addgit commitgit pushgit branchgit pull git常用基本操作命令 git clone 项目开发中项目负责人会在github上创建一个远程仓库,我们需要使用git clone将远程仓库…

密码学11

概论 计算机安全的最核心三个关键目标(指标)/为:保密性 Confidentiality、完整性 Integrity、可用性 Availability ,三者称为 CIA三元组 数据保密性:确保隐私或是秘密信息不向非授权者泄漏,也不被非授权者使…

netstat -tuln | grep 27017(显示所有监听状态的 TCP 和 UDP 端口,并且以数字形式显示地址和端口号)

文章目录 1. 确定占用端口的进程使用 lsof 命令使用 fuser 命令 2. 结束占用端口的进程3. 修改 MongoDB 配置文件4. 检查 MongoDB 日志文件5. 重新启动 MongoDB 服务6. 检查 MongoDB 服务状态总结 [rootlocalhost etc]# netstat -tuln | grep 27017 tcp 0 0 127.0.…

ElasticSearch7.x入门教程之集群安装(一)

文章目录 前言一、es7.x版本集群安装二、elasticsearch-head安装三、Kibana安装总结 前言 在工作中遇到了,便在此记录一下,以防后面会再次遇到。第一次使用是在2020年末,过了很久了,忘了些许部分了。 在工作当中,如果…

I.MX6U 裸机开发18.GPT定时器实现高精度延时

I.MX6U 裸机开发18.GPT定时器实现高精度延时 一、GPT定时器简介1. GPT 功能2. 时钟源3. 框图4. 运行模式(1)Restart mode(2)Free-Run Mode 5. 中断类型(1)溢出中断 Rollover Interrupt(2&#x…

key-value存储实现

文章目录 一、项目简介二、项目流程图三、网络3.1、epoll实现3.2、io_uring实现 四、协议五、存储5.1、array实现5.2、rbtree实现5.3、hash实现 六、测试 一、项目简介 key-value存储其实是一个小型的redis,用户在客户端输入存储相关的指令发送给服务器端&#xff…

大公司如何实现打印机共享的?如何对打印机进行管控或者工号登录后进行打印?异地打印机共享的如何实现可以帮助用户在不同地理位置使用同一台打印机完成打印任务?

大公司如何实现打印机共享的?如何对打印机进行管控或者工号登录后进行打印?异地打印机共享的如何实现可以帮助用户在不同地理位置使用同一台打印机完成打印任务? 如果在局域网内,可以不需要进行二次开发,通过对打印机进…

微软发布Win11 24H2系统11月可选更新KB5046740!

系统之家11月22日报道,微软针对Win11 24H2系统推出2024年11月最新可选更新补丁KB5046740,更新后系统版本后升至26100.2454,此次更新后修复当应用程序以PDF和XLSX格式导出图表对象时停止响应、无法使用API查找旋转信息等问题。以下小编将给大家…

探索 RocketMQ:企业级消息中间件的选择与应用

一、关于RocketMQ RocketMQ 是一个高性能、高可靠、可扩展的分布式消息中间件,它是由阿里巴巴开发并贡献给 Apache 软件基金会的一个开源项目。RocketMQ 主要用于处理大规模、高吞吐量、低延迟的消息传递,它是一个轻量级的、功能强大的消息队列系统&…

李宏毅机器学习课程知识点摘要(6-13集)

pytorch简单的语法和结构 dataset就是数据集,dataloader就是分装好一堆一堆的 他们都是torch.utils.data里面常用的函数,已经封装好了 下面的步骤是把数据集读进来 这里是读进来之后,进行处理 声音信号,黑白照片,红…

Wekan看板安装部署与使用介绍

Wekan看板安装部署与使用介绍 1. Wekan简介 ​ Wekan 是一个开源的看板式项目管理工具,它的配置相对简单,因为大多数功能都是开箱即用的。它允许用户以卡片的形式组织和跟踪任务,非常适合敏捷开发和日常任务管理。Wekan 的核心功能包括看板…

【Mysql】开窗聚合函数----SUM,AVG, MIN,MAX

1、概念 在窗口中,每条记录动态地应用聚合函数(如:SUM(),AVG(),MAX(),MIN(),COUNT(),)可以动态计算在指定的窗口内的各种聚合函数值。 2、操作 以下操作将基于employee表进行操作。 sum() 进行sum的时候,没有order …

EWA Volume Splatting

摘要 本文提出了一种基于椭圆高斯核的直接体绘制新框架,使用了一种投影方法(splatting approach)。为避免混叠伪影(aliasing artifacts),我们引入了一种重采样滤波器的概念,该滤波器结合了重建核…

Vue实训---0-完成Vue开发环境的搭建

1.在官网下载和安装VS Code编辑器 完成中文语言扩展(chinese),安装成功后,需要重新启动VS Code编辑器,中文语言扩展才可以生效。 安装Vue-Official扩展,步骤与安装中文语言扩展相同(专门用于为“…

C# 超链接控件LinkLabel无法触发Alt快捷键

在C#中,为控件添加快捷键的方式有两种,其中一种就是Windows中较为常见的Alt快捷键,比如运行对话框,记事本菜单等。只需要按下 Alt 框号中带下划线的字母即可触发该控件的点击操作。如图所示 在C#开发中,实现类似的操作…

赛氪媒体支持“2024科普中国青年之星创作交流活动”医学专场落幕

2024年11月15日下午,由中国科普作家协会、科普中国发展服务中心主办,什刹海文化展示中心承办,并携手国内产学研一体融合领域的领军者——赛氪网共同支持的“2024科普中国青年之星创作交流活动”医学科普专场,在什刹海文化展示中心…

《现代制造技术与装备》是什么级别的期刊?是正规期刊吗?能评职称吗?

​问题解答 问:《现代制造技术与装备》是不是核心期刊? 答:不是,是知网收录的第二批认定学术期刊。 问:《现代制造技术与装备》级别? 答:省级。主管单位:齐鲁工业大学&#xff0…

(十一)Python字符串常用操作

一、访问字符串值 Python访问子字符串变量,可以使用方括号来截取字符串。与列表的索引一样,字符串索引从0开始。 hh"LaoTie 666" hh[2] mm"床前明月光" mm[3] 字符串的索引值可以为负值。若索引值为负数,则表示由字符…