利用Pytorch预训练模型进行图像分类

Use Pre-trained models for Image Classification.

# This post is rectified on the base of https://learnopencv.com/pytorch-for-beginners-image-classification-using-pre-trained-models/# And we have re-orginaized the code script.

预训练模型(Pre-trained models)是在ImageNet等大型基准数据集上训练的神经网络模型。深度学习社区从这些开源模型中受益匪浅。此外,预训练模型也是计算机视觉研究取得快速进展的一个重要因素。其他研究人员和从业人员可以使用这些最先进的模型,而不是从头开始重新训练。

# Here are some examples of classic pre-trained models.

在这里插入图片描述

在详细介绍如何使用预训练模型进行图像分类之前,我们先来看看有哪些预训练模型。我们将以 AlexNet 和 ResNet101 为例进行讨论。这两个网络都在 ImageNet 数据集上训练过。

ImageNet 数据集拥有超过 1400 万张由斯坦福大学维护的图像。它被广泛用于各种与图像相关的深度学习项目。这些图像属于不同的类别或标签。预训练模型(如 AlexNet 和 ResNet101)的目的是将图像作为输入并预测其类别。

这里的 "预训练 "是指,深度学习架构 AlexNet 和 ResNet101 已经在某个(庞大的)数据集上进行过训练,因此带有由此产生的权重和偏差。架构与权重和偏置之间的区别应该非常明显,因为我们将在下一节看到,TorchVision 同时拥有架构和预训练模型。

1.1 Model Inference Process

由于我们将重点讨论如何使用预先训练好的模型来预测输入的类别(标签),因此我们也来讨论一下其中涉及的过程。这个过程被称为模型推理。整个过程包括以下主要步骤:

(1) 读取输入图像;
(2) 对图像进行转换;例如resize、center crop、normalization等;
(3) 前向传递:使用预训练的模型权重来获得输出向量,而输出向量中的每个元素都描述了模型对于输入图像属于特定类别的置信度预测结果;
(4) 预测结果:基于获得的置信度分数,显示预测结果。

1.2 Loading Pre-Trained Network using TorchVision

