【PyTorch】图像多分类项目部署

【PyTorch】图像多分类项目

【PyTorch】图像多分类项目部署

如果需要在独立于训练脚本的新脚本中部署模型,这种情况模型和权重在内存中不存在,因此需要构造一个模型类的对象,然后将存储的权重加载到模型中。

加载模型参数,验证模型的性能,并在测试数据集上部署模型

from torch import nn
from torchvision import models# 定义一个resnet18模型,不使用预训练参数
model_resnet18 = models.resnet18(pretrained=False)
# 获取模型的全连接层的输入特征数
num_ftrs = model_resnet18.fc.in_features
# 定义分类的类别数
num_classes=10
# 将全连接层的输出特征数改为分类的类别数
model_resnet18.fc = nn.Linear(num_ftrs, num_classes)import torch 
path2weights="./models/resnet18_pretrained.pt"
# 加载预训练的ResNet18模型权重
model_resnet18.load_state_dict(torch.load(path2weights))
# 将ResNet-18模型设置为评估模式
model_resnet18.eval();
# 检查CUDA是否可用
if torch.cuda.is_available():# 如果可用,将设备设置为CUDAdevice = torch.device("cuda")# 将模型移动到CUDA设备上model_resnet18=model_resnet18.to(device)def deploy_model(model,dataset,device, num_classes=10,sanity_check=False):# 获取数据集的长度len_data=len(dataset)# 初始化输出张量y_out=torch.zeros(len_data,num_classes)# 初始化真实标签张量y_gt=np.zeros((len_data),dtype="uint8")# 将模型移动到指定设备model=model.to(device)# 初始化时间列表elapsed_times=[]with torch.no_grad():for i in range(len_data):# 获取数据集中的一个样本x,y=dataset[i]# 将真实标签存入张量y_gt[i]=y# 记录开始时间start=time.time()    # 将输入数据传入模型进行预测yy=model(x.unsqueeze(0).to(device))# 将预测结果存入张量y_out[i]=torch.softmax(yy,dim=1)# 计算预测时间elapsed=time.time()-start# 将预测时间存入列表elapsed_times.append(elapsed)# 如果进行完整性检查,则跳出循环if sanity_check is True:break# 计算平均预测时间inference_time=np.mean(elapsed_times)*1000# 打印平均预测时间print("average inference time per image on %s: %.2f ms " %(device,inference_time))# 返回预测结果和真实标签return y_out.numpy(),y_gt
from torchvision import datasets
import torchvision.transforms as transforms# 数据转换
data_transformer = transforms.Compose([transforms.ToTensor()])path2data="./data"# 加载数据
test0_ds=datasets.STL10(path2data, split='test', download=True,transform=data_transformer)
print(test0_ds.data.shape)

from sklearn.model_selection import StratifiedShuffleSplit# 创建StratifiedShuffleSplit对象,设置分割次数为1,测试集大小为0.2,随机种子为0
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)# 获取test0_ds的索引
indices=list(range(len(test0_ds)))# 获取test0_ds的标签
y_test0=[y for _,y in test0_ds]# 对索引和标签进行分割
for test_index, val_index in sss.split(indices, y_test0):# 打印测试集和验证集的索引print("test:", test_index, "val:", val_index)# 打印测试集和验证集的大小print(len(val_index),len(test_index))

