14、保存与加载PyTorch训练的模型和超参数

文章目录

  • 1. state_dict
  • 2. 模型保存
  • 3. check_point
  • 4. 详细保存
  • 5. Docker
  • 6. 机器学习常用库

1. state_dict

nn.Module 类是所有神经网络构建的基类,即自己构建一个深度神经网络也是需要继承自nn.Module类才行,并且nn.Module中的state_dict包含神经网络中的权重weight ,偏置bias,过程量buffer,举例说明:

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :NN_Embedding.py
# @Time      :2024/11/26 22:50
# @Author    :Jason Zhang
import torch
from torch import nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.linear1 = nn.Linear(3, 4)self.relu = nn.ReLU()self.linear2 = nn.Linear(4, 5)self.batch_norm = nn.BatchNorm2d(4)def forward(self, x):x = self.linear1(x)x = self.relu(x)y = self.linear2(x)return yif __name__ == "__main__":my_test = MyModel()my_keys = my_test.state_dict().keys()print(f"my_keys={my_keys}")
  • 结果:
    从结果中看出,跟说明的一样,不仅存的是weight,bias ,还有buffer
y_keys=odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias', 'batch_norm.weight', 'batch_norm.bias', 'batch_norm.running_mean', 'batch_norm.running_var', 'batch_norm.num_batches_tracked'])

2. 模型保存

保存和加载

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :torch_save.py
# @Time      :2024/11/27 21:33
# @Author    :Jason Zhang
import torch
import torchvision.models as modelsif __name__ == "__main__":run_code = 0model = models.vgg16(weights='IMAGENET1K_V1')torch.save(model.state_dict(), 'model_weights.pth')model.load_state_dict(torch.load('model_weights.pth', weights_only=True))model.eval()torch.save(model, 'model.pth')

3. check_point

# Define model
import torch
from torch import nn
from torch import optim
import torch.nn.functional as Fclass TheModelClass(nn.Module):def __init__(self):super(TheModelClass, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# Initialize model
model = TheModelClass()# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size())# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():print(var_name, "\t", optimizer.state_dict()[var_name])
Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]

4. 详细保存

在训练过程中,我们希望详细保存,以至于我们可以在中断训练中恢复训练。
保存模型

5. Docker

关于Docker方式搭建深度神经网络环境和配置
在这里插入图片描述

6. 机器学习常用库

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/479909.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Threejs进阶教程-着色器篇】9.顶点着色器入门

【Threejs进阶教程-着色器篇】9.顶点着色器入门 本系列教程第一篇地址,建议按顺序学习认识顶点着色器varying介绍顶点着色器与片元着色器分别的作用Threejs在Shader中的内置变量各种矩阵gl_Position 尝试使用顶点着色器增加分段数增强效果 制作平面鼓包效果鼓包效果…

w058基于web的美发门店管理系统

🙊作者简介:拥有多年开发工作经验,分享技术代码帮助学生学习,独立完成自己的项目或者毕业设计。 代码可以查看文章末尾⬇️联系方式获取,记得注明来意哦~🌹赠送计算机毕业设计600个选题excel文件&#xff0…

leetcode 二叉树的最大深度

104. 二叉树的最大深度 已解答 简单 相关标签 相关企业 给定一个二叉树 root ,返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 示例 1: 输入:root [3,9,20,null,null,15,7] 输出:3…

VMware ubuntu创建共享文件夹与Windows互传文件

1.如图1所示,点击虚拟机,点击设置; 图1 2.如图2所示,点击选项,点击共享文件夹,如图3所示,点击总是启用,点击添加; 图2 图3 3.如图4所示,出现命名共享文件夹…

matlab 实现混沌麻雀搜索算法的光伏MPPT控制仿真

1、内容简介 略 103-可以交流、咨询、答疑 2、内容说明 略 3、仿真分析 略 4、参考论文 略

Unity3D 截图

使用 Unity3D 自带的截图接口,制作截图工具。 截图 有时候我们想对 Unity 的窗口进行截图,如果直接使用一些截图工具,很难截取到一张完整分辨率的图片(例如,我们想要截取一张 1920 * 1080 的图片)。 其实…

STM32F10x 定时器

