【Pytorch】pytorch中保存模型的三种方式

【Pytorch】pytorch中保存模型的三种方式

文章目录

  • 【Pytorch】pytorch中保存模型的三种方式
    • 1. torch保存模型相关的api
      • 1.1 torch.save()
      • 1.2 torch.load()
      • 1.3 torch.nn.Module.load_state_dict()
      • 1.4 什么是state_dict()
        • 1.4. 1 举个例子
    • 2. pytorch模型文件后缀
    • 3. 存储整个模型
      • 3.1 直接保存整个模型
      • 3.2 直接加载整个模型
    • 4. 只保存模型的权重
      • 4.1 保存模型权重
      • 4.2 读取模型权重
    • 5. 使用Checkpoint保存中间结果
      • 5.1 保存Checkpoint
      • 5.2 加载Checkpoint
    • Reference

1. torch保存模型相关的api

1.1 torch.save()

torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)

参考自https://pytorch.org/docs/stable/generated/torch.save.html#torch-save

Image

torch.save()的功能是保存一个序列化的目标到磁盘当中,该函数使用了Python中的pickle库用于序列化,具体参数的解释如下

参数功能
obj需要保存的对象
f指定保存的路径
pickle_module用于 pickling 元数据和对象的模块
pickle_protocol指定 pickle protocal 可以覆盖默认参数

常见用法

# dirctly save entiry model
torch.save('model.pth')
# save model'weights only
torch.save(model.state_dict(), 'model_weights.pth')
# save checkpoint
checkpint = {'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,'epoch': epoch
}
torch.save(checkpoint, 'checkpoint_path.pth')

1.2 torch.load()

torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args)

参考自https://pytorch.org/docs/stable/generated/torch.load.html#torch-load

Image

torch.load()的功能是加载模型,使用python中的unpickle工具来反序列化对象,并且加载到对应的设备上,具体的参数解释如下

参数功能
f对象的存放路径
map_location需要映射到的设备
pickle_module用于 unpickling 元数据和对象的模块

常见用法

# specify the device to use
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load entiry model to cuda if available
model = torch.load('whole_model.pth', map_location=device)
# load model's weight to cuda if available
model.load_state_dict(torch.load('model_weights.pth'), map_location=device)
# load checkpoint
checkpoint = torch.load('checkpoint_path.pth', map_location=device)
# checkpoint加载出来就像个字典,预先保存的是否放置了什么内容,加载之后就可以这样来获取
loss = checkpoint['loss']
epoch = chekpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict']
optimizer.load_state_dict(checkpoint['optimizer_state_dict']

1.3 torch.nn.Module.load_state_dict()

torch.nn.Module.load_state_dict(state_dict, strict=True, assign=False)

参考自https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict

Image

torch.nn.Module.load_state_dict()将参数和缓冲区从 state_dict 复制到此模块及其后代中。 如果 strict 为 True,则 state_dict 的键必须与该模块的 state_dict() 函数返回的键完全匹配。具体的参数描述如下

参数功能
state_dict保存parameters和persistent buffers的字典
strict是否强制要求state_dict中的key和model.state_dict返回的key严格一致

1.4 什么是state_dict()

torch.nn.Module.state_dict()

参考自https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.state_dict

Image

其实state_dict可以理解为一种简单的Python Dictionary,其功能是将每层之间的参数进行一一映射并且存储在python的数据类型字典中。因此state_dict可以轻松地进行修改、保存等操作。

除了torch.nn.Module拥有state_dict()方法之外,torch.optim.Optimizer也具有state_dict()方法。如下所示

torch.optim.Optimizer.state_dict()

参考自https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.state_dict.html

1.4. 1 举个例子
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimclass SimpleModel(nn.Module):def __init__(self, input_size, output_size):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(input_size, 100)self.fc2 = nn.Linear(100, output_size)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)if __name__ == "__main__":model = SimpleModel(10, 2)optimizer = optim.Adam(model.parameters(), lr=0.001)print("Check Model's State Dict:")for key, value in model.state_dict().items():print(key, "\t", value.size())print("Check Optimizer's State Dict:")for key, value in optimizer.state_dict().items():print(key, "\t", value)