from torch.utils.data import Subset# 从test0_ds中选取val_index索引的子集,赋值给val_ds
val_ds=Subset(test0_ds,val_index)
# 从test0_ds中选取test_index索引的子集,赋值给test_ds
test_ds=Subset(test0_ds,test_index)
# 定义均值
mean=[0.4467106, 0.43980986, 0.40664646]
# 定义标准差
std=[0.22414584,0.22148906,0.22389975]
# 定义一个名为test0_transformer的变量,用于将一系列的图像变换操作组合在一起
test0_transformer = transforms.Compose([# 将图像转换为Tensor类型transforms.ToTensor(),# 对图像进行归一化操作,使用mean和std作为均值和标准差transforms.Normalize(mean, std),])   
# 将test0_transformer赋值给test0_ds的transform属性
test0_ds.transform=test0_transformer
import time
import numpy as np# 调用deploy_model函数,传入model_resnet18,val_ds,device和sanity_check参数,返回y_out和y_gt
y_out,y_gt=deploy_model(model_resnet18,val_ds,device=device,sanity_check=False)
# 打印y_out和y_gt的形状
print(y_out.shape,y_gt.shape)

from sklearn.metrics import accuracy_score# 将y_out中的最大值索引赋值给y_pred
y_pred = np.argmax(y_out,axis=1)
# 打印y_pred和y_gt的形状
print(y_pred.shape,y_gt.shape)# 计算并打印y_pred和y_gt的准确率
acc=accuracy_score(y_pred,y_gt)
print("accuracy: %.2f" %acc)

 

# 部署模型,得到预测结果和真实标签
y_out,y_gt=deploy_model(model_resnet18,test_ds,device=device)# 取出预测结果中概率最大的类别
y_pred = np.argmax(y_out,axis=1)# 计算准确率
acc=accuracy_score(y_pred,y_gt)# 打印准确率
print(acc)

from torchvision import utils
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
np.random.seed(1)# 定义一个函数,用于显示图像
def imshow(inp, title=None):# 定义图像的均值和标准差mean=[0.4467106, 0.43980986, 0.40664646]std=[0.22414584,0.22148906,0.22389975]# 将图像从tensor转换为numpy数组,并转置inp = inp.numpy().transpose((1, 2, 0))# 将均值和标准差转换为numpy数组mean = np.array(mean)std = np.array(std)# 将图像的像素值进行归一化inp = std * inp + mean# 将像素值限制在0和1之间inp = np.clip(inp, 0, 1)# 显示图像plt.imshow(inp)# 如果有标题,则显示标题if title is not None:plt.title(title)# 暂停0.001秒plt.pause(0.001) # 定义网格大小
grid_size=16
# 随机生成4个索引
rnd_inds=np.random.randint(1,len(test_ds),grid_size)
# 打印随机生成的索引
print("image indices:",rnd_inds)# 根据索引获取对应的图像和标签
x_grid_test=[test_ds[i][0] for i in rnd_inds]
y_grid_test=[(y_pred[i],y_gt[i]) for i in rnd_inds]# 将图像转换为网格
x_grid_test=utils.make_grid(x_grid_test, nrow=4, padding=2)
# 打印网格的形状
print(x_grid_test.shape)# 设置图像的大小
plt.rcParams['figure.figsize'] = (10, 10)
# 显示网格
imshow(x_grid_test,y_grid_test)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/383014.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

NFTScan 浏览器现已支持 .mint 域名搜索功能!

近日,NFT 数据基础设施 NFTScan 浏览器现已支持用户输入 .mint 域名进行 Mint Blockchain 网络钱包地址的搜索查询, NFTScan 用户能够轻松地使用域名追踪 NFT 交易,为 NFT 钱包地址相关的搜索查询功能增加透明度和便利性。 NFTScan explorer…

goenv丝滑控制多版本go

安装 先装下goenv brew install goenv去 ~/.bash_profile 添加一下 export GOENV_ROOT"$HOME/.goenv" export PATH"$GOENV_ROOT/bin:$PATH" eval "$(goenv init -)"执行一下让配置生效 source ~/.bash_profile插一嘴,如果之前是在…

VScode 批量操作

VScode 批量操作 批量修改 按住 alt/option 键, 选择需要批量操作的位置 如果是多行,则按住 altshift 键 可以直接操作 但是有时候比如变量命名,可能需要递增操作的命名 需要下载插件 Increment Selection 按照1的方法多选光标之后&am…

Java | 自制AWT单词猜一猜小游戏(测试版)

目录 游戏标题 开发过程 开发想法 技术栈 代码呈现 导包 核心代码 游戏标题 探索知识的迷宫,体验自制AWT单词猜一猜小游戏 在数字时代,学习可以是多彩的,游戏可以是智慧的。我们自豪地推出“单词猜猜猜”是一款结合了教育与娱乐的自制…

MQTTX连接华为云IoTDA

目录 华为IoTDA平台 MQTTX连接参数的设置 物模型的构建 属性上报 基本数据格式 时戳 我以前上课都是用巴法云服务器来演示MQTT的,前几天因为测试工业互联网关使用了华为的IoTDA,觉得也不算太复杂,今天尝试用MQTTX连接华为云&#xff0c…

Python读取grib数据获取变量推荐姿势

前情提要 最近使用的EC和GFS预报数据给的都是grib2格式的,之前用惯nc格式的,python读取grib2数据的时候还走了些弯路,看到很多博客上给的教程其实不能满足我的需求,现在搞明白了分享一下 pygrib安装 第一个问题就是我电脑上pyg…

HTTPS 的加密过程 详解

HTTP 由于是明文传输,所以安全上存在以下三个风险: 窃听风险,比如通信链路上可以获取通信内容。篡改风险,比如通信内容被篡改。冒充风险,比如冒充网站。 HTTPS 在 HTTP 与 TCP 层之间加入了 SSL/TLS 协议&#xff0c…

1 go语言环境的搭建

本专栏将从基础开始,循序渐进,由浅入深讲解Go语言,希望大家都能够从中有所收获,也请大家多多支持。 查看相关资料与知识库 专栏地址:Go专栏 如果文章知识点有错误的地方,请指正!大家一起学习,…

腾讯云k8s相关

1.某个服务腾讯云内网地址? 比如:spiderflow-web正式环境:http://spiderflow-web.sd-backend:30001 试一试:

Linux shell实现执行任务进度条动态显示(回旋镖)-3

process_bar.sh #!/bin/bash #display boomerangwhile : doclearfor i in {1..20}doecho -e "\033[2;${i}H*"sleep 0.1doneclearfor i in {20..1}doecho -e "\033[2;${i}H*"sleep 0.1doneclear done验证:

机器学习笔记-01-初识基础(问题-解答自查版)

前言 以下问题以Q&A形式记录,基本上都是笔者在初学一轮后,掌握不牢或者频繁忘记的点 Q&A的形式有助于学习过程中时刻关注自己的输入与输出关系,也适合做查漏补缺和复盘。 本文对读者可以用作自查,答案在后面&#xff0…

(35)远程识别(又称无人机识别)(二)

文章目录 前言 4 ArduRemoteID 5 终端用户数据的设置和使用 6 测试 7 为OEMs添加远程ID到ArduPilot系统的视频教程 前言 在一些国家,远程 ID 正在成为一项法律要求。以下是与 ArduPilot 兼容的设备列表。这里(here)有一个关于远程 ID 的很好解释和常见问题列表…

ubuntu那些ppa源在哪

Ubuntu中的 PPA 终极指南 - UBUNTU粉丝之家 什么是PPA PPA 代表个人包存档。 PPA 允许应用程序开发人员和 Linux 用户创建自己的存储库来分发软件。 使用 PPA,您可以轻松获取较新的软件版本或官方 Ubuntu 存储库无法提供的软件。 为什么使用PPA? 正如…

VLC输出NDI媒体流

目录 1. 下载安装VLC Play 2. 首先在电脑上安装NDI Tools 3. 运行VLC进行输出配置 4. 播放视频 5. 验证 (1)用Studio Monitor验证 (2)用OBS验证 NDI(Network Device Interface)即网络设备接口,是由美国 NewTek 公司开发的免费标准,它可使兼容的视频产品以高质量…

视图、存储过程、触发器

一、视图 视图是从一个或者几个基本表(或视图)导出的表。它与基 本表不同,是一个虚表,视图只能用来从查询,不能做增删改(虚拟的表) 1.创建视图 创建视图的语法: create view 视图名【view_xxx / v_xxx】 a…

基于CALMET诊断模型的高时空分辨率精细化风场模拟

原文链接:基于CALMET诊断模型的高时空分辨率精细化风场模拟https://mp.weixin.qq.com/s?__bizMzUzNTczMDMxMg&mid2247610033&idx7&sn999fb0fa3a0e57acebdfe209587ce7f3&chksmfa826f56cdf5e640f7dba429a9213a38d1222415eccd8660f4cf9fb46fa1a5ab3c5…

Netty:基于NIO的 Java 网络应用编程框架

Netty 是一个被广泛使用的,基于NIO的 Java 网络应用编程框架,Netty框架可以帮助开发者快速、简单的实现客户端和服务端的网络应用程序。“快速”和“简单”并不用产生维护性或性能上的问题。Netty 利用 Java 语言的NIO网络编程的能力,并隐藏其…

JavaWeb笔记_JSPEL

一.JSP相关技术 1.1 JSP由来 当我们需要向页面输出大量的HTML代码的时候,我们需要通过response对象写多次来输出HTML代码 response.getWriter().write("<font>文本</font>"); 页面的展示和servlet密不可分,不利于后期代码维护,因此推出一种可以…

记录一仿真错误,波形缩放有脉冲高信号,放大看不到信号了

是因为信号拉高的时间太短&#xff0c;拉高之后又把它拉低了&#xff0c;需要仔细看一下信号生成的代码。 错误代码与正确代码##正确代码always (posedge clk or negedge rst_n)begin if(!rst_n)beginwr_en < 1d0;wr_data < 8h0;endelse if(state_c DATASEND …

前端JS特效第52波:鼠标经过文字标题百叶窗动画特效风格切换图片轮播js效果

鼠标经过文字标题百叶窗动画特效风格切换图片轮播js效果&#xff0c;先来看看效果&#xff1a; 部分核心的代码如下&#xff1a; <html><head><meta charset"utf-8"><title>鼠标经过文字标题百叶窗动画特效风格切换图片轮播js效果</titl…