c++通过tensorRT调用模型进行推理

模型来源
算法工程师训练得到的onnx模型

c++对模型的转换
拿到onnx模型后,通过tensorRT将onnx模型转换为对应的engine模型,注意:训练用的tensorRT版本和c++调用的tensorRT版本必须一致。

如何转换:

  1. 算法工程师直接转换为.engine文件进行交付。
  2. 自己转换,进入tensorRT安装目录\bin目录下,将onnx模型拷贝到bin目录,地址栏中输入cmd回车弹出控制台窗口,然后输入转换命令,如:

trtexec --onnx=model.onnx --saveEngine=model.engine --workspace=1024 --optShapes=input:1x13x512x640 --fp16

然后回车,等待转换完成,完成后如图所示:
在这里插入图片描述
并且在bin目录下生成.engine模型文件。

c++对.engine模型文件的调用和推理
首先将tensorRT对模型的加载及推理进行封装,命名为CTensorRT.cpp,老样子贴代码:

//CTensorRT.cpp
class Logger : public nvinfer1::ILogger {void log(Severity severity, const char* msg) noexcept override {if (severity <= Severity::kWARNING)std::cout << msg << std::endl;}
};Logger logger;
class CtensorRT
{
public:CtensorRT() {}~CtensorRT() {}private:std::shared_ptr<nvinfer1::IExecutionContext> _context;std::shared_ptr<nvinfer1::ICudaEngine> _engine;nvinfer1::Dims _inputDims;nvinfer1::Dims _outputDims;
public:void cudaCheck(cudaError_t ret, std::ostream& err = std::cerr){if (ret != cudaSuccess){err << "Cuda failure: " << cudaGetErrorString(ret) << std::endl;abort();}}bool loadOnnxModel(const std::string& filepath){auto builder = std::unique_ptr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(logger));if (!builder){return false;}const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);auto network = std::unique_ptr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch));if (!network){return false;}auto config = std::unique_ptr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());if (!config){return false;}auto parser = std::unique_ptr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, logger));if (!parser){return false;}parser->parseFromFile(filepath.c_str(), static_cast<int32_t>(nvinfer1::ILogger::Severity::kWARNING));std::unique_ptr<IHostMemory> plan{ builder->buildSerializedNetwork(*network, *config) };if (!plan){return false;}std::unique_ptr<IRuntime> runtime{ createInferRuntime(logger) };if (!runtime){return false;}_engine = std::shared_ptr<nvinfer1::ICudaEngine>(runtime->deserializeCudaEngine(plan->data(), plan->size()));if (!_engine){return false;}_context = std::shared_ptr<nvinfer1::IExecutionContext>(_engine->createExecutionContext());if (!_context){return false;}int nbBindings = _engine->getNbBindings();assert(nbBindings == 2); // 输入和输出,一共是2个for (int i = 0; i < nbBindings; i++){if (_engine->bindingIsInput(i))_inputDims = _engine->getBindingDimensions(i);    // (1,3,752,752)else_outputDims = _engine->getBindingDimensions(i);}return true;}bool loadEngineModel(const std::string& filepath){std::ifstream file(filepath, std::ios::binary);if (!file.good()){return false;}std::vector<char> data;try{file.seekg(0, file.end);const auto size = file.tellg();file.seekg(0, file.beg);data.resize(size);file.read(data.data(), size);}catch (const std::exception& e){file.close();return false;}file.close();auto runtime = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(logger));_engine = std::shared_ptr<nvinfer1::ICudaEngine>(runtime->deserializeCudaEngine(data.data(), data.size()));if (!_engine){return false;}_context = std::shared_ptr<nvinfer1::IExecutionContext>(_engine->createExecutionContext());if (!_context){return false;}int nbBindings = _engine->getNbBindings();assert(nbBindings == 2); // 输入和输出,一共是2个// 为输入和输出创建空间for (int i = 0; i < nbBindings; i++){if (_engine->bindingIsInput(i))_inputDims = _engine->getBindingDimensions(i);    //得到输入结构else_outputDims = _engine->getBindingDimensions(i);//得到输出结构}return true;}void ONNX2TensorRT(const char* ONNX_file, std::string save_ngine){// 1.创建构建器的实例nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);// 2.创建网络定义uint32_t flag = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);nvinfer1::INetworkDefinition* network = builder->createNetworkV2(flag);// 3.创建一个 ONNX 解析器来填充网络nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, logger);// 4.读取模型文件并处理任何错误parser->parseFromFile(ONNX_file, static_cast<int32_t>(nvinfer1::ILogger::Severity::kWARNING));for (int32_t i = 0; i < parser->getNbErrors(); ++i){std::cout << parser->getError(i)->desc() << std::endl;}// 5.创建一个构建配置,指定 TensorRT 应该如何优化模型nvinfer1::IBuilderConfig* config = builder->createBuilderConfig();// 7.指定配置后,构建引擎nvinfer1::IHostMemory* serializedModel = builder->buildSerializedNetwork(*network, *config);// 8.保存TensorRT模型std::ofstream p(save_ngine, std::ios::binary);p.write(reinterpret_cast<const char*>(serializedModel->data()), serializedModel->size());// 9.序列化引擎包含权重的必要副本,因此不再需要解析器、网络定义、构建器配置和构建器,可以安全地删除delete parser;delete network;delete config;delete builder;// 10.将引擎保存到磁盘,并且可以删除它被序列化到的缓冲区delete serializedModel;}uint32_t getElementSize(nvinfer1::DataType t) noexcept{switch (t){case nvinfer1::DataType::kINT32: return 4;case nvinfer1::DataType::kFLOAT: return 4;case nvinfer1::DataType::kHALF: return 2;case nvinfer1::DataType::kBOOL:case nvinfer1::DataType::kINT8: return 1;}return 0;}int64_t volume(const nvinfer1::Dims& d){return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies<int64_t>());}bool infer(unsigned char* input, int real_input_size, cv::Mat& out_mat){tensor_custom::BufferManager buffer(_engine);cudaStream_t stream;cudaStreamCreate(&stream); // 创建异步cuda流int binds = _engine->getNbBindings();for (int i = 0; i < binds; i++){if (_engine->bindingIsInput(i)){size_t input_size;float* host_buf = static_cast<float*>(buffer.getHostBufferData(i, input_size));memcpy(host_buf, input, real_input_size);break;}}// 将输入传递到GPUbuffer.copyInputToDeviceAsync(stream);// 异步执行bool status = _context->enqueueV2(buffer.getDeviceBindngs().data(), stream, nullptr);if (!status)return false;buffer.copyOutputToHostAsync(stream);for (int i = 0; i < binds; i++){if (!_engine->bindingIsInput(i)){size_t output_size;float* tmp_out = static_cast<float*>(buffer.getHostBufferData(i, output_size));//do your something herebreak;}}cudaStreamSynchronize(stream);cudaStreamDestroy(stream);return true;}
};

