记录pytorch实现自定义算子并转onnx文件输出

概览:记录了如何自定义一个算子,实现pytorch注册,通过C++编译为库文件供python端调用,并转为onnx文件输出

整体大概流程:

  • 定义算子实现为torch的C++版本文件
  • 注册算子
  • 编译算子生成库文件
  • 调用自定义算子

一、编译环境准备

1,在pytorch官网下载如下C++的libTorch package,下载完成后解压文件,是一个libtorch文件夹。

2,提前准备好python,以及pytorch

3,本示例使用了opencv库,所以需要提前安装好opencv。

二、自定义算子的实现

1,实现自定义算子函数

在解压后的libtorch文件夹统计目录,实现自定义算子,用opencv库实现的图像投射函数:warp_perspective。warp_perspective函数后面几行就是实现自定义算子的注册

warpPerspective.cpp文件:

#include "torch/script.h"
#include "opencv2/opencv.hpp"torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {// BEGIN image_matcv::Mat image_mat(/*rows=*/image.size(0),/*cols=*/image.size(1),/*type=*/CV_32FC1,/*data=*/image.data_ptr<float>());// END image_mat// BEGIN warp_matcv::Mat warp_mat(/*rows=*/warp.size(0),/*cols=*/warp.size(1),/*type=*/CV_32FC1,/*data=*/warp.data_ptr<float>());// END warp_mat// BEGIN output_matcv::Mat output_mat;cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{ image.size(0),image.size(1) });// END output_mat// BEGIN output_tensortorch::Tensor output = torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{ image.size(0),image.size(1) });return output.clone();// END output_tensor
}
//static auto registry = torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);  // torch.__version__: 1.5.0torch.__version__ >= 1.6.0  torch/include/torch/library.h
TORCH_LIBRARY(my_ops, m) {m.def("warp_perspective", warp_perspective);
}

2,同级目录创建CMakeList.txt文件

里面需要修改你自己的python下torch的路径,以及你对应安装python版pytorch是cpu还是gpu的。

cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(warp_perspective)set(CMAKE_VERBOSE_MAKEFILE ON)
# >>> build type 
set(CMAKE_BUILD_TYPE "Release")				# 指定生成的版本
set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g2 -ggdb")
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall")set(TORCH_ROOT "/home/xxx/anaconda3/lib/python3.10/site-packages/torch")   
include_directories(${TORCH_ROOT}/include)
link_directories(${TORCH_ROOT}/lib/)# Opencv
find_package(OpenCV REQUIRED)# Define our library target
add_library(warp_perspective SHARED warpPerspective.cpp)# Enable C++14
target_compile_features(warp_perspective PRIVATE cxx_std_17)# libtorch库文件
target_link_libraries(warp_perspective # CPUc10 torch_cpu# GPU# c10_cuda # torch_cuda)# opencv库文件
target_link_libraries(warp_perspective${OpenCV_LIBS}
)add_definitions(-D _GLIBCXX_USE_CXX11_ABI=0)

3,编译生成库文件

同级目录创建build文件夹,进入build文件夹利用CMakeList.txt进行编译,生成libwarp_perspective.so库文件

mkdir build
cd build
cmake ..
make

4,python版pytorch进行自定义算子的测试

注意我的以上代码都是放在了/data/xxx/mylib路径下,所以torch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")就找到库文件的位置。

这里我随便找了一张图片,和直接用python版的opencv做投射变换的结果作为golden对比。如下分别是原图,golden, 自定义pytorch算子的输出。自定义算子的输出不太对,但是图像轮廓和投射效果是对的,后面有时间我再检查一下是什么原因。

测试代码: 

import torch
import cv2
import numpy as nptorch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")im=cv2.imread("/data/xxx/mylib/cat.jpg",0)pst1 = np.float32([[56,65], [368,52], [28,387], [389,390]])
pst2 = np.float32([[100,145], [300,100], [80,290], [310,300]])
#2.2获取透视变换矩阵
T = cv2.getPerspectiveTransform(pst1, pst2)in_data =torch.from_numpy(np.float32(im))
in2_data = torch.Tensor(T)out1=torch.ops.my_ops.warp_perspective(in_data,in2_data)
dst0=np.uint8(out1.numpy())
cv2.imwrite("/data/xxx/mylib/cat_warp.jpg",dst0)dst = cv2.warpPerspective(im, np.float32(T), (im.shape[1], im.shape[0]))
cv2.imwrite("/data/xxx/mylib/cat_warp_gold.jpg",dst)

三、自定义算子导出为onnx文件

将注册的pytorch的自定义算子导出为onnx文件查看,效果图如下:

导出代码文件如下

import torch
import numpy as nptorch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")
class MyNet(torch.nn.Module):def __init__(self, name):super(MyNet, self).__init__()self.model_name = namedef forward(self, in_data, warp_data):return torch.ops.my_ops.warp_perspective(in_data, warp_data)def my_custom(g, in_data, warp_data):return g.op("cus_ops::warp_perspective", in_data, warp_data)
torch.onnx.register_custom_op_symbolic("my_ops::warp_perspective", my_custom, 9)if __name__ == "__main__":net = MyNet("my_ops")in_data = torch.randn((32, 32))warp_data = torch.rand((3, 3))out = net(in_data, warp_data)print("out: ", out)# export onnxtorch.onnx.export(net,(in_data, warp_data),"./my_ops_export_model2.onnx",input_names=["img_data", "warp_mat"],output_names=["out_img"],custom_opsets={"cus_ops": 11},)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/191161.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

dameng数据库数据id decimal类型,精度丢失

问题处理 这一次也是精度丢失&#xff0c;但是问题呢还是不一样&#xff0c;这一次所有的id都被加一了&#xff0c;只有id字段被加一&#xff0c;还有的查询查出来封装成对象之后对象的id字段被减一了&#xff0c;数据库id字段使用的decimal&#xff08;20,6&#xff09;&…

sass 封装媒体查询工具

背景 以往写媒体查询可能是这样的&#xff1a; .header {display: flex;width: 100%; }media (width > 320px) and (width < 480px) {.header {height: 50px;} }media (width > 480px) and (width < 768px) {.header {height: 60px;} }media (width > 768px) …

Linux下MSSQL (SQL Server)数据库无法启动故障处理

有同事反馈一套CentOS7下的mssql server2017无法启动需要我帮忙看看&#xff0c;启动报错情况如下 检查日志并没有更新日志信息 乍一看mssql-server服务有问题&#xff0c;检查mssql也确实没有进程 既然服务有问题&#xff0c;那么我们用一种方式直接手工后台启动mssql引擎来…

集成Line、Facebook、Twitter、Google、微信、QQ、微博、支付宝的三方登录sdk

下载地址&#xff1a; https://githubfast.com/anerg2046/sns_auth 安装方式建议使用composer进行安装 如果linux执行composer不方便的话&#xff0c;可以在本地新建个文件夹&#xff0c;然后执行上面的composer命令&#xff0c;把代码sdk和composer文件一起上传到项目适当位…

使用validator实现枚举类型校验

使用validator实现枚举类型校验 前言&#xff1a; 在前端调用后端接口传递参数的过程中&#xff0c;我们往往需要对前端传递过来的参数进行校验&#xff0c;比如说我们此时需要对用户的状态进行更新&#xff0c;而用户的状态只有正常和已删除&#xff0c;并且是在代码中通过枚…

点大商城V2版 2.5.3全插件开源独立版 百度+支付宝+QQ+头条+小程序端+unipp开源端安装测试教程

点大商城V2是一款采用全新界面设计支持多端覆盖的小程序应用&#xff0c;支持H5、微信公众号、微信小程序、头条小程序、支付宝小程序、百度小程序&#xff0c;本程序是点大商城V2独立版&#xff0c;包含全部插件&#xff0c;代码全开源&#xff0c;并且有VUE全端代码。 适用范…

使用Python和requests库的简单爬虫程序

这是一个使用Python和requests库的简单爬虫程序。我们将使用代理来爬取网页内容。以下是代码和解释&#xff1a; import requests from fake_useragent import UserAgent # 每行代理信息 proxy_host "jshk.com.cn" # 创建一个代理器 proxy {http: http:// proxy_…

Clickhouse学习笔记(10)—— 查询优化

单表查询 Prewhere 替代 where prewhere与where相比&#xff0c;在过滤数据的时候会首先读取指定的列数据&#xff0c;来判断数据过滤&#xff0c;等待数据过滤之后再读取 select 声明的列字段来补全其余属性 简单来说就是先过滤再查询&#xff0c;而where过滤是先查询出对应…

Android Studio真机运行时提示“安装失败”

用中兴手机真机运行没问题&#xff0c;用Vivo运行就提示安装失败。前提&#xff0c;手机已经打开了调试模式。 报错 Android Studio报错提示&#xff1a; Error running app The application could not be installed: INSTALL_FAILED_TEST_ONLY 手机报错提示&#xff1a; 修…

专访|OpenTiny 社区 Mr 栋:结合兴趣,明确定位,在开源中给自己一些技术性挑战

前言 OpenTiny 开源之夏项目终于迎来了圆满的结局。借此机会&#xff0c;我们采访了 TinyReact 的共建者 Mr 栋同学。 Mr 栋同学是一位热衷于前端技术的开发者&#xff0c;对前端开发充满了激情和热爱。同时他也是一位即将毕业的大四在校生。在 OpenTiny 开源项目中&#xff0…

Window安装MongoDB

三种NOSQL的一种,Redis MongoDB ES 应用场景: 1.社交场景:使用Mongodb存储用户信息,以及用户发表的朋友圈信息,通过地理位置索引实现附近的人,地点等功能 2.游戏场景:使用Mongodb存储游戏用户信息,用户的装备,积分等直接以内嵌文档的形式存储,方便查询,高效率存储和访问…

软路由R4S+iStoreOS实现公网远程桌面局域网内电脑

软路由R4SiStoreOS实现公网远程桌面局域网内电脑 文章目录 软路由R4SiStoreOS实现公网远程桌面局域网内电脑简介 一、配置远程桌面公网地址配置隧道 二、家中使用永久固定地址 访问公司电脑具体操作方法是&#xff1a;2.1 登录页面2.2 再次配置隧道2.3 查看访问效果 简介 上篇…

EDA实验-----3-8译码器设计(QuartusII)

目录 一. 实验目的 二. 实验仪器 三. 实验原理及内容 1.实验原理 2.实验内容 四&#xff0e;实验步骤 五. 实验报告 六. 注意事项 七. 实验过程 1.创建Verilog文件&#xff0c;写代码 ​编辑 2.波形仿真 3.连接电路图 4.烧录操作 一. 实验目的 学会Verilog HDL的…

JVM如何运行,揭秘Java虚拟机运行时数据区

目录 一、概述 二、程序计数器 三、虚拟机栈 四、本地方法栈 五、本地方法接口 六、堆 &#xff08;一&#xff09;概述 &#xff08;二&#xff09;堆空间细分 七、方法区 一、概述 不同的JVM对于内存的划分方式和管理机制存在部分差异&#xff0c;后续针对HotSpot虚…

前端案例-css实现ul中对li进行换行

场景描述&#xff1a; 我想要实现&#xff0c;在展示的item个数少于4个的时候&#xff0c;则排成一行&#xff0c;并且均分&#xff08;比如说有3个&#xff0c;则每个的宽度为33.3%&#xff09;&#xff0c;如果item 个数大于4&#xff0c;则进行换行。 效果如下&#xff1a…

网络运维Day14

监控概述 监控的目的 报告系统运行状况每一部分必须同时监控内容包括吞吐量、反应时间、使用率等提前发现问题进行服务器性能调整前&#xff0c;知道调整什么找出系统的瓶颈在什么地方 监控的资源类别 公开数据 Web、FTP、SSH、数据库等应用服务TCP或UDP端口 私有数据 CPU、内…

【Java 进阶篇】JQuery DOM操作:舞动网页的属性魔法

在前端的舞台上&#xff0c;属性操作是我们与HTML元素进行互动的关键步骤之一。而JQuery&#xff0c;这位前端开发的巫师&#xff0c;通过简洁而强大的语法&#xff0c;为我们提供了便捷的属性操作工具。在这篇博客中&#xff0c;我们将深入研究JQuery DOM操作中的属性操作&…

Android Rxjava架构原理与使用的详解解答

简单介绍 Rxjava这个名字&#xff0c;其中java代表java语言&#xff0c;而Rx是什么意思呢&#xff1f;Rx是Reactive Extensions的简写&#xff0c;翻译过来就是&#xff0c;响应式拓展。所以Rxjava的名字的含义就是&#xff0c;对java语言的拓展&#xff0c;让其可以实现对数据…

【论文精读】Pose-Free Neural Radiance Fields via Implicit Pose Regularization

今天读的是一篇发表在ICCV 2023上的文章&#xff0c;作者来自NTU。 文章地址&#xff1a;点击前往 文章目录 Abstract1 Intro2 Related Work3 Preliminary4 Proposed Method4.1 Overall Framework4.2 Scene Codebook Construction4.3 Pose-Guided View Reconstruction4.4 Train…

HTML设置标签栏的图标

添加此图标最简单的方法无需修改内容&#xff0c;只需按以下步骤操作即可&#xff1a; 1.准备一个 ico 格式的图标 2.将该图标命名为 favicon.ico 3.将图标文件置于index.html同级目录即可 为什么我的没有变化&#xff1f; 答曰&#xff1a;ShiftF5强制刷新一下网页就行了