【编译系列】Torch.compile()训练编译——算子融合逻辑 工程化

1. 背景:

torch.compile()中,Dynamo作为前端负责计算图的捕获,后端有inductor、tvm等进行编译优化。

  • Dynamo:在Python字节码层面注入pass,实现bytecode-to-bytecode的优化,通过对bytecode逐行进行解析构建FX Graph
  • Inductor:负责对FX Graph进行AOTAutograd生成joint-graph、decompose成PrimTorch基础op、基于硬件生成对应的kernel代码

本次主要分享以下几个方面:

  • TorchInductor中的算子融合逻辑
  • 如何在torch.compile中自定义融合算子

2. TorchInductor算子融合逻辑:

在这里插入图片描述
在TorchInductor中负责算子融合的主要有三类:

  1. FX Graph上进行的算子融合:FX IR的target可能还是torch.ops级别,属于比较粗粒度的融合,通常在推理场景下生效,如推理场景下会对Conv+BN进行算子融合,而训练场景下因为权重更新问题不会生效。
  2. GraphLowering过程中的inline:GraphLowering主要负责将FX Graph转为Inductor IR,在转换成Inductor IR的过程中对那些纯计算的中间结果进行inline实现融合效果。
  3. Inductor IR上的算子融合:Scheduler对GraphLowering后所有内存分配的Inductor IR(Inductor里面称为buffer)中有共享内存访问的算子进行融合。

2.1. FX Graph上的算子融合:

此阶段存在尚未decompose的op,且还需要进行AOTAutograd构建反传计算图,因此训练场景下进行算子融合的话会比较受限(通常需要提供算子的反传函数),但相对而言能在更顶层上进行融合收益也会更明显。以Conv+BN代码为例,

# Conv+BN
# code
import torch
from typing import List
import torch._dynamo as dynamo
import torch
import torch.nn as nnclass ConvNet(nn.Module):def __init__(self, num_classes=10):super(ConvNet, self).__init__()self.conv_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=2) # assuming input is 3-channel imageself.bn_1 = nn.BatchNorm2d(16)self.relu_1 = nn.ReLU()def forward(self, x):out = self.conv_1(x)out = self.bn_1(out)out = self.relu_1(out)return outdef test():data = torch.randn((2,3,128,128),requires_grad=True,device="cuda")model = ConvNet().to("cuda")model.eval()model = dynamo.optimize("inductor")(model)output = model(data)test()

在训练场景下,可以发现并没有在FX IR上进行任何算子融合的操作,就是先计算conv—》计算均值方差—》计算BN。

