【深度学习框架格式转化】【CPU】Pytorch模型转ONNX模型格式流程详解【入门】

【深度学习框架格式转化】【GPU】Pytorch模型转ONNX模型格式流程详解【入门】

提示:博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论

文章目录

  • 【深度学习框架格式转化】【GPU】Pytorch模型转ONNX模型格式流程详解【入门】
  • 前言
  • PyTorch模型环境搭建(CPU)
  • 安装onnx和onnxruntime(CPU)
  • pytorch2onnx
  • 总结


前言

神经网络的模型通常在深度学习框架(PyTorc、TensorFlow和Caffe等)下训练得到,这些特定环境的深度学习框架依赖较多,规模较大,不适合在生产环境中安装,onnx支持大多数框架下模型的转换,便于整合模型,并且深度学习模型需要大量的算力才能满足实时运行需求,需要优化模型的运行效率,onnx并则能带来稳定的提速。
onnx还能再转化成TensorRT(GPU)格式和OpenVINO(CPU)格式进行推理,进一步提升速度

CPU模式下的格式转化,无论Pytorch还是ONNX搭建流程都十分简便,适合入门学习,也对极其适合对硬件要求很低的轻量级模型的运行。

后续可以学习【GPU】Pytorch模型转ONNX格式流程详解


PyTorch模型环境搭建(CPU)

博主以伪装对象分割(COS)之PFNet算法为例进行详解:【PFNet-pytorch代码】。
用PyTorch运行一个伪装对象分割模型PFNet,并把模型部署到ONNX Runtime这个推理引擎上。
博主在win10环境下装anaconda环境,搭建PFNet模型运行的PyTorch环境(官网下载地址)

# 创建虚拟环境
conda create -n pytorch2onnx_cpu python=3.10 -y
# 激活环境
activate pytorch2onnx_cpu 
# 下载githup源代码到合适文件夹,并cd到代码文件夹内(科学上网)
git clone https://github.com/Mhaiyang/CVPR2021_PFNet.git
# 安装pytorch(cpu)
pip3 install torch torchvision torchaudio

博主在这里不会详细讲解代码内容,只关注代码的使用,即代码的测试过程。源码作者提供了预训练权重和测试数据,博主整理到了【百度云,提取码:a660】上供大家下载。
下载resnet50-19c8e357.pth放置到CVPR2021_PFNet\backbone\resnet下:

下载PFNet.pth放置到CVPR2021_PFNet下:

下载测试数据集CAMO_TestingDataset.zip、CHAMELEON_TestingDataset.zip和COD10K_TestingDataset.zip解压重命名放置到CVPR2021_PFNet\data\test中:

使用预训练权重进行测试,修改infer.py文件内容

