使用Pytorch导出自定义ONNX算子

在实际部署模型时有时可能会遇到想用的算子无法导出onnx,但实际部署的框架是支持该算子的。此时可以通过自定义onnx算子的方式导出onnx模型(注:自定义onnx算子导出onnx模型后是无法使用onnxruntime推理的)。下面给出个具体应用中的示例:需要导出pytorch的affine_grid算子,但在pytorch的2.0.1版本中又无法正常导出该算子,故可通过如下自定义算子代码导出。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypesclass CustomAffineGrid(Function):@staticmethoddef forward(ctx, theta: torch.Tensor, size: torch.Tensor):grid = F.affine_grid(theta=theta, size=size.cpu().tolist())return grid@staticmethoddef symbolic(g: torch.Graph, theta: torch.Tensor, size: torch.Tensor):return g.op("AffineGrid", theta, size)class MyModel(nn.Module):def __init__(self) -> None:super().__init__()def forward(self, x: torch.Tensor, theta: torch.Tensor, size: torch.Tensor):grid = CustomAffineGrid.apply(theta, size)x = F.grid_sample(x, grid=grid, mode="bilinear", padding_mode="zeros")return xdef main():with torch.inference_mode():custum_model = MyModel()x = torch.randn(1, 3, 224, 224)theta = torch.randn(1, 2, 3)size = torch.as_tensor([1, 3, 512, 512])torch.onnx.export(model=custum_model,args=(x, theta, size),f="custom.onnx",input_names=["input0_x", "input1_theta", "input2_size"],output_names=["output"],dynamic_axes={"input0_x": {2: "h0", 3: "w0"},"output": {2: "h1", 3: "w1"}},opset_version=16,operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)if __name__ == '__main__':main()

在上面代码中,通过继承torch.autograd.Function父类的方式实现导出自定义算子,继承该父类后需要用户自己实现forward以及symbolic两个静态方法,其中forward方法是在pytorch正常推理时调用的函数,而symbolic方法是在导出onnx时调用的函数。对于forward方法需要按照正常的pytorch语法来实现,其中第一个参数必须是ctx但对于当前导出onnx场景可以不用管它,后面的参数是实际自己传入的参数。对于symbolic方法的第一个必须是g,后面的参数任为实际自己传入的参数,然后通过g.op方法指定具体导出自定义算子的名称,以及输入的参数(注:上面示例中传入的都是Tensor所以可以直接传入,对与非Tensor的参数可见下面一个示例)。最后在使用时直接调用自己实现类的apply方法即可。使用netron打开自己导出的onnx文件,可以看到如下所示网络结构。
在这里插入图片描述

有时按照使用的推理框架导出自定义算子时还需要设置一些参数(非Tensor)那么可以参考如下示例,例如要导出int型的参数k那么可以通过传入k_i来指定,要导出float型的参数scale那么可以通过传入scale_f来指定,要导出string型的参数clockwise那么可以通过传入clockwise_s来指定:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypesclass CustomRot90AndScale(Function):@staticmethoddef forward(ctx, x: torch.Tensor):x = torch.rot90(x, k=1, dims=(3, 2))  # clockwise 90x *= 1.2return x@staticmethoddef symbolic(g: torch.Graph, x: torch.Tensor):return g.op("Rot90AndScale", x, k_i=1, scale_f=1.2, clockwise_s="yes")class MyModel(nn.Module):def __init__(self) -> None:super().__init__()def forward(self, x: torch.Tensor):return CustomRot90AndScale.apply(x)def main():with torch.inference_mode():custum_model = MyModel()x = torch.randn(1, 3, 224, 224)torch.onnx.export(model=custum_model,args=(x,),f="custom_rot90.onnx",input_names=["input"],output_names=["output"],dynamic_axes={"input": {2: "h0", 3: "w0"},"output": {2: "w0", 3: "h0"}},opset_version=16,operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)if __name__ == '__main__':main()

使用netron打开自己导出的onnx文件,可以看到如下所示信息。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/269121.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

重构笔记系统:Docker Compose在微服务架构中的应用与优化

虽然我的笔记系统的开发是基于微服务的思想,但是在服务的配置和编排上感觉还是不太合理,具体来说,在开发上的配置和在生产上的配置差别太大。现在规模小,后面规模变大,估计这一块会成为系统生长的瓶颈。 因此&#xff…

推特API(Twitter API)对接说明,用户code To Token换取

前期准备 提前准备、说明:目前对接推特api开发门户分为3个版本,分别是免费的,100美金一个月的基础版以及5000美金一个月的企业版,免费的目前就两个接口可以调用,所以想要对接和使用推特最基本的也需要付100美元一个月…

ETAS工具链ISOLAR-AB重要概念,RTE配置,ECU抽取

RTE配置界面,包含ECU抽取关联 首次配置RTE,出现需要勾选的抽取EXTRACT 创建System System制作SWC到ECU的Mapping System制作System Data 的Mapping

LabVIEW眼结膜微血管采集管理系统

LabVIEW眼结膜微血管采集管理系统 开发一套基于LabVIEW的全自动眼结膜微血管采集管理系统,以提高眼结膜微血管临床研究的效率。系统集成了自动化图像采集、图像质量优化和规范化数据管理等功能,有效缩短了图像采集时间,提高了图像质量&#…

9.WEB渗透测试-Linux基础知识-Linux用户权限管理(上)

