Pytorch深度学习-----神经网络模型的保存与加载(VGG16模型)

系列文章目录

PyTorch深度学习——Anaconda和PyTorch安装
Pytorch深度学习-----数据模块Dataset类
Pytorch深度学习------TensorBoard的使用
Pytorch深度学习------Torchvision中Transforms的使用(ToTensor,Normalize,Resize ,Compose,RandomCrop)
Pytorch深度学习------torchvision中dataset数据集的使用(CIFAR10)
Pytorch深度学习-----DataLoader的用法
Pytorch深度学习-----神经网络的基本骨架-nn.Module的使用
Pytorch深度学习-----神经网络的卷积操作
Pytorch深度学习-----神经网络之卷积层用法详解
Pytorch深度学习-----神经网络之池化层用法详解及其最大池化的使用
Pytorch深度学习-----神经网络之非线性激活的使用(ReLu、Sigmoid)
Pytorch深度学习-----神经网络之线性层用法
Pytorch深度学习-----神经网络之Sequential的详细使用及实战详解
Pytorch深度学习-----损失函数(L1Loss、MSELoss、CrossEntropyLoss)
Pytorch深度学习-----优化器详解(SGD、Adam、RMSprop)
Pytorch深度学习-----现有网络模型的使用及修改(VGG16模型)


文章目录

  • 系列文章目录
  • 一、网络模型的保存
    • 1.方法一
    • 2.方法二
  • 二、网络模型的加载
    • 1.方法一
    • 2.方法二
  • 三、总结


一、网络模型的保存

1.方法一

保存整个模型,包括其相关的所有参数

torch.save(obj, f, pickle_protocol=DEFAULT_PROTOCOL)

参数说明:

obj: 要保存的对象,可以是模型、张量、字典等。
f: 要保存到的文件路径或文件对象。
pickle_protocol: 序列化协议的版本,默认为DEFAULT_PROTOCOL。

代码如下:

import torch
import torchvision.models as models
from torch import nnvgg16_true = models.vgg16(weights=True)
vgg16_false = models.vgg16(weights=False)torch.save(vgg16_true, "vgg16_model_true.pth")

其中.pth是后缀标志。

在这里插入图片描述

2.方法二

只保存模型参数,在原有vgg16对象中使用.state_dict()方法即可。

代码如下:

import torch
import torchvision.models as models
from torch import nnvgg16_true = models.vgg16(weights=True)
vgg16_false = models.vgg16(weights=False)torch.save(vgg16_true.state_dict(), "vgg16_model_true_2.pth")

在这里插入图片描述

二、网络模型的加载

1.方法一

对应于上述中保存模型的方法1进行加载。

相关函数如下:

torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)

参数说明:

f: 要加载的文件路径或文件对象。
map_location: 可选参数,用于指定在哪个设备上加载模型。如果不提供该参数,默认会加载到当前设备。
pickle_module: 可选参数,用于指定用于反序列化的模块。默认为pickle。
pickle_load_args: 其他可选的用于反序列化的参数。

代码如下:

import torch
import torchvision.models as models
from torch import nnmodel1 = torch.load("vgg16_model_true.pth")  # 因为vgg16_model_true.pth是使用方法一保存的,故输出后是整个模型网络结构
print(model1)
model2 = torch.load("vgg16_model_true_2.pth")  # 因为vgg16_model_true_2.pth是使用方法二保存的,只保留模型参数,故输出后是整个字典类型
print(model2)

vgg16_model_true.pth加载结果

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

vgg16_model_true_2.pth加载结果