# 训练场景下的Conv+BN的FXGraph
def forward(self, primals, tangents):primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)convolution = torch.ops.aten.convolution.default(primals_8, primals_1, primals_2, [1, 1], [2, 2], [1, 1], False, [0, 0], 1);  primals_2 = Noneconvert_element_type = torch.ops.prims.convert_element_type.default(primals_5, torch.float32)convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_6, torch.float32)add = torch.ops.aten.add.Tensor(convert_element_type_1, 1e-05);  convert_element_type_1 = Nonesqrt = torch.ops.aten.sqrt.default(add);  add = Nonereciprocal = torch.ops.aten.reciprocal.default(sqrt);  sqrt = Nonemul = torch.ops.aten.mul.Tensor(reciprocal, 1);  reciprocal = Noneunsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type, -1);  convert_element_type = Noneunsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, -1);  unsqueeze = Noneunsqueeze_2 = torch.ops.aten.unsqueeze.default(mul, -1);  mul = Noneunsqueeze_3 = torch.ops.aten.unsqueeze.default(unsqueeze_2, -1);  unsqueeze_2 = Nonesub = torch.ops.aten.sub.Tensor(convolution, unsqueeze_1);  unsqueeze_1 = Nonemul_1 = torch.ops.aten.mul.Tensor(sub, unsqueeze_3);  sub = unsqueeze_3 = Noneunsqueeze_4 = torch.ops.aten.unsqueeze.default(primals_3, -1)unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, -1);  unsqueeze_4 = Nonemul_2 = torch.ops.aten.mul.Tensor(mul_1, unsqueeze_5);  mul_1 = unsqueeze_5 = Noneunsqueeze_6 = torch.ops.aten.unsqueeze.default(primals_4, -1);  primals_4 = Noneunsqueeze_7 = torch.ops.aten.unsqueeze.default(unsqueeze_6, -1);  unsqueeze_6 = Noneadd_1 = torch.ops.aten.add.Tensor(mul_2, unsqueeze_7);  mul_2 = unsqueeze_7 = Nonerelu = torch.ops.aten.relu.default(add_1);  add_1 = Nonealias = torch.ops.aten.alias.default(relu)alias_1 = torch.ops.aten.alias.default(alias);  alias = Nonealias_2 = torch.ops.aten.alias.default(alias_1);  alias_1 = Nonealias_3 = torch.ops.aten.alias.default(alias_2);  alias_2 = Nonele = torch.ops.aten.le.Scalar(alias_3, 0);  alias_3 = Nonescalar_tensor = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))where = torch.ops.aten.where.self(le, scalar_tensor, tangents_1);  le = scalar_tensor = tangents_1 = Noneadd_2 = torch.ops.aten.add.Tensor(primals_6, 1e-05);  primals_6 = Nonersqrt = torch.ops.aten.rsqrt.default(add_2);  add_2 = Noneunsqueeze_8 = torch.ops.aten.unsqueeze.default(primals_5, 0);  primals_5 = Noneunsqueeze_9 = torch.ops.aten.unsqueeze.default(unsqueeze_8, 2);  unsqueeze_8 = Noneunsqueeze_10 = torch.ops.aten.unsqueeze.default(unsqueeze_9, 3);  unsqueeze_9 = Nonesum_1 = torch.ops.aten.sum.dim_IntList(where, [0, 2, 3])sub_1 = torch.ops.aten.sub.Tensor(convolution, unsqueeze_10);  convolution = unsqueeze_10 = Nonemul_3 = torch.ops.aten.mul.Tensor(where, sub_1);  sub_1 = Nonesum_2 = torch.ops.aten.sum.dim_IntList(mul_3, [0, 2, 3]);  mul_3 = Nonemul_8 = torch.ops.aten.mul.Tensor(rsqrt, primals_3);  primals_3 = Noneunsqueeze_17 = torch.ops.aten.unsqueeze.default(mul_8, 0);  mul_8 = Noneunsqueeze_18 = torch.ops.aten.unsqueeze.default(unsqueeze_17, 2);  unsqueeze_17 = Noneunsqueeze_19 = torch.ops.aten.unsqueeze.default(unsqueeze_18, 3);  unsqueeze_18 = Nonemul_9 = torch.ops.aten.mul.Tensor(where, unsqueeze_19);  where = unsqueeze_19 = Nonemul_10 = torch.ops.aten.mul.Tensor(sum_2, rsqrt);  sum_2 = rsqrt = Nonesum_3 = torch.ops.aten.sum.dim_IntList(mul_9, [0, 2, 3])convolution_backward = torch.ops.aten.convolution_backward.default(mul_9, primals_8, primals_1, [16], [1, 1], [2, 2], [1, 1], False, [0, 0], 1, [True, True, False]);  mul_9 = primals_8 = primals_1 = Nonegetitem = convolution_backward[0]getitem_1 = convolution_backward[1];  convolution_backward = Nonereturn pytree.tree_unflatten([relu, getitem_1, sum_3, mul_10, sum_1, None, None, None, getitem], self._out_spec)

在推理场景下,可以发现此时的FX Graph发生了变化,会先计算均值方差----》和Conv的weight和bias进行加法----》计算conv,即对应将BN的权重融合到Conv中进行算子融合的操作。

