系列文章目录
PyTorch深度学习——Anaconda和PyTorch安装
Pytorch深度学习-----数据模块Dataset类
Pytorch深度学习------TensorBoard的使用
Pytorch深度学习------Torchvision中Transforms的使用(ToTensor,Normalize,Resize ,Compose,RandomCrop)
Pytorch深度学习------torchvision中dataset数据集的使用(CIFAR10)
Pytorch深度学习-----DataLoader的用法
Pytorch深度学习-----神经网络的基本骨架-nn.Module的使用
Pytorch深度学习-----神经网络的卷积操作
Pytorch深度学习-----神经网络之卷积层用法详解
Pytorch深度学习-----神经网络之池化层用法详解及其最大池化的使用
Pytorch深度学习-----神经网络之非线性激活的使用(ReLu、Sigmoid)
Pytorch深度学习-----神经网络之线性层用法
Pytorch深度学习-----神经网络之Sequential的详细使用及实战详解
Pytorch深度学习-----损失函数(L1Loss、MSELoss、CrossEntropyLoss)
Pytorch深度学习-----优化器详解(SGD、Adam、RMSprop)
Pytorch深度学习-----现有网络模型的使用及修改(VGG16模型)
文章目录
- 系列文章目录
- 一、网络模型的保存
- 1.方法一
- 2.方法二
- 二、网络模型的加载
- 1.方法一
- 2.方法二
- 三、总结
一、网络模型的保存
1.方法一
保存整个模型,包括其相关的所有参数
torch.save(obj, f, pickle_protocol=DEFAULT_PROTOCOL)
参数说明:
obj:
要保存的对象,可以是模型、张量、字典等。
f:
要保存到的文件路径或文件对象。
pickle_protocol:
序列化协议的版本,默认为DEFAULT_PROTOCOL。
代码如下:
import torch
import torchvision.models as models
from torch import nnvgg16_true = models.vgg16(weights=True)
vgg16_false = models.vgg16(weights=False)torch.save(vgg16_true, "vgg16_model_true.pth")
其中.pth是后缀标志。
2.方法二
只保存模型参数,在原有vgg16对象中使用.state_dict()方法即可。
代码如下:
import torch
import torchvision.models as models
from torch import nnvgg16_true = models.vgg16(weights=True)
vgg16_false = models.vgg16(weights=False)torch.save(vgg16_true.state_dict(), "vgg16_model_true_2.pth")
二、网络模型的加载
1.方法一
对应于上述中保存模型的方法1进行加载。
相关函数如下:
torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)
参数说明:
f:
要加载的文件路径或文件对象。
map_location:
可选参数,用于指定在哪个设备上加载模型。如果不提供该参数,默认会加载到当前设备。
pickle_module:
可选参数,用于指定用于反序列化的模块。默认为pickle。
pickle_load_args:
其他可选的用于反序列化的参数。
代码如下:
import torch
import torchvision.models as models
from torch import nnmodel1 = torch.load("vgg16_model_true.pth") # 因为vgg16_model_true.pth是使用方法一保存的,故输出后是整个模型网络结构
print(model1)
model2 = torch.load("vgg16_model_true_2.pth") # 因为vgg16_model_true_2.pth是使用方法二保存的,只保留模型参数,故输出后是整个字典类型
print(model2)
vgg16_model_true.pth加载结果
VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)
vgg16_model_true_2.pth加载结果
OrderedDict([('features.0.weight', tensor([[[[-5.5373e-01, 1.4270e-01, 5.2896e-01],[-5.8312e-01, 3.5655e-01, 7.6566e-01],[-6.9022e-01, -4.8019e-02, 4.8409e-01]],[[ 1.7548e-01, 9.8630e-03, -8.1413e-02],[ 4.4089e-02, -7.0323e-02, -2.6035e-01],[ 1.3239e-01, -1.7279e-01, -1.3226e-01]],[[ 3.1303e-01, -1.6591e-01, -4.2752e-01],[ 4.7519e-01, -8.2677e-02, -4.8700e-01],[ 6.3203e-01, 1.9308e-02, -2.7753e-01]]],[[[ 2.3254e-01, 1.2666e-01, 1.8605e-01],[-4.2805e-01, -2.4349e-01, 2.4628e-01],[-2.5066e-01, 1.4177e-01, -5.4864e-03]],[[-1.4076e-01, -2.1903e-01, 1.5041e-01],[-8.4127e-01, -3.5176e-01, 5.6398e-01],[-2.4194e-01, 5.1928e-01, 5.3915e-01]],[[-3.1432e-01, -3.7048e-01, -1.3094e-01],[-4.7144e-01, -1.5503e-01, 3.4589e-01],[ 5.4384e-02, 5.8683e-01, 4.9580e-01]]],[[[ 1.7715e-01, 5.2149e-01, 9.8740e-03],[-2.7185e-01, -7.1709e-01, 3.1292e-01],[-7.5753e-02, -2.2079e-01, 3.3455e-01]],[[ 3.0924e-01, 6.7071e-01, 2.0546e-02],[-4.6607e-01, -1.0697e+00, 3.3501e-01],[-8.0284e-02, -3.0522e-01, 5.4460e-01]],[[ 3.1572e-01, 4.2335e-01, -3.4976e-01],[ 8.6354e-02, -4.6457e-01, 1.1803e-02],[ 1.0483e-01, -1.4584e-01, -1.5765e-02]]],...,
2.方法二
import torch
import torchvision.models as models
from torch import nnvgg16_true = models.vgg16(weights=True)vgg16_true.load_state_dict(torch.load("vgg16_model_true_2.pth")) # 针对第二种加载参数的情况,使其显示完整的网络结构
print(vgg16_true)
VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)
注意: 加载模型时,要确保当前代码中使用的模型类与之前保存的模型类相同。
三、总结
torch.load()是PyTorch中用于加载保存的对象的函数,可以加载之前使用torch.save()保存的模型、张量、字典等。可以指定要加载的文件路径或文件对象,并可选地指定加载到的设备、反序列化模块等参数。