一、定义
- torch.jit.trace 相关代码解读
- onnx 内部实现
3 查看是否为aten 算子 - aten 算子实现
- torch.autograd.Functions 算子实现
- 自定义算子实现
- 查找未实现的节点
- 一次性发现所有的未实现 aten 算子
二、实现
- torch.jit.trace 相关代码解读
1. torch.jit.script() : 将其转换为可运行的脚本。转换后的脚本可以像普通的 Python 函数一样调用,也可以保存到磁盘并在没有 PyTorch 依赖的环境中执行。
2. torch.jit.trace : 跟踪了给定输入张量的执行路径,因此在使用转换后的模块对象进行推理时,输入张量的维度和数据类型必须与跟踪时使用的相同。
3 查看是否为aten 算子
import torchprint(torch.jit.trace(torch.nn.ELU(), # moduletorch.ones(1) # example input).graph
)
算子追踪,
3. aten 算子实现
1.查看torch 接口定义 torch/nn/functional.pyi
2.查看onnx 算子命名 https://github.com/onnx/onnx/blob/main/docs/Operators.md
3. 查看注册函数书写 symbolic_opset9.py
import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15
# Define a custom symbolic function for aten::relu.
# The custom symbolic function is incorrect, which will result in mismatches.#def relu(input: Tensor) -> Tensor: ... 查看接口定义,
def correct_relu_symbolic_function(g, input):return g.op("Relu", input) #查看onnx 实现torch.onnx.register_custom_op_symbolic( #注册"aten::relu",correct_relu_symbolic_function,opset_version=opset_version,
)class Model(torch.nn.Module):def __init__(self):super().__init__()self.layers = torch.nn.Sequential(torch.nn.Linear(3, 4),torch.nn.ReLU(),torch.nn.Linear(4, 5),torch.nn.ReLU(),torch.nn.Linear(5, 6),)def forward(self, x):return self.layers(x)graph_info = torch.onnx.verification.find_mismatch(Model(),(torch.randn(2, 3),),opset_version=opset_version,
)
- torch.autograd.Functions 算子实现
如果算子是torch.autograd.Functions 的子模块,可以使用该方法实现。
import torchclass MyRelu(torch.autograd.Function):@staticmethoddef forward(ctx, input: torch.Tensor) -> torch.Tensor:ctx.save_for_backward(input)return input.clamp(min=0)@staticmethoddef symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15myrelu = MyRelu.apply #核心
class Model(torch.nn.Module):def __init__(self):super().__init__()self.layers = torch.nn.Sequential(torch.nn.Linear(3, 4),torch.nn.Linear(4, 5),torch.nn.Linear(5, 6),)def forward(self, x):return myrelu(self.layers(x))graph_info = torch.onnx.verification.find_mismatch(Model(),(torch.randn(2, 3),),opset_version=opset_version,
)
-
自定义算子实现
1. onnx 算子实现- 自定义c++ 算子 +Extending TorchScript with Custom C++ Operators 实现
-
查找未实现的节点
import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15
# Define a custom symbolic function for aten::relu.
# The custom symbolic function is incorrect, which will result in mismatches. 注册函数错误,导致find_mismatch 算子
def incorrect_relu_symbolic_function(g, self):return self
torch.onnx.register_custom_op_symbolic("aten::relu",incorrect_relu_symbolic_function,opset_version=opset_version,
)
class Model(torch.nn.Module):def __init__(self):super().__init__()self.layers = torch.nn.Sequential(torch.nn.Linear(3, 4),torch.nn.ReLU(),torch.nn.Linear(4, 5),torch.nn.ReLU(),torch.nn.Linear(5, 6),)def forward(self, x):return self.layers(x)
graph_info = torch.onnx.verification.find_mismatch(Model(),(torch.randn(2, 3),),opset_version=opset_version,
)#===================== Mismatch info for graph partition : ======================
================================ Mismatch error ================================
Tensor-likes are not close!
Mismatched elements: 12 / 12 (100.0%)
Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed)
Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed)
==================================== Tree: =====================================
5 X __2 X __1 \u2713
id: | id: 0 | id: 00| || |__1 X (aten::relu)| id: 01||__3 X __1 \u2713id: 1 | id: 10||__2 X __1 X (aten::relu)id: 11 | id: 110||__1 \u2713id: 111
=========================== Mismatch leaf subgraphs: ===========================
['01', '110']
============================= Mismatch node kinds: =============================
{'aten::relu': 2}
修改后:
aten 算子实现
import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15
# Define a custom symbolic function for aten::relu.
# The custom symbolic function is incorrect, which will result in mismatches.#def relu(input: Tensor) -> Tensor: ... 查看接口定义,
def correct_relu_symbolic_function(g, input):return g.op("Relu", input) #查看onnx 实现torch.onnx.register_custom_op_symbolic( #注册"aten::relu",correct_relu_symbolic_function,opset_version=opset_version,
)class Model(torch.nn.Module):def __init__(self):super().__init__()self.layers = torch.nn.Sequential(torch.nn.Linear(3, 4),torch.nn.ReLU(),torch.nn.Linear(4, 5),torch.nn.ReLU(),torch.nn.Linear(5, 6),)def forward(self, x):return self.layers(x)graph_info = torch.onnx.verification.find_mismatch(Model(),(torch.randn(2, 3),),opset_version=opset_version,
)
方式二、
c++ 自定义算子
import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15from torch.onnx import register_custom_op_symbolic # 为 TorchScript 算子补充注册符号函数
from torch.onnx.symbolic_helper import parse_args
# '''
# 装饰器 @parse_args 了。简单来说,TorchScript 算子的符号函数要求标注出每一个输入参数的类型。比如"v"表示 Torch 库里的 value 类型,
# 一般用于标注张量,而"i"表示 int 类型,"f"表示 float 类型,"none"表示该参数为空。具体的类型含义可以在 torch.onnx.symbolic_helper.py
# '''
@parse_args("v", "v")
def correct_relu_symbolic_function(g,input):return g.op("Relu", input)torch.onnx.register_custom_op_symbolic( #注册"aten::relu",correct_relu_symbolic_function,opset_version=opset_version,
)class Model(torch.nn.Module):def __init__(self):super().__init__()self.layers = torch.nn.Sequential(torch.nn.Linear(3, 4),torch.nn.ReLU(),torch.nn.Linear(4, 5),torch.nn.ReLU(),torch.nn.Linear(5, 6),)def forward(self, x):return self.layers(x)graph_info = torch.onnx.verification.find_mismatch(Model(),(torch.randn(2, 3),),opset_version=opset_version,
)
- 一次性发现所有的未实现 aten 算子
import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15class Model(torch.nn.Module):def __init__(self):super().__init__()self.layers = torch.nn.Sequential(torch.nn.Linear(3, 4),torch.nn.ReLU(),torch.nn.Linear(4, 5),torch.nn.ReLU(),torch.nn.Linear(5, 6),)def forward(self, x):return self.layers(x)torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(Model(), (torch.randn(2, 3),), opset_version=opset_version
)print(set(unconvertible_ops))