输出的结果如下

Check Model's State Dict:
fc1.weight       torch.Size([100, 10])
fc1.bias         torch.Size([100])
fc2.weight       torch.Size([2, 100])
fc2.bias         torch.Size([2])
Check Optimizer's State Dict:
state    {}
param_groups     [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1, 2, 3]}]

2. pytorch模型文件后缀

常用的torch模型文件后缀有.pt.pth,这是最常见的PyTorch模型文件后缀,表示模型的权重、结构和状态字典(state_dict)都被保存在其中。

torch.save(model.state_dict(), 'model_weights.pth')
torch.save(model, 'full_model.pt')

还有检查点后缀如.ckpt.checkpoint,这些后缀常被用于保存模型的检查点,包括权重和训练状态等。它们也可以表示模型的中间状态,以便在训练期间从中断的地方继续训练。

checkpoint = {'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'epoch': epoch,# 其他信息
}
torch.save(checkpoint, 'model_checkpoint.ckpt')

还有其他跨框架的数据结构例如.h5,PyTorch的模型也可以保存为HDF5文件格式用于跨框架的数据交换,可以使用h5py库来进行读写

import h5pywith h5py.File('model.h5', 'w') as f:# 将模型参数逐一保存到HDF5文件for name, param in model.named_parameters():f.create_dataset(name, data=param.numpy())

3. 存储整个模型

可以直接使用torch.save()torch.load()来加载和保存整个模型到文件中,这种方式保存了模型的所有权重、架构及其其他相关信息,即使不知道模型的结构也能够直接通过权重文件来加载模型

3.1 直接保存整个模型

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimimport osclass SimpleModel(nn.Module):def __init__(self, input_size, output_size):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(input_size, 256)self.fc2 = nn.Linear(256, 256)self.fc3 = nn.Linear(256, output_size)def forward(self, x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))return self.fc3(x)if __name__ == "__main__":model = SimpleModel(10, 2)# specify the save pathurl = os.path.dirname(os.path.realpath(__file__)) + '/models/'# 如果路径不存在则创建if not os.path.exists(url):os.makedirs(url)# specify the model save namemodel_name = 'simple_model.pth'# save the model to filetorch.save(model, url + model_name)

我们直接将模型保存到了当前文件夹下的./models文件夹中,

3.2 直接加载整个模型

由于我们已经保存了模型的所有相关信息,所以我们可以不知道模型的结构也能加载该模型,如下所示

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimimport osclass SimpleModel(nn.Module):def __init__(self, input_size, output_size):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(input_size, 256)self.fc2 = nn.Linear(256, 256)self.fc3 = nn.Linear(256, output_size)def forward(self, x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))return self.fc3(x)if __name__ == "__main__":device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# model = SimpleModel(10, 2)# specify the save pathurl = os.path.dirname(os.path.realpath(__file__)) + '/models/'# 如果路径不存在则创建if not os.path.exists(url):os.makedirs(url)# specify the model save namemodel_name = 'simple_model.pth'# load the modelif os.path.exists(url + model_name):model = torch.load(url + model_name, map_location=device)print("Success Load Model From:\n\t%s"%(url+model_name))

成功加载了模型


4. 只保存模型的权重

4.1 保存模型权重

利用前面提到的state_dict()方法来完成这一操作

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimimport osclass SimpleModel(nn.Module):def __init__(self, input_size, output_size):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(input_size, 256)self.fc2 = nn.Linear(256, 256)self.fc3 = nn.Linear(256, output_size)def forward(self, x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))return self.fc3(x)if __name__ == "__main__":# specify devicedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = SimpleModel(10, 2)# specify the save pathurl = os.path.dirname(os.path.realpath(__file__)) + '/models/'# 如果路径不存在则创建if not os.path.exists(url):os.makedirs(url)# specify the model save namemodel_name = 'simple_model_weights.pth'torch.save(model.state_dict(), url + model_name)