# 推理场景下的Conv+BN的FXGraph
def forward(self, primals, tangents):primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)add = torch.ops.aten.add.Tensor(primals_6, 1e-05);  primals_6 = Nonersqrt = torch.ops.aten.rsqrt.default(add);  add = Noneview = torch.ops.aten.view.default(rsqrt, [-1, 1, 1, 1]);  rsqrt = Noneview_1 = torch.ops.aten.view.default(primals_3, [16, 1, 1, 1]);  primals_3 = Nonemul = torch.ops.aten.mul.Tensor(view_1, view);  view_1 = Nonemul_1 = torch.ops.aten.mul.Tensor(primals_1, mul)view_2 = torch.ops.aten.view.default(mul, [16])sub = torch.ops.aten.sub.Tensor(primals_2, primals_5);  primals_2 = primals_5 = Nonemul_2 = torch.ops.aten.mul.Tensor(view_2, sub)add_1 = torch.ops.aten.add.Tensor(primals_4, mul_2);  primals_4 = mul_2 = Noneconvolution = torch.ops.aten.convolution.default(primals_8, mul_1, add_1, [1, 1], [2, 2], [1, 1], False, [0, 0], 1);  add_1 = Nonerelu = torch.ops.aten.relu.default(convolution);  convolution = Nonealias = torch.ops.aten.alias.default(relu)alias_1 = torch.ops.aten.alias.default(alias);  alias 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/11854.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Docker 部署教程jenkins

Docker 部署 jenkins 教程 Jenkins 官方网站 Jenkins 是一个开源的自动化服务器,主要用于持续集成(CI)和持续交付(CD)过程。它帮助开发人员自动化构建、测试和部署应用程序,显著提高软件开发的效率和质量…

2025/2/3 云服务器数据库与idea相连

幸福就摆在你面前,你却把阴影当成山川瀑布,你说你无法幸福。 轻量应用服务器https://swasnext.console.aliyun.com/servers/cn-heyuanhttps://swasnext.console.aliyun.com/servers/cn-heyuanhttps://swasnext.console.aliyun.com/servers/cn-heyuanhttp…

【memgpt】letta 课程1/2:从头实现一个自我编辑、记忆和多步骤推理的代理

llms-as-operating-systems-agent-memory llms-as-operating-systems-agent-memory内存 操作系统的内存管理

6. 【Vue实战--孢子记账--Web 版开发】-- 主币种设置

