Pytorch导出onnx模型并在C++环境中调用(含python和C++工程)

Pytorch导出onnx模型并在C++环境中调用(含python和C++工程)

工程下载链接:Pytorch导出onnx模型并在C++环境中调用(python和C++工程)

机器学习多层感知机MLP的Pytorch实现-以表格数据为例-含数据集和PyCharm工程中简单介绍了在python中使用pytorch搭建神经网络模型的步骤和代码工程,此处介绍AI模型的跨平台调用问题,即使用跨平台的ONNX框架,在C++代码中进行模型调用

参考:pytorch导出模型并使用onnxruntime C++部署加载模型推理

目录

  • Pytorch导出onnx模型并在C++环境中调用(含python和C++工程)
    • 1、pkl权重文件转换为onnx模型
      • 1.1、加载保存的模型
      • 1.2、pytorch调用加载的模型进行测试
      • 1.3、导出onnx模型
      • 1.4、在python中调用onnx模型测试
      • 1.5、全部代码
    • 2、C++调用onnx模型
      • 2.1、库的下载安装和官方手册
      • 2.2、C++调用代码实现
      • 2.3、注意,模型文件和onnx的dll要在exe同一级目录
    • 3、运行时遇到的一些问题
      • 编译报错---error: ‘_Frees_ptr_opt_‘ has not been declared
      • 运行报错---The given version [14] is not supported, only version 1 to 10 is supported in this build.

1、pkl权重文件转换为onnx模型

在机器学习多层感知机MLP的Pytorch实现-以表格数据为例-含数据集和PyCharm工程中,我们对训练完成的模型进行了模型的保存:

torch.save(model.state_dict(),'weights/mlp_weights-epoch%d-Total_loss%.4f-val_loss%.4f.pkl' % ((epoch + 1), train_loss, val_loss / (iteration + 1)))

此处我们需要先加载保存的模型,如何再将其导出为onnx格式。

1.1、加载保存的模型

这一步主要是要把保存的模型恢复出来:

import numpy as np
import onnxruntime
import torch
from torch import nn
import torch.nn.functional as F# 定义多层感知机(MLP)类,继承自nn.Module
class MLP(nn.Module):# 类的初始化方法def __init__(self):# 调用父类nn.Module的初始化方法super(MLP, self).__init__()self.hidden1 = nn.Linear(in_features=8, out_features=50, bias=True)self.hidden2 = nn.Linear(50, 50)self.hidden3 = nn.Linear(50, 50)self.hidden4 = nn.Linear(50, 50)self.predict = nn.Linear(50, 1)# 定义前向传播方法def forward(self, x):# x是输入数据,通过第一个隐藏层并应用ReLU激活函数x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = F.relu(self.hidden3(x))x = F.relu(self.hidden4(x))output = self.predict(x)# 将输出展平为一维张量# out = output.view(-1)return output# 检查是否有可用的GPU
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
# 定义模型并将其转移到GPU
model = MLP().to(device)
model.load_state_dict(torch.load('weights/mlp_weights-epoch1000-Total_loss0.0756-val_loss0.0105.pkl',weights_only=True))
model.eval()

1.2、pytorch调用加载的模型进行测试

先简单测试一下这个模型,定义一个输入全为0的数组作为输入,打印输出的结果:

# 初始化一个输入全为0的数组进行测试
x = (torch.from_numpy(np.array([0, 0, 0, 0, 0, 0, 0, 0]).astype(np.float32)).to(device))
y = model(x).cpu().detach().numpy()[0]
print(f"pytorch直接测试结果为: {y}")

在这里插入图片描述

1.3、导出onnx模型

使用下面的代码将原来的模型model导出为onnx的模型,其中x是上面定义的案例输入