OrderedDict([('features.0.weight', tensor([[[[-5.5373e-01,  1.4270e-01,  5.2896e-01],[-5.8312e-01,  3.5655e-01,  7.6566e-01],[-6.9022e-01, -4.8019e-02,  4.8409e-01]],[[ 1.7548e-01,  9.8630e-03, -8.1413e-02],[ 4.4089e-02, -7.0323e-02, -2.6035e-01],[ 1.3239e-01, -1.7279e-01, -1.3226e-01]],[[ 3.1303e-01, -1.6591e-01, -4.2752e-01],[ 4.7519e-01, -8.2677e-02, -4.8700e-01],[ 6.3203e-01,  1.9308e-02, -2.7753e-01]]],[[[ 2.3254e-01,  1.2666e-01,  1.8605e-01],[-4.2805e-01, -2.4349e-01,  2.4628e-01],[-2.5066e-01,  1.4177e-01, -5.4864e-03]],[[-1.4076e-01, -2.1903e-01,  1.5041e-01],[-8.4127e-01, -3.5176e-01,  5.6398e-01],[-2.4194e-01,  5.1928e-01,  5.3915e-01]],[[-3.1432e-01, -3.7048e-01, -1.3094e-01],[-4.7144e-01, -1.5503e-01,  3.4589e-01],[ 5.4384e-02,  5.8683e-01,  4.9580e-01]]],[[[ 1.7715e-01,  5.2149e-01,  9.8740e-03],[-2.7185e-01, -7.1709e-01,  3.1292e-01],[-7.5753e-02, -2.2079e-01,  3.3455e-01]],[[ 3.0924e-01,  6.7071e-01,  2.0546e-02],[-4.6607e-01, -1.0697e+00,  3.3501e-01],[-8.0284e-02, -3.0522e-01,  5.4460e-01]],[[ 3.1572e-01,  4.2335e-01, -3.4976e-01],[ 8.6354e-02, -4.6457e-01,  1.1803e-02],[ 1.0483e-01, -1.4584e-01, -1.5765e-02]]],...,

2.方法二

import torch
import torchvision.models as models
from torch import nnvgg16_true = models.vgg16(weights=True)vgg16_true.load_state_dict(torch.load("vgg16_model_true_2.pth"))  # 针对第二种加载参数的情况,使其显示完整的网络结构
print(vgg16_true)
VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

注意: 加载模型时,要确保当前代码中使用的模型类与之前保存的模型类相同。

三、总结

torch.load()是PyTorch中用于加载保存的对象的函数,可以加载之前使用torch.save()保存的模型、张量、字典等。可以指定要加载的文件路径或文件对象,并可选地指定加载到的设备、反序列化模块等参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/82865.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构刷题训练——链表篇(一)

目录 前言 题目一:链表的中间节点 思路 分析 题解 题目二:链表中倒数第k个结点 思路 分析 题解 题目三:合并两个有序链表 思路 分析 题解 方法二 题解 题目四:链表的回文结构 思路 分析 题解 总结 前言 今天我将开…

2023华数杯C题总结

前言 对这次比赛中遇到的问题和卡住的思路进行复盘,整理相关心得,供以后比赛参考 🧡1.认识数据类型🧡 连续变量:母亲年龄、妊娠时间、CBTS、EPDS、HADS、整晚睡醒时间、婴儿年龄 无序分类变量:婚姻状态、…

Gpt微信小程序搭建的前后端流程 - 前端小程序部分-2.确定交互所需的后端API(二)

Gpt微信小程序搭建的前后端流程 - 前端小程序部分-2.确定交互所需的后端API(二) 参考微信小程序-小柠AI智能聊天,可自行先体验。 根据上一节的小程序静态页面设计,需要从后端获取数据的主要4个点: 登录流程;获取今日已提问次数&a…

Unity制作护盾——2、力场冲击波护盾

Unity制作力场护盾 大家好,我是阿赵。   继续做护盾,这一期做一个力场冲击波护盾。 一、效果展示 主要的效果并不是这个球,而是护盾在被攻击的时候,会出现一个扩散的冲击波。比如上图在右边出现了冲击波 如果在左边被攻击&am…

MongoDB 6.0.8 安装配置

一、前言 MongoDB是一个基于分布式文件存储的数据库。由C语言编写。旨在为WEB应用提供可扩展的高性能数据存储解决方案。在高负载的情况下,添加更多的节点,可以保证服务器性能。 MongoDB 将数据存储为一个文档,数据结构由键值(key>value…

[分享]STM32G070 串口 乱码 解决方法

硬件 NUCLEO-G070RB 工具 cubemx 解决方法 7bit 改为 8bit printf 配置方法 添加头文件 #include <stdio.h> 添加重定向代码 #ifdef __GNUC__#define PUTCHAR_PROTOTYPE int __io_putchar(int ch)#else#define PUTCHAR_PROTOTYPE int fputc(int ch, FILE *f)#endi…

卷积神经网络实现MNIST手写数字识别 - P1

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f366; 参考文章&#xff1a;365天深度学习训练营-第P1周&#xff1a;实现mnist手写数字识别&#x1f356; 原作者&#xff1a;K同学啊 | 接辅导、项目定制&#x1f680; 文章来源&#xff1a;K同…

SPM(Swift Package Manager)开发及常见事项

SPM怎么使用的不再赘述&#xff0c;其优点是Cocoapods这样的远古产物难以望其项背的&#xff0c;而且最重要的是可二进制化、对xcproj项目无侵入&#xff0c;除了网络之外简直就是为团队开发的项目库依赖最好的管理工具&#xff0c;是时候抛弃繁杂低下的cocoapods了。 一&…

C语言:打开调用堆栈

第一步&#xff1a;打断点 第二步&#xff1a;FnF5 第三步&#xff1a;按如图找到调用堆栈

使用Flask.Request的方法和属性,获取get和post请求参数(二)

1、Flask中的request 在Python发送Post、Get等请求时&#xff0c;我们使用到requests库。Flask中有一个request库&#xff0c;有其特有的一些方法和属性&#xff0c;注意跟requests不是同一个。 2、Post请求&#xff1a;request.get_data() 用于服务端获取客户端请求数据。注…

积累常见的有针对性的python面试题---python面试题001

1.考点列表的.remove方法的参数是传入的对应的元素的值,而不是下标 然后再看remove这里,注意这个是,删除写的那个值,比如这里写3,就是删除3, 而不是下标. remove不是下标删除,而是内容删除. 2.元组操作,元组不支持修改,某个下标的内容 可以问他如何修改元组的某个元素 3.…

【MMU】认识 MMU 及内存映射的流程

MMU&#xff08;Memory Manager Unit&#xff09;&#xff0c;是内存管理单元&#xff0c;负责将虚拟地址转换成物理地址。除此之外&#xff0c;MMU 实现了内存保护&#xff0c;进程无法直接访问物理内存&#xff0c;防止内存数据被随意篡改。 目录 一、内存管理体系结构 1、…

idea打开多个项目需要开多个窗口(恢复询问弹窗)

【版权所有&#xff0c;文章允许转载&#xff0c;但须以链接方式注明源地址&#xff0c;否则追究法律责任】【创作不易&#xff0c;点个赞就是对我最大的支持】 前言 仅作为学习笔记&#xff0c;供大家参考 总结的不错的话&#xff0c;记得点赞收藏关注哦&#xff01; 使用…

【TypeScript】中定义与使用 Class 类的解读理解

目录 类的概念类的继承 &#xff1a;类的存取器&#xff1a;类的静态方法与静态属性&#xff1a;类的修饰符&#xff1a;参数属性&#xff1a;抽象类&#xff1a;类的类型: 总结&#xff1a; 类的概念 类是用于创建对象的模板。他们用代码封装数据以处理该数据。JavaScript 中的…

一起学SF框架系列7.1-spring-AOP-基础知识

AOP(Aspect-oriented Programming-面向切面编程&#xff09;是一种编程模式&#xff0c;是对OOP(Object-oriented Programming-面向对象编程&#xff09;一种有益补充。在OOP中&#xff0c;万事万物都是独立的对象&#xff0c;对象相互耦合关系是基于业务进行的&#xff1b;但在…

MySQL之深入InnoDB存储引擎——Undo页

文章目录 一、UNDO日志格式1、INSERT操作对应的UNDO日志2、DELETE操作对应的undo日志3、UPDATE操作对应的undo日志1&#xff09;不更新主键2&#xff09;更新主键的操作 3、增删改操作对二级索引的影响 二、UNDO页三、UNDO页面链表四、undo日志具体写入过程五、回滚段1、回滚段…

C语言系列之原码、反码和补码

一.欢迎来到我的酒馆 讨论c语言中&#xff0c;原码、反码、补码。 目录 一.欢迎来到我的酒馆二.原码 二.原码 2.1在计算机中&#xff0c;所有数据都是以二进制存储的&#xff0c;但不是直接存储二进制数&#xff0c;而是存储二进制的补码。原码很好理解&#xff0c;就是对应的…

SQL Server数据库如何添加Oracle链接服务器(Windows系统)

SQL Server数据库如何添加Oracle链接服务器 一、在添加访问Oracle的组件1.1 下载Oracle的组件 Oracle Provider for OLE DB1.2 注册该组件1.2.1 下载的压缩包解压位置1.2.2 接着用管理员运行Cmd 此处一定要用管理员运行&#xff0c;否则会报错 二、配置环境变量三、 重启SQL Se…

IDEA开启并配置services窗口

一、选择view -> Tool Windows -> Services 二、底下栏会出现Services 然后右键添加工程即可

Apache DolphinScheduler 3.1.8 版本发布,修复 SeaTunnel 相关 Bug

近日&#xff0c;Apache DolphinScheduler 发布了 3.1.8 版本。此版本主要基于 3.1.7 版本进行了 bug 修复&#xff0c;共计修复 16 个 bug, 1 个 doc, 2 个 chore。 其中修复了以下几个较为重要的问题&#xff1a; 修复在构建 SeaTunnel 任务节点的参数时错误的判断条件修复 …