[AI部署-tensorRT] customlayer定义添加过程解析

文章目录

  • tensorRT添加自定义层步骤
  • 1. trt如何解析onnx的? 整体流程图
    • 2. builtin_op_importor是干什么的?
    • 3. 怎么添加trt plugin
    • 4. 如何进行量化collection过程
  • references

  • nvidia 官方plugin文档: https://www.nvidia.cn/content/dam/en-zz/zh_cn/assets/webinars/2020/feb21/TensorRT_7-0_plugin.pdf
  • 比较旧了

tensorRT添加自定义层步骤

  • 下载源码 onnx-tensorrt
  • 参考tensorRT官网源码中plugin中的instanceNormalizationPlugin,写好自己customlayer.h和customlayer.cpp的实现。(都是官方写好的自定义op的示例,是op的逻辑)
  • 在builtin_op_importers.cpp中使用DEFINE_BUILTIN_OP_IMPORTER添加对自己注册Op的使用(解析器)。
  • 在CMakeLists.txt中,set(IMPORTER_SOURCES… 下面将自己的customlayer.cpp加进去。
  • 按照教程,重新编译自己的onnx-tensorRT, 生成onnx2trt工具,然后用这个工具可以将onnx转化为trt文件
onnx2trt ./onnx/customer_op.onnx -v

tensorRT-customlayer

1. trt如何解析onnx的? 整体流程图

整个tensorRT解析onnx开始的解析入口: parser->parseFromFile …
onnxparser_process

  • 上图描述了如何构建解析自定义onnx node算子的解析器,并且将其保存到std::unordered_map<string, T>数据结构的builtin_op_importers中,该结构在ModelImporter解析器中被使用,而ModelImporter类是被createParser通过createNvOnnxParser_INTERNAL调用的; 这样整个链路就解释通了。
  • 这个自定义onnx node算子的解析器是用来解析onnx中自定义算子中的保存的不变值,如attributes,weights等,解析出来,应该是通过PluginFieldCollection,传到tensorRT的plugin的执行体中,在执行的时候被使用, 具体这个挂钩过程如何实现的呢? 看如下介绍builtin_op_importor是干什么的?

2. builtin_op_importor是干什么的?

  • 这个cpp主要完成将自定义onnx node的解析器注册到builtin_op_importers中
  • 并且构建解析器,解析自定义node中的attributes和weights等内容,并将属性值通过PluginFieldCollection, 传到tensorRT的PluginCreator的createPlugin中,并且通过类的实例化,将必要的属性参数传递给plugin类.

注册过程是怎样的?


// 通过DEFINE_BUILTIN_OP_IMPORTER进行注册
// 这个op的名字对应这onnx中node的type, 而函数体中的pluginName对应的是你在tensorRT中plugin构建的时候,getPluginName中返回的名字#define DEFINE_BUILTIN_OP_IMPORTER(op)                                                                                 \NodeImportResult import##op(                                                                                       \IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, std::vector<TensorOrWeights>& inputs);         \static const bool op##_registered_builtin_op = registerBuiltinOpImporter(#op, import##op);                         \IGNORE_UNUSED_GLOBAL(op##_registered_builtin_op);                                                                  \NodeImportResult import##op(                                                                                       \IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, std::vector<TensorOrWeights>& inputs)bool registerBuiltinOpImporter(std::string op, NodeImporter const& importer)
{bool inserted = getBuiltinOpImporterMap().insert({op, importer}).second;assert(inserted);return inserted;
}string_map<NodeImporter>& getBuiltinOpImporterMap()
{static string_map<NodeImporter> builtin_op_importers;return builtin_op_importers;
}

**如何解析onnx node? **

DEFINE_BUILTIN_OP_IMPORTER(op_name) {# ... ... 省略一部分内容
OnnxAttrs attrs(node, ctx);  // 获取属性
float epsilon = attrs.get("epsilon", 1e-5f);// Populate instanceNormalization plugin properties.
const std::string pluginName = "InstanceNormalization_TRT";
const std::string pluginVersion = "1";
std::vector<nvinfer1::PluginField> f;
f.emplace_back("epsilon", &epsilon, nvinfer1::PluginFieldType::kFLOAT32, 1);
f.emplace_back("scales", scale_weights.values, nvinfer1::PluginFieldType::kFLOAT32, scale_weights.count());
f.emplace_back("bias", bias_weights.values, nvinfer1::PluginFieldType::kFLOAT32, bias_weights.count());// Create plugin from registry
nvinfer1::IPluginV2* plugin = createPlugin(node.name(), importPluginCreator(pluginName, pluginVersion), f);ASSERT(plugin != nullptr && "InstanceNormalization plugin was not found in the plugin registry!",
ErrorCode::kUNSUPPORTED_NODE);auto* layer = ctx->network()->addPluginV2(&tensorPtr, 1, *plugin);  // 自定义节点和层
ctx->registerLayer(layer, node.name());//
// // Map Quantization node to a scale node
auto layer = ctx->network()->addScale(input, mode, shift, scale, power);... ... 
}

在onnx_tensorRT库中的builtin_op_importers.cpp中有很多解析器的例子,可以仿照的写自己的

带参数的onnx的weight是如何传到Plugin中被执行的?

1. 属性如何传过来的? 使用clip layer的例子


// 通过这个语句,将f传递给下面的IPluginV2* ClipPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) 的fc
// Create plugin from registry (DEFINE_BUILTIN_OP_IMPORTER(op_name)中的过程)
nvinfer1::IPluginV2* plugin = createPlugin(node.name(), importPluginCreator(pluginName, pluginVersion), f);// fc中保存的onnx的属性,传递给ClipPlugin
IPluginV2* ClipPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept
{float clipMin, clipMax;const PluginField* fields = fc->fields;// Parse fields from PluginFieldCollectionassert(fc->nbFields == 2);for (int i = 0; i < fc->nbFields; i++){if (strcmp(fields[i].name, "clipMin") == 0){assert(fields[i].type == PluginFieldType::kFLOAT32);clipMin = *(static_cast<const float*>(fields[i].data));}else if (strcmp(fields[i].name, "clipMax") == 0){assert(fields[i].type == PluginFieldType::kFLOAT32);clipMax = *(static_cast<const float*>(fields[i].data));}}return new ClipPlugin(name, clipMin, clipMax);
}//ClipPlugin又传递给自己的成员变量, 在执行enqueue的时候,成员变量就能被用了
ClipPlugin::ClipPlugin(const std::string name, float clipMin, float clipMax): mLayerName(name), mClipMin(clipMin), mClipMax(clipMax)
{
}

2. weights 如何传过来的?

// 也是将weights 通过传给PluginField 然后传递给Plugin enqueue进行使用, 和属性一致

3. 怎么添加trt plugin

  • 继承IPluginV2的一些子类,然后实现一些成员函数,主要执行体是enqueue函数;成员函数的解释看文章[AI部署-TensorRT] IPluginV2的解析
  • 构建继承IPluginCreator类的子类,并用REGISTER_TENSORRT_PLUGIN将自定义层注册到tensorRT中
class OnnxPoolPluginV2Creator : public IPluginCreator
{
public:const char* getPluginName() const noexcept override{return "MaxPool";}const char* getPluginVersion() const noexcept override{return "2";}const PluginFieldCollection* getFieldNames() noexcept override{return &mFieldCollection;}IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override{auto* plugin = new OnnxPoolPluginV2(*fc);mFieldCollection = *fc;mPluginName = name;return plugin;}IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override{auto* plugin = new OnnxPoolPluginV2(serialData, serialLength);mPluginName = name;return plugin;}void setPluginNamespace(const char* libNamespace) noexcept override{mNamespace = libNamespace;}const char* getPluginNamespace() const noexcept override{return mNames集成pace.c_str();}private:std::string mNamespace;std::string mPluginName;PluginFieldCollection mFieldCollection{0, nullptr};
};REGISTER_TENSORRT_PLUGIN(OnnxPoolPluginV2Creator);

4. 如何进行量化collection过程

  • plugin中enqueue的输入InputDesc中存在scale变量,这个应该是用于plugin在做PTQ的时候在collection 量化scale时需要使用和更新的,让后在PTQ过程输出cache的时候,根据这个scale导出到文档。
  • TODO

references


  • TensorRT5.1.5.0 实践 onnx-TensorRT的自定义op
  • tensorRT samples
  • tensorRT部署教程-bilibili
  • tensorRT部署教程-sourcecode

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/2270.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[Do374]Ansible一键搭建sftp实现用户批量增删

[Do374]Ansible一键搭建sftp实现用户批量增删 1. 前言2. 思路3. sftp搭建及用户批量新增3.1 配置文件内容3.2 执行测试3.3 登录测试3.4 确认sftp服务器配置文件 4. 测试删除用户 1. 前言 最近准备搞一下RHCA LV V,外加2.9之后的ansible有较大变化于是练习下Do374的课程内容. 工…

易语言文字识别OCR

一.引言 文字识别&#xff0c;也称为光学字符识别&#xff08;Optical Character Recognition, OCR&#xff09;&#xff0c;是一种将不同形式的文档&#xff08;如扫描的纸质文档、PDF文件或数字相机拍摄的图片&#xff09;中的文字转换成可编辑和可搜索的数据的技术。随着技…

Docker 镜像制作原理 做一个自己的docker镜像

一.手动制作镜像 启动容器进入容器定制基于容器生成镜像 1.启动容器 启动容器之前我们首先要有一个镜像&#xff0c;这个镜像可以是从docker拉取&#xff0c;例如&#xff1a;现在pull一个ubuntu镜像到本机。 docker pull ubuntu:22.04 我们接下来可以基于这个容器进行容器…

网络编程 - - TCP套接字通信及编程实现

概述 TCP&#xff08;Transmission Control Protocol&#xff0c;传输控制协议&#xff09;是一种面向连接的、可靠的传输层协议。在网络编程中&#xff0c;TCP常用于实现客户端和服务器之间的可靠数据传输。本文将基于C语言实现TCP服务端和客户端建立通信的过程。 三次握手 在…

近红外简单ROI分析matlab(NIRS_SPM)

本次笔记主要想验证上篇近红外分析是否正确&#xff0c;因为叠加平均有不同的计算方法&#xff0c;一种是直接将每个通道的5分钟实时长单独进行叠加平均&#xff0c;另一种是将通道划分为1分钟的片段&#xff0c;将感兴趣的通道数据进行对应叠加平均&#xff0c;得到一个总平均…

从玩具到工业控制--51单片机的跨界传奇【2】

咱们在上一篇博客里面讲解了什么是单片机《单片机入门》&#xff0c;让大家对单片机有了初步的了解。我们今天继续讲解一些有关单片机的知识&#xff0c;顺便也讲解一下我们单片机用到的C语言知识。如果你对C语言还不太了解的话&#xff0c;可以看看博主的C语言专栏哟&#xff…

Python调用go语言编译的库

要在 Python 中调用用 Go 语言编写的库&#xff0c;可以使用 Go 语言的 cgo 特性将 Go 代码编译成共享库&#xff08;如 .so 文件&#xff09;&#xff0c;然后在 Python 中通过 ctypes 或 cffi 模块加载和调用这个共享库。 新建main.go文件&#xff0c;使用go语言编写如下代码…

学成在线_内容管理模块_创建模块工程

学成在线模块工程 1.各个微服务依赖基础工程2.每个微服务都是一个前后端分离的项目3.xuecheng-plus-content&#xff1a;内容管理模块工程xuecheng-plus-content-modelxuecheng-plus-content-servicexuecheng-plus-content-api 1.各个微服务依赖基础工程 2.每个微服务都是一个前…

1️⃣Java中的集合体系学习汇总(List/Map/Set 详解)

目录 01. Java中的集合体系 02. 单列集合体系​ 1. Collection系列集合的遍历方式 &#xff08;1&#xff09;迭代器遍历&#xff08;2&#xff09;增强for遍历​编辑&#xff08;3&#xff09;Lambda表达式遍历 03.List集合详解 04.Set集合详解 05.总结 Collection系列…

智能科技与共情能力加持,哈曼重新定义驾乘体验

2025年1月6日&#xff0c;拉斯维加斯&#xff0c;2025年国际消费电子展——想象一下&#xff0c;当您步入一辆汽车&#xff0c;它不仅能响应您的指令&#xff0c;更能理解您的需求、适应您的偏好&#xff0c;并为您创造一个独特且专属的交互环境。作为汽车科技领域的知名企业和…

Unity中实现倒计时结束后干一些事情

问题描述&#xff1a;如果我们想实现在一个倒计时结束后可以执行某个方法&#xff0c;比如挑战成功或者挑战失败&#xff0c;或者其他什么的比如生成boss之类的功能&#xff0c;而且你又不想每次都把代码复制一遍&#xff0c;那么就可以用下面这种方法。 结构 实现步骤 创建一…

从0开始学习搭网站第二天

前言&#xff1a;今天比较惭愧&#xff0c;中午打铲吃了一把&#xff0c;看着也到钻二了&#xff0c;干脆顺手把这个赛季的大师上了&#xff0c;于是乎一直到网上才开始工作&#xff0c;同样&#xff0c;今天的学习内容大多来自mdn社区mdn 目录 怎么把文件上传到web服务器采用S…

STM32 FreeRTOS时间片调度---FreeRTOS任务相关API函数---FreeRTOS时间管理

目录 时间片调度简介 FreeRTOS任务相关API函数介绍 延时函数介绍 时间片调度简介 在FreeRTOS中&#xff0c;同等优先级的任务会轮流分享相同的CPU时间&#xff0c;这个时间被称为时间片。在这里&#xff0c;一个时间片的长度等同于SysTick中断的周期。 FreeRTOS任务相关API…

VM(虚拟机)和Linux的安装

文章目录 1.虚拟机1.1 VM的安装和删除1.1.1 安装前提1.1.2 安装步骤 1.2 虚拟机快照1.3 虚拟机的克隆 2.Linux的安装2.1 CentOS2.2 Ubuntu 1.虚拟机 &#xff08;1&#xff09;Linux系统的安装方式 ①物理机安装&#xff1a;直接将操作系统安装到服务器硬件上 ②虚拟机安装&am…

C++算法第十五天

复习周终于结束了&#xff0c;这也是复习周结束后的第一篇文章&#xff0c;请各位小伙伴们细细品尝&#xff0c;废话不多说&#xff0c;我们开始今天的讲解。 第一题 题目链接 918. 环形子数组的最大和 - 力扣&#xff08;LeetCode&#xff09; 题目解析 代码原理 注意&…

mysql-5.7.18保姆级详细安装教程

本文主要讲解如何安装mysql-5.7.18数据库&#xff1a; 将绿色版安装包mysql-5.7.18-winx64解压后目录中内容如下图&#xff0c;该例是安装在D盘根目录。 在mysql安装目录中新建my.ini文件&#xff0c;文件内容及各配置项内容如下图&#xff0c;需要先将配置项【skip-grant-tab…

<OS 有关>Ubuntu 24 安装 openssh-server, tailscale+ssh 慢增加

更新日志&#xff1a; Created on 14Jan.2025 by Dave , added openssh-server, tailescape Updated on 15Jan.2025, added "tailescape - tailscape ssh" 前期准备&#xff1a; 1. 更新可用软件包的数据库 2. 升级系统中所有已安装的软件包到最新版本 3. 安装 cur…

STM32-keil安装时遇到的一些问题以及解决方案

前言&#xff1a; 本人项目需要使用到STM32,故需配置keil 5&#xff0c;在配置时遇到了以下问题&#xff0c;并找到相应的解决方案&#xff0c;希望能够为遇到相同问题的道友提供一些解决思路 1、提示缺少&#xff08;missing&#xff09;version 5编译器 step1&#xff1a;找…

HTTP1.0/1.1/2.0/3.0 的区别?

HTTP&#xff08;Hypertext Transfer Protocol&#xff09;是用于传输超文本的协议。各版本的主要区别体现在性能优化、数据传输方式以及支持的功能上。 每一次协议的更新都是对旧协议的改进&#xff1a; 1. HTTP1.0 发布于1996年 无连接&#xff08;Connectionless&#…

蓝桥杯_B组_省赛_2022(用作博主自己学习)

题目链接算法11.九进制转十进制 - 蓝桥云课 进制转换 21.顺子日期 - 蓝桥云课 时间与日期 31.刷题统计 - 蓝桥云课 时间与日期 41.修剪灌木 - 蓝桥云课 思维 51.X 进制减法 - 蓝桥云课 贪心 61.统计子矩阵 - 蓝桥云课 二维前缀和 71.积木画 - 蓝桥云课 动态规划 82.扫雷 - 蓝桥…