- 保存检查点:save
- 函数功能详解:
- 核心作用:
- 示例调用
- 加载检查点(补充说明)
- 恢复检查点:load
- 核心功能
- 代码功能解析
- 应用场景
- 示例调用
- `save` 和 `load` 的对比
介绍神经网络训练 Train
类的两个常用函数 save 和 load。
保存检查点:save
def save(self, milestone):if not self.accelerator.is_local_main_process:returndata = {'step': self.step,'model': self.accelerator.get_state_dict(self.model),'opt': self.opt.state_dict(),'ema': self.ema.state_dict(),'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None,}torch.save(data, str(self.results_path / f'model-{milestone}.pt'))
这个 save
函数的功能是保存训练过程中模型的状态,以便在需要时能够恢复训练或使用保存的模型进行推理和评估。它是典型的检查点(checkpoint)保存函数。
函数功能详解:
-
保存触发条件:
if not self.accelerator.is_local_main_process:return
- 该代码段确保只有主进程(主节点)会执行保存操作,避免在分布式训练中因多个进程同时写入文件导致冲突或冗余。
-
构建保存的数据字典:
data = {'step': self.step, # 当前训练步数,用于恢复训练进度。'model': self.accelerator.get_state_dict(self.model), # 模型的状态字典,包括所有参数和结构。'opt': self.opt.state_dict(), # 优化器的状态,用于恢复优化过程。'ema': self.ema.state_dict(), # EMA(指数移动平均)状态,用于保存和恢复平滑的权重更新。'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None, # 混合精度训练的缩放器状态(若存在)。 }
step
: 当前的训练步数,恢复时可以从中断的地方继续。model
: 使用Accelerator
包装后的模型状态字典,确保即使在分布式或混合精度训练下,也能正确提取模型参数。opt
: 优化器状态字典,包括动量、学习率等信息。ema
: 如果使用了 EMA 平滑模型参数,它的状态字典也会被保存。scaler
: 如果使用了混合精度(FP16)训练,保存GradScaler
的状态以支持后续训练。
-
保存数据到文件:
torch.save(data, str(self.results_path / f'model-{milestone}.pt'))
torch.save
将data
字典序列化并保存为.pt
文件。- 文件名格式为
model-{milestone}.pt
,其中milestone
是训练进度的标识(通常是保存检查点时的步数或 epoch)。
核心作用:
-
训练中断后的恢复:
- 保存训练的步数、模型、优化器和 EMA 的状态,以便在意外中断后可以从最近的保存点继续训练,而不需要从头开始。
-
模型的版本管理:
- 每次达到里程碑(milestone)时保存一个检查点,方便回溯特定训练阶段的模型,进行性能对比或调试。
-
支持推理或部署:
- 保存的检查点文件可以用来加载模型参数,用于推理、测试或部署到生产环境。
-
分布式训练的兼容性:
- 通过
Accelerator
管理模型和优化器状态,确保在分布式训练中正确保存参数。
- 通过
示例调用
假设训练每 1000 步保存一次,可以如下调用:
if step % 1000 == 0:trainer.save(step)
保存的文件会命名为 model-1000.pt
、model-2000.pt
等,分别对应不同训练阶段的模型状态。
加载检查点(补充说明)
与保存对应,恢复模型时通常需要加载保存的检查点:
checkpoint = torch.load('model-1000.pt')
trainer.step = checkpoint['step']
trainer.model.load_state_dict(checkpoint['model'])
trainer.opt.load_state_dict(checkpoint['opt'])
trainer.ema.load_state_dict(checkpoint['ema'])if 'scaler' in checkpoint and checkpoint['scaler'] is not None:trainer.accelerator.scaler.load_state_dict(checkpoint['scaler'])
恢复检查点:load
def load(self, milestone):accelerator = self.acceleratordevice = accelerator.devicedata = torch.load(str(self.results_path / f'model-{milestone}.pt'), map_location=device)model = self.accelerator.unwrap_model(self.model)model.load_state_dict(data['model'])print("model loaded: ", str(self.results_path / f'model-{milestone}.pt'))self.step = data['step']self.opt.load_state_dict(data['opt'])if self.accelerator.is_main_process:self.ema.load_state_dict(data["ema"])if 'version' in data:print(f"loading from version {data['version']}")if exists(self.accelerator.scaler) and exists(data['scaler']):self.accelerator.scaler.load_state_dict(data['scaler'])
这段代码的功能与 save
函数是相对的,用于从检查点加载训练状态,以便恢复中断的训练或者加载已经训练好的模型进行推理和评估。
核心功能
load
函数的主要任务是从指定的检查点文件加载模型和训练相关的状态,包括:
- 模型的权重参数。
- 训练步数。
- 优化器的状态。
- EMA 的权重状态(如果在主进程)。
- 混合精度训练的缩放器状态(若使用 FP16)。
代码功能解析
-
加载检查点文件:
data = torch.load(str(self.results_path / f'model-{milestone}.pt'), map_location=device)
- 从
self.results_path
路径中加载指定里程碑milestone
对应的检查点文件。 - 使用
map_location=device
确保将数据加载到正确的设备(例如 GPU 或 CPU)。
- 从
-
加载模型参数:
model = self.accelerator.unwrap_model(self.model) model.load_state_dict(data['model']) print("model loaded: ", str(self.results_path / f'model-{milestone}.pt'))
- 使用
accelerator.unwrap_model
解包模型(因为Accelerator
会对模型进行包装)。 - 加载检查点中保存的
model
参数到当前模型实例。 - 打印加载的检查点文件路径,便于确认。
- 使用
-
恢复训练步数:
self.step = data['step']
- 恢复训练时的当前步数,确保训练从中断点继续。
-
恢复优化器状态:
self.opt.load_state_dict(data['opt'])
- 加载优化器的状态,包括动量和学习率调度器的状态。
-
恢复 EMA 权重:
if self.accelerator.is_main_process:self.ema.load_state_dict(data["ema"])
- 如果当前是主进程,则恢复检查点中保存的 EMA 权重状态,以保持平滑的模型权重更新。
-
加载检查点版本(可选):
if 'version' in data:print(f"loading from version {data['version']}")
- 如果检查点中保存了版本信息(
version
),打印该版本号,便于调试和跟踪。
- 如果检查点中保存了版本信息(
-
恢复混合精度缩放器状态:
if exists(self.accelerator.scaler) and exists(data['scaler']):self.accelerator.scaler.load_state_dict(data['scaler'])
- 如果启用了 FP16 混合精度训练,并且检查点保存了
scaler
状态,则恢复该状态,确保继续训练时数值精度一致。
- 如果启用了 FP16 混合精度训练,并且检查点保存了
应用场景
-
训练中断后的恢复:
- 通过加载最近的检查点,可以从中断点继续训练,节省时间和计算资源。
-
模型调试与优化:
- 可以在不同的里程碑(milestone)加载模型,观察其表现,分析训练过程中的问题。
-
推理和评估:
- 加载保存的模型进行测试或部署。
示例调用
假设保存的检查点文件名为 model-1000.pt
,恢复训练的代码如下:
trainer.load(1000) # 加载第 1000 步的训练状态
加载完成后,训练可以从第 1000 步继续:
for step in range(trainer.step, trainer.train_num_steps):trainer.train_step()if step % trainer.save_and_sample_every == 0:trainer.save(step)
save
和 load
的对比
功能 | save | load |
---|---|---|
作用 | 保存训练状态 | 恢复训练状态 |
保存内容/加载内容 | 模型参数、优化器状态、EMA 权重、缩放器状态、当前步数 | 恢复模型参数、优化器状态、EMA 权重、缩放器状态、当前步数 |
触发条件 | 达到保存频率(如每 1000 步)或手动调用 | 从指定检查点加载 |
常见用途 | 检查点管理、训练中断后可恢复 | 继续训练、推理或调试 |
通过 save
和 load
的配合,整个训练过程变得更加健壮,既支持中断恢复,又方便结果管理和调试。