PyTorch 模型性能分析和优化 - 第 6 部分

玩具模型

为了方便我们的讨论,我们使用流行的 timm python 模块(版本 0.9.7)定义了一个简单的基于 Vision Transformer (ViT) 的分类模型。我们将模型的 patch_drop_rate 标志设置为 0.5,这会导致模型在每个训练步骤中随机丢弃一半的补丁。使用 torch.use_definistic_algorithms 函数和 cuBLAS 环境变量 CUBLAS_WORKSPACE_CONFIG 对训练脚本进行编程,以最大限度地减少不确定性。请参阅下面的代码块以获取完整的模型定义:

import torch, time, os
import torch.optim
import torch.profiler
import torch.utils.data
from timm.models.vision_transformer import VisionTransformer
from torch.utils.data import Dataset

# use the GPU
device = torch.device("cuda:0")

# configure PyTorch to use reproducible algorithms
torch.manual_seed(0)
os.environ[
        "CUBLAS_WORKSPACE_CONFIG"
    ] = ":4096:8"
torch.use_deterministic_algorithms(True)

# define the ViT-backed classification model
model = VisionTransformer(patch_drop_rate=0.5).cuda(device)
# define the loss function
loss_fn = torch.nn.CrossEntropyLoss()
# define the training optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# use random data
class FakeDataset(Dataset):
    def __len__(self):
        return 1000000

    def __getitem__(self, index):
        rand_image = torch.randn([3224224], dtype=torch.float32)
        label = torch.tensor(data=[index % 1000], dtype=torch.int64)
        return rand_image, label

train_set = FakeDataset()
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128
                                           num_workers=8, pin_memory=True)


t0 = time.perf_counter()
summ = 0
count = 0
model.train()

