解决deepspeed框架的bug:不保存调度器状态,模型训练重启时学习率从头开始

deepspeed存在一个bug,即在训练时不保存调度器状态,因此如果训练中断后再重新开始训练,调度器还是会从头开始而不是接着上一个checkpoint的调度器状态来训练。这个bug在deepspeed的github中也有其他人提出:https://github.com/microsoft/DeepSpeed/issues/3875
因此我们需要写一个保存调度器状态的代码,才可以解决这个问题。
具体方法是加一个callback类,专门负责保存调度器的状态以及在训练重新开始时加载调度器的状态:
先在训练文件中给trainer加一个callback

from smoe.callbacks.save_model import SchedulerStateCallback
trainer.add_callback(SchedulerStateCallback)
class SchedulerStateCallback(TrainerCallback):def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):if os.environ.get("RANK", "0") == "0":#scheduler = kwargs['lr_scheduler']scheduler = kwargs.get("lr_scheduler")if scheduler is None:return scheduler_state = scheduler.state_dict()#save_path = os.path.join(args.output_dir, SCHEDULER_NAME)# 使用 PREFIX_CHECKPOINT_DIR 和 global_step 创建检查点目录名checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"# 完整的检查点目录路径checkpoint_path = os.path.join(args.output_dir, checkpoint_folder)# 如果目录不存在,则创建它if not os.path.exists(checkpoint_path):os.makedirs(checkpoint_path)# 完整的保存路径save_path = os.path.join(checkpoint_path, SCHEDULER_NAME)# 保存scheduler状态torch.save(scheduler_state, save_path)def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):# 如果resume_from_checkpoint设置了有效路径if args.resume_from_checkpoint is not None:load_path = os.path.join(args.resume_from_checkpoint, SCHEDULER_NAME)# 如果该路径下有保存的调度器状态,则加载它if os.path.exists(load_path):#scheduler = kwargs['lr_scheduler']scheduler = kwargs.get("lr_scheduler")if scheduler is None:return scheduler_state = torch.load(load_path)scheduler.load_state_dict(scheduler_state)

解决效果如下,我们可以看到,在chaeckpoint10重新开始训练的时候,学习率是接着之前的学习率开始的(5.5e-7),而不是从头开始(0.5e-7):
在这里插入图片描述在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/126537.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ResNet 09

