pytorch自定义算子导出onnx

文章目录

      • 1、为什么要自定义算子?
      • 2、如何自定义算子
      • 3、自定义算子导出onnx
      • 4、example
        • 1、重写一个pytorch 自定义算子(实现自定义激活函数)
        • 2、现有算子上封装pytorch 自定义算子(实现动态放大超分辨率模型)

1、为什么要自定义算子?

1、没有现成可用的算子,需要根据自己的接口重写。
2、现有的算子接口不兼容,需要在原有的算子上进行封装。

2、如何自定义算子

继承torch.autograd.Function类,实现其forward()backward()方法,就可以成为一个pytorch自定义算子。就可以在模型训练推理中完成前向推理和反向传播。
forward() 函数的第一个参数必须是ctx, 后面是输入。
在工程部署上,一般为了加快计算,自定义算子需要用cuda 实现forward()、backward()kernel 函数。

3、自定义算子导出onnx

实现其symbolic 静态方法,当我们调用torch.onnx.export()时,就可以导出onnx 算子。
symbolic是符号函数,通常在其内部返回一个g.op()对象。g.op() 把一个 PyTorch 算子映射成一个或多个 ONNX 算子,或者是自定义的 ONNX 算子。
symbolic函数的第一个参数必须是g, 后面是和forward()对应的输入。
g.op() 做算子映射,g.op 的参数:
1、第一个参数为算子名字
2、后面参数与forward() 输入对应
3、往后可以是一些算子自带常量和属性值。常量视为输入,属性值需要用 字段_s/i/f = 默认值表示。_s 表示字符串,_i 表示 int64, _f 表示 float32。常量用类似 g.op(“Constant”, value_t=torch.tensor([3, 2, 1], dtype=torch.float32))表示

4、example

1、重写一个pytorch 自定义算子(实现自定义激活函数)

实现自己的激活函数MYSELU 算子。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx
import torch.autograd#继承torch.autograd.Function
class MYSELUImpl(torch.autograd.Function): @staticmethoddef symbolic(g, x, p):return g.op("MYSELU", x, p,  # 表示onnx算子的名称为MYSELU,参数与forward()对应# 给算子传一个常数参数g.op("Constant", value_t=torch.tensor([3, 2, 1], dtype=torch.float32)),attr1_s="这是字符串属性", # s表示字符串attr2_i=[1, 2, 3], # i表示整数attr3_f=222  # f表示浮点数)@staticmethoddef forward(ctx, x, p): # 前行推理return x * 1 / (1 + torch.exp(-x))class MYSELU(nn.Module): def __init__(self, n):super().__init__()self.param = nn.parameter.Parameter(torch.arange(n).float())def forward(self, x):return MYSELUImpl.apply(x, self.param) #推理调用class Model(nn.Module):def __init__(self):super().__init__()self.conv = nn.Conv2d(1, 1, 3, padding=1)self.myselu = MYSELU(3)self.conv.weight.data.fill_(1)self.conv.bias.data.fill_(0)def forward(self, x):x = self.conv(x)x = self.myselu(x)return x
2、现有算子上封装pytorch 自定义算子(实现动态放大超分辨率模型)

实现动态放大超分辨率模型。我们希望实现:
forward(self, x, upscale_factor)
这样一个接口,x 为图像输入,upscale_factor为动态放大倍数。
pytorch 现有放大算子有nn.Upsample 和 interpolate, 但是nn.Upsample 在初始化阶段固化了放大倍数,而 PyTorch 的 interpolate 插值算子可以在运行阶段选择放大倍数。

class SuperResolutionNet(nn.Module): def forward(self, x, upscale_factor): x = interpolate(x, scale_factor=upscale_factor.item(), mode='bicubic', align_corners=False) 
... 
# Inference 
# Note that the second input is torch.tensor(3) 
torch_output = model(torch.from_numpy(input_img), torch.tensor(3)).detach().numpy() 
... 
with torch.no_grad(): torch.onnx.export(model, (x, torch.tensor(3)), "srcnn2.onnx", opset_version=11, input_names=['input', 'factor'], output_names=['output']) 

尝试使用以上方法导出onnx 时,虽然没有报错能成功导出onnx,但是有TraceWarning 的警告,说明导出onnx有追踪失败。这是由于我们使用了 torch.Tensor.item() 把数据从 Tensor 里取出来,而导出 ONNX 模型时这个操作是无法被记录的,只好报了一条 TraceWarning。

因此我们需要自定义算子,让onnx在追踪时刻能work。我们看到nn.Upsample 和 interpolate在转onnx时都映射到了Resize 操作。所以自定义算子在Resize 操作上进行封装即可。
在这里插入图片描述

Resize 操作有三个输入,x, roi, scale, 我们就是要动态输入scale。展开 scales,可以看到 scales 是一个长度为 4 的一维张量,其内容为 [1, 1, 3, 3],
如果我们能够自己生成一个 ONNX 的 Resize 算子,让 scales 成为一个可变量而不是常量,就像它上面的 X 一样,那这个超分辨率模型就能动态缩放了。

import torch 
from torch import nn 
from torch.nn.functional import interpolate 
import torch.onnx 
import cv2 
import numpy as np 
class NewInterpolate(torch.autograd.Function): @staticmethod def symbolic(g, input, scales): return g.op("Resize", input, g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)), scales, coordinate_transformation_mode_s="pytorch_half_pixel", cubic_coeff_a_f=-0.75, mode_s='cubic', nearest_mode_s="floor") @staticmethod def forward(ctx, input, scales): scales = scales.tolist()[-2:] return interpolate(input, scale_factor=scales, mode='bicubic', align_corners=False) class StrangeSuperResolutionNet(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4) self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0) self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2) self.relu = nn.ReLU() def forward(self, x, upscale_factor): x = NewInterpolate.apply(x, upscale_factor) out = self.relu(self.conv1(x)) out = self.relu(self.conv2(out)) out = self.conv3(out) return out 