我们直接将模型权重保存到了当前文件夹下的./models文件夹中,

4.2 读取模型权重

由于我们只保存了模型的权重信息,不知道模型的结构,所以必须要先实例化模型才行。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimimport osclass SimpleModel(nn.Module):def __init__(self, input_size, output_size):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(input_size, 256)self.fc2 = nn.Linear(256, 256)self.fc3 = nn.Linear(256, output_size)def forward(self, x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))return self.fc3(x)if __name__ == "__main__":# specify devicedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# get modelmodel = SimpleModel(10, 2)# specify the save pathurl = os.path.dirname(os.path.realpath(__file__)) + '/models/'# 如果路径不存在则创建if not os.path.exists(url):os.makedirs(url)# specify the model save namemodel_name = 'simple_model_weights.pth'if os.path.exists(url + model_name):model.load_state_dict(torch.load(url + model_name, map_location=device))print("Success Load Model'weights From:\n\t%s"%(url+model_name))

5. 使用Checkpoint保存中间结果

5.1 保存Checkpoint

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os# 数据准备
x = torch.tensor(np.random.rand(100, 1), dtype=torch.float32)
y = 3 * x + 2 + 0.1 * torch.randn(100, 1)# 定义模型
class SimpleLinearModel(nn.Module):def __init__(self):super(SimpleLinearModel, self).__init__()self.linear = nn.Linear(1, 1)def forward(self, x):return self.linear(x)if __name__=="__main__":# specify devicedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 实例化模型model = SimpleLinearModel()# 定义损失函数和优化器criterion = nn.MSELoss()optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练循环num_epochs = 1000checkpoint_interval = 100  # 保存检查点的间隔url = os.path.dirname(os.path.realpath(__file__))+'/models/'if not os.path.exists(url):os.makedirs(url)checkpoint_file = 'checkpoint.pth'  # 检查点文件路径for epoch in range(num_epochs):# 前向传播outputs = model(x)loss = criterion(outputs, y)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()# 打印训练信息if (epoch + 1) % checkpoint_interval == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 保存检查点checkpoint = {'epoch': epoch + 1,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss.item(),}torch.save(checkpoint, url+checkpoint_file)

5.2 加载Checkpoint

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os# 数据准备
x = torch.tensor(np.random.rand(100, 1), dtype=torch.float32)
y = 3 * x + 2 + 0.1 * torch.randn(100, 1)# 定义模型
class SimpleLinearModel(nn.Module):def __init__(self):super(SimpleLinearModel, self).__init__()self.linear = nn.Linear(1, 1)def forward(self, x):return self.linear(x)if __name__=="__main__":# specify devicedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 实例化模型model = SimpleLinearModel()# 定义损失函数和优化器criterion = nn.MSELoss()optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练循环num_epochs = 1000checkpoint_interval = 100  # 保存检查点的间隔url = os.path.dirname(os.path.realpath(__file__))+'/models/'if not os.path.exists(url):os.makedirs(url)checkpoint_file = 'checkpoint.pth'  # 检查点文件路径# load from checkpointcheckpoint = torch.load(url+checkpoint_file)for key, value in checkpoint.items():print(key, '-->', value)model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])epoch = checkpoint['epoch']loss = checkpoint['loss']print('Loaded checkpoint from epoch %d. Loss %f' % (epoch, loss))

输出如下

