一学就会的深度学习基础指令及操作步骤(5)使用预训练模型

文章目录

    • 使用预训练模型
      • 加载预训练模型
      • 图像加载与预处理
      • 预测

使用预训练模型

查看模型库和常用模型

加载预训练模型

from torchvision.models import vgg16  # VGG16模型架构的定义
from torchvision.models import VGG16_Weights  # VGG16的预训练权重配置# load the VGG16 network *pre-trained* on the ImageNet dataset
weights = VGG16_Weights.DEFAULT  # 获取默认的预训练权重(通常是ImageNet上训练的)
model = vgg16(weights=weights)  # 创建VGG16模型,并加载指定权重

当指定 weights=VGG16_Weights.DEFAULT 时,PyTorch会:

  1. 根据配置自动下载对应的预训练权重文件(如 .pth)。
  2. 将权重加载到 vgg16 定义的模型架构中,确保每层参数正确匹配。

VGG16 模型结构如下:

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

VGG16 神经网络,主要用于图像分类任务(如识别1000种物体)。它的结构设计非常规整,像搭积木一样层层堆叠。

VGG16 模型结构拆解

1. 特征提取部分 (features) —— 层层递进捕捉图像细节

  • 核心操作:重复堆叠 “卷积层 + ReLU激活”,每2~3个卷积后接一个 最大池化层 压缩尺寸。
  • 具体流程
  • 前2层:输入3通道图片 → 64通道 → 捕捉基础边缘纹理。
  • 池化:尺寸减半(如224x224 → 112x112),保留关键特征。
  • 后续层:逐步增加通道数(128 → 256 → 512),提取更复杂图案(如形状、物体局部)。
  • 池化间隔:每阶段通过池化压缩空间信息,减少计算量。

2.分类部分 (classifier) —— 综合特征做判断

  • 全局池化 (avgpool):将特征图压缩成固定大小(7x7),避免输入尺寸影响。
  • 全连接层
    • 前两层:4096维 → 通过Dropout随机关闭部分神经元,防止死记硬背。
    • 最后一层:4096维 → 1000维,输出1000个类别的概率。

VGG16 设计优势

1. 统一的小卷积核 (3x3)

  • 优势:多个小卷积核叠加可等效大卷积核的感受野,但参数更少、非线性更强。
    • 例如:2层3x3卷积 ≈ 1层5x5卷积,但参数量减少 (3x3x2=18 vs 5x5=25)。
      效果:更高效捕捉空间特征,提升模型深度。

2. 深度结构

  • 16层(13卷积 + 3全连接)的深度结构能提取多层次特征:
    • 浅层:边缘、纹理 → 中层:形状、部件 → 深层:完整物体。

3. 模块化设计

  • 每阶段结构重复(如卷积→激活→池化),代码易实现,模型可扩展性强(如VGG19)。

4. 泛化能力强

  • 在ImageNet上预训练后,可作为其他任务的基础模型(迁移学习),适应性强。

VGG16 局限性

  • 参数量大:全连接层占据大量参数(如25088→4096),计算成本高。
  • 现代替代品:后续模型(如ResNet)通过残差连接解决深层梯度问题,效果更好。

图像加载与预处理

pre_trans = weights.transforms()  # 调用weights自带的标准化预处理流程
pre_trans
>>> ImageClassification(# 1. 调整尺寸与裁剪
>>>     crop_size=[224]  # 从缩放后的图片中心裁剪出 224x224像素 的区域,作为模型输入。
>>>     resize_size=[256]    # 先把图片等比缩放到 短边256像素# 2. 归一化处理# 归一化后像素 = (原始像素 - mean) / std# 让输入数据的分布接近标准正态(均值为0,标准差为1),加速模型收敛,稳定训练过程。 (这些数值是ImageNet数据集的统计值,沿用可兼容预训练模型。)
>>>     mean=[0.485, 0.456, 0.406]  #  RGB三通道的均值,用于将像素值减去均值(中心化)
>>>     std=[0.229, 0.224, 0.225]   # RGB三通道的标准差,用于将像素值除以标准差(缩放至标准正态分布)
>>>     interpolation=InterpolationMode.BILINEAR # 缩放图片时使用双线性插值,平滑像素间的过渡,减少锯齿感
>>> )

