【深度学习】pytorch快速得到mobilenet_v2 pth 和onnx

在linux执行这个程序:

import torch
import torch.onnx
from torchvision import transforms, models
from PIL import Image
import os# Load MobileNetV2 model
model = models.mobilenet_v2(pretrained=True)
model.eval()# Download an example image from the PyTorch website
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try:os.system(f"wget {url} -O {filename}")
except Exception as e:print(f"Error downloading image: {e}")# Preprocess the input image
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])input_image = Image.open(filename)
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension# Perform inference on CPU
with torch.no_grad():output = model(input_tensor)# Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes
print(output[0])# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities)# Download ImageNet labels using wget
os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")# Read the categories
with open("imagenet_classes.txt", "r") as f:categories = [s.strip() for s in f.readlines()]# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):print(categories[top5_catid[i]], top5_prob[i].item())# Save the PyTorch model
torch.save(model.state_dict(), "mobilenet_v2.pth")# Convert the PyTorch model to ONNX with specified input and output names
dummy_input = torch.randn(1, 3, 224, 224)
onnx_path = "mobilenet_v2.onnx"
input_names = ['input']
output_names = ['output']
torch.onnx.export(model, dummy_input, onnx_path, input_names=input_names, output_names=output_names)print(f"PyTorch model saved to 'mobilenet_v2.pth'")
print(f"ONNX model saved to '{onnx_path}'")# Load the ONNX model
import onnx
import onnxruntimeonnx_model = onnx.load(onnx_path)
onnx_session = onnxruntime.InferenceSession(onnx_path)# Convert input tensor to ONNX-compatible format
input_tensor_onnx = input_tensor.numpy()# Perform inference on ONNX with the correct input name
onnx_output = onnx_session.run(['output'], {'input': input_tensor_onnx})
onnx_probabilities = torch.nn.functional.softmax(torch.tensor(onnx_output[0]), dim=1)# Show top categories per image for ONNX
onnx_top5_prob, onnx_top5_catid = torch.topk(onnx_probabilities, 5)
print("\nTop categories for ONNX:")
for i in range(onnx_top5_prob.size(1)):print(categories[onnx_top5_catid[0][i]], onnx_top5_prob[0][i].item())

得到:

在这里插入图片描述
用本地pth推理:

import torch
from torchvision import transforms, models
from PIL import Image# Load MobileNetV2 model
model = models.mobilenet_v2()
model.load_state_dict(torch.load("mobilenet_v2.pth", map_location=torch.device('cpu')))
model.eval()# Preprocess the input image
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# Load the example image
input_image = Image.open("dog.jpg")
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension# Perform inference on CPU
with torch.no_grad():output = model(input_tensor)# Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes
# print(output[0])# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)# Load ImageNet labels
categories = []
with open("imagenet_classes.txt", "r") as f:categories = [s.strip() for s in f.readlines()]# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):print(categories[top5_catid[i]], top5_prob[i].item())

用onnx推理:

import torch
import onnxruntime
from torchvision import transforms
from PIL import Image# Load the ONNX model
onnx_path = "mobilenet_v2.onnx"
onnx_session = onnxruntime.InferenceSession(onnx_path)# Preprocess the input image
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# Load the example image
input_image = Image.open("dog.jpg")
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension# Convert input tensor to ONNX-compatible format
input_tensor_onnx = input_tensor.numpy()# Perform inference on ONNX
onnx_output = onnx_session.run(None, {'input': input_tensor_onnx})
onnx_probabilities = torch.nn.functional.softmax(torch.tensor(onnx_output[0]), dim=1)# Load ImageNet labels
categories = []
with open("imagenet_classes.txt", "r") as f:categories = [s.strip() for s in f.readlines()]# Show top categories per image for ONNX
onnx_top5_prob, onnx_top5_catid = torch.topk(onnx_probabilities, 5)
print("Top categories for ONNX:")
for i in range(onnx_top5_prob.size(1)):print(categories[onnx_top5_catid[0][i]], onnx_top5_prob[0][i].item())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/197276.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

