文章目录
- tensorRT添加自定义层步骤
- 1. trt如何解析onnx的? 整体流程图
- 2. builtin_op_importor是干什么的?
- 3. 怎么添加trt plugin
- 4. 如何进行量化collection过程
- references
- nvidia 官方plugin文档: https://www.nvidia.cn/content/dam/en-zz/zh_cn/assets/webinars/2020/feb21/TensorRT_7-0_plugin.pdf
- 比较旧了
tensorRT添加自定义层步骤
- 下载源码 onnx-tensorrt
- 参考tensorRT官网源码中plugin中的instanceNormalizationPlugin,写好自己customlayer.h和customlayer.cpp的实现。(都是官方写好的自定义op的示例,是op的逻辑)
- 在builtin_op_importers.cpp中使用DEFINE_BUILTIN_OP_IMPORTER添加对自己注册Op的使用(解析器)。
- 在CMakeLists.txt中,set(IMPORTER_SOURCES… 下面将自己的customlayer.cpp加进去。
- 按照教程,重新编译自己的onnx-tensorRT, 生成onnx2trt工具,然后用这个工具可以将onnx转化为trt文件
onnx2trt ./onnx/customer_op.onnx -v
1. trt如何解析onnx的? 整体流程图
整个tensorRT解析onnx开始的解析入口: parser->parseFromFile …
- 上图描述了如何构建解析自定义onnx node算子的解析器,并且将其保存到std::unordered_map<string, T>数据结构的builtin_op_importers中,该结构在ModelImporter解析器中被使用,而ModelImporter类是被createParser通过createNvOnnxParser_INTERNAL调用的; 这样整个链路就解释通了。
- 这个自定义onnx node算子的解析器是用来解析onnx中自定义算子中的保存的不变值,如attributes,weights等,解析出来,应该是通过PluginFieldCollection,传到tensorRT的plugin的执行体中,在执行的时候被使用, 具体这个挂钩过程如何实现的呢? 看如下介绍builtin_op_importor是干什么的?
2. builtin_op_importor是干什么的?
- 这个cpp主要完成将自定义onnx node的解析器注册到builtin_op_importers中
- 并且构建解析器,解析自定义node中的attributes和weights等内容,并将属性值通过PluginFieldCollection, 传到tensorRT的PluginCreator的createPlugin中,并且通过类的实例化,将必要的属性参数传递给plugin类.
注册过程是怎样的?
// 通过DEFINE_BUILTIN_OP_IMPORTER进行注册
// 这个op的名字对应这onnx中node的type, 而函数体中的pluginName对应的是你在tensorRT中plugin构建的时候,getPluginName中返回的名字#define DEFINE_BUILTIN_OP_IMPORTER(op) \NodeImportResult import##op( \IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, std::vector<TensorOrWeights>& inputs); \static const bool op##_registered_builtin_op = registerBuiltinOpImporter(#op, import##op); \IGNORE_UNUSED_GLOBAL(op##_registered_builtin_op); \NodeImportResult import##op( \IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, std::vector<TensorOrWeights>& inputs)bool registerBuiltinOpImporter(std::string op, NodeImporter const& importer)
{bool inserted = getBuiltinOpImporterMap().insert({op, importer}).second;assert(inserted);return inserted;
}string_map<NodeImporter>& getBuiltinOpImporterMap()
{static string_map<NodeImporter> builtin_op_importers;return builtin_op_importers;
}
**如何解析onnx node? **
DEFINE_BUILTIN_OP_IMPORTER(op_name) {# ... ... 省略一部分内容
OnnxAttrs attrs(node, ctx); // 获取属性
float epsilon = attrs.get("epsilon", 1e-5f);// Populate instanceNormalization plugin properties.
const std::string pluginName = "InstanceNormalization_TRT";
const std::string pluginVersion = "1";
std::vector<nvinfer1::PluginField> f;
f.emplace_back("epsilon", &epsilon, nvinfer1::PluginFieldType::kFLOAT32, 1);
f.emplace_back("scales", scale_weights.values, nvinfer1::PluginFieldType::kFLOAT32, scale_weights.count());
f.emplace_back("bias", bias_weights.values, nvinfer1::PluginFieldType::kFLOAT32, bias_weights.count());// Create plugin from registry
nvinfer1::IPluginV2* plugin = createPlugin(node.name(), importPluginCreator(pluginName, pluginVersion), f);ASSERT(plugin != nullptr && "InstanceNormalization plugin was not found in the plugin registry!",
ErrorCode::kUNSUPPORTED_NODE);auto* layer = ctx->network()->addPluginV2(&tensorPtr, 1, *plugin); // 自定义节点和层
ctx->registerLayer(layer, node.name());//
// // Map Quantization node to a scale node
auto layer = ctx->network()->addScale(input, mode, shift, scale, power);... ...
}
在onnx_tensorRT库中的builtin_op_importers.cpp中有很多解析器的例子,可以仿照的写自己的
带参数的onnx的weight是如何传到Plugin中被执行的?
1. 属性如何传过来的? 使用clip layer的例子
// 通过这个语句,将f传递给下面的IPluginV2* ClipPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) 的fc
// Create plugin from registry (DEFINE_BUILTIN_OP_IMPORTER(op_name)中的过程)
nvinfer1::IPluginV2* plugin = createPlugin(node.name(), importPluginCreator(pluginName, pluginVersion), f);// fc中保存的onnx的属性,传递给ClipPlugin
IPluginV2* ClipPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept
{float clipMin, clipMax;const PluginField* fields = fc->fields;// Parse fields from PluginFieldCollectionassert(fc->nbFields == 2);for (int i = 0; i < fc->nbFields; i++){if (strcmp(fields[i].name, "clipMin") == 0){assert(fields[i].type == PluginFieldType::kFLOAT32);clipMin = *(static_cast<const float*>(fields[i].data));}else if (strcmp(fields[i].name, "clipMax") == 0){assert(fields[i].type == PluginFieldType::kFLOAT32);clipMax = *(static_cast<const float*>(fields[i].data));}}return new ClipPlugin(name, clipMin, clipMax);
}//ClipPlugin又传递给自己的成员变量, 在执行enqueue的时候,成员变量就能被用了
ClipPlugin::ClipPlugin(const std::string name, float clipMin, float clipMax): mLayerName(name), mClipMin(clipMin), mClipMax(clipMax)
{
}
2. weights 如何传过来的?
// 也是将weights 通过传给PluginField 然后传递给Plugin enqueue进行使用, 和属性一致
3. 怎么添加trt plugin
- 继承IPluginV2的一些子类,然后实现一些成员函数,主要执行体是enqueue函数;成员函数的解释看文章[AI部署-TensorRT] IPluginV2的解析
- 构建继承IPluginCreator类的子类,并用REGISTER_TENSORRT_PLUGIN将自定义层注册到tensorRT中
class OnnxPoolPluginV2Creator : public IPluginCreator
{
public:const char* getPluginName() const noexcept override{return "MaxPool";}const char* getPluginVersion() const noexcept override{return "2";}const PluginFieldCollection* getFieldNames() noexcept override{return &mFieldCollection;}IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override{auto* plugin = new OnnxPoolPluginV2(*fc);mFieldCollection = *fc;mPluginName = name;return plugin;}IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override{auto* plugin = new OnnxPoolPluginV2(serialData, serialLength);mPluginName = name;return plugin;}void setPluginNamespace(const char* libNamespace) noexcept override{mNamespace = libNamespace;}const char* getPluginNamespace() const noexcept override{return mNames集成pace.c_str();}private:std::string mNamespace;std::string mPluginName;PluginFieldCollection mFieldCollection{0, nullptr};
};REGISTER_TENSORRT_PLUGIN(OnnxPoolPluginV2Creator);
4. 如何进行量化collection过程
- plugin中enqueue的输入InputDesc中存在scale变量,这个应该是用于plugin在做PTQ的时候在collection 量化scale时需要使用和更新的,让后在PTQ过程输出cache的时候,根据这个scale导出到文档。
- TODO
references
- TensorRT5.1.5.0 实践 onnx-TensorRT的自定义op
- tensorRT samples
- tensorRT部署教程-bilibili
- tensorRT部署教程-sourcecode