上面代码等同于下面代码

IMG_WIDTH, IMG_HEIGHT = (224, 224)pre_trans = transforms.Compose([transforms.ToDtype(torch.float32, scale=True), # Converts [0, 255] to [0, 1]transforms.Resize((IMG_WIDTH, IMG_HEIGHT)),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],),transforms.CenterCrop(224)
])

对图像进行预处理,以便能以适当的格式(1, 3, 224, 224)将其送入模型中

def load_and_process_image(file_path):# Print image's original shape, for referenceprint('Original image shape: ', mpimg.imread(file_path).shape)image = tv_io.read_image(file_path).to(device)image = pre_trans(image)  # weights.transforms()image = image.unsqueeze(0)  # Turn into a batchreturn imageprocessed_image = load_and_process_image("data/doggy_door_images/happy_dog.jpg")
print("Processed image shape: ", processed_image.shape)>>> Original image shape:  (1200, 1800, 3)
>>> Processed image shape:  torch.Size([1, 3, 224, 224])

预测

vgg_classes = json.load(open("data/imagenet_class_index.json"))def readable_prediction(image_path):# Show imageshow_image(image_path)# Load and pre-process image 加载图像并预处理image = load_and_process_image(image_path)# Make predictions 模型推理,取第一个(唯一)样本的输出output = model(image)[0]  # Unbatchpredictions = torch.topk(output, 3) # 获取概率最高的3个类别indices = predictions.indices.tolist() # 转换为列表# Print predictions in readable formout_str = "Top results: "# 映射索引到类别名称,遍历索引列表,从字典中提取对应的类别名称pred_classes = [vgg_classes[str(idx)][1] for idx in indices]out_str += ", ".join(pred_classes)print(out_str)return predictionsreadable_prediction("data/doggy_door_images/happy_dog.jpg")
>>> Original image shape:  (1200, 1800, 3)
>>> Top results: Staffordshire_bullterrier, American_Staffordshire_terrier, Labrador_retriever
>>> torch.return_types.topk( values=tensor([19.6133, 15.8125, 14.4607], device='cuda:0', grad_fn=<TopkBackward0>), indices=tensor([179, 180, 208], device='cuda:0'))readable_prediction("data/doggy_door_images/brown_bear.jpg")
>>> Original image shape:  (2592, 3456, 3)
>>> Top results: brown_bear, American_black_bear, sloth_bear
>>> torch.return_types.topk(
values=tensor([33.0100, 27.1086, 22.9985], device='cuda:0', grad_fn=<TopkBackward0>),
indices=tensor([294, 295, 297], device='cuda:0'))readable_prediction("data/doggy_door_images/sleepy_cat.jpg")
>>> Original image shape:  (1200, 1800, 3)
>>> Top results: tiger_cat, tabby, Egyptian_cat
>>> torch.return_types.topk(
values=tensor([16.7054, 13.8567, 12.5219], device='cuda:0', grad_fn=<TopkBackward0>),
indices=tensor([282, 281, 285], device='cuda:0'))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/32093.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

推荐一个比较好的开源的工作流引擎

由于DeepSeek等AI大模型的出现&#xff0c;工作流模式再次流行起来&#xff0c;低代码甚至零代码就可以实现应用开发&#xff0c;而且有DeepSeek这样的超级AI作为大脑&#xff0c;人人都可以开发自动化工作流。 比如搭建邮件助手工作流&#xff0c;可以自动润色各种邮件内容。…

CarPlanner:用于自动驾驶大规模强化学习的一致性自回归轨迹规划