使用定时器实现:B5 E5的开关 添加相关的.h路径文件 添加相关的.c配置文件 led.h文件 用于声明LED函数 #ifndef __LED_H //没有定义__LED_H #define __LED_H //就定义__LED_H #define LED1_ON GPIO_ResetBits(GPIOB,GPIO_Pin_5) #defi…

PMP好考吗,有多大的价值?

非常好考!PMP目前大陆地区的笔试是只有选择题的,运气好的话 蒙一个都能对,所以PMP的通过率高,这也是很多人考了吐槽PMP没用,是“水证”,但是每年考PMP 的人不减反增,大家可以想一下,…

css:项目

这是一个完整的网站制作的流程 美工会先制作一个原型图: 原型图写的不详细,就是体现一个网页大致的布局 然后美工再做一个psd样例图片 然后再交给程序员 项目 模块化开发:把代码的不同的样式封装起来,需要用到相同样式的标签就…

VsCode 插件推荐(个人常用)

VsCode 插件推荐(个人常用)

黑马程序员Java项目实战《苍穹外卖》Day01

苍穹外卖-day01 课程内容 软件开发整体介绍苍穹外卖项目介绍开发环境搭建导入接口文档Swagger 项目整体效果展示: ​ 管理端-外卖商家使用 ​ 用户端-点餐用户使用 当我们完成该项目的学习,可以培养以下能力: 1. 软件开发整体介绍 作为一…

Python双向链表、循环链表、栈

一、双向链表 1.作用 双向链表也叫双面链表。 对于单向链表而言。只能通过头节点或者第一个节点出发,单向的访问后继节点,每个节点只能记录其后继节点的信息(位置),不能向前遍历。 所以引入双向链表,双…

k8s网络服务

k8s 中向外界提供服务的几种方法port-forward、NodePort,以及 更加常用的提供服务的资源ingress。 1 kubectl port-forward service/redis 6379:6379 现在k8s中有一个pod运行在6379,本机访问映射到6379上,它可以针对部署,服务&…

eduSRC挖洞思路

声明 学习视频来自 B 站UP主泷羽sec,如涉及侵权马上删除文章。 笔记的只是方便各位师傅学习知识,以下网站只涉及学习内容,其他的都与本人无关,切莫逾越法律红线,否则后果自负。 ✍🏻作者简介:致…

Leetcode - 周赛424

目录 一,3354. 使数组元素等于零 二, 3355. 零数组变换 I 三,3356. 零数组变换 II 四,3357. 最小化相邻元素的最大差值 一,3354. 使数组元素等于零 本题实际上是一个前/后缀和的问题,就是判断前缀和与后…

Vue2中 vuex 的使用

1.安装 vuex 安装vuex与vue-router类似,vuex是一个独立存在的插件,如果脚手架初始化没有选 vuex,就需要额外安装。 yarn add vuex3 或者 npm i vuex3 233 Vue2 Vue-Router3 Vuex3 344 Vue3 Vue-Router4 Vuex4 2. 新建 store/index.j…

数据结构C语言描述5(图文结合)--队列,数组、链式、优先队列的实现

前言 这个专栏将会用纯C实现常用的数据结构和简单的算法;有C基础即可跟着学习,代码均可运行;准备考研的也可跟着写,个人感觉,如果时间充裕,手写一遍比看书、刷题管用很多,这也是本人采用纯C语言…

Windows修复SSL/TLS协议信息泄露漏洞(CVE-2016-2183) --亲测

漏洞说明: 打开链接:https://docs.microsoft.com/zh-cn/troubleshoot/windows-server/windows-security/restrict-cryptographic-algorithms-protocols-schannel 可以看到: 找到:应通过配置密码套件顺序来控制 TLS/SSL 密码 我们…

深度学习图像视觉 RKNN Toolkit2 部署 RK3588S边缘端 过程全记录

深度学习图像视觉 RKNN Toolkit2 部署 RK3588S边缘端 过程全记录 认识RKNN Toolkit2 工程文件学习路线: Anaconda Miniconda安装.condarc 文件配置镜像源自定义conda虚拟环境路径创建Conda虚拟环境 本地训练环境本地转换环境安装 RKNN-Toolkit2:添加 lin…

controller中的参数注解@Param @RequestParam和@RequestBody的不同

现在controller中有个方法:(LoginUserRequest是一个用户类对象) PostMapping("/test/phone")public Result validPhone(LoginUserRequest loginUserRequest) {return Result.success(loginUserRequest);}现在讨论Param("login…