免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 内容参考于: 易锦网校会员专享课 上一个内容:8.WEB渗透测试-Linux基础知识-Linux基础操作(二)-CSDN博客 用户管…

Laravel框架: Call to a member function connect() on null 异常报错处理

Laravel框架: Call to a member function connect() on null 异常报错处理 Date: 2024.03.01 21:03:11 author: lijianzhan 原文链接: https://learnku.com/laravel/t/63721 问题: local.ERROR: Call to a member function connect() on null {"…

【办公类-21-08】三级育婴师 多个二级文件夹的docx合并成PDF

背景需求: 前期制作了单题文件夹 【办公类-21-07】新建文件夹 三级育婴师操作参考题目-CSDN博客文章浏览阅读439次,点赞7次,收藏10次。【办公类-21-07】新建文件夹 三级育婴师操作参考题目https://blog.csdn.net/reasonsummer/article/details/1363360…

Spring Boot项目中不使用@RequestMapping相关注解,如何动态发布自定义URL路径

一、前言 在Spring Boot项目开发过程中,对于接口API发布URL访问路径,一般都是在类上标识RestController或者Controller注解,然后在方法上标识RequestMapping相关注解,比如:PostMapping、GetMapping注解,通…

Gitlab: PHP项目CI/CD实践

目录 1 说明 2 CI/CD 2.1 部署方式一:增量部署 2.1.1 目标服务器准备 2.2.2 Gitlab及Envoy脚本 2.2 部署方式二:镜像构建与部署 2.2.1 推送到私有化容器仓库 准备工作 脚本 要点 2.2.2 推送到hub.docker.com 准备工作 脚本 3 参考&#x…

鸿蒙实战应用开发:【拨打电话】功能

概述 本示例通过输入电话,进行电话拨打,及电话相关信息的显示。 样例展示 涉及OpenHarmony技术特性 网络通信 基础信息 拨打电话 介绍 本示例使用call相关接口实现了拨打电话并显示电话相关信息的功能 效果预览 使用说明 1.输入电话号码后&#…

「滚雪球学Java」:多线程(章节汇总)

咦咦咦,各位小可爱,我是你们的好伙伴——bug菌,今天又来给大家普及Java SE相关知识点了,别躲起来啊,听我讲干货还不快点赞,赞多了我就有动力讲得更嗨啦!所以呀,养成先点赞后阅读的好…

STM32F1 - SPI读写Flash

Serial peripheral interface 1> 实验概述2> SPI硬件框图初始化程序 3> STM32的SPI通信时序3.1> 时序图3.2> 文字描述3.3> 注意事项3.4> 流程图表示3.5> 程序表示接收程序:发送程序: 4> SPI的4种模式5> W25Q128存储结构块…

【学习心得】网站运行时间轴(爬虫逆向)

一、网站运行时间轴 掌握网站运行时间轴,有助于我们对“请求参数加密”和“响应数据加密”这两种反爬手段的深入理解。 二、从网站运行的时间轴角度来理解两种反爬手段 1、加载HTML: 这是浏览器访问网站时的第一步,服务器会返回基础…

docker基线安全修复和容器逃逸修复

一、docker安全基线存在的问题和修复建议 1、将容器的根文件系统挂载为只读 修复建议: 添加“ --read-only”标志,以允许将容器的根文件系统挂载为只读。 可以将其与卷结合使用,以强制容器的过程仅写入要保留的位置。 可以使用命令&#x…

基于Springboot的助农管理系统(有报告)。Javaee项目,springboot项目。

演示视频: 基于Springboot的助农管理系统(有报告)。Javaee项目,springboot项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结构&…

台式电脑电源各线的电压和电流输出和输出电流

台式电脑电源是电脑硬件的重要组成部分。 它为计算机的各个部件提供所需的电压和电流。 不同的硬件设备和组件有不同的电压和电流输出。 下面详细介绍台式电脑电源各线的电压,包括3.3V、5V、12V、-12V、-5V和5VSB,以及它们的输出电流和用途。 3.3V&#…

Java递归生成本地文件目录树形结构

Java递归生成本地文件目录(树行结构) 1.读取txt文件保存的文件目录结构 2.递归生成本地文件目录树形结构,并修改目录文件前缀进行递增 3.结果截图 4.代码 package com.zfi.server.device;import io.swagger.annotations.Api; import org.springframework.web.bind…

Leaflet 加载高德地图

前言 在前面的文章中,我们学习了如何使用 Leaflet 创建一个基本的地图。在本文中,我们将学习如何在 Leaflet 中加载高德地图,并结合实际应用构建地图点击事件。 一、介绍 高德地图是一款由高德软件提供的数字地图服务,在国内使用…

数学建模【灰色关联分析】

一、灰色关联分析简介 一般的抽象系统,如社会系统、经济系统、农业系统、生态系统、教育系统等都包含有许多种因素,多种因素共同作用的结果决定了该系统的发展态势。人们常常希望知道在众多的因素中,哪些是主要因素,哪些是次要因素;哪些因素…

Apache SeaTunnel 2.3.4 版本发布:功能升级,性能提升

​Apache SeaTunnel团队自豪地宣布2.3.4版本正式发布!本次更新聚焦于增强核心功能,改善用户体验,并进一步优化文档质量。 此次版本发布带来了多项重要更新和功能增强,包括核心与API的修复、文档的全面优化、Catalog支持的引入&…