以上自定义了Resize 算子,将scale 作为算子的一个输入,最后还是调用interpolate。但是scale已经变成自定义输入参数。
参数映射如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/477916.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

构建高效在线教育:SpringBoot课程管理系统

1系统概述 1.1 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及,互联网成为人们查找信息的重要场所,二十一世纪是信息的时代,所以信息的管理显得特别重要。因此,使用计算机来管理在线课程管理系统的相关信息成为必然。开发…

CSS3新特性——字体图标、2D、3D变换、过渡、动画、多列布局

目录 一、Web字体 二、字体图标 三、2D变换 1.位移 (1)浮动 (2)相对定位 (3)绝对定位和固定定位 (4)位移 用位移实现盒子的水平垂直居中 2.缩放 利用缩放调整字体到12px以下&#xff…

python Flask指定IP和端口

from flask import Flask, request import uuidimport json import osapp Flask(__name__)app.route(/) def hello_world():return Hello, World!if __name__ __main__:app.run(host0.0.0.0, port5000)

linux ubuntu的脚本知

目录 一、变量的引用 二、判断指定的文件是否存在 三、判断目录是否存在 四、判断最近一次命令执行是否成功 五、一些比较符号 六、"文件"的读取和写入 七、echo打印输出 八、ubuntu切换到root用户 N、其它可以参考的网址 脚本功能强大,用起来也…

C++(进阶) 第1章 继承

C(进阶) 第1章 继承 文章目录 前言一、继承1.什么是继承2.继承的使用 二、继承方式1.private成员变量的(3种继承方式)继承2. private继承方式3.继承基类成员访问⽅式的变化 三、基类和派生类间的转换1.切片 四、 继承中的作⽤域1.隐藏规则&am…

Load-Balanced-Online-OJ(负载均衡式在线OJ)