# training loop wrapped with profiler object
with torch.profiler.profile(
    schedule=torch.profiler.schedule(wait=1, warmup=4, active=3, repeat=1),
    on_trace_ready=torch.profiler.tensorboard_trace_handler('/tmp/perf')
as prof:
    for step, data in enumerate(train_loader):
        inputs = data[0].to(device=device, non_blocking=True)
        label = data[1].squeeze(-1).to(device=device, non_blocking=True)
        with torch.profiler.record_function('forward'):
            outputs = model(inputs)
            loss = loss_fn(outputs, label)
        optimizer.zero_grad(set_to_none=True)
        with torch.profiler.record_function('backward'):
            loss.backward()
        with torch.profiler.record_function('optimizer_step'):
            optimizer.step()
        prof.step()
        batch_time = time.perf_counter() - t0
        if step > 1:  # skip first step
            summ += batch_time
            count += 1
        t0 = time.perf_counter()
        if step > 500:
            break

    print(f'average step time: {summ/count}')

我们将在 Amazon EC2 g5.2xlarge 实例(包含 NVIDIA A10G GPU 和 8 个 vCPU)上运行实验,并使用官方 AWS PyTorch 2.0 Docker 映像。

初始性能结果

在下图中,我们捕获了 TensorBoard 插件跟踪视图中显示的性能结果:

alt

虽然训练步骤的前向传递中的操作在顶部线程中聚集在一起,但在底部线程的向后传递中似乎出现了性能问题。在那里我们看到单个操作 GatherBackward 占据了跟踪的很大一部分。仔细观察,我们可以看到底层操作包括“to”、“copy_”和“cudaStreamSynchronize”。

这时你自然会问:为什么会出现这种情况?我们的模型定义的哪一部分导致了它? GatherBackward 跟踪提示可能涉及 torch.gather 操作,但它来自哪里以及为什么会导致同步事件?

在我们之前的文章中(例如,此处),我们提倡使用带标签的 torch.profiler.record_function 上下文管理器来查明性能问题的根源。这里的问题是性能问题发生在我们无法控制的向后传递中!特别是,我们无法使用上下文管理器将单个操作包装在向后传递中。

理论上,可以通过对跟踪视图的深入分析以及将后向传递中的每个片段与其前向传递中的相应操作进行匹配来识别有问题的模型操作。然而,这不仅非常乏味,而且还需要深入了解模型训练步骤的所有低级操作。

使用 torch.profiler.record_function 标签的优点是它使我们能够轻松地定位模型的有问题的部分。理想情况下,我们希望即使在向后传递中出现性能问题的情况下也能够保留相同的功能。

使用 PyTorch Backward Hooks 进行性能分析

尽管 PyTorch 不允许您包装单独的向后传递操作,但它确实允许您使用其钩子支持来添加和/或附加自定义功能。 PyTorch 支持将钩子注册到 torch.Tensors 和 torch.nn.Modules。尽管我们在本文中提出的技术将依赖于将向后钩子注册到模块,但张量钩子注册可以类似地用于替换或增强基于模块的方法。

在下面的代码块中,我们定义了一个包装函数,它接受一个模块并注册一个 full_backward_hook 和一个 full_backward_pre_hook (尽管实际上一个就足够了)。每个钩子都被编程为使用 torch.profiler.record_function 函数简单地将消息添加到捕获的分析跟踪中。

backward_pre_hook 被编程为打印“之前”消息,backward_hook 被编程为打印“之后”消息。附加可选的详细信息字符串以区分同一模块类型的多个实例。

def backward_hook_wrapper(module, details=None):
    
    # define register_full_backward_pre_hook function
    def bwd_pre_hook_print(self, output):
        message = f'before backward of {module.__class__.__qualname__}'
        if details:
            message = f'{message}{details}'
        with torch.profiler.record_function(message):
            return output

    # define register_full_backward_hook function
    def bwd_hook_print(self, input, output):
        message = f'after backward of {module.__class__.__qualname__}'
        if details:
            message = f'{message}{details}'
        with torch.profiler.record_function(message):
            return input

    # register hooks
    module.register_full_backward_pre_hook(bwd_pre_hook_print)
    module.register_full_backward_hook(bwd_hook_print)
    return module

使用backward_hook_wrapper函数,我们可以开始定位性能问题的根源。我们首先仅包装模型和损失函数,如下面的代码块所示:

model = backward_hook_wrapper(model)
loss_fn = backward_hook_wrapper(loss_fn)

使用 TensorBoard 插件 Trace View 的搜索框,我们可以识别“之前”和“之后”消息的位置,并推断出模型和损失的反向传播的开始和结束位置。这使我们能够得出结论,性能问题发生在模型的向后传递中。下一步是使用 back_hook_wrapper 函数包装 Vision Transformer 的内部模块:

model.patch_embed = backward_hook_wrapper(model.patch_embed)
model.pos_drop = backward_hook_wrapper(model.pos_drop)
model.patch_drop = backward_hook_wrapper(model.patch_drop)
model.norm_pre = backward_hook_wrapper(model.norm_pre)
model.blocks = backward_hook_wrapper(model.blocks)
model.norm = backward_hook_wrapper(model.norm)
model.fc_norm = backward_hook_wrapper(model.fc_norm)
model.head_drop = backward_hook_wrapper(model.head_drop)

在上面的代码块中,我们指定了每个内部模块。包装所有模型第一级模块的另一种方法是迭代其named_children:

for submodule in model.named_children():
    submodule = backward_hook_wrapper(submodule)

下面的图像捕获显示在有问题的 GatherBackward 操作之前存在“before back of PatchDropout”消息:

alt

我们的性能分析表明,性能问题的根源是 PathDropout 模块。检查模块的forward函数,我们确实可以看到对torch.gather的调用。

就我们的玩具模型而言,我们只需要进行两次分析迭代即可找到性能问题的根源。在实践中,可能需要对该方法进行额外的迭代。

请注意,PyTorch 包含 torch.nn.modules.module.register_module_full_backward_hook 函数,该函数将在一次调用中将钩子附加到训练步骤中的所有模块。尽管这在简单情况下(例如我们的玩具示例)可能就足够了,但它无法使人区分同一模块类型的不同实例。

现在我们知道了性能问题的根源,我们可以开始尝试修复它。

优化建议:尽可能使用索引而不是收集

现在我们知道问题的根源在于 DropPatches 模块的 torch.gather 操作,我们可以研究长主机设备同步事件的触发因素可能是什么。我们的调查让我们回到 torch.use_definistic_algorithms 函数的文档,该函数告诉我们,当在需要 grad 的 CUDA 张量上调用时,torch.gather 会表现出非确定性行为,除非在模式设置为 True 的情况下调用 torch.use_definistic_algorithms。

换句话说,通过将脚本配置为使用确定性算法,我们修改了 torch.gather 向后传递的默认行为。事实证明,正是这种变化导致需要同步事件。事实上,如果我们删除此配置,性能问题就会消失!问题是,我们能否保持算法的确定性而不需要付出性能损失。

在下面的代码块中,我们提出了 PathDropout 模块前向函数的替代实现,该实现使用 torch.Tensor 索引而不是 torch.gather 产生相同的输出。修改后的代码行已突出显示。

from timm.layers import PatchDropout

class MyPatchDropout(PatchDropout):
    def forward(self, x):
        prefix_tokens = x[:, :self.num_prefix_tokens]
        x = x[:, self.num_prefix_tokens:]
        B = x.shape[0]
        L = x.shape[1]
        num_keep = max(1, int(L * (1. - self.prob)))
        keep_indices = torch.argsort(torch.randn(B, L, device=x.device),
                                     dim=-1)[:, :num_keep]

        # The following three lines were modified from the original
        # to use PyTorch indexing rather than torch.gather
        stride = L * torch.unsqueeze(torch.arange(B, device=x.device), 1)
        keep_indices = (stride + keep_indices).flatten()
        x = x.reshape(B * L, -1)[keep_indices].view(B, num_keep, -1)

        x = torch.cat((prefix_tokens, x), dim=1)
        return x


model.patch_drop = MyPatchDropout(
    prob = model.patch_drop.prob,
    num_prefix_tokens = model.patch_drop.num_prefix_tokens
)

在下图中,我们捕获了上述更改后的跟踪视图:

alt

我们可以清楚地看到,冗长的同步事件不再存在。

就我们的玩具模型而言,我们很幸运,torch.gather 操作的使用方式允许将其替换为 PyTorch 索引。当然,情况并非总是如此。 torch.gather 的其他用法可能没有基于索引的等效实现。

结果

在下表中,我们比较了在不同场景下训练玩具模型的性能结果:

alt

在我们的玩具示例中,优化虽然可衡量,但影响不大——性能提升约 2%。有趣的是,可重现模式下的 torch 索引比默认(非确定性)torch.gather 的表现更好。根据这些发现,尽可能评估使用索引而不是 torch.gather 的选项可能是一个好主意。

总结

尽管 PyTorch 因易于调试和跟踪而享有(合理的)声誉,但 torch.autograd 仍然是一个谜,并且分析训练步骤的向后传递可能相当困难。为了应对这一挑战,PyTorch 支持在反向传播的不同阶段插入钩子。在这篇文章中,我们展示了如何在迭代过程中使用 PyTorch 向后钩子以及 torch.profiler.record_function 来识别向后传递中性能问题的根源。我们将此技术应用于一个简单的 ViT 模型,并了解了 torch.gather 操作的一些细微差别。

在这篇文章中,我们讨论了一种非常具体的性能瓶颈类型。请务必查看我们在媒体上发布的其他帖子,其中涵盖了与机器学习工作负载的性能分析和性能优化相关的各种主题。

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/164155.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

软件测试肖sir__python之ui自动化定位方法(2)

Selenium中元素定位方法 一、定位方法 要实现UI自动化,就必须学会定位web页面元素,Selenium核心 webdriver模块提供了9种定位元素方法: 定位方式 提供方法 id定位 find_element_by_id() name定位 find_element_by_name() class定位 find_elem…

开源游戏引擎和模拟器的项目合集 | 开源专题 No.38

yuzu-emu/yuzu Stars: 26.2k License: GPL-3.0 yuzu是一款全球最受欢迎的开源Nintendo Switch模拟器,由Citra创建者编写。它采用C语言编写,并具有可移植性,在Windows和Linux上进行积极维护。该模拟器能够全速运行大多数商业游戏&#xff0c…

TSINGSEE烟火识别算法的技术原理是什么?如何应用在视频监控中?

AI烟火识别算法是基于深度学习技术的一种视觉识别算法,主要用于在视频监控场景中自动检测和识别烟雾、火焰的行为。该技术基于深度学习神经网络技术,可以动态识别烟雾和火焰从有到无、从小到大、从大到小、从小烟到浓烟的状态转换过程。 1、技术原理 1…

图论与网络优化

2.概念与计算 2.1 图的定义 2.1.1 定义 图(graph) G G G 是一个有序的三元组&#xff0c;记作 G < V ( G ) , E ( G ) , ψ ( G ) > G<V(G),E(G),\psi (G)> G<V(G),E(G),ψ(G)>。 V ( G ) V(G) V(G) 是顶点集。 E ( G ) E(G) E(G) 是边集。 ψ ( G ) \…

HTTP基础

HTTP请求报文格式 HTTP 的请求报文分为三个部分 请求行&#xff08;Request Line&#xff09;、请求头&#xff08;Request Header&#xff09;和请求体&#xff08;Request Body&#xff09;。请求体是HTTP请求的核心&#xff0c;其中包含了需要上传服务器的数据。常见的请求…

Open3D(C++) 最小二乘拟合二维直线(拉格朗日乘子法)

目录 一、算法原理二、代码实现三、结果展示本文由CSDN点云侠原创,爬虫网站自重 一、算法原理 平面直线的表达式为: y = k x + b (1) y=kx+

oracle-AWR报告生成方法

AWR报告生成方法 1. 以oracle用户登陆服务器 2. 进入到要保存awr报告的目录 3. 以sysdba身份连接数据库 sqlplus / as sysdba4. 执行生成AWR报告命令 ?/rdbms/admin/awrrpt.sql5. 选择AWR报告的文件格式 6. 选择生成多少天的AWR报告 7. 选择报告的快照起始和结束ID 8. 输入生…

npm ERR! exited with error code: 128

1.遇到的问题 报错信息&#xff1a;npm ERR! E:\tools\Gitt\Git\cmd\git.EXE ls-remote -h -t https://github.com/nhn/raphael.git npm ERR! npm ERR! fatal: unable to access https://github.com/nhn/raphael.git/: OpenSSL SSL_read: Connection was reset, errno 10054 …

JVM相关面试题

文章目录 什么是 JVM?Java语言的执行原理&#xff1f;Java的字节码文件结构&#xff1f;什么是u2,u4? JVM 加载过程都是什么步骤&#xff1f;什么是类加载器&#xff1f;什么是双亲委派模型如何打破双亲委派机制&#xff1f;什么是 tomcat 类加载机制&#xff1f;什么是JVM内…

C# 给List编个序号

给List编个号 int[] numbers { 5, 4, 1, 3, 9, 8, 6, 7, 2, 0 };int i 0;var q (from n in numbersselect (index:i,number:n)).ToList();foreach (var v in q) {Console.WriteLine($"i {v.index}, v {v.number}"); }

关于刷题时使用数组的小注意事项

&#x1f4af; 博客内容&#xff1a;关于刷题时使用数组的小技巧 &#x1f600; 作  者&#xff1a;陈大大陈 &#x1f680; 个人简介&#xff1a;一个正在努力学技术的准前端&#xff0c;专注基础和实战分享 &#xff0c;欢迎私信&#xff01; &#x1f496; 欢迎大家&#…

Python如何实现多线程编程

目录 线程的创建 线程的管理 线程的同步 线程池 线程同步和锁 总结 Python是一种广泛使用的编程语言&#xff0c;它具有丰富的库和工具&#xff0c;可以用来实现多线程编程。多线程编程是一种并行计算技术&#xff0c;它通过将任务划分为多个独立的任务并利用多个线程同时…

vue实现在页面拖拽放大缩小div并显示鼠标在div的坐标

1、功能要求&#xff1a; 实现在一个指定区域拖拽div,并可以放大缩小&#xff0c;同时显示鼠标在该div里的坐标&#xff0c;如图可示 缩小并拖动 2、实现 <div class"div_content" ref"div_content"><div class"div_image" id"…

C++(Qt)软件调试---linux使用dmesg定位程序崩溃位置(14)

C(Qt)软件调试—linux使用dmesg定位程序崩溃位置&#xff08;14&#xff09; 文章目录 C(Qt)软件调试---linux使用dmesg定位程序崩溃位置&#xff08;14&#xff09;1、前言2、ELF文件3、常用工具4、使用dmesg定位异常位置1.1 异常发生在可执行程序中1.2 异常发生在动态库中 1、…

C++算法:二叉树的序列化与反序列化

#题目 序列化是将一个数据结构或者对象转换为连续的比特位的操作&#xff0c;进而可以将转换后的数据存储在一个文件或者内存中&#xff0c;同时也可以通过网络传输到另一个计算机环境&#xff0c;采取相反方式重构得到原数据。 请设计一个算法来实现二叉树的序列化与反序列化。…

花生好车基于 KubeSphere 的微服务架构实践

公司简介 花生好车成立于 2015 年 6 月&#xff0c;致力于打造下沉市场汽车出行解决方案第一品牌。通过自建直营渠道&#xff0c;瞄准下沉市场&#xff0c;现形成以直租、批售、回租、新能源汽车零售&#xff0c;四大业务为核心驱动力的汽车新零售平台&#xff0c;目前拥有门店…

[自定义 Vue 组件] 小尾巴 Logo 组件 TailLogo

文字归档于&#xff1a;https://www.yuque.com/u27599042/coding_star/apt6y731ybmxgu5g 组件效果 组件依赖 自定义字符串工具函数 stringIsNull https://www.yuque.com/u27599042/coding_star/slncupw7un3ce7cb import {stringIsNull} from "/utils/string_utils.js&q…

移动App安全检测的必要性,app安全测试报告的编写注意事项

随着移动互联网的迅猛发展&#xff0c;移动App已经成为人们日常生活中不可或缺的一部分。然而&#xff0c;虽然App给我们带来了便利和乐趣&#xff0c;但也伴随着一些潜在的安全风险。黑客、病毒、恶意软件等威胁着用户的隐私和财产安全&#xff0c;因此进行安全检测就显得尤为…

acwing算法基础之数据结构--KMP算法

目录 1 知识点2 模板 1 知识点 KMP算法已经集成到string类型的find()方法了&#xff0c; 但这里我们不用这个&#xff0c;我们自己来实现这个方法。 KMP算法的关键步骤&#xff1a; p[N]表示输入模式串&#xff0c;求取该模式串的ne[]数组。ne[i]表示前缀等于后缀的长度&…

UE4 UltrDynamicSky与场景物体进行交互

找到材质 找到其最父类的材质 把这个拖过去连上即可