从这篇文章开始我们将一起实现孢子记账的功能,这篇文章实现主币种设置。这个功能比较简单,因此我们从这个功能开始做。 一、功能 根据项目前期的需求调研,用户需要在设置主币种的时候查看汇率信息(别问为什么有这么个需求&#…

51单片机(STC89C52)开发:点亮一个小灯

软件安装: 安装开发板CH340驱动。 安装KEILC51开发软件:C51V901.exe。 下载软件:PZ-ISP.exe 创建项目: 新建main.c 将main.c加入至项目中: main.c:点亮一个小灯 #include "reg52.h"sbit LED1P2^0; //P2的…

GESP2023年9月认证C++六级( 第三部分编程题(2)小杨的握手问题)

参考程序1&#xff08;暴力枚举&#xff09; #include <iostream> using namespace std;int main() {int n 0;cin >> n; // 读入同学的数量int num[300000]; // 存储同学的学号for (int i 0; i < n; i) {cin >> num[i]; // 读入同学的进入顺序}long…

【C++篇】哈希表

目录 一&#xff0c;哈希概念 1.1&#xff0c;直接定址法 1.2&#xff0c;哈希冲突 1.3&#xff0c;负载因子 二&#xff0c;哈希函数 2.1&#xff0c;除法散列法 /除留余数法 2.2&#xff0c;乘法散列法 2.3&#xff0c;全域散列法 三&#xff0c;处理哈希冲突 3.1&…

GPT与Deepseek等数据驱动AI的缺点

当前数据驱动的AI&#xff08;包括GPT与Deepseek等各种大小模型&#xff09;只进行了数/物理性的初步探索&#xff0c;尚未触及人机环境生态系统的复杂性。也就是说&#xff0c;当前的数据驱动型 AI&#xff0c;虽然在处理大量数据、解决特定任务方面取得了显著进展&#xff0c…

阿里云盘PC端打不开解决办法

阿里云盘服务中心 搜索&#xff1a;PC端无法启动怎么办 选择问题 PC端双击云盘图标没有反应&#xff08;windows系统&#xff09; 下载null.sys&#xff0c;先执行压缩包里面的 null.reg 注册表&#xff0c;再按官方文档操作&#xff0c;不然会报错&#xff0c;搞完建议重启一…

树莓派pico入坑笔记,故障解决:请求 USB 设备描述符失败,故障码(43)

今天心血来潮&#xff0c;拿出吃灰的pico把玩一下&#xff0c;打开thonny&#xff0c;上电&#xff0c;然后...... 上电识别不到端口&#xff0c;windows报错&#xff0c;请求 USB 设备描述符失败&#xff0c;故障码&#xff08;43&#xff09; 一开始以为是坏了&#xff08;磕…

Linux——文件系统

一、从硬件出发 1&#xff09;磁盘的主要构成 通常硬盘是由盘片、主轴、磁头、摇摆臂、马达、永磁铁等部件组成&#xff0c;其中一个硬盘中有多块盘片和多个磁头&#xff0c;堆叠在一起&#xff0c;工作时由盘片旋转和摇摆臂摇摆及逆行寻址从而运作&#xff0c;磁头可以对盘片…

FPGA 时钟多路复用

时钟多路复用 您可以使用并行和级联 BUFGCTRL 的组合构建时钟多路复用器。布局器基于时钟缓存 site 位置可用性查找最佳布局。 如果可能&#xff0c;布局器将 BUFGCTRL 布局在相邻 site 位置中以利用专用级联路径。如无法实现&#xff0c;则布局器将尝试将 BUFGCTRL 从…

C++底层学习预备:模板初阶

文章目录 1.编程范式2.函数模板2.1 函数模板概念2.2 函数模板原理2.3 函数模板实例化2.3.1 隐式实例化2.3.2 显式实例化 2.4 模板参数的匹配原则 3.类模板希望读者们多多三连支持小编会继续更新你们的鼓励就是我前进的动力&#xff01; 进入STL库学习之前我们要先了解有关模板的…

Baklib如何在知识管理领域成为领军者与六款产品的综合评析

内容概要 在知识管理领域&#xff0c;Baklib凭借其卓越的技术和创新的产品线&#xff0c;已经确立了行业的领导地位。作为一个全面的知识管理平台&#xff0c;Baklib为企业提供了高效、便捷的知识存储和管理方案&#xff0c;帮助组织有效整合内外部知识资源。其主要特点包括强…

Baklib阐明企业内容管理与内容中台的本质差异

内容概要 在快速发展的数字时代&#xff0c;企业对信息的管理愈加重视。内容管理在企业日常运营中扮演了重要角色&#xff0c;而随着技术的不断进步&#xff0c;内容中台的概念逐渐走入视野。了解这两者的不同&#xff0c;不仅有助于企业更有效地管理内容&#xff0c;还能提升…

Java 大视界 -- Java 大数据在智能电网中的应用与发展趋势(71)

&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎来到 青云交的博客&#xff01;能与诸位在此相逢&#xff0c;我倍感荣幸。在这飞速更迭的时代&#xff0c;我们都渴望一方心灵净土&#xff0c;而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识&#xff0c;也…

deepseek 本地化部署和小模型微调

安装ollama 因为本人gpu卡的机器系统是centos 7, 直接使用ollama会报 所以ollama使用镜像方式进行部署&#xff0c; 拉取镜像ollama/ollama 启动命令 docker run -d --privileged -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama 查看ollama 是否启动…

【腾讯前端面试】纯css画图形

之前参加腾讯面试&#xff0c;第一轮是笔试&#xff0c;面试官发的试卷里有一题手写css画一个扇形、一个平行四边形……笔试时间还是比较充裕的&#xff0c;但是我对这题完全没有思路&#x1f62d;于是就空着了&#xff0c;最后也没过。 今天偶然翻到廖雪峰大佬的博客里提到了关…

物联网 STM32【源代码形式-ESP8266透传】连接OneNet IOT从云产品开发到底层MQTT实现,APP控制 【保姆级零基础搭建】

一、MQTT介绍 MQTT&#xff08;Message Queuing Telemetry Transport&#xff0c;消息队列遥测传输协议&#xff09;是一种基于发布/订阅模式的轻量级通讯协议&#xff0c;构建于TCP/IP协议之上。它最初由IBM在1999年发布&#xff0c;主要用于在硬件性能受限和网络状况不佳的情…

探秘Linux IO虚拟化:virtio的奇幻之旅

在当今数字化时代&#xff0c;虚拟化技术早已成为推动计算机领域发展的重要力量。想象一下&#xff0c;一台物理主机上能同时运行多个相互隔离的虚拟机&#xff0c;每个虚拟机都仿佛拥有自己独立的硬件资源&#xff0c;这一切是如何实现的呢&#xff1f;今天&#xff0c;就让我…