网络训练中的检查点——保存和恢复训练状态

  • 保存检查点:save
    • 函数功能详解:
    • 核心作用:
    • 示例调用
    • 加载检查点(补充说明)
  • 恢复检查点:load
    • 核心功能
    • 代码功能解析
    • 应用场景
    • 示例调用
  • `save` 和 `load` 的对比

介绍神经网络训练 Train 类的两个常用函数 save 和 load。

保存检查点:save

 def save(self, milestone):if not self.accelerator.is_local_main_process:returndata = {'step': self.step,'model': self.accelerator.get_state_dict(self.model),'opt': self.opt.state_dict(),'ema': self.ema.state_dict(),'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None,}torch.save(data, str(self.results_path / f'model-{milestone}.pt')) 

这个 save 函数的功能是保存训练过程中模型的状态,以便在需要时能够恢复训练或使用保存的模型进行推理和评估。它是典型的检查点(checkpoint)保存函数

函数功能详解:

  1. 保存触发条件

    if not self.accelerator.is_local_main_process:return
    
    • 该代码段确保只有主进程(主节点)会执行保存操作,避免在分布式训练中因多个进程同时写入文件导致冲突或冗余。
  2. 构建保存的数据字典

    data = {'step': self.step,  # 当前训练步数,用于恢复训练进度。'model': self.accelerator.get_state_dict(self.model),  # 模型的状态字典,包括所有参数和结构。'opt': self.opt.state_dict(),  # 优化器的状态,用于恢复优化过程。'ema': self.ema.state_dict(),  # EMA(指数移动平均)状态,用于保存和恢复平滑的权重更新。'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None,  # 混合精度训练的缩放器状态(若存在)。
    }
    
    • step: 当前的训练步数,恢复时可以从中断的地方继续。
    • model: 使用 Accelerator 包装后的模型状态字典,确保即使在分布式或混合精度训练下,也能正确提取模型参数。
    • opt: 优化器状态字典,包括动量、学习率等信息。
    • ema: 如果使用了 EMA 平滑模型参数,它的状态字典也会被保存。
    • scaler: 如果使用了混合精度(FP16)训练,保存 GradScaler 的状态以支持后续训练。
  3. 保存数据到文件

    torch.save(data, str(self.results_path / f'model-{milestone}.pt'))
    
    • torch.savedata 字典序列化并保存为 .pt 文件。
    • 文件名格式为 model-{milestone}.pt,其中 milestone 是训练进度的标识(通常是保存检查点时的步数或 epoch)。

核心作用:

  1. 训练中断后的恢复

    • 保存训练的步数、模型、优化器和 EMA 的状态,以便在意外中断后可以从最近的保存点继续训练,而不需要从头开始。
  2. 模型的版本管理

    • 每次达到里程碑(milestone)时保存一个检查点,方便回溯特定训练阶段的模型,进行性能对比或调试。
  3. 支持推理或部署

    • 保存的检查点文件可以用来加载模型参数,用于推理、测试或部署到生产环境。
  4. 分布式训练的兼容性

    • 通过 Accelerator 管理模型和优化器状态,确保在分布式训练中正确保存参数。

示例调用

假设训练每 1000 步保存一次,可以如下调用:

if step % 1000 == 0:trainer.save(step)

保存的文件会命名为 model-1000.ptmodel-2000.pt 等,分别对应不同训练阶段的模型状态。

加载检查点(补充说明)

与保存对应,恢复模型时通常需要加载保存的检查点:

checkpoint = torch.load('model-1000.pt')
trainer.step = checkpoint['step']
trainer.model.load_state_dict(checkpoint['model'])
trainer.opt.load_state_dict(checkpoint['opt'])
trainer.ema.load_state_dict(checkpoint['ema'])if 'scaler' in checkpoint and checkpoint['scaler'] is not None:trainer.accelerator.scaler.load_state_dict(checkpoint['scaler'])

恢复检查点:load

    def load(self, milestone):accelerator = self.acceleratordevice = accelerator.devicedata = torch.load(str(self.results_path / f'model-{milestone}.pt'), map_location=device)model = self.accelerator.unwrap_model(self.model)model.load_state_dict(data['model'])print("model loaded: ", str(self.results_path / f'model-{milestone}.pt'))self.step = data['step']self.opt.load_state_dict(data['opt'])if self.accelerator.is_main_process:self.ema.load_state_dict(data["ema"])if 'version' in data:print(f"loading from version {data['version']}")if exists(self.accelerator.scaler) and exists(data['scaler']):self.accelerator.scaler.load_state_dict(data['scaler']) 

这段代码的功能与 save 函数是相对的,用于从检查点加载训练状态,以便恢复中断的训练或者加载已经训练好的模型进行推理和评估。

核心功能

load 函数的主要任务是从指定的检查点文件加载模型和训练相关的状态,包括:

  1. 模型的权重参数
  2. 训练步数
  3. 优化器的状态
  4. EMA 的权重状态(如果在主进程)。
  5. 混合精度训练的缩放器状态(若使用 FP16)

代码功能解析

  1. 加载检查点文件

    data = torch.load(str(self.results_path / f'model-{milestone}.pt'), map_location=device)
    
    • self.results_path 路径中加载指定里程碑 milestone 对应的检查点文件。
    • 使用 map_location=device 确保将数据加载到正确的设备(例如 GPU 或 CPU)。
  2. 加载模型参数

    model = self.accelerator.unwrap_model(self.model)
    model.load_state_dict(data['model'])
    print("model loaded: ", str(self.results_path / f'model-{milestone}.pt'))
    
    • 使用 accelerator.unwrap_model 解包模型(因为 Accelerator 会对模型进行包装)。
    • 加载检查点中保存的 model 参数到当前模型实例
    • 打印加载的检查点文件路径,便于确认。
  3. 恢复训练步数

    self.step = data['step']
    
    • 恢复训练时的当前步数,确保训练从中断点继续。
  4. 恢复优化器状态

    self.opt.load_state_dict(data['opt'])
    
    • 加载优化器的状态,包括动量和学习率调度器的状态。
  5. 恢复 EMA 权重

    if self.accelerator.is_main_process:self.ema.load_state_dict(data["ema"])
    
    • 如果当前是主进程,则恢复检查点中保存的 EMA 权重状态,以保持平滑的模型权重更新。
  6. 加载检查点版本(可选)

    if 'version' in data:print(f"loading from version {data['version']}")
    
    • 如果检查点中保存了版本信息(version),打印该版本号,便于调试和跟踪。
  7. 恢复混合精度缩放器状态

    if exists(self.accelerator.scaler) and exists(data['scaler']):self.accelerator.scaler.load_state_dict(data['scaler'])
    
    • 如果启用了 FP16 混合精度训练,并且检查点保存了 scaler 状态,则恢复该状态,确保继续训练时数值精度一致。

应用场景

  1. 训练中断后的恢复

    • 通过加载最近的检查点,可以从中断点继续训练,节省时间和计算资源。
  2. 模型调试与优化

    • 可以在不同的里程碑(milestone)加载模型,观察其表现,分析训练过程中的问题。
  3. 推理和评估

    • 加载保存的模型进行测试或部署。

示例调用

假设保存的检查点文件名为 model-1000.pt,恢复训练的代码如下:

trainer.load(1000)  # 加载第 1000 步的训练状态

加载完成后,训练可以从第 1000 步继续:

for step in range(trainer.step, trainer.train_num_steps):trainer.train_step()if step % trainer.save_and_sample_every == 0:trainer.save(step)

saveload 的对比

功能saveload
作用保存训练状态恢复训练状态
保存内容/加载内容模型参数、优化器状态、EMA 权重、缩放器状态、当前步数恢复模型参数、优化器状态、EMA 权重、缩放器状态、当前步数
触发条件达到保存频率(如每 1000 步)或手动调用从指定检查点加载
常见用途检查点管理、训练中断后可恢复继续训练、推理或调试

通过 saveload 的配合,整个训练过程变得更加健壮,既支持中断恢复,又方便结果管理和调试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/475812.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【卡尔曼滤波】数据预测Prediction观测器的理论推导及应用 C语言、Python实现(Kalman Filter)

【卡尔曼滤波】数据预测Prediction观测器的理论推导及应用 C语言、Python实现(Kalman Filter) 更新以gitee为准: 文章目录 数据预测概念和适用方式线性系统的适用性 数据预测算法和卡尔曼滤波公式推导状态空间方程和观测器先验估计后验估计…

大模型时代的具身智能系列专题(十三)

迪士尼研究中心 瑞士苏黎世迪斯尼研究中心致力于不同领域的业务活动,其中包括电影、电视、公园和度假村以及消费产品。我们针对所有这些领域进行科研工作。我们开发能使我们将后道生产元素整合到前级生产中的技术。由此可节省许多昂贵的效果,这些效果最…

IDEA2023设置控制台日志输出到本地文件

1、Run->Edit Configurations 2、选择要输出日志的日志,右侧,IDEA2023的Logs在 Modify option 里 选中就会展示Logs栏。注意一定要先把这个日志文件创建出来,不然不会自动创建日志文件的 IDEA以前版本的Logs会直接展示出来 3、但是…

o1的风又吹到多模态,直接吹翻了GPT-4o-mini

开源LLaVA-o1:一个设计用于进行自主多阶段推理的新型VLM。与思维链提示不同,LLaVA-o1独立地参与到总结、视觉解释、逻辑推理和结论生成的顺序阶段。 LLaVA-o1超过了一些更大甚至是闭源模型的性能,例如Gemini-1.5-pro、GPT-4o-mini和Llama-3.…

AJAX的基本使用

AJAX的基本使用 🎉🎉🎉欢迎来到我的博客,我是一名自学了2年半前端的大一学生,熟悉的技术是JavaScript与Vue.目前正在往全栈方向前进, 如果我的博客给您带来了帮助欢迎您关注我,我将会持续不断的更新文章!!!🙏🙏&#x…

DDei在线设计器V1.2.43版发布

2024-11-21-----V1.2.43 一、bug 修复 1. 修复只读情况下,连线依然可以通过特殊点调整的 bug 2. 修复了同一页面多个实例时,部分方法只会引用最后一个实例的问题 3. 修复了组合控件和容器控件改变容器后没有清理的问题,优化了容器的实现 4. …

C++进阶:哈希表实现

目录 一:哈希表的概念 1.1直接定址法 1.2哈希冲突 1.3负载因子 1.4实现哈希函数的方法 1.4.1除法散列法/除留余数法 1.4.2乘法散列法 1.4.3全域散列法 1.5处理哈希冲突 1.5.1开放地址法 线性探测 二次探测 ​编辑 双重散列 1.5.2链地址法 二.代码实现 2.1开放地址…

鸿蒙NEXT开发案例:血型遗传计算

【引言】 血型遗传计算器是一个帮助用户根据父母的血型预测子女可能的血型的应用。通过选择父母的血型,应用程序能够快速计算出孩子可能拥有的血型以及不可能拥有的血型。这个过程不仅涉及到了简单的数据处理逻辑,还涉及到UI设计与交互体验的设计。 【…

(十八)JavaWeb后端开发案例——会话/yml/过滤器/拦截器

目录 1.业务逻辑实现 1.1 登录校验技术——会话 1.1.1Cookie 1.1.2session 1.1.3JWT令牌技术 2.参数配置化 3.yml格式配置文件 4.过滤器Filter 5.拦截器Interceptor 1.业务逻辑实现 Day10-02. 案例-部门管理-查询_哔哩哔哩_bilibili //Controller层/*** 新增部门*/Pos…

2024.5 AAAiGLaM:通过邻域分区和生成子图编码对领域知识图谱对齐的大型语言模型进行微调

GLaM: Fine-Tuning Large Language Models for Domain Knowledge Graph Alignment via Neighborhood Partitioning and Generative Subgraph Encoding 问题 如何将特定领域知识图谱直接整合进大语言模型(LLM)的表示中,以提高其在图数据上自…

amd显卡和nVidia显卡哪个好 amd和英伟达的区别介绍

AMD和英伟达是目前市场上最主要的两大显卡品牌,它们各有自己的特点和优势,也有不同的适用场景和用户群体。那么,AMD显卡和英伟达显卡到底哪个好?它们之间有什么区别?我们又该如何选择呢?本文将从以下几个方…

接口加密了怎么测?

🍅 点击文末小卡片 ,免费获取软件测试全套资料,资料在手,涨薪更快 1、定义加密需求 确定哪些数据需要进行加密。这可以是用户敏感信息、密码、身份验证令牌等。确定使用的加密算法,如对称加密(如AES&am…

接口上传视频和oss直传视频到阿里云组件

接口视频上传 <template><div class"component-upload-video"><el-uploadclass"avatar-uploader":action"uploadImgUrl":on-progress"uploadVideoProcess":on-success"handleUploadSuccess":limit"lim…

springboot基于数据挖掘的广州招聘可视化分析系统

摘 要 基于数据挖掘的广州招聘可视化分析系统是一个创新的在线平台&#xff0c;旨在通过深入分析大数据来优化和改善广州地区的招聘流程。系统利用Java语言、MySQL数据库&#xff0c;结合目前流行的 B/S架构&#xff0c;将广州招聘可视化分析管理的各个方面都集中到数据库中&a…

VIM的下载使用与基本指令【入门级别操作】

VIM——超级文本编辑器 在当今时代&#xff0c;功能极其复杂的代码编辑器和集成开发环境&#xff08;IDE&#xff09;有很多。 但如果只想要一个超轻量级的代码编辑器&#xff0c;用于 Unix、C 或其他语言/系统&#xff0c;而不需要那些华而不实的功能&#xff0c;该怎么办呢&…

心情追忆-首页“毒“鸡汤AI自动化

之前&#xff0c;我独自一人开发了一个名为“心情追忆”的小程序&#xff0c;旨在帮助用户记录日常的心情变化及重要时刻。我从项目的构思、设计、前端&#xff08;小程序&#xff09;开发、后端搭建到最终部署。经过一个月的努力&#xff0c;通过群聊分享等方式&#xff0c;用…

开源代码统计工具cloc的简单使用

一.背景 公司之前开发了个小系统&#xff0c;要去申请著作权&#xff0c;需要填写代码数量。应该怎么统计呢&#xff1f;搜索了一下&#xff0c;还是用开源工具cloc吧&#xff01;我的操作系统是windows&#xff0c;代码主要是java项目和vue项目。 二.到哪里找 可以去官方下载…

基于单片机的条形码识别结算设计

本设计基于单片机的条形码辨识与结算系统。该系统主要用于超市、商场等场所的商品结算&#xff0c;实现了在超市内对不同种类商品进行自动识别及自动分类结算的功能。该系统由STM32F103C8T6单片机、摄像头、显示、蜂鸣器报警、按键和电源等多个模块构成。该系统可实现商品自动识…

进程间通信的信号艺术:机制、技术与实战应用深度剖析

目录 1 什么是信号 2 为什么要有信号 3 对于信号的反应 3.1 默认行为 3.2 signal()函数 -- 自定义行为对信号做出反应 3.3 对信号进行忽略 4 信号的产生的类型 4.1 kill命令 4.2 键盘输入产生信号 4.3 系统调用接口 4.3.1 kill() 4.3.2 raise() 函数 4.4 软件条件 …

美畅物联丨JT/T 808 终端设备如何加入畅联云平台

在道路运输行业中&#xff0c;JT/T 808终端设备的应用正变得越来越广泛&#xff0c;把该设备接入畅联云平台&#xff0c;能够达成更高效的车辆管理与监控功能。今天&#xff0c;我们就来探讨一下JT/T 808终端设备接入畅联云平台的步骤与要点。 一、了解畅联云平台接入要求 首先…