loss --> 0.01629752665758133
(test_ros_python) sjh@sjhR9000X:~/Documents/python_draft$  cd /home/sjh/Documents/python_draft ; /usr/bin/env /home/sjh/anaconda3/envs/metaRL/bin/python /home/sjh/.vscode/extensions/ms-python.python-2023.18.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher 40897 -- /home/sjh/Documents/python_draft/check_checkpoint.py 
epoch --> 1000
model_state_dict --> OrderedDict([('linear.weight', tensor([[2.6938]])), ('linear.bias', tensor([2.1635]))])
optimizer_state_dict --> {'state': {0: {'momentum_buffer': None}, 1: {'momentum_buffer': None}}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1]}]}
loss --> 0.01629752665758133
Loaded checkpoint from epoch 1000. Loss 0.016298

我们成功从断点处加载checkpoint, 可以再从这个断点处继续训练

Reference

参考一

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/158318.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Netty深入浅出Java网络编程学习笔记(二) Netty进阶应用篇

目录 四、应用 1、粘包与半包 现象分析 粘包 半包 本质 解决方案 短链接 定长解码器 行解码器 长度字段解码器——LTC 2、协议设计与解析 协议的作用 Redis协议 HTTP协议 自定义协议 组成要素 编码器与解码器 编写测试类 Sharable注解 自定义编解码器能否使用Sharable注解 3、在…

科技为饮食带来创新,看AI如何打造智能营养时代

在当今社会,快节奏的生活方式、便捷的食品选择以及现代科技的快速发展正深刻地重塑着我们对健康的认知和实践,它已经不再仅仅是一个话题,而是一个备受关注的社会焦点。在这个纷繁复杂的交汇点上,AI技术的介入为我们开辟了前所未有…

【Rust笔记】浅聊 Rust 程序内存布局

浅聊Rust程序内存布局 内存布局看似是底层和距离应用程序开发比较遥远的概念集合,但其对前端应用的功能实现颇具现实意义。从WASM业务模块至Nodejs N-API插件,无处不涉及到FFI跨语言互操作。甚至,做个文本数据的字符集转换也得FFI调用操作系统…

云原生网关可观测性综合实践

作者:钰诚 可观测性 可观测性(Observability)是指系统、应用程序或服务的运行状态、性能和行为能够被有效地监测、理解和调试的能力。 随着系统架构从单体架构到集群架构再到微服务架构的演进,业务越来越庞大,也越来…

QMidi Pro for Mac:打造您的专属卡拉OK体验

你是否曾经厌倦于在KTV里与朋友们争夺麦克风?是否想要在家中享受自定义的卡拉OK体验?现在,有了QMidi Pro for Mac,一切变得简单而愉快! QMidi Pro是一款功能强大的卡拉OK播放器,专为Mac用户设计。它充分利…

【C语言】程序环境和预处理

程序环境: 1、C语言的任何一种实现,存在两个不同的环境; 2、翻译环境:将源代码转换成可执行的二进制指令(机器指令);.c文件(源文件——文本信息的代码)->&#xff0…

论文学习——Class-Conditioned Latent Diffusion Model For DCASE 2023

