pytest学习-pytorch单元测试

pytorch单元测试

  • 一.公共模块[common.py]
  • 二.普通算子测试[test_clone.py]
  • 三.集合通信测试[test_ccl.py]
  • 四.测试命令
  • 五.测试报告

希望测试pytorch各种算子、block、网络等在不同硬件平台,不同软件版本下的计算误差、耗时、内存占用等指标.

本文基于torch.testing._internal

一.公共模块[common.py]

import torch
from torch import nn
import math
import torch.nn.functional as F
import time
import os
import socket
import sys
from datetime import datetime
import numpy as np
import collections
import math
import json
import copy
import traceback
import subprocess
import unittest
import torch
import inspect
from torch.testing._internal.common_utils import TestCase, run_tests,parametrize,instantiate_parametrized_tests
from torch.testing._internal.common_distributed import MultiProcessTestCase
import torch.distributed as distos.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
os.environ["RANDOM_SEED"] = "0" device="cpu"
device_type="cpu"
device_name="cpu"try:if torch.cuda.is_available():     device_name=torch.cuda.get_device_name().replace(" ","")device="cuda:0"device_type="cuda"ccl_backend='nccl'
except:passhost_name=socket.gethostname()    
sdk_version=os.getenv("SDK_VERSION","")   						 #从环境变量中获取sdk版本号
metric_data_root=os.getenv("TORCH_UT_METRICS_DATA","./ut_data")  #日志存放的目录
device_count=torch.cuda.device_count()if not os.path.exists(metric_data_root):os.makedirs(metric_data_root)def device_warmup(device):'''设备warmup,确保设备已经正常工作,排除设备初始化的耗时'''left = torch.rand([128,512], dtype = torch.float16).to(device)right = torch.rand([512,128], dtype = torch.float16).to(device)out=torch.matmul(left,right)torch.cuda.synchronize()torch.manual_seed(1) 
np.random.seed(1)def loop_decorator(loops,rank=0):'''循环装饰器,用于统计函数的执行时间,内存占用等'''def decorator(func):def wrapper(*args,**kwargs):latency=[]memory_allocated_t0=torch.cuda.memory_allocated(rank)for _ in range(loops):input_copy=[x.clone() for x in args]beg= datetime.now().timestamp() * 1e6pred= func(*input_copy)gt=kwargs["golden"]torch.cuda.synchronize()end=datetime.now().timestamp() * 1e6mse = torch.mean(torch.pow(pred.cpu().float()- gt.cpu().float(), 2)).item()latency.append(end-beg)memory_allocated_t1=torch.cuda.memory_allocated(rank)avg_latency=np.mean(latency[len(latency)//2:]).round(3)first_latency=latency[0]return { "first_latency":first_latency,"avg_latency":avg_latency,"memory_allocated":memory_allocated_t1-memory_allocated_t0,"mse":mse}return wrapperreturn decoratorclass TorchUtMetrics:'''用于统计测试结果,比较之前的最小值'''def __init__(self,ut_name,thresold=0.2,rank=0):self.ut_name=f"{ut_name}_{rank}"self.thresold=thresoldself.rank=rankself.data={"ut_name":self.ut_name,"metrics":[]}self.metrics_path=os.path.join(metric_data_root,f"{self.ut_name}_{self.rank}.jon")try:with open(self.metrics_path,"r") as f:self.data=json.loads(f.read())except:passdef __enter__(self):self.beg= datetime.now().timestamp() * 1e6return selfdef __exit__(self, exc_type, exc_val, exc_tb):        self.report()self.save_data()def save_data(self):with open(self.metrics_path,"w") as f:f.write(json.dumps(self.data,indent=4))def set_metrics(self,metrics):self.end=datetime.now().timestamp() * 1e6item=collections.OrderedDict()item["time"]=datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')item["sdk_version"]=sdk_versionitem["device_name"]=device_nameitem["host_name"]=host_nameitem["metrics"]=metricsitem["metrics"]["e2e_time"]=self.end-self.begself.cur_item=itemself.data["metrics"].append(self.cur_item)def get_metric_names(self):return self.data["metrics"][0]["metrics"].keys()def get_min_metric(self,metric_name,devicename=None):min_value=0min_value_index=-1for idx,item in enumerate(self.data["metrics"]):if devicename and (devicename!=item['device_name']):                continue            val=float(item["metrics"][metric_name])if min_value_index==-1 or val<min_value:min_value=valmin_value_index=idxreturn min_value,min_value_indexdef get_metric_info(self,index):metrics=self.data["metrics"][index]return f'{metrics["device_name"]}@{metrics["sdk_version"]}'def report(self):assert len(self.data["metrics"])>0for metric_name in self.get_metric_names():min_value,min_value_index=self.get_min_metric(metric_name)min_value_same_dev,min_value_index_same_dev=self.get_min_metric(metric_name,device_name)cur_value=float(self.cur_item["metrics"][metric_name])print(f"-------------------------------{metric_name}-------------------------------")print(f"{cur_value}#{device_name}@{sdk_version}")if min_value_index_same_dev>=0:print(f"{min_value_same_dev}#{self.get_metric_info(min_value_index_same_dev)}")if min_value_index>=0:print(f"{min_value}#{self.get_metric_info(min_value_index)}")

二.普通算子测试[test_clone.py]

from common import *
class TestCaseClone(TestCase):#如果不满足条件,则跳过这个测试@unittest.skipIf(device_count>1, "Not enough devices") def test_todo(self):print(".TODO")#框架会自动遍历以下参数组合@parametrize("shape", [(10240,20480),(128,256)])@parametrize("dtype", [torch.float16,torch.float32])def test_clone(self,shape,dtype):#让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量@loop_decorator(loops=5)def run(input_dev):output=input_dev.clone()return output#记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2) as m:input_host=torch.ones(shape,dtype=dtype)*np.random.rand()input_dev=input_host.to(device)metrics=run(input_dev,golden=input_host.cpu())m.set_metrics(metrics)assert(metrics["mse"]==0)instantiate_parametrized_tests(TestCaseClone)if __name__ == "__main__":run_tests()

三.集合通信测试[test_ccl.py]

from common import *
class TestCCL(MultiProcessTestCase):'''CCL测试用例'''def _create_process_group_vccl(self, world_size, store):dist.init_process_group(ccl_backend, world_size=world_size, rank=self.rank, store=store)        pg = dist.distributed_c10d._get_default_group()return pgdef setUp(self):super().setUp()self._spawn_processes()def tearDown(self):super().tearDown()try:os.remove(self.file_name)except OSError:pass@propertydef world_size(self):return 4#框架会自动遍历以下参数组合@unittest.skipIf(device_count<4, "Not enough devices") @parametrize("op",[dist.ReduceOp.SUM])@parametrize("shape", [(1024,8192)])@parametrize("dtype", [torch.int64])def test_allreduce(self,op,shape,dtype):if self.rank >= self.world_size:returnstore = dist.FileStore(self.file_name, self.world_size)pg = self._create_process_group_vccl(self.world_size, store)if not torch.distributed.is_initialized():returntorch.cuda.set_device(self.rank)device = torch.device(device_type,self.rank)device_warmup(device)#让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量@loop_decorator(loops=5,rank=self.rank)def run(input_dev):dist.all_reduce(input_dev, op=op)return input_dev#记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2,rank=self.rank) as m:input_host=torch.ones(shape,dtype=dtype)*(100+self.rank)gt=[torch.ones(shape,dtype=dtype)*(100+i) for i in range(self.world_size)]gt_=gt[0]for i in range(1,self.world_size):gt_=gt_+gt[i]input_dev=input_host.to(device)metrics=run(input_dev,golden=gt_)m.set_metrics(metrics)assert(metrics["mse"]==0)dist.destroy_process_group(pg)instantiate_parametrized_tests(TestCCL)if __name__ == "__main__":run_tests()

四.测试命令

# 运行所有的测试
pytest -v -s -p no:warnings --html=torch_report.html --self-contained-html --capture=sys ./# 运行某一个测试
python3 test_clone.py -k "test_clone_shape_(128, 256)_float32"

五.测试报告

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/313059.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

春藤实业启动SAP S/4HANA Cloud Public Edition项目,与工博科技携手数字化转型之路

3月11日&#xff0c;广东省春藤实业有限公司&#xff08;以下简称“春藤实业”&#xff09;SAP S/4HANA Cloud Public Edition&#xff08;以下简称“SAP ERP公有云”&#xff09;项目正式启动。春藤实业董事长陈董、联络协调项目经理慕总、内部推行项目经理陈总以及工博董事长…

【函数式接口使用✈️✈️】通过具体的例子实现函数结合策略模式的使用

目录 前言 一、核心函数式接口 1. Consumer 2. Supplier 3. Function,> 二、场景模拟 1.面向对象设计 2. 策略接口实现&#xff08;以 Function 接口作为策略&#xff09; 三、对比 前言 在 Java 8 中引入了Stream API 新特性&#xff0c;这使得函数式编程风格进…

Chatgpt掘金之旅—有爱AI商业实战篇|品牌故事业务|(十六)

演示站点&#xff1a; https://ai.uaai.cn 对话模块 官方论坛&#xff1a; www.jingyuai.com 京娱AI 一、AI技术创业在品牌故事业务有哪些机会&#xff1f; 人工智能&#xff08;AI&#xff09;技术作为当今科技创新的前沿领域&#xff0c;为创业者提供了广阔的机会和挑战。随…

OpenCV从入门到精通实战(七)——探索图像处理:自定义滤波与OpenCV卷积核

本文主要介绍如何使用Python和OpenCV库通过卷积操作来应用不同的图像滤波效果。主要分为几个步骤&#xff1a;图像的读取与处理、自定义卷积函数的实现、不同卷积核的应用&#xff0c;以及结果的展示。 卷积 在图像处理中&#xff0c;卷积是一种重要的操作&#xff0c;它通过…

生成人工智能体:人类行为的交互式模拟论文与源码架构解析(2)——架构分析 - 核心思想环境搭建技术选型

4.架构分析 4.1.核心思想 超越一阶提示&#xff0c;通过增加静态知识库和信息检索方案或简单的总结方案来扩展语言模型。 将这些想法扩展到构建一个代理架构&#xff0c;该架构处理检索&#xff0c;其中过去的经验在每个时步动态更新&#xff0c;并混合与npc当前上下文和计划…

c++ qt6.5 打包sqlite组件无法使用,尽然 也需要dll支持!这和开发php 有什么区别!

运行 程序会默认使用当前所在文件夹中的 dll 文件&#xff0c;若文件不存在&#xff0c;会使用系统环境变量路径中的文件&#xff1b;又或者是需要在程序源代码中明确指定使用的 dll 的路径。由于我安装 Qt 时将相关 dll 文件路径都添加到了系统环境变量中&#xff0c;所以即使…

.net反射(Reflection)

文章目录 一.概念&#xff1a;二.反射的作用&#xff1a;三.代码案例&#xff1a;四.运行结果&#xff1a; 一.概念&#xff1a; .NET 反射&#xff08;Reflection&#xff09;是指在运行时动态地检查、访问和修改程序集中的类型、成员和对象的能力。通过反射&#xff0c;你可…

C语言通过键盘输入给结构体内嵌的结构体赋值——指针法

1 需求 以录入学生信息&#xff08;姓名、学号、性别、出生日期&#xff09;为例&#xff0c;首先通过键盘输入需要录入的学生的数量&#xff0c;再依次输入这些学生的信息&#xff0c;输入完成后输出所有信息。 2 代码 #include<stdio.h> #include<stdlib.h>//…

React - 基础学习

React基础 React更新视图的流程 是 一层一层查找 到对应的视图做更新 如何生成React工程 // 生成简单的react npx create-react-app react-app// 生成typescript的react npx create-react-app react-app-ts --template typescriptReact的基本能力 父子组件 // 父组…

openGauss学习笔记-265 openGauss性能调优-TPCC性能调优测试指导-操作系统配置

文章目录 openGauss学习笔记-265 openGauss性能调优-TPCC性能调优测试指导-操作系统配置265.1安装openEuler操作系统265.2 修改操作系统内核PAGESIZE为64KB。265.3 关闭CPU中断的服务irqbalance openGauss学习笔记-265 openGauss性能调优-TPCC性能调优测试指导-操作系统配置 本…

OpenCV基本图像处理操作(八)——光流估计

光流估计 光流估计是一种用于检测图像序列中像素点运动的技术。它基于这样的假设&#xff1a;在连续的视频帧之间&#xff0c;一个物体的移动会导致像素强度的连续性变化。通过分析这些变化&#xff0c;光流方法可以估计每个像素点的运动速度和方向。 光流估计通常用于多种应…

idea 将项目上传到gitee远程仓库具体操作

目录标题 一、新建仓库二、初始化项目三、addcommit四、配置远程仓库五、拉取远程仓库内容六、push代码到仓库 一、新建仓库 新建仓库教程 注意&#xff1a;远程仓库的初始文件不要与本地存在名字一样的文件&#xff0c;不然拉取会因为冲突而失败。可以把远程一样的初始文件删…

Labview2024安装包(亲测可用)

目录 一、软件简介 二、软件下载 一、软件简介 LabVIEW是一种由美国国家仪器&#xff08;NI&#xff09;公司开发的程序开发环境&#xff0c;它显著区别于其他计算机语言&#xff0c;如C和BASIC。传统的计算机语言是基于文本的语言来产生代码&#xff0c;而LabVIEW则采用图形化…

HEF4046BT功能参数及避免使用的场景、应用前置放大器

制造商:NXP 产品种类:锁相环 PLL 类型:PLL 电路数量:1 电源电压 最大:15 V 电源电压 最小:3 V 最大工作温度: 85 C 安装风格:SMD/SMT 封装:SO-16 封装:Bulk 商标:NXP Semiconductors 最小工作温度:- 40 C 工作电源电压:3.3 V, 5 V, 9 V, 12 V HEF4046BT 是一种 CMO…

LINUX中使用cron定时任务被隐藏,咋回事?

一、问题现象 线上服务器运行过程中&#xff0c;进程有莫名进程被启动&#xff0c;怀疑是有定时任务自动启动&#xff0c;当你用常规方法去查看&#xff0c;比如使用crontab去查看定时器任务&#xff0c;提示no crontab for root 或者使用cat到/var/spool/cron目录下去查看定时…

Linux编辑器-vim的使用

vim的基本概念 vim的三种模式(其实有好多模式&#xff0c;目前掌握这3种即可),分别是命令模式&#xff08;command mode&#xff09;、插 入模式&#xff08;Insert mode&#xff09;和底行模式&#xff08;last line mode&#xff09;&#xff0c;各模式的功能区分如下&#…

027——从GUI->Client->Server->driver实现对SR501的控制

目录 1、修改显示界面 2、 添加对SR501显示的处理和tcp消息的处理 3、 在服务器程序中添加对SR501的处理 4、 编写驱动句柄 5、 修改底层驱动 1、修改显示界面 有个奇怪的问题这里的注释如果用 就会报错不知道为啥&#xff0c;只能用#来注释 我把显示这里需要显示的器件的…

nginx部署上线

1. windows配置nginx 打包命令 npm run build:prod 1. 安装 nginx mac windows 2. mac / windows 环境下ngnix部署启动项目 2. nginx 解决 history 的 404 问题 3. nginx配置代理解决生产环境跨域问题

Docker构建Golang项目常见问题

Docker构建Golang项目常见问题 1 dockerfile报错&#xff1a;failed to read expected number of bytes: unexpected EOF2 go mod tidy: go.mod file indicates go 1.21, but maximum supported version is 1.17 1 dockerfile报错&#xff1a;failed to read expected number o…

rhce.定时任务和延迟任务项目

一 . 在系统中设定延迟任务要求如下&#xff1a; 在系统中建立 easylee 用户&#xff0c;设定其密码为 easylee 延迟任务由 root 用户建立 要求在 5 小时后备份系统中的用户信息文件到/backup中 确保延迟任务是使用非交互模式建立 确保系统中只有 root 用户和easylee用户可以…