负载均衡式在线OJ 前言1. 项目介绍2. 所用技术与环境所用技术栈开发环境 3. 项目宏观结构3.1 项目核心模块3.2 项目的宏观结构 4. comm公共模块4.1 日志(log.hpp )4.1.1 日志主要内容4.1.2 日志使用方式4.1.2 日志代码 4.2 工具(util.hpp&…

c++->内部类 匿名对象

内部类&#xff1a;&#xff08;例如&#xff1a;b定义在a类中&#xff09; 注意事项&#xff1a; &#xff08;1&#xff09;内部类b可以直接使用外部类的static变量&#xff0c;但是并不属于外部类的友元&#xff01;&#xff01;&#xff01;&#xff01; #include <s…

C++ std::unique_ptr的使用及源码分析

目录 1.简介 2.使用方法 2.1.创建 unique_ptr 2.2.删除对象 2.3.转移所有权 2.4.自定义删除器 2.5.从函数返回 std::unique_ptr 2.6.将 std::unique_ptr 作为函数参数 3.适用场景 4.与原始指针的区别 5.优缺点 6.源码分析 6.1.构造函数 6.2.存储分析 6.3.默认删…

系统思考—关键决策

最近听到一句话特别扎心&#xff1a;“不是环境毁了企业&#xff0c;而是企业误判了环境。” 在大环境变化面前&#xff0c;很多企业的反应是快速调整&#xff0c;但这真的有效吗&#xff1f;其实&#xff0c;太快的动作&#xff0c;往往是误判的开始。 环境变化带来压力&…

【Java 解释器模式】实现高扩展性的医学专家诊断规则引擎

&#x1f9d1; 博主简介&#xff1a;CSDN博客专家&#xff0c;历代文学网&#xff08;PC端可以访问&#xff1a;https://literature.sinhy.com/#/literature?__c1000&#xff0c;移动端可微信小程序搜索“历代文学”&#xff09;总架构师&#xff0c;15年工作经验&#xff0c;…

ES八股相关知识

为什么要使用ElasticSearch&#xff1f;和传统关系数据库&#xff08;如 MySQL&#xff09;有什么不同&#xff1f; 典型回答 数据模型 Elasticsearch 是基于文档的搜索引擎&#xff0c;它使用 JSON 文档来存储数据。在 Elasticsearch 中&#xff0c;相关的数据通常存储在同…

局域网与广域网:探索网络的规模与奥秘(3/10)

一、局域网的特点 局域网覆盖有限的地理范围&#xff0c;通常在几公里以内&#xff0c;具有实现资源共享、服务共享、维护简单、组网开销低等特点&#xff0c;主要传输介质为双绞线&#xff0c;并使用少量的光纤。 局域网一般是方圆几千米以内的区域网络&#xff0c;其特点丰富…

EMD-KPCA-Transformer多变量回归预测!分解+降维+预测!多重创新!直接写核心!

EMD-KPCA-Transformer多变量回归预测&#xff01;分解降维预测&#xff01;多重创新&#xff01;直接写核心&#xff01; 目录 EMD-KPCA-Transformer多变量回归预测&#xff01;分解降维预测&#xff01;多重创新&#xff01;直接写核心&#xff01;效果一览基本介绍程序设计参…

编程之路,从0开始:文件操作(2)

Hello大家好&#xff01;很高兴我们又见面啦&#xff01;给生活添点passion&#xff0c;开始今天的编程之路&#xff01; 今天我们来继续学习C语言的文件操作。 我的博客&#xff1a;<但凡. 我的专栏&#xff1a;编程之路 持续更新高质量内容&#xff0c;欢迎点赞、关注&…

mybatis学习(三)

声明&#xff1a;该内容来源于动力节点&#xff0c;本人在学习mybatis过程中参考该内容&#xff0c;并自己做了部分笔记&#xff0c;但个人觉得本人做的笔记不如动力节点做的好&#xff0c;故使用动力节点的笔记作为后续mybatis的复习。 六、在WEB中应用MyBatis&#xff08;使…

ES6 模块化语法

目录 ES6 模块化语法 分别暴露 统一暴露 ​编辑 默认暴露 ES6 模块化引入方式 ES6 模块化语法 模块功能主要由两个命令构成&#xff1a;export 和 import。 ⚫ export 命令用于规定模块的对外接口&#xff08;哪些数据需要暴露&#xff0c;就在数据前面加上关键字即可…

【Spring boot】微服务项目的搭建整合swagger的fastdfs和demo的编写

文章目录 1. 微服务项目搭建2. 整合 Swagger 信息3. 部署 fastdfsFastDFS安装环境安装开始图片测试FastDFS和nginx整合在Storage上安装nginxnginx安装不成功排查:4. springboot 整合 fastdfs 的demodemo编写1. 微服务项目搭建 版本总结: spring boot: 2.6.13springfox-boot…

无线电磁波在自由空间的衰减

自由空间损耗&#xff0c;指的是电磁波在空气中传播时候的能量损耗&#xff0c;电磁波在穿透任何介质的时候都会有损耗。在传输路径上的损耗&#xff0c;即为路径损耗。 自由空间路径损耗&#xff08;Free Space Path Loss&#xff09;的基本公式&#xff1a; 简化的自由空间损…

UE5实现可销毁对象的淡化销毁

进入对象材质 设置 的不透明蒙版 不透明蒙版见 UE材质不透明蒙版选项-CSDN博客 默认混合模式(不透明)下无法进行设置&#xff0c;将混合模式修改为 混合模式见 UE5材质混合模式-CSDN博客 新添加Texture sample节点 关于Texture sample&#xff1a;UE5材质Texture Sample …

【Linux学习】【Ubuntu入门】1-7 ubuntu下磁盘管理

1.准备一个U盘或者SD卡&#xff08;插上读卡器&#xff09;&#xff0c;将U盘插入主机电脑&#xff0c;右键点击属性&#xff0c;查看U盘的文件系统确保是FAT32格式 2.右键单击ubuntu右下角图标&#xff0c;将U盘与虚拟机连接 参考链接 3. Ubuntu磁盘文件&#xff1a;/dev/s…