安卓中轻量级数据存储方案分析探讨

轻量级数据存储功能通常用于保存应用的一些常用配置信息,并不适合需要存储大量数据和频繁改变数据的场景。应用的数据保存在文件中,这些文件可以持久化地存储在设备上。需要注意的是,应用访问的实例包含文件所有数据,这些数据会一…

Qt6版使用Qt5中的类遇到的问题解决方案

如果有需要请关注下面微信公众号,会有更多收获! 1.QLinkedList 是 Qt 中的一个双向链表类。它提供了高效的插入和删除操作,尤其是在中间插入和删除元素时,比 QVector 更加优秀。下面是使用 QLinkedList 的一些基本方法&#xff1a…

微服务学习 | Eureka注册中心

微服务远程调用 在order-service的OrderApplication中注册RestTemplate 在查询订单信息时,需要同时返回订单用户的信息,但是由于微服务的关系,用户信息需要在用户的微服务中去查询,故需要用到上面的RestTemplate来让订单的这个微…

JVM虚拟机:通过日志学习PS+PO垃圾回收器

我们刚才设置参数的时候看到了-XXPrintGCDetails表示输出详细的GC处理日志,那么我们如何理解这个日志呢?日志是有规则的,我们需要按照这个规则来理解日志中的内容,它有两个格式,一个格式是GC的格式(新生代&…

YOLOv8/YOLOv7/YOLOv5/YOLOv4/Faster-rcnn系列算法改进【NO.79】改进损失函数为VariFocal Loss

前言 作为当前先进的深度学习目标检测算法YOLOv8,已经集合了大量的trick,但是还是有提高和改进的空间,针对具体应用场景下的检测难点,可以不同的改进方法。此后的系列文章,将重点对YOLOv8的如何改进进行详细的介绍&…

Egress-TLS-Origination

目录 文章目录 目录本节实战1、出口网关TLS发起2、通过 egress 网关发起双向 TLS 连接关于我最后 本节实战 实战名称🚩 实战:Egress TLS Origination-2023.11.19(failed)🚩 实战:通过 egress 网关发起双向 TLS 连接-2023.11.19(测…

CDN是什么,能起到什么作用

随着互联网的快速发展,用户对于快速、稳定、高效的互联网体验的需求日益增长。为了满足这一需求,内容分发网络(CDN)应运而生,并在近年来得到了广泛应用。CDN通过在全球范围内部署大量的服务器和网络节点,实…

excel怎么能锁住行 和/或 列的自增长,保证粘贴公式的时候不自增长或者只有部分自增长

例如在C4单元格中输入了公式: 现在如果把C4拷贝到C5,D3会自增长为D4: 现在如果想拷贝的时候不自增长,可以先把光标放到C4单元格,然后按F4键,行和列的前面加上了$符号,锁定了: …

QEMU显示虚拟化的几种选项

QEMU可以通过通过命令行"-vga type"选择为客户机模拟的VGA卡的类别,可选择的类型有多个: -vga typeSelect type of VGA card to emulate. Valid values for type arecirrusCirrus Logic GD5446 Video card. All Windows versions starting from Windows 95 should …

深度学习YOLO图像视频足球和人体检测 - python opencv 计算机竞赛

文章目录 0 前言1 课题背景2 实现效果3 卷积神经网络4 Yolov5算法5 数据集6 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 深度学习YOLO图像视频足球和人体检测 该项目较为新颖,适合作为竞赛课题方向,学长非…

Vatee万腾未来科技之航:Vatee创新引领的新纪元

在当今数字化时代,Vatee万腾科技正在开创一段引领未来的全新征程。以其卓越的创新能力和领导地位,Vatee万腾成为数字化领域的引领者。其未来科技之航展现了一种独特的数字化愿景,引领着科技创新进入新的纪元。 Vatee万腾在数字科技领域展现出…

html-网站菜单-点击显示导航栏

一、效果图 1.点击显示菜单栏&#xff0c;点击x号关闭&#xff1b; 2.点击一级菜单&#xff0c;展开显示二级&#xff0c;并且加号变为减号&#xff1b; 3.点击其他一级导航&#xff0c;自动收起展开的导航。 二、代码实现 <!DOCTYPE html> <html><head>&…

【LeetCode刷题日志】225.用队列实现栈

&#x1f388;个人主页&#xff1a;库库的里昂 &#x1f390;C/C领域新星创作者 &#x1f389;欢迎 &#x1f44d;点赞✍评论⭐收藏✨收录专栏&#xff1a;LeetCode 刷题日志&#x1f91d;希望作者的文章能对你有所帮助&#xff0c;有不足的地方请在评论区留言指正&#xff0c;…

浅谈基于云计算的环境智能监控系统

随着经济的飞速发展&#xff0c;环境污染也越来越严重&#xff0c;环境监控成为了政府与社会关注的焦点。本文提出了一种基于云计算的环境智能监控系统——EasyCVR&#xff0c;该系统综合应用了传感器、云计算、大数据、人工智能等技术&#xff0c;具有实时、准确、高效的监控能…

DIY私人图床:使用CFimagehost源码自建无需数据库支持的PHP图片托管服务

文章目录 1.前言2. CFImagehost网站搭建2.1 CFImagehost下载和安装2.2 CFImagehost网页测试2.3 cpolar的安装和注册 3.本地网页发布3.1 Cpolar临时数据隧道3.2 Cpolar稳定隧道&#xff08;云端设置&#xff09;3.3.Cpolar稳定隧道&#xff08;本地设置&#xff09; 4.公网访问测…

JVM 调优指南

文章目录 为什么要学 JVM一、JVM 整体布局二、Class 文件规范三、类加载模块四、执行引擎五、GC 垃圾回收1 、JVM内存布局2 、 JVM 有哪些主要的垃圾回收器&#xff1f;3 、分代垃圾回收工作机制 六、对 JVM 进行调优的基础思路七、 GC 情况分析实例 JVM调优指南 -- 楼兰 ​ JV…

ArcGIS Pro 优化的热点分析【Optimized Hot Spot Analysis】

ArcGIS Pro 优化的热点分析【Optimized Hot Spot Analysis】Optimized Hot Spot Analysis 优化的热点分析https://mp.weixin.qq.com/s/lfoIls8exW5G6PPJ9gtDew em&#xff0c;先给大家推荐一个空间统计分析的学习资源网站 https://spatialstats-analysis-1.hub.arcgis.com/ .…

6.9平衡二叉树(LC110-E)

绝对值函数&#xff1a;abs() 算法&#xff1a; 高度和深度的区别&#xff1a; 节点的高度&#xff1a;节点到叶子节点的距离&#xff08;从下往上&#xff09; 节点的深度&#xff1a;节点到根节点的距离&#xff08;从上往下&#xff09; 逻辑&#xff1a;一个平衡二叉树…

pipeline jenkins流水线

Pipeline 是 Jenkins 中一种灵活且强大的工作流机制&#xff0c;它允许您以代码的形式来定义和管理持续集成和持续交付的流程。 Pipeline 的作用主要体现在以下几个方面&#xff1a; 可编排的构建流程&#xff1a;使用 Pipeline&#xff0c;您可以将一个或多个阶段&#xff08…

Mac安装win程序另一个方案

前言 今天跟大家分享的是mac装win程序的另一个思路&#xff0c;不需要大动干戈的装双系统、虚拟机。最后分享感受&#xff0c;先说过程吧。 一、思路介绍 其实&#xff0c;就是利用CrossOver&#xff0c;这个软件的介绍大家可以自行了解。不过不得不说这玩意还是国外的人思路新…