调用方式

int main()
{vector<int> dims = { 1,13,512,640 };vector<float> vall;for (int i=0;i<13;i++){string file = "D:\\xxx\\" + to_string(i) + ".png";cv::Mat mt = imread(file, IMREAD_GRAYSCALE);cv::resize(mt, mt, cv::Size(640,512));mt.convertTo(mt, CV_32F, 1.0 / 255);cv::Mat shape_xr = mt.reshape(1, mt.total() * mt.channels());std::vector<float> vec_xr = mt.isContinuous() ? shape_xr : shape_xr.clone();vall.insert(vall.end(), vec_xr.begin(), vec_xr.end());}cv::Mat mt_4d(4, &dims[0], CV_32F, vall.data());string engine_model_file = "model.engine";CtensorRT cTensor;if (cTensor.loadEngineModel(engine_model_file)){cv::Mat out_mat;if (!cTensor.infer(mt_4d.data, vall.size() * 4, out_mat))std::cout << "infer error!" << endl;elsecv::imshow("out", out_mat);}elsestd::cout << "load model file failed!" << endl;cv::waitKey(0);return 0;
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/127924.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器人制作开源方案 | 桌面级机械臂--应用设计

本节内容将基于机器视觉带着大家进行应用实训。机器视觉是人工智能正在快速发展的一个分支&#xff0c;简单说来机器视觉就是用机器代替人眼来做测量和判断。机器视觉系统是通过机器视觉产品&#xff08;即图像摄取装置&#xff0c;分CMOS和CCD两种&#xff09;将被摄取目标转换…

Android学习之路(14) Context详解

一. 简介 在 Android 开发中、亦或是面试中都离不开四大组件的身影&#xff0c;而在创建或启动这些组件时&#xff0c;并不能直接通过 new 关键字后跟类名来创建实例对象&#xff0c;而是需要有它们各自的上下文环境&#xff0c;也就是本篇文章要讨论的 Context。 1.1 Contex…

ComPtr源码分析

ComPtr源码分析 ComPtr是微软提供的用来管理COM组件的智能指针。DirectX的API是由一系列的COM组件来管理的&#xff0c;形如ID3D12Device&#xff0c;IDXGISwapChain等的接口类最终都继承自IUnknown接口类&#xff0c;这个接口类包含AddRef和Release两个方法&#xff0c;分别用…

Qt6中使用Qt Charts

官方文档&#xff1a;Qt Charts 6.5.2 如果你是使用 CMake 构建的&#xff0c;则应在 CMakeLists.txt 中添加如下两行代码&#xff1a; find_package(Qt6 REQUIRED COMPONENTS Charts)target_link_libraries(mytarget PRIVATE Qt6::Charts) 其中 mytarget 为你的项目名称。一共…

aardio语言的通用数据表维护

import win.ui; /*DSG{{*/ var winform win.form(text"通用数据表维护";right617;bottom427;bgcolor15780518) winform.add( buttonAdd{cls"button";text"增加空行";left469;top40;right564;bottom80;flat1;z2}; buttonDel{cls"button&quo…

应用爆炸式增长,看F5如何做好网络安全防护

近年来&#xff0c;应用的数量呈现爆炸式增长。出行、支付、订单&#xff0c;开会&#xff0c;数字化的形式都在取代传统的消费&#xff0c;业务开展、工作内容都在发生着巨大的变化。随着数字化进程的加速&#xff0c;安全风险、安全问题暴露得越来越多。作为拥有强大安全基因…

【雷达原理】雷达信号级建模与仿真

目录 前言一、LFMCW信号概述1.1 优点1.2 缺点 二、LFMCW信号模型2.1 发射信号模型2.2 接收信号模型2.3 信号混频 三、MATLAB仿真3.1 仿真结果3.2 代码 四、参考文献 前言 雷达信号形式多种多样&#xff0c;按照雷达的体制进行分类&#xff0c;有脉冲雷达和连续波雷达。脉冲雷达…

Nacos docker实现nacos高可用集群项目

目录 Nacos是什么&#xff1f; Nacos在公司里的运用是什么&#xff1f; 使用docker构建nacos容器高可用集群 实验规划图&#xff1a;​编辑 1、拉取nacos镜像 2、创建docker网桥&#xff08;实现集群内的机器的互联互通&#xff08;所有的nacos和mysql&#xff09;&#x…

pytorch代码实现之空间通道重组卷积SCConv

空间通道重组卷积SCConv 空间通道重组卷积SCConv&#xff0c;全称Spatial and Channel Reconstruction Convolution&#xff0c;CPR2023年提出&#xff0c;可以即插即用&#xff0c;能够在减少参数的同时提升性能的模块。其核心思想是希望能够实现减少特征冗余从而提高算法的效…

WebDAV之π-Disk派盘 + 天悦日记

天悦日记是一款清爽简约的日记记录工具,通过天悦日记app随时随地快速写日记,更有智能数据统计分析报表,多端同步多种备份,本地备份和基于WebDAV协议的云端备份。跨平台使用,支持多设备、多平台无差别使用。天悦日记将每一天经历都清晰记录在手机,一目了然知道曾经的经历,…

Linux初探 - 概念上的理解和常见指令的使用

目录 Linux背景 Linux发展史 GNU 应用场景 发行版本 从概念上认识Linux 操作系统的概念 用户的概念 路径与目录 Linux下的文件 时间戳的概念 常规权限 特殊权限 Shell的概念 常用指令 ls tree stat clear pwd echo cd touch mkdir rmdir rm cp mv …

uboot顶层Makefile前期所做工作说明四

一. uboot顶层 Makefile文件 uboot 顶层 Makefile&#xff0c;就是 uboot源码工程的根目录下的 Makefile文件。 本文继续对 uboot顶层 Makefile的前期准备工作进行介绍。续上一篇文章内容的学习&#xff0c;如下&#xff1a; uboot顶层Makefile前期所做工作说明三_凌肖战的博…

DAMO-YOLO训练自己的数据集,使用onnxruntime推理部署

DAMO-YOLO训练自己的数据集&#xff0c;使用onnxruntime推理部署 DAMO-YOLO 是阿里达摩院智能计算实验室开发的一种兼顾速度与精度的目标检测算法&#xff0c;在高精度的同时&#xff0c;保持了很高的推理速度。 DAMO-YOLO 是在 YOLO 框架基础上引入了一系列新技术&#xff0…

Java的环境配置

目录 window系统安装java下载JDK配置环境变量JAVA_HOME 设置PATH设置CLASSPATH 设置测试JDK是否安装成功 Linux&#xff0c;UNIX&#xff0c;Solaris&#xff0c;FreeBSD环境变量设置流行JAVA开发工具使用 Eclipse 运行第一个 Java 程序 window系统安装java 下载JDK 首先我们…

爬虫进阶-反爬破解5(selenium的优势和点击操作+chrome的远程调试能力+通过Chrome隔离实现一台电脑登陆多个账号)

目录 一、selenium的优势和点击操作 二、chrome的远程调试能力 三、通过Chrome隔离实现一台电脑登陆多个账号 一、selenium的优势和点击操作 1.环境搭建 工具&#xff1a;Chrome浏览器chromedriverselenium win用户&#xff1a;chromedriver.exe放在python.exe旁边 MacO…

Unity汉化一个插件 制作插件汉化工具

我是编程一个菜鸟&#xff0c;英语又不好&#xff0c;有的插件非常牛&#xff01;我想学一学&#xff0c;页面全是英文&#xff0c;完全不知所措&#xff0c;我该怎么办啊...尝试在Unity中汉化一个插件 效果&#xff1a; 思路&#xff1a; 如何在Unity中把一个自己喜欢的插件…

新装Ubuntu系统的一些配置

背景&#xff1a; 最近办公要在Ubuntu系统上进行&#xff0c;于是自己安装了一个Ubuntu22.04系统&#xff0c;记录下新系统做的一些基本配置。 环境 &#xff1a; 系统&#xff1a;Ubuntu-22.04内核&#xff1a;6.2.0-26-generic架构&#xff1a;x86_64 一、 配置root密码 新…

Centos7 完全断网离线环境下安装MySQL 8.0.33 图文教程

Centos7 完全断网离线环境安装MySQL 8.0.33 图文教程 1.1前言1.2 下载离线安装包1.3 将下载好的离线安装包上传到Centos 7 服务器1.3.1 方式一:联网环境下可利用rz命令进行文件上传1.3.2 方式二:断网环境下使用 XFtp 等软件工具进行上传1.4 解压安装包1.5 执行安装脚本1.6 重…

Linux TCP和UDP协议

目录 TCP协议TCP协议的面向连接1.三次握手2.四次挥手 TCP协议的可靠性1.TCP状态转移——TIME_WAIT 状态TIME_WAIT 状态存在的意义&#xff1a;&#xff08;1&#xff09;可靠的终止TCP连接。&#xff08;2&#xff09;让迟来的TCP报文有足够的时间被识别并被丢弃。 2.应答确认、…

信息安全技术概论-李剑-持续更新

图片和细节来源于 用户 xiejava1018 一.概述 随着计算机网络技术的发展&#xff0c;与时代的变化&#xff0c;计算机病毒也经历了从早期的破坏为主到勒索钱财敲诈经济为主&#xff0c;破坏方式也多种多样&#xff0c;由早期的破坏网络到破坏硬件设备等等 &#xff0c;这也…