# [Optinal Step]
# %pip install torchvision
# Load necessary packages.
from PIL import Image
import torch
import torchvision
from torchvision import models
from torchvision import transformsprint(torch.__version__)
print(torchvision.__version__)
2.0.0
0.15.0
# Check the different models and architectures available to us.
dir(models)
['AlexNet','AlexNet_Weights','ConvNeXt','ConvNeXt_Base_Weights','ConvNeXt_Large_Weights','ConvNeXt_Small_Weights','ConvNeXt_Tiny_Weights','DenseNet','DenseNet121_Weights','DenseNet161_Weights','DenseNet169_Weights','DenseNet201_Weights','EfficientNet','EfficientNet_B0_Weights','EfficientNet_B1_Weights','EfficientNet_B2_Weights','EfficientNet_B3_Weights','EfficientNet_B4_Weights','EfficientNet_B5_Weights','EfficientNet_B6_Weights','EfficientNet_B7_Weights','EfficientNet_V2_L_Weights','EfficientNet_V2_M_Weights','EfficientNet_V2_S_Weights','GoogLeNet','GoogLeNetOutputs','GoogLeNet_Weights','Inception3','InceptionOutputs','Inception_V3_Weights','MNASNet','MNASNet0_5_Weights','MNASNet0_75_Weights','MNASNet1_0_Weights','MNASNet1_3_Weights','MaxVit','MaxVit_T_Weights','MobileNetV2','MobileNetV3','MobileNet_V2_Weights','MobileNet_V3_Large_Weights','MobileNet_V3_Small_Weights','RegNet','RegNet_X_16GF_Weights','RegNet_X_1_6GF_Weights','RegNet_X_32GF_Weights','RegNet_X_3_2GF_Weights','RegNet_X_400MF_Weights','RegNet_X_800MF_Weights','RegNet_X_8GF_Weights','RegNet_Y_128GF_Weights','RegNet_Y_16GF_Weights','RegNet_Y_1_6GF_Weights','RegNet_Y_32GF_Weights','RegNet_Y_3_2GF_Weights','RegNet_Y_400MF_Weights','RegNet_Y_800MF_Weights','RegNet_Y_8GF_Weights','ResNeXt101_32X8D_Weights','ResNeXt101_64X4D_Weights','ResNeXt50_32X4D_Weights','ResNet','ResNet101_Weights','ResNet152_Weights','ResNet18_Weights','ResNet34_Weights','ResNet50_Weights','ShuffleNetV2','ShuffleNet_V2_X0_5_Weights','ShuffleNet_V2_X1_0_Weights','ShuffleNet_V2_X1_5_Weights','ShuffleNet_V2_X2_0_Weights','SqueezeNet','SqueezeNet1_0_Weights','SqueezeNet1_1_Weights','SwinTransformer','Swin_B_Weights','Swin_S_Weights','Swin_T_Weights','Swin_V2_B_Weights','Swin_V2_S_Weights','Swin_V2_T_Weights','VGG','VGG11_BN_Weights','VGG11_Weights','VGG13_BN_Weights','VGG13_Weights','VGG16_BN_Weights','VGG16_Weights','VGG19_BN_Weights','VGG19_Weights','ViT_B_16_Weights','ViT_B_32_Weights','ViT_H_14_Weights','ViT_L_16_Weights','ViT_L_32_Weights','VisionTransformer','Weights','WeightsEnum','Wide_ResNet101_2_Weights','Wide_ResNet50_2_Weights','_GoogLeNetOutputs','_InceptionOutputs','__builtins__','__cached__','__doc__','__file__','__loader__','__name__','__package__','__path__','__spec__','_api','_meta','_utils','alexnet','convnext','convnext_base','convnext_large','convnext_small','convnext_tiny','densenet','densenet121','densenet161','densenet169','densenet201','detection','efficientnet','efficientnet_b0','efficientnet_b1','efficientnet_b2','efficientnet_b3','efficientnet_b4','efficientnet_b5','efficientnet_b6','efficientnet_b7','efficientnet_v2_l','efficientnet_v2_m','efficientnet_v2_s','get_model','get_model_builder','get_model_weights','get_weight','googlenet','inception','inception_v3','list_models','maxvit','maxvit_t','mnasnet','mnasnet0_5','mnasnet0_75','mnasnet1_0','mnasnet1_3','mobilenet','mobilenet_v2','mobilenet_v3_large','mobilenet_v3_small','mobilenetv2','mobilenetv3','optical_flow','quantization','regnet','regnet_x_16gf','regnet_x_1_6gf','regnet_x_32gf','regnet_x_3_2gf','regnet_x_400mf','regnet_x_800mf','regnet_x_8gf','regnet_y_128gf','regnet_y_16gf','regnet_y_1_6gf','regnet_y_32gf','regnet_y_3_2gf','regnet_y_400mf','regnet_y_800mf','regnet_y_8gf','resnet','resnet101','resnet152','resnet18','resnet34','resnet50','resnext101_32x8d','resnext101_64x4d','resnext50_32x4d','segmentation','shufflenet_v2_x0_5','shufflenet_v2_x1_0','shufflenet_v2_x1_5','shufflenet_v2_x2_0','shufflenetv2','squeezenet','squeezenet1_0','squeezenet1_1','swin_b','swin_s','swin_t','swin_transformer','swin_v2_b','swin_v2_s','swin_v2_t','vgg','vgg11','vgg11_bn','vgg13','vgg13_bn','vgg16','vgg16_bn','vgg19','vgg19_bn','video','vision_transformer','vit_b_16','vit_b_32','vit_h_14','vit_l_16','vit_l_32','wide_resnet101_2','wide_resnet50_2']

以AlexNet为例,我们可以看到还有一个名称为alexnet的条目。其中,大写的名称是Python类(AlexNet),而alexnet是一个便于操作的函数(convenience function),用于从AlexNet类返回实例化的模型。