batch_size = 1  # 批处理大小
export_onnx_file = "test.onnx"  # 目的ONNX文件名
torch.onnx.export(model,(x),export_onnx_file,opset_version=10,do_constant_folding=True,  # 是否执行常量折叠优化input_names=["input"],  # 输入名output_names=["output"],  # 输出名dynamic_axes={"input": {0: "batch_size"},  # 批处理变量"output": {0: "batch_size"}})

这个函数的具体定义可以参考:从pytorch转换到onnx

1.4、在python中调用onnx模型测试

使用下面代码加载onnx模型并进行测试:

resnet_session = onnxruntime.InferenceSession(export_onnx_file)
inputs = {resnet_session.get_inputs()[0].name: x.cpu().detach().numpy()}
outs = resnet_session.run(None, inputs)[0][0]
print(f"py onnx直接测试结果为: {outs}")

可以看到pytorch的结果和onnx的结果是基本一致的
在这里插入图片描述

1.5、全部代码

import numpy as np
import onnxruntime
import torch
from torch import nn
import torch.nn.functional as F# 定义多层感知机(MLP)类,继承自nn.Module
class MLP(nn.Module):# 类的初始化方法def __init__(self):# 调用父类nn.Module的初始化方法super(MLP, self).__init__()self.hidden1 = nn.Linear(in_features=8, out_features=50, bias=True)self.hidden2 = nn.Linear(50, 50)self.hidden3 = nn.Linear(50, 50)self.hidden4 = nn.Linear(50, 50)self.predict = nn.Linear(50, 1)# 定义前向传播方法def forward(self, x):# x是输入数据,通过第一个隐藏层并应用ReLU激活函数x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = F.relu(self.hidden3(x))x = F.relu(self.hidden4(x))output = self.predict(x)# 将输出展平为一维张量# out = output.view(-1)return output# 检查是否有可用的GPU
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
# 定义模型并将其转移到GPU
model = MLP().to(device)
model.load_state_dict(torch.load('weights/mlp_weights-epoch1000-Total_loss0.0756-val_loss0.0105.pkl',weights_only=True))
model.eval()# 初始化一个输入全为0的数组进行测试
x = (torch.from_numpy(np.array([0, 0, 0, 0, 0, 0, 0, 0]).astype(np.float32)).to(device))
y = model(x).cpu().detach().numpy()[0]
print(f"pytorch直接测试结果为: {y}")batch_size = 1  # 批处理大小
export_onnx_file = "test.onnx"  # 目的ONNX文件名
torch.onnx.export(model,(x),export_onnx_file,opset_version=10,do_constant_folding=True,  # 是否执行常量折叠优化input_names=["input"],  # 输入名output_names=["output"],  # 输出名dynamic_axes={"input": {0: "batch_size"},  # 批处理变量"output": {0: "batch_size"}})resnet_session = onnxruntime.InferenceSession(export_onnx_file)
inputs = {resnet_session.get_inputs()[0].name: x.cpu().detach().numpy()}
outs = resnet_session.run(None, inputs)[0][0]
print(f"py onnx直接测试结果为: {outs}")

2、C++调用onnx模型

2.1、库的下载安装和官方手册

这个整个库的下载还是要到官方的github仓库去:microsoft/onnxruntime
具体的使用方式参考英文的手册:https://onnxruntime.ai/docs/】

此处下载的window版本的,下载下来可以得到头文件和库文件:
在这里插入图片描述在这里插入图片描述
因此在实际编程的时候我使用的Cmakelist来链接到相关的库,我是使用VS code + gcc构成的C++编译环境

if (WIN32)include_directories(${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-1.14.0/onnxruntime-win-x64-1.14.0/include)
else()include_directories(${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-1.14.0/onnxruntime-linux-x64-1.14.0/include)
endif()if (WIN32)link_directories(${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-1.14.0/onnxruntime-win-x64-1.14.0/lib)
else()link_directories(${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-1.14.0/onnxruntime-linux-x64-1.14.0/lib)
endif()

实际的工程目录结构如下所示:
在这里插入图片描述

2.2、C++调用代码实现

下面代码实现和1.4、在python中调用onnx模型测试相同的效果,输入是全0的数组,进行计算并返回相关结果:

#include <iostream>
#include <array>
#include <algorithm>
#include "onnxruntime_cxx_api.h"#define ONNX_IN_OUT_SIZE_MAX 20int main(int argc, char* argv[])
{//print helloprintf("hello");std::vector<float> input_matrix_vector={0, 0, 0, 0, 0, 0, 0, 0};int onnx_input_shape = 8;int onnx_output_shape = 1;// --- define model path#if _WIN32const wchar_t* model_path = L"./model.onnx"; // you can use string to wchar_t* function to convert#elseconst char* model_path = "./model.onnx";#endif// --- init onnxruntime envOrt::Env env(ORT_LOGGING_LEVEL_WARNING, "Default");auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);// set optionsOrt::SessionOptions session_option;session_option.SetIntraOpNumThreads(1); // extend the number to do parallelsession_option.SetGraphOptimizationLevel(ORT_ENABLE_ALL);// --- prepare dataconst char* input_names[] = { "input" }; // must keep the same as model exportconst char* output_names[] = { "output" };std::array<float, ONNX_IN_OUT_SIZE_MAX> input_matrix;std::array<float, ONNX_IN_OUT_SIZE_MAX> output_matrix;if(input_matrix_vector.size()>ONNX_IN_OUT_SIZE_MAX){throw std::runtime_error("input_matrix_vector.size()<ONNX_IN_OUT_SIZE_MAX");}std::copy(input_matrix_vector.begin(),input_matrix_vector.end(),input_matrix.begin());// must use int64_t type to match argsstd::array<int64_t, 1> input_shape{ onnx_input_shape };std::array<int64_t, 1> output_shape{ onnx_output_shape };Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_matrix.data(), input_matrix.size(), input_shape.data(), input_shape.size());Ort::Value output_tensor = Ort::Value::CreateTensor<float>(memory_info, output_matrix.data(), output_matrix.size(), output_shape.data(), output_shape.size());// --- predictOrt::Session session(env, model_path, session_option); // FIXME: must check if model file exist or valid, otherwise this will cause crashsession.Run(Ort::RunOptions{ nullptr }, input_names, &input_tensor, 1, output_names, &output_tensor, 1); // here only use one input output channelstd::vector<float> outputVector(onnx_output_shape);std::copy(output_matrix.begin(),output_matrix.begin()+onnx_output_shape,outputVector.begin());std::cout << "--- predict result ---" << std::endl;// matrix outputstd::cout << "ouput matrix: ";for (int i = 0; i < outputVector.size(); i++)std::cout << outputVector[i] << " ";std::cout << std::endl;// argmax value// int argmax_value = std::distance(output_matrix.begin(), std::max_element(output_matrix.begin(), output_matrix.end()));// std::cout << "output argmax value: " << argmax_value << std::endl;// getchar();return 0;
}

可以看到最终的返回结果为:
在这里插入图片描述
和之前在python中的结果是一致的!!!

2.3、注意,模型文件和onnx的dll要在exe同一级目录

在这里插入图片描述

3、运行时遇到的一些问题

编译报错—error: ‘Frees_ptr_opt‘ has not been declared

在编译器命令行或者代码中定义这些宏,使其在非MSVC环境中被忽略。在代码的开头( onnxruntime_c_api.h 文件中)添加以下代码:

#pragma once
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
//add code here
#ifndef _Frees_ptr_opt_ 
#define _Frees_ptr_opt_ 
#endif 
#ifndef _In_ 
#define _In_ 
#endif 
#ifndef _Out_ 
#define _Out_ 
#endif 
#ifndef _Inout_ 
#define _Inout_
#endif

运行报错—The given version [14] is not supported, only version 1 to 10 is supported in this build.

将onnxruntime-1.14.0\onnxruntime-win-x64-1.14.0\lib的onnxruntime.dll复制一份到exe的目录下面,这是因为路径默认索引的是System32中的老版本库文件:
在这里插入图片描述
System32中存在老版本的onnx动态库:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/618.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2025新春烟花代码(二)HTML5实现孔明灯和烟花效果

效果展示 源代码 <!DOCTYPE html> <html lang"en"> <script>var _hmt _hmt || [];(function () {var hm document.createElement("script");hm.src "https://hm.baidu.com/hm.js?45f95f1bfde85c7777c3d1157e8c2d34";var …

[Transformer] The Structure of GPT, Generative Pretrained Transformer

The Structure of Generative Pretrained Transformer Reference: The Transformer architecture of GPT models How GPT Models Work

使用MATLAB正则表达式从文本文件中提取数据

使用MATLAB正则表达式从文本文件中提取数据 使用Python正则表达式从文本文件中提取数据的代码请看这篇文章使用正则表达式读取文本数据【Python】-CSDN博客 文本数据格式 需要提取 V 后面的数据, 并绘制出曲线. index 1V 0.000000W 0.000000E_theta 0.000000UINV 0.0…

电脑提示directx错误导致玩不了游戏怎么办?dx出错的解决方法

想必大家都有过这样的崩溃瞬间&#xff1a;满心欢喜打开心仪的游戏&#xff0c;准备在虚拟世界里大杀四方或者畅游冒险&#xff0c;结果屏幕上突然弹出个 DirectX 错误的提示框&#xff0c;紧接着游戏闪退&#xff0c;一切美好戛然而止。DirectX 作为 Windows 系统下游戏运行的…

python学opencv|读取图像(三十二)使用cv2.getPerspectiveTransform()函数制作透视图-变形的喵喵

【1】引言 前序已经对图像展开了平移、旋转缩放和倾斜拉伸技巧探索&#xff0c;相关链接为&#xff1a; python学opencv|读取图像&#xff08;二十八&#xff09;使用cv2.warpAffine&#xff08;&#xff09;函数平移图像-CSDN博客 python学opencv|读取图像&#xff08;二十…

初学spring 框架(了解spring框架的技术背景,核心体现,入门案例)

目录 技术背景 为什么要学习spring 框架&#xff1f; 学习spring 框架可以解决什么问题&#xff1f; 了解spring框架的核心体现 入门案例 步骤 1 导入 依赖 2 搭建三层架构体现【根据实际情况 构建】 3 添加配置文件 Test 测试类中 从 Ioc 容器 获取 Student 对象 总…

用户界面的UML建模11

然而&#xff0c;在用户界面方面&#xff0c;重要的是要了解《boundary》类是如何与这个异常分层结构进行关联的。 《exception》类的对象可以作为《control》类的对象。因此&#xff0c;《exception》类能够聚合《boundary》类。 参见图12&#xff0c;《exception》Database…

IDEA的常用设置

目录 一、显示顶部工具栏 二、设置编辑区字体按住鼠标滚轮变大变小&#xff08;看需要设置&#xff09; 三、设置自动导包和优化导入的包&#xff08;有的时候还是需要手动导包&#xff09; 四、设置导入同一个包下的类&#xff0c;超过指定个数的时候&#xff0c;合并为*&a…

STM32-笔记39-SPI-W25Q128

一、什么是SPI&#xff1f; SPI是串行外设接口&#xff08;Serial Peripheral Interface&#xff09;的缩写&#xff0c;是一种高速的&#xff0c;全双工&#xff0c;同步的通信总线&#xff0c;并且 在芯片的管脚上只占用四根线&#xff0c;节约了芯片的管脚&#xff0c;同时为…

uniapp小程序中隐藏顶部导航栏和指定某页面去掉顶部导航栏小程序

uniappvue3开发小程序过程中隐藏顶部导航栏和指定某页面去掉顶部导航栏方法 在page.json中 "globalStyle": {"navigationStyle":"custom",}, 如果是指定某个页面关闭顶部导航栏&#xff0c;在style中添加"navigationStyle": "cus…

【电子通识】PWM驱动让有刷直流电机恒流工作

电机的典型驱动方法包括电压驱动、电流驱动以及PWM驱动。本文将介绍采用PWM驱动方式的恒流工作。 首先介绍的是什么是PWM驱动的电机恒流工作&#xff0c;其次是PWM驱动电机恒流工作时电路的工作原理。 PWM驱动 当以恒定的电流驱动电机时&#xff0c;电机会怎样工作呢&#xff1…

Mysql--运维篇--主从复制和集群(主从复制I/O线程,SQL线程,二进制日志,中继日志,集群NDB)

一、主从复制 MySQL的主从复制&#xff08;Master-Slave Replication&#xff09;是一种数据冗余和高可用性的解决方案&#xff0c;它通过将一个或多个从服务器&#xff08;Slave&#xff09;与主服务器&#xff08;Master&#xff09;同步来实现。主从复制的基本原理是&#…

Mac 删除ABC 输入法

参考链接&#xff1a;百度安全验证 Mac下删除系统自带输入法ABC&#xff0c;正解&#xff01;_mac删除abc输入法-CSDN博客 ABC 输入法和搜狗输入法等 英文有冲突~~ 切换后还会在英文状态&#xff0c;可以删除 &#xff1b;可能会对DNS 输入有影响&#xff0c;但是可以通过复…

Mac——Cpolar内网穿透实战

摘要 本文介绍了在Mac系统上实现内网穿透的方法&#xff0c;通过打开远程登录、局域网内测试SSH远程连接&#xff0c;以及利用cpolar工具实现公网SSH远程连接MacOS的步骤。包括安装配置homebrew、安装cpolar服务、获取SSH隧道公网地址及测试公网连接等关键环节。 1. MacOS打开…

Unity中对象池的使用(用一个简单粗暴的例子)

问题描述&#xff1a;Unity在创建和销毁对象的时候是很消耗性能的&#xff0c;所以我们在销毁一个对象的时候&#xff0c;可以不用Destroy&#xff0c;而是将这个物体隐藏后放到回收池里面&#xff0c;当再次需要的时候如果回收池里面有之前回收的对象&#xff0c;就直接拿来用…

【再谈设计模式】模板方法模式 - 算法骨架的构建者

一、引言 在软件工程、软件开发过程中&#xff0c;我们经常会遇到一些算法或者业务逻辑具有固定的流程步骤&#xff0c;但其中个别步骤的实现可能会因具体情况而有所不同的情况。模板方法设计模式&#xff08;Template Method Design Pattern&#xff09;就为解决这类问题提供了…

【大模型】Langchain-Chatchat-v0.3.1 的环境配置

1 Langchahin-chatchat的工程简介 本项目是利用 langchain 思想实现的基于本地知识库的问答应用&#xff0c;目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。 本项目的最新版本中可使用 Xinference、Ollama 等框架接入 GLM-4-Chat、 Qwen2-In…

jenkins的作用以及操作

一 jenkins 1.1 概念 1.2 流程 1.2.1 流程 1.2.2 配置 1.3 jenkins容器自动化部署

【UE5 C++课程系列笔记】29——在UE中使用第三方库的流程

目录 前言 步骤 一、新建插件 二、创建第三方库 三、使用第三方库 前言 主要就是介绍如何将普通C++工程生成的头文件和.dll导入到UE中去使用。 步骤 一、新建插件 1. 打开插件浏览器选项卡 2. 打开插件创建器 3. 选择“第三方库”,这里命名为“MyThirdPartyLibrary…

Mybatis——Mybatis开发经验总结

摘要 本文主要介绍了MyBatis框架的设计与通用性&#xff0c;阐述了其作为Java持久化框架的亮点&#xff0c;包括精良的架构设计、丰富的扩展点以及易用性和可靠性。同时&#xff0c;对比了常见持久层框架&#xff0c;分析了MyBatis在关系型数据库交互中的优势。此外&#xff0…