Pytorch模型保存与加载,并在加载的模型基础上继续训练
系统学习Pytorch笔记三:Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)
一、只保存参数
1. 保存
一般地,采用一条语句即可保存参数:
torch.save(model.state_dict(), path)
其中model指定义的模型实例变量,如 model=vgg16( ), path是保存参数的路径,如 path=‘./model.pth’ , path=‘./model.tar’, path=‘./model.pkl’, 保存参数的文件一定要有后缀扩展名。
特别地,如果还想保存某一次训练采用的优化器、epochs等信息,可将这些信息组合起来构成一个字典,然后将字典保存起来:
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
torch.save(state, path)
2. 加载
针对上述第一种情况,也只需要一句即可加载模型:
model.load_state_dict(torch.load(path))
针对上述第二种以字典形式保存的方法,加载方式如下:
checkpoint = torch.load(path) # load model
model.load_state_dict(checkpoint['model']) # load parameters
optimizer.load_state_dict(checkpoint['optimizer']) # load optimizer
epoch = checkpoint(['epoch']) # load epoch for training continue
二、保存整个模型
1. 保存
torch.save(model, path)
2. 加载
model = torch.load(path)
三、在训练中pytorch通过Dataloader加载数据
torch.utils.data.DataLoader():
构建可迭代的数据装载器, 我们在训练的时候,每一个for
循环,每一次iteration
,就是从DataLoader
中获取一个batch_size
大小的数据的。
例如:
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)