25年2月来自浙大和菜鸟网络的论文“CarPlanner: Consistent Auto-regressive Trajectory Planning for Large-scale Reinforcement Learning in Autonomous Driving”。 轨迹规划对于自动驾驶至关重要&#xff0c;可确保在复杂环境中安全高效地导航。虽然最近基于学习的方法&a…

Fedora41安装MySQL8.4.4

Fedora41安装MySQL8.4.4 Fedora41用yum仓库安装MySQL8.4.4 笔记250310下载安装启动mysqld服务查看生成的初始密码 , 用初始密码登录登录后,必须修改初始密码才能执行其它操作可选设置降低密码强度要求, 使用简单密码降低 validate_password 组件对密码强度的要求 用SET GLOBAL命…

信息安全意识之安全组织架构图

一、信息安全技术概论1.网络在当今社会中的重要作用2.信息安全的内涵 网络出现前&#xff1a;主要面向数据的安全&#xff0c;对信息的机密性、完整性和可用性的保护&#xff0c;即CIA三元组 网络出现后&#xff0c;还涵盖了面向用户的安全&#xff0c;即鉴别&#xff0c;授权&…

安卓Android与iOS设备管理对比:企业选择指南

目录 一、管理方式差异 Android Enterprise方案包含三种典型模式&#xff1a; Apple MDM方案主要提供两种模式&#xff1a; 二、安全防护能力 Android系统特点&#xff1a; 三、应用管理方案 四、设备选择建议 五、典型场景推荐 需求场景 推荐方案 六、决策建议要点…

linunx ubuntu24.04.02装libfuse2导致无法开机进不了桌面解决办法

osu.appimage运行需要libfuse2 然后我就下了fuse,打了两把第二天无法开机 这样是不能开机的 这样是可以开机的 解决办法一&#xff1a;玩星火商店的osu&#xff0c;好了问题解决 解决办法二&#xff1a; 在这个页面 ctrl alt f2进入tty6 sudo apt install ubuntu-desktop 进…

mysql-8.0.41-winx64 手动安装详细教程(2025版)

mysql-8.0.41-winx64 手动安装详细教程&#xff08;2025版&#xff09; 一、下载安装包二、配置环境变量三、安装配置四、启动 MySQL 服务&#xff0c;修改密码 一、下载安装包 安装地址如下&#xff1a; https://dev.mysql.com/downloads/mysql/使用7-zip或其他解压软件&…

wireguard搭配udp2raw部署内网

前言 上一篇写了使用 wireguard 可以非常轻松的进行组网部署&#xff0c;但是如果服务器厂商屏蔽了 udp 端口&#xff0c;那就没法了 针对 udp 被服务器厂商屏蔽的情况&#xff0c;需要使用一款 udp2raw 或 socat 类似的工具&#xff0c;来将 udp 打包成 tcp 进行通信 这里以…

[杂学笔记] TCP和UDP的区别,对http接口解释 , Cookie和Session的区别 ,http和https的区别 , 智能指针 ,断点续传

文章目录 1. TCP和UDP的区别2. 对http接口解释3. Cookie和Session的区别4. http和https的区别5. 智能指针6.断点续传 1. TCP和UDP的区别 tcp的特点&#xff1a; 面向连接&#xff0c;可靠性高&#xff0c;全双工&#xff0c;面向字节流udp特点&#xff1a;无连接&#xff0c;不…

Qt入门笔记

目录 一、前言 二、创建Qt项目 2.1、使用向导创建 2.2、最简单的Qt应用程序 2.2.1、main函数 2.2.2、widget.h文件 2.2.3、widget.cpp文件 2.3、Qt按键Botton 2.3.1、创建一个Botton 2.3.2、信号与槽 2.3.3、按键使用信号与槽的方法 2.4、文件Read与Write-QFile类 2…

Unity辅助工具_头部与svn