文章目录 引言正文AbstractIntroductionSystem Overview2.1 Latent Diffusion with sound-class-based conditioning以声音类别为条件的潜在扩散模型2.2 Variational Autoencoder and neural vocoder变分自编码器和神经声码器FAD-oriented Postprocessing filter(专…

Linux开启SSH

Linux开启SSH 1.虚拟机确定连通性 如果是虚拟机的话则需要进行确定和宿主主机之间能正常联通(不能联通还远程个啥) 获取到虚拟机的IP 参考文章:Linux获取本机IP地址使用宿主机ping一下虚拟机的IP查看是否联通 2.安装SSH服务端 安装工具来使得能够通过SSH进行连接 命令 sudo a…

springBoot组件注册

springBoot组件注册 前言1、创建组件文件2、写属性3、生成get和set方法4、以前注册的方法5、现在注册的方法6、在启动文件查看7、多实例Scope("prototype")8、注册第三方包导入对应的场景启动器注册组件查看是否存在也可以通过Import(FastsqlException.class)导入但是…

C++医院影像科PACS源码:三维重建、检查预约、胶片打印、图像处理、测量分析等

PACS连接DICOM接口的医疗器械(如CT、MRI、CR、DR、DSA、各种窥镜成像系统设备等),实现图像无损传输,实现DICOM胶片打印机回传打印功能,支持各种图像处理,可以进行窗技术调节,与登记台管理系统共…

Spring Boot 中的 TransactionTemplate 是什么,如何使用

Spring Boot中的TransactionTemplate:简化事务管理 事务管理是任何应用程序中至关重要的部分,特别是在处理数据库操作时。Spring Boot提供了多种方式来管理事务,其中之一是使用TransactionTemplate。本文将深入探讨TransactionTemplate是什么…

树莓派玩转openwrt软路由:5.OpenWrt防火墙配置及SSH连接

1、SSH配置 打开System -> Administration,打开SSH Access将Interface配置成unspecified。 如果选中其他的接口表示仅在给定接口上侦听,如果未指定,则在所有接口上侦听。在未指定下,所有的接口均可通过SSH访问认证。 2、防火…

如何在手机上设置节日提醒和倒计时天数?

在平淡的生活和工作中,时不时有各种各样节日的点缀,为我们的日常增添了一些仪式感,例如春节、元宵节、情人节、端午节、七夕节等。此外还有一些特殊的日子也值得纪念,例如恋爱纪念日、结婚纪念日、亲朋好友生日等。面对这些节日&a…

CodeForces每日好题10.14

给你一个字符串 让你删除一些字符让它变成一个相邻的字母不相同的字符串,问你最小的删除次数 以及你可以完成的所有方/案数 求方案数往DP 或者 组合数学推公式上面去想,发现一个有意思的事情 例如1001011110 这个字符串你划分成1 00 1 0 1111 0 每…

Step 1 搭建一个简单的渲染框架

Step 1 搭建一个简单的渲染框架 万事开头难。从萌生到自己到处看源码手抄一个mini engine出来的想法,到真正敲键盘去抄,转眼过去了很久的时间。这次大概的确是抱着认真的想法,打开VS从零开始抄代码。不知道能坚持多久呢。。。 本次的主题是搭…

多城镇信息发布付费置顶公众号开源版开发

多城镇信息发布付费置顶公众号开源版开发 以下是多城镇信息发布付费置顶公众号的功能列表: 信息发布:用户可以在公众号上发布各类信息,如房屋租售、二手物品交易、招聘信息等。 信息置顶:用户可以选择付费将自己的信息置顶在公众…

vue2时间处理插件——dayjs

在vue时间处理上有很多的方法和实现,可以自己实现,但是效率不高,所以,在框架开发中我们一般不会手写,一般是使用集成的第三方插件来解决我们的问题,在vue3中大家一般都使用Moment.js来处理,所以…

print() 函数

二、print() 函数 这里先说一下 print() 函数,如果你是新手,可能对函数不太了解,没关系,在这里你只要了解它的组成部分和作用就可以了,后面函数这一块会详细说明的。 print() 函数由两部分构成 : 指令&a…

Folium笔记: Popup

1 介绍 在 folium 中,Popup 是一个用于在地图上显示附加信息的对象。当在地图上点击一个标记(例如,一个点或者一个形状)时,Popup 会显示出来。Popup 可以包含纯文本,但也可以包含HTML代码 2 主要参数 htm…

【SCSS篇】Vite+Vue3项目全局引入scss文件

文章目录 前言一、安装与使用1.1 安装1.2 scss 全局文件编写1.2.1 概述 1.3 全局引入和配置1.4 组件内使用 vue2 项目引入 sass附:忽略ts类型检测 前言 Sass 是世界上最成熟、最稳定、最强大的专业级CSS扩展语言!在日常项目开发过程中使用非常广泛&…