这些方便函数也可以有不同的参数集,例如:densenet121、densenet161、densenet169以及densenet201,都是DenseNet的实例,但层数分别为121,161,169和201.

1.3. Using AlexNet for Image Classification

AlexnetNet是图像识别领域早期的一个突破性网络结构,相关文章可以参考Understanding Alexnet。该网络架构如下:

在这里插入图片描述

Step 1: Load the pre-trained model

# Create an instance of the network.
alexnet = models.alexnet(pretrained=True)
/home/wsl_ubuntu/anaconda3/envs/xy_trans/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.warnings.warn(
/home/wsl_ubuntu/anaconda3/envs/xy_trans/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.warnings.warn(msg)
# Note: Pytorch模型的扩展名通常为.pt或.pth
# Check the model details.
print(alexnet)
AlexNet((features): Sequential((0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))(1): ReLU(inplace=True)(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)(3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(4): ReLU(inplace=True)(5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)(6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(7): ReLU(inplace=True)(8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(9): ReLU(inplace=True)(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(6, 6))(classifier): Sequential((0): Dropout(p=0.5, inplace=False)(1): Linear(in_features=9216, out_features=4096, bias=True)(2): ReLU(inplace=True)(3): Dropout(p=0.5, inplace=False)(4): Linear(in_features=4096, out_features=4096, bias=True)(5): ReLU(inplace=True)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

Step 2: Specify image transformations

# Use transforms to compose all the data transformations.
transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])     # Three numbers for RGB Channels.
# transforms.Resize: Resize the input images to 256x256 pixels.
# transforms.CenterCrop: Crop the image to 224×224 pixels about the center.
# transforms.Normalize: Normalize the image by setting its mean and standard deviation to the specified values.
# transforms.ToTensor: Convert the image to Pytorch tensor datatype.

Step 3: Load the input image and pre-process it.

# Download image
# !wget https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg -O dog.jpg
img = Image.open("./dog.jpg")
img

在这里插入图片描述

# Pre-process the image.
trans_img = transform(img)img_batch = torch.unsqueeze(trans_img, 0)

Step 4: Model Inference

# Set the model to eval model.
alexnet.eval()out = alexnet(img_batch)
print(out.shape)
torch.Size([1, 1000])
# Download classes text file
!wget https://raw.githubusercontent.com/Lasagne/Recipes/master/examples/resnet50/imagenet_classes.txt
--2023-12-14 21:30:09--  https://raw.githubusercontent.com/Lasagne/Recipes/master/examples/resnet50/imagenet_classes.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 0.0.0.0, ::
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|0.0.0.0|:443... failed: Connection refused.
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|::|:443... failed: Connection refused.
# Load labels.
with open('imagenet_classes.txt') as f:classes = [line.strip() for line in f.readlines()]
# Find out the maximum score.
_, index = torch.max(out, 1)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
print(classes[index[0]], percentage[index[0]].item())
Labrador retriever 41.58513259887695
# The model predicts the image to be of a Labrador Retriever with a 41.58% confidence.
_, indices = torch.sort(out, descending=True)
[(classes[idx], percentage[idx].item()) for idx in indices[0][:5]]
[('Labrador retriever', 41.58513259887695),('golden retriever', 16.59164810180664),('Saluki, gazelle hound', 16.286897659301758),('whippet', 2.8539111614227295),('Ibizan hound, Ibizan Podenco', 2.39247727394104)]

1.4. Using ResNet for Image Classification

# Load the resnet101 model.
resnet = models.resnet101(pretrained=True)# Set the model to eval mode.
resnet.eval()# carry out model inference.
out = resnet(img_batch)# Print the top 5 classes predicted by the model.
_, indices = torch.sort(out, descending=True)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
[(classes[idx], percentage[idx].item()) for idx in indices[0][:5]]
Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /home/wsl_ubuntu/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:51<00:00, 3.47MB/s] [('Labrador retriever', 48.255577087402344),('dingo, warrigal, warragal, Canis dingo', 7.900773048400879),('golden retriever', 6.91691780090332),('Eskimo dog, husky', 3.6434390544891357),('bull mastiff', 3.046128273010254)]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/217385.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uniapp交互反馈api的使用示例

官方文档链接&#xff1a;uni.showToast(OBJECT) | uni-app官网 1.uni.showToast({}) 显示消息提示框。 常用属性&#xff1a; title:页面提示的内容 image&#xff1a;改变提示框默认的icon图标 duration&#xff1a;提示框在页面显示多少秒才让它消失 添加了image属性后。 注…

前端体系:前端应用

目录 前端体系基础 html&#xff08;超文本标记语言&#xff09; css&#xff08;层叠样式单&#xff09; javascript&#xff08;&#xff09; 一、前端体系概述 二、前端框架 React Vue Angular 三、前端库和工具 lodash Redux Webpack 四、模块化和组件化 ES…

Java中的链表

文章目录 前言一、链表的概念及结构二、单向不带头非循坏链表的实现2.1打印链表2.2求链表的长度2.3头插法2.4尾插法2.5任意位置插入2.6查找是否包含某个元素的节点2.7删除第一次出现这个元素的节点2.8删除包含这个元素的所以节点2.9清空链表单向链表的测试 三、双向不带头非循坏…

RNN介绍及Pytorch源码解析

介绍一下RNN模型的结构以及源码&#xff0c;用作自己复习的材料。 RNN模型所对应的源码在&#xff1a;\PyTorch\Lib\site-packages\torch\nn\modules\RNN.py文件中。 RNN的模型图如下&#xff1a; 源码注释中写道&#xff0c;RNN的数学公式&#xff1a; 表示在时刻的隐藏状态…

可替代LM5145,5.5V-100V Vin同步降压控制器_SCT82A30

SCT82A30是一款100V电压模式控制同步降压控制器&#xff0c;具有线路前馈。40ns受控高压侧MOSFET的最小导通时间支持高转换比&#xff0c;实现从48V输入到低压轨的直接降压转换&#xff0c;降低了系统复杂性和解决方案成本。如果需要&#xff0c;在低至6V的输入电压下降期间&am…

C语言之文件操作(下)

C语言之文件操作&#xff08;下&#xff09; 文章目录 C语言之文件操作&#xff08;下&#xff09;1. 文件的顺序读写1.1 文件的顺序读写函数1.1.1 字符输入/输出函数&#xff08;fgetc/fputc&#xff09;1.1.2 ⽂本⾏输⼊/输出函数&#xff08;fgets/fputs&#xff09;1.1.3 格…

Spring Boot之自定义starter

&#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 接下来看看由辉辉所写的关于Spring Boot的相关操作吧 目录 &#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 一. starter是什么 二.为什么要使…

大模型应用_PrivateGPT

https://github.com/imartinez/privateGPT 1 功能 整体功能&#xff0c;想解决什么问题 搭建完整的 RAG 系统&#xff0c;与 FastGPT相比&#xff0c;界面比较简单。但是底层支持比较丰富&#xff0c;可用于知识库的完全本地部署&#xff0c;包含大模型和向量库。适用于保密级…

SWPU NSS新生赛

&#x1f60b;大家好&#xff0c;我是YAy_17&#xff0c;是一枚爱好网安的小白&#xff0c;正在自学ing。 本人水平有限&#xff0c;欢迎各位大佬指点&#xff0c;一起学习&#x1f497;&#xff0c;一起进步⭐️。 ⭐️此后如竟没有炬火&#xff0c;我便是唯一的光。⭐️ 最近…

万界星空科技AI低代码云MES系统

在企业生产管理过程中&#xff0c;从市场、生产现场到产品交付&#xff0c;生产制造行业都面临着诸多挑战&#xff0c;比如&#xff1a; 订单排产难度大&#xff1a;订单混乱&#xff0c;常漏排产、错排产&#xff1b;产能不明晰&#xff0c;无法承诺交期&#xff0c;常丢单&a…

流程控制之条件判断

目录 流程控制之条件判断 2.1.if语句语法 2.1.1单分支结构 2.1.2双分支结构 2.1.3多分支结构 2.2.案例 例一: 例2: 例3: 例4: 例5: 例6: 例7: 例8: 例9: 2.3.case多条件判断 2.3.1.格式 2.3.2.执行过程 例10: 流程控制之条件判断 2.1.if语句语法 2.1.1单分…

ArcGIS for Android开发引入arcgis100.15.2

最后再点击同步即可&#xff01;&#xff01;&#xff01;

oracle aq java jms使用(数据类型为XMLTYPE)

记录一次冷门技术oracle aq的使用 版本 oracle 11g 创建用户 -- 创建用户 create user testaq identified by 123456; grant connect, resource to testaq;-- 创建aq所需要的权限 grant execute on dbms_aq to testaq; grant execute on dbms_aqadm to testaq; begindbms_a…

基于Spring Boot、Mybatis、Redis和Layui的企业电子招投标系统源码实现与立项流程

招投标管理系统是一款适用于招标代理、政府采购、企业采购和工程交易等领域的企业级应用平台。该平台以项目为主线&#xff0c;从项目立项到项目归档&#xff0c;实现了全流程的高效沟通和协作。通过该平台&#xff0c;用户可以实时共享项目数据信息&#xff0c;实现规范化管理…

【数据结构入门精讲 | 第一篇】打开数据结构之门

数据结构与算法是计算机科学中的核心概念&#xff0c;也与现实生活如算法岗息息相关。鉴于全网数据结构文章良莠不齐且集成度不高&#xff0c;故开设本专栏&#xff0c;为初学者提供指引。 目录 基本概念数据结构为何面世算法基本数据类型抽象数据类型使用抽象数据类型的好处 数…

微信小程序:模态框(弹窗)的实现

效果 wxml <!--新增&#xff08;点击按钮&#xff09;--> <image classimg src"{{add}}" bindtapadd_mode></image> <!-- 弹窗 --> <view class"modal" wx:if"{{showModal}}"><view class"modal-conten…

消息队列(MQ)

对于 MQ 来说&#xff0c;不管是 RocketMQ、Kafka 还是其他消息队列&#xff0c;它们的本质都是&#xff1a;一发一存一消费。下面我们以这个本质作为根&#xff0c;一起由浅入深地聊聊 MQ。 01 从 MQ 的本质说起 将 MQ 掰开了揉碎了来看&#xff0c;都是「一发一存一消费」&…

java实现冒泡排序及其动图演示

冒泡排序是一种简单的排序算法&#xff0c;它重复地遍历要排序的数列&#xff0c;一次比较两个元素&#xff0c;如果它们的顺序错误就把它们交换过来。重复这个过程直到整个数列都是按照从小到大的顺序排列。 具体步骤如下&#xff1a; 比较相邻的两个元素&#xff0c;如果前…

世界5G大会

会议名称:世界 5G 大会 时间:2023 年 12 月 5 日-12 月 8 日 地点:河南郑州 一、会议简介 世界 5G 大会,是由国务院批准,国家发展改革委、科技部、工 信部与地方政府共同主办,未来移动通信论坛联合属地主管厅局联合 承办,邀请全球友好伙伴共同打造的全球首个 5G 领域…

Spring Boot 3 整合 WebSocket (STOMP协议) 和 Vue 3 实现实时通信

&#x1f680; 作者主页&#xff1a; 有来技术 &#x1f525; 开源项目&#xff1a; youlai-mall &#x1f343; vue3-element-admin &#x1f343; youlai-boot &#x1f33a; 仓库主页&#xff1a; Gitee &#x1f4ab; Github &#x1f4ab; GitCode &#x1f496; 欢迎点赞…