一、发展 1989年,Yann LeCun提出了一种用反向传导进行更新的卷积神经网络,称为LeNet。 1998年,Yann LeCun提出了一种用反向传导进行更新的卷积神经网络,称为LeNet-5 AlexNet是2012年ISLVRC 2012(ImageNet Large Sca…

MySQL——备份和还原

备份 热备 即MySQL服务在运行的时候进行的备份 mysqldump命令 mysqldump --databases db1 db2 db3 > dump.sql mysqldump -uroot -pSanchuang1234# --all-databases >all_db.sql mysqldump -uroot -pSanchuang123# --databases TENNIS >/backup/tennis.sql mysq…

Always On 数据库无法自动同步的问题

问题: 在给客户的SQL Server 2019 配置好Always On 之后,不久就出现高可用组里的一个库无法正常同步。 第一次出现,以为是偶发性问题,直接右键点击恢复数据同步,没一会就同步好了;过了一个月问题又出现了…

高云USB下载器仿真器用户手册(包括在线逻辑分析仪的使用方法)

高云 USB 仿真器用户手册 一.简介 仿真器用于高云 GOWIN 公司所生产的 FPGA,可用于程序下载和调试。主要特点如下: 1.支持宽电压1.2V - 3.6V; 2.速度最高可达30Mb/s,极速完成下载和波形调试功能; 3.完美支持在线逻…

【个人博客系统网站】项目的发布 · 通过公网IP访问我们的网站 · 思考总结

【JavaEE】进阶 个人博客系统(6) 文章目录 【JavaEE】进阶 个人博客系统(6)1. 项目发布1.1 后端代码修改1.1.1 数据库密码1.1.2 端口号修改1.1.3 文件保存地址修改1.1.4 静态资源映射修改 1.2 云服务器1.2.1 建库建表1.2.2 必要…

自动化驱动程序管理

在部署操作系统时,每次都从下载和分发所需的驱动程序中实现真正的独立性可能是一场艰苦的战斗。特别是具有硬件多样化的环境,并且需要支持新的硬件类型时。借助 OS Deployer,可以对所有端点使用一个映像,无论品牌和型号如何&#…

使用maven idea环境

目录 idea三种方式执行maven命令 工程导入 生命周期lifecycle 插件和目标 常用命令 创建模块工程后 idea三种方式执行maven命令 想在哪个工程模块上执行就点开哪一个 如果觉得双击完clean再双击install麻烦,可以 如果有需要还可以给命令后面加参数 ​​​ 第三种…

C# 共享项目的应用

概述 共享项目也可以称为共享资产项目,它允许在多个目标项目之间共享的代码。 它支持编译器指令,可以有条件地包含特定于平台的代码,以便编译为引用共享项目的项目的子集。 还有 IDE 支持,可帮助管理编译器指令并直观显示代码在每个应用程序中的外观。 什么是共享项目? …

XL-LightHouse 与 Flink 和 ClickHouse 流式大数据统计系统

一个Flink任务只能并行处理一个或少数几个数据流,而XL-LightHouse一个任务可以并行处理数万个、几十万个数据流; 一个Flink任务只能实现一个或少数几个数据指标,而XL-LightHouse单个任务就能支撑大批量、数以万计的数据指标。 1、XL-LightHo…

Excel文件生成与下载(SpringBoot项目)(easypoi)

说明 通过接口&#xff0c;导出表格。 使用SpringBoot框架和easypoi表格解析框架&#xff0c;生成Excel表格&#xff0c;并通过接口下载。 表格示例 依赖 版本 <easypoi.version>4.4.0</easypoi.version>依赖 <!-- easypoi --> <dependency><…

springboot整合mybatisPlus全技巧(2-常用开发技巧:通用字段插入)

本系列专题基于 springboot 整合 mybatisPlus 的各种文章早已烂大街的背景下&#xff0c;根据 整合过程&#xff0c;MP开发中的常见技巧&#xff0c;MP开发中遇到的各种坑 三个方面&#xff0c;来对这一专题做一个全面且实用的总结&#xff0c;基本上只要你吃透这篇文章&#x…

Linux mac Windows三系统 局域网文件共享方法

主要工具&#xff1a; Samba是一个开源的软件套件&#xff0c;允许Linux系统与Windows系统之间共享文件和打印机。 一、首先是Linux共享的设置 ①安装 sudo apt-get install samba ②创建共享文件夹 sudo mkdir /home/share ③配置用户 sudo smbpasswd -a kequan ④修改…

Java智慧工地信息化管理平台源码,依托计算机信息、网络通讯、物联网、系统集成及云计算技术建立

Java智慧工地源码 智慧工地APP源码 系统定义&#xff1a; 智慧工地信息化管理平台是依托计算机信息、网络通讯、物联网、系统集成及云计算技术&#xff0c;通过数据采集、信息动态交互、智能分析&#xff0c;建立起来的一套集成的项目建设综合管理系统。实现项目管理信息化、网…

图像噪声--添加噪声

椒盐噪声 椒盐噪声就是给图片添加黑白噪点&#xff0c;椒指的是黑色的噪点(0,0,0),盐指的是白色的噪点(255,255,255)&#xff0c;通过num来控制噪声多少&#xff0c;值越大添加的噪声越多&#xff0c;图像损坏的更加严重。 void add_salt_pepper_noise(Mat& src,Mat& …

淘宝双11数据分析与预测课程案例中(林子雨)错误点总结

问题一&#xff1a;可视化代码中男女买家各个年龄段对比散点图中数值不显示以及坐标不正确问题如下图 解决方法&#xff1a; 1修改坐标 2修改数值 修改后散点图 问题二&#xff1a;各省份的总成交量对比中地图显示不出来 有时间再写

JavaScript中的原型链(prototype chain)

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ JavaScript中的原型链⭐ 写在最后 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端之旅 欢迎来到前端入门之旅&#xff01;这个专栏是为那些对Web开发感兴趣、刚刚踏…

手机电脑scoket通信 手机软件 APP inventor 服务端程序python

python scoket 通信 再帮助同学坐课题的时候接触到了scoket通信&#xff0c;了解到这应该是基层网络通信的原理&#xff0c;于是就导出搜索了一下相关的资料&#xff0c;简单来说scoket通信就是&#xff0c;可以让不同设备在同一个网络环境的条件下&#xff0c;可以实现相互通…

视频汇聚/视频云存储/视频监控管理平台EasyCVR安全检查的相关问题及解决方法2.0

开源EasyDarwin视频监控TSINGSEE青犀视频平台EasyCVR能在复杂的网络环境中&#xff0c;将分散的各类视频资源进行统一汇聚、整合、集中管理&#xff0c;在视频监控播放上&#xff0c;TSINGSEE青犀视频安防监控汇聚平台可支持1、4、9、16个画面窗口播放&#xff0c;可同时播放多…

Ribbon负载均衡+Nacos服务搭建

Ribbon负载均衡 流程 首先通过RibbonLoadBalanceerClient获取服务名&#xff0c;并传给DynamicServerListLoadBalancer——>通过EureKa-server获取服务名对应服务列表(也就是被注册到EureKa中的服务&#xff0c;可能包括不同端口的)&#xff0c;然后我们会根据IRule中的服务…

基于SSM的网络游戏公司官方平台

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用JSP技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…