# 1.修改infer.py,只保留在test中有的数据集
to_test = OrderedDict([('CHAMELEON', chameleon_path),('CAMO', camo_path),('COD10K', cod10k_path),# ('NC4K', nc4k_path)])# 2.修改infer.py,删除/注释所有使用gpu相关代码
# device_ids = [0]
# torch.cuda.set_device(device_ids[0])# net = PFNet(backbone_path).cuda(device_ids[0])
net = PFNet(backbone_path)# img_var = Variable(img_transform(img).unsqueeze(0)).cuda(device_ids[0])
img_var = Variable(img_transform(img).unsqueeze(0))# 3.修改config.py中的内容
# datasets_root = '../data/NEW'修改成datasets_root = './data              

在CVPR2021_PFNet\results可以查看效果:

数据量比较大,运行速度也不算快。

到这里PyTorch模型环境搭建(CPU)完毕。


安装onnx和onnxruntime(CPU)

需要在anaconda虚拟环境安装onnx和onnxruntime

# 激活环境
activate pytorch2onnx_cpu 
# 安装onnx
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple onnx
# 安装CPU版
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple onnxruntime

获取ONNX Runtime的版本信息

import onnxruntime as ort
print("ONNX Runtime version:", ort.__version__)

pytorch2onnx

在CVPR2021_PFNet目录下新建pytorch2onnx.py文件并执行文件

import onnx
from onnx import numpy_helper
import torch
from PFNet import PFNet
backbone_path = './backbone/resnet/resnet50-19c8e357.pth'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
example = torch.randn(1,3, 416, 416).to(device)     # 1 3 416 416
print(example.dtype)
model = PFNet(backbone_path)                        # PFNet网络模型model.load_state_dict(torch.load(r'PFNet.pth'))     # 加载训练好的模型
model = model.to(device)                            # 模型放到cpu上
model.eval()torch.onnx.export(model, example, r"PFNet.onnx")     	# 导出模型
model_onnx = onnx.load(r"PFNet.onnx")                   # onnx加载保存的onnx模型
onnx.checker.check_model(model_onnx)                    # 检查模型是否有问题
print(onnx.helper.printable_graph(model_onnx.graph))    # 打印onnx网络

pytorch模型转化成onnx模型成功。
现在抛开任何pytorch相关的依赖,使用onnx模型完成测试,新建run_onnx.py,代码是参考源代码的推理部分infer.py改写来的。

import onnxruntime as ort
import numpy as np
from collections import OrderedDict
from config import *
from PIL import Image
from numpy import mean
import time
import datetimedef composed_transforms(image):mean = np.array([0.485, 0.456, 0.406])  # 均值std = np.array([0.229, 0.224, 0.225])  # 标准差# transforms.Resize是双线性插值resized_image = image.resize((args['scale'], args['scale']), resample=Image.BILINEAR)# onnx模型的输入必须是np,并且数据类型与onnx模型要求的数据类型保持一致resized_image = np.array(resized_image)normalized_image = (resized_image/255.0 - mean) / stdreturn np.round(normalized_image.astype(np.float32), 4)def check_mkdir(dir_name):if not os.path.exists(dir_name):os.makedirs(dir_name)to_test = OrderedDict([# ('CHAMELEON', chameleon_path),# ('CAMO', camo_path),('COD10K', cod10k_path),])
args = {'scale': 416,'save_results': True
}def main():# 保存检测结果的地址results_path = './results2'exp_name = 'PFNet'providers = ["CPUxecutionProvider"]ort_session = ort.InferenceSession("PFNet.onnx", providers=providers)  # 创建一个推理sessioninput_name = ort_session.get_inputs()[0].name# 输出有四个output_names = [output.name for output in ort_session.get_outputs()]start = time.time()for name, root in to_test.items():time_list = []image_path = os.path.join(root, 'image')if args['save_results']:check_mkdir(os.path.join(results_path, exp_name, name))img_list = [os.path.splitext(f)[0] for f in os.listdir(image_path) if f.endswith('jpg')]for idx, img_name in enumerate(img_list):img = Image.open(os.path.join(image_path, img_name + '.jpg')).convert('RGB')w, h = img.size#  对原始图像resize和归一化img_var = composed_transforms(img)# np的shape从[w,h,c]=>[c,w,h]img_var = np.transpose(img_var, (2, 0, 1))# 增加数据的维度[c,w,h]=>[bathsize,c,w,h]img_var = np.expand_dims(img_var, axis=0)start_each = time.time()prediction = ort_session.run(output_names, {input_name: img_var})time_each = time.time() - start_eachtime_list.append(time_each)# 除去多余的bathsize维度,NumPy变会PIL同样需要变换数据类型# *255替换pytorch的to_pilprediction = (np.squeeze(prediction[3])*255).astype(np.uint8)if args['save_results']:(Image.fromarray(prediction).resize((w, h)).convert('L').save(os.path.join(results_path, exp_name, name, img_name + '.png')))print(('{}'.format(exp_name)))print("{}'s average Time Is : {:.3f} s".format(name, mean(time_list)))print("{}'s average Time Is : {:.1f} fps".format(name, 1 / mean(time_list)))end = time.time()print("Total Testing Time: {}".format(str(datetime.timedelta(seconds=int(end - start)))))
if __name__ == '__main__':main()

在CVPR2021_PFNet\results2可以查看效果:
在这里插入图片描述

到这里读者将代码迁移到新机器时,可以不再安装pytorch相关依赖就能使用模型的预测功能,这可以极大的减少所依赖环境的大小。


总结

尽可能简单、详细的介绍CPU模式下Pytorch模型转ONNX格式的流程,后续介绍GPU版本的格式转化,学习难度只是有略微提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/136700.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

中国汽车供应商远赴德国,中国智驾方案能否远渡重洋?

作者|Amy 编辑|德新 今年的上海车展,中国智能汽车的进步有目共睹,吸引了大批外企高管和研发人员的关注,甚至引发了海外车企一系列的动作和调整。 而在刚刚结束的慕尼黑车展,中国车企及汽车供应链把「肌肉」秀到了现代汽车起源地…

大模型如何赋能智能客服

2022年,大模型技术的出色表现让人们瞩目。随着深度学习和大数据技术的发展,大模型在很多领域的应用已经成为可能。许多公司开始探索如何将大模型技术应用于自己的业务中,智能客服也不例外。 智能客服是现代企业中非常重要的一部分&#xff0…

Python 图形化界面基础篇:创建工具栏

Python 图形化界面基础篇:创建工具栏 引言 Tkinter 库简介步骤1:导入 Tkinter 模块步骤2:创建 Tkinter 窗口步骤3:创建工具栏步骤4:向工具栏添加工具按钮步骤5:处理工具按钮的点击事件步骤6:启动…

基于matlab实现的卡尔曼滤波匀加速直线运动仿真

完整程序: clear clc %% 初始化参数 delta_t 0.1; %采样时间 T 8; %总运行时长 t 0:delta_t:T; %时间序列 N length(t); %序列的长度 x0 0; %初始位置 u0 0; %初速度 U 10; %控制量、加速度 F [1 delta_t 0 1]; %状态转移矩阵 B …

【c#-Nuget 包“在此源中不可用”】 Nuget package “Not available in this source“

标题c#-Nuget 包“在此源中不可用”…但 VS 仍然知道它吗? (c# - Nuget package “Not available in this source”… but VS still knows about it?) 听起来您的公司有一个发布包的内部 NuGet feed,而不是公共 NuGet.org feed。您应该向您的同事询问…

CentOS 7 安装踩坑

CentOS与Ubuntu并称为Linux最著名的两个发行版,但由于笔者主要从事深度学习图像算法工作,Ubuntu作为谷歌和多数依赖库的亲儿子占据着最高生态位。但最近接手的一个项目里,甲方指定需要在CentOS7上运行项目代码,笔者被迫小小cos了一…

在华为云服务器上CentOS 7安装单机版Redis

https://redis.io/是官网地址。 点击右上角的Download。 可以进入https://redis.io/download/——Redis官网下载最新版的网址。 然后在https://redis.io/download/页面往下拉,点击下图超链接这里。 进入https://download.redis.io/releases/下载自己需要的安装…

【二叉树】的顺序存储(堆的实现)

📙作者简介: 清水加冰,目前大二在读,正在学习C/C、Python、操作系统、数据库等。 📘相关专栏:C语言初阶、C语言进阶、C语言刷题训练营、数据结构刷题训练营、有感兴趣的可以看一看。 欢迎点赞 &#x1f44d…

64位Ubuntu20.04.5 LTS系统安装32位运行库

背景: 在ubutu(版本为20.04.5 LTS)中运行./arm-none-linux-gnueabi-gcc -v 后提示“no such device”。 经多方查证,是ubutu的版本是64位的,而需要运行的编译工具链是32位的,因此会不兼容。 解决方法就是在…

十分钟理解OSPF路由协议

十分钟理解OSPF路由协议 1.RIP的缺陷以跳数为度量值最大跳数为15更新路由表采用全更新收敛速度慢 2.RIP与OSPF比较OSPF概述运行OSPF协议之前运行OSPF协议之后 3.OSPF协议工作过程1.发现邻居2.建立邻接关系3.传递链路状态信息4.计算路由 4.OSPF分区域管理 有RIP协议,…

Bootstrap 框架学习笔记(基础)

来自于 Twitter,基于 HTML、CSS、JavaScript。 有关网站:Bootstrap中文网Bootstrap是Twitter推出的一个用于前端开发的开源工具包。它由Twitter的设计师Mark Otto和Jacob Thornton合作开发,是一个CSS/HTML框架。目前,Bootstrap最…

Java Semaphore使用例子和流程

目录 Semaphore例子代码和输出semaphore.acquire();semaphore.release(); Semaphore semaphore : 英[ˈseməfɔː(r)] 美[ˈseməfɔːr] n. 旗语; 信号标; v. 打旗语; (用其他类似的信号系统)发信号; [例句]Semaphore was widely used at sea, before the advent of electr…

ssh登录时间久或登陆后报错

情况1 问题描述: ssh登录时间很久,登录后出现abrt-cli status timed out 的报错 问题原因: .lock文件被锁导致 执行systemctl status abrtd.service可以看到被锁的.lock 处理方式: ps -ef | grep pid 找到被锁的进程kill掉…

图片格式大全

青春不能回头,青春也没有终点。 大全介绍 图片格式有多种,每种格式都有其独特的特性和用途。以下是一些常见的图片格式以及它们的介绍: JPEG(Joint Photographic Experts Group): 文件扩展名:…

1786_MTALAB代码生成把通用函数生成独立文件

全部学习汇总: GitHub - GreyZhang/g_matlab: MATLAB once used to be my daily tool. After many years when I go back and read my old learning notes I felt maybe I still need it in the future. So, start this repo to keep some of my old learning notes…

计算机竞赛 深度学习 python opencv 火焰检测识别

文章目录 0 前言1 基于YOLO的火焰检测与识别2 课题背景3 卷积神经网络3.1 卷积层3.2 池化层3.3 激活函数:3.4 全连接层3.5 使用tensorflow中keras模块实现卷积神经网络 4 YOLOV54.1 网络架构图4.2 输入端4.3 基准网络4.4 Neck网络4.5 Head输出层 5 数据集准备5.1 数…

QT记事本+登陆界面的简单实现

主体头文件 #ifndef JSB_H #define JSB_H#include <QMainWindow> #include <QMenuBar>//菜单栏 #include <QToolBar>//工具栏 #include <QStatusBar>//状态栏 #include <QTextEdit>//文本 #include <QLabel>//标签 #include <QDebug&g…

Android之MediaCodec::PostAndAwaitResponse消息原理(四十三)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生从来没有捷径,只有行动才是治疗恐惧和懒惰的唯一良药. 更多原创,欢迎关注:Android…

企业内部安全与风控管理图解

企业内部安全说外部安全&#xff0c;企业领导者都非常关注&#xff0c;由于各方面原因&#xff0c;。。。力不从心&#xff0c;妥协&#xff01; 方向&#xff1a; 1、制度 结合企业实情&#xff0c;编制企业安全管理制度 2、硬件 处理常规硬件外观&#xff0c;加壳与锁定、…

【力扣每日一题】2023.9.10 打家劫舍Ⅳ

目录 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 代码&#xff1a; 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 题目翻译有些烂&#xff0c;我来二次翻译一下&#xff0c;找出数组中k个两两互不相邻的数&#xff0c;求出它们的最大值。要求最大值尽可…