Unity调用者按钮增加PlaySideButton using QQu; using UnityEditor; using UnityEngine; [InitializeOnLoad] public class PlaySideButton {static PlaySideButton(){UnityEditorToolbar.RightToolbarGUI.Add(OnRightToolbarGUI);UnityEditorToolbar.LeftToolbarGUI.Add(OnLe…

ubuntu软件

视频软件&#xff0c;大部分的编码都能适应 sudo apt install vlc图片软件 sudo apt install gwenview截图软件 sudo apt install flameshot设置快捷键 flameshot flameshot gui -p /home/cyun/Pictures/flameshot也就是把它保存到一个自定义的路径 菜单更换 sudo apt r…

Spring (十)事务

目录 一 Spring数据库的相关配置&#xff1a; 1 导入包&#xff1a; 2 配置数据库连接信息 3 可以直接使用&#xff1a;DataSource,JdbcTemplate 二 事务管理&#xff1a; 1 事务管理的实现 1.1 开启Spring事务管理 1.2 为指定方法添加事务 2 关键类与接口 2.1 事务拦…

【MySQL_04】数据库基本操作(用户管理--配置文件--远程连接--数据库信息查看、创建、删除)

文章目录 一、MySQL 用户管理1.1 用户管理1.11 mysql.user表详解1.12 添加用户1.13 修改用户权限1.14 删除用户1.15 密码问题 二、MySQL 配置文件2.1 配置文件位置2.2 配置文件结构2.3 常用配置参数 三、MySQL远程连接四、数据库的查看、创建、删除4.1 查看数据库4.2 创建、删除…

Java算术运算符与算术表达式

Java算术运算符包括&#xff08;加、正号&#xff09;、-&#xff08;减、负号&#xff09;、*&#xff08;乘&#xff09;、/&#xff08;除&#xff09;、%&#xff08;求余&#xff09;、&#xff08;自增&#xff09;和--&#xff08;自减&#xff09;。它们用于构建算术表…

【网络】HTTP协议、HTTPS协议

HTTP与HTTPS HTTP协议概述 HTTP(超文本传输协议):工作在OSI顶层应用层,用于客户端(浏览器)与服务器之间的通信,B/S模式 无状态:每次请求独立,服务器不保存客户端状态(通过Cookie/Session扩展状态管理)。基于TCP:默认端口80(HTTP)、443(HTTPS),保证可靠传输。请…

Linux 网络:skb 数据管理

文章目录 1. 前言2. skb 数据管理2.1 初始化2.2 数据的插入2.2.1 在头部插入数据2.2.2 在尾部插入数据 2.2 数据的移除 3. 小结 1. 前言 限于作者能力水平&#xff0c;本文可能存在谬误&#xff0c;因此而给读者带来的损失&#xff0c;作者不做任何承诺。 2. skb 数据管理 数…

考研数学复习之定积分定义求解数列极限(超详细教程)

定积分定义求解数列极限是一种将数列极限问题转化为定积分问题进行求解的方法。这种方法通常适用于那些和式数列极限,其主要思路是将数列的项看作是某个函数在某一点或某一段区间上的取值或某种形式的和,然后利用定积分的性质和计算方法,来求解这类数列的极限。 定积分定义 设函…

Linux-基础开发工具

1.软件包管理器 1.1什么是软件包 • 在Linux下安装软件, ⼀个通常的办法是下载到程序的源代码, 并进⾏编译, 得到可执⾏程序. • 但是这样太⿇烦了, 于是有些⼈把⼀些常⽤的软件提前编译好, 做成软件包(可以理解成windows上 的安装程序)放在⼀个服务器上, 通过包管理器可以很…

【无人机路径规划】基于麻雀搜索算法(SSA)的无人机路径规划(Matlab)

效果一览 代码获取私信博主基于麻雀搜索算法&#xff08;SSA&#xff09;的无人机路径规划&#xff08;Matlab&#xff09; 一、算法背景与核心思想 麻雀搜索算法&#xff08;Sparrow Search Algorithm, SSA&#xff09;是一种受麻雀群体觅食行为启发的元启发式算法&#xff0…