transformers进行学习率调整lr_scheduler(warmup)

一、get_scheduler实现warmup

1、warmup基本思想

Warmup(预热)是深度学习训练中的一种技巧,旨在逐步增加学习率以稳定训练过程,特别是在训练的早期阶段。它主要用于防止在训练初期因学习率过大导致的模型参数剧烈波动或不稳定。预热阶段通常是指在训练开始时,通过多个步长逐步将学习率从一个较低的值增加到目标值(通常是预定义的最大学习率)。

2、warmup基本实现

from transformers import get_schedulerscheduler = get_scheduler(name="cosine",  # 可以选择 'linear', 'cosine', 'polynomial', 'constant', 'constant_with_warmup'optimizer=optimizer,num_warmup_steps=100,  # 预热步数num_training_steps=num_training_steps  # 总的训练步数
)#linear:线性学习率下降
#cosine:余弦退火
#polynomial:多项式衰减
#constant:常数学习率
#constant_with_warmup:预热后保持常数# 上述代码等价于
from transformers import get_cosine_scheduler_with_warmupscheduler = get_cosine_scheduler_with_warmup(optimizer=optimizer,num_warmup_steps=100,  # 预热步数num_training_steps=num_training_steps  # 总的训练步数
)# 同理等价于linear, polynomial, constant分别等价于
from transformers import (get_constant_schedule, get_polynomial_decay_schedule_with_warmup, get_linear_schedule_with_warmup)

 二、各种warmup策略学习率变化规律

1、get_constant_schedule学习率变化规律

2、get_cosine_schedule_with_warmup学习率变化规律

3、get_cosine_with_hard_restarts_schedule_with_warmup学习率变化规律

4、get_linear_schedule_with_warmup学习率变化规律

5、get_polynomial_decay_schedule_with_warmup学习率变化规律(power=2, power=1类似于linear)

6、注意事项

  • 如果网络中不同框架采用不同的学习率,上述的warmup策略仍然有效(如图二、5中所示) 
  • 给schduler设置的number_training_steps一定要和训练过程相匹配,如下所示。

7、可视化学习率过程

import matplotlib.pyplot as plt
from transformers import get_scheduler
from torch.optim import AdamW
import torch
import math# 定义一些超参数learning_rate = 1e-3  # 初始学习率# 假设有一个模型
model = torch.nn.Linear(10, 2)# 获得训练总的步数
epochs = 50
batch_size = 32
#train_loader = ***
#num_train_loader = len(train_loader)
num_train_loader = 1235num_training_steps = epochs * math.ceil(num_train_loader/batch_size) # 总的训练步数# 定义优化器
optimizer = AdamW(model.parameters(), lr=learning_rate)# 创建学习率调度器
scheduler = get_scheduler(name="cosine",  # 可以选择 'linear', 'cosine', 'polynomial', 'constant', 'constant_with_warmup'optimizer=optimizer,num_warmup_steps=100,  # 预热步数num_training_steps=num_training_steps  # 总的训练步数
)# 存储每一步的学习率
learning_rates = []# for step in range(num_training_steps):
#    optimizer.step()
#    scheduler.step()
#    learning_rates.append(optimizer.param_groups[0]['lr'])for epoch in range(epochs):# for batch in train_loader:for step in range(0, num_train_loader, batch_size):optimizer.zero_grad()# loss.backward()optimizer.step()scheduler.step()learning_rates.append(optimizer.param_groups[0]['lr'])# 绘制学习率曲线
plt.plot(learning_rates)
plt.xlabel("Training Steps")
plt.ylabel("Learning Rate")
plt.title("Learning Rate Schedule")
plt.show()

实验结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/385037.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity3D之TCP网络通信(客户端)

文章目录 概述TCP核心类异步机制 Unity中创建TCP客户端Unity中其它脚本获取TCP客户端接受到的数据后续改进 本文将以Unity3D应用项目作为客户端去连接制定的服务器为例进行相关说明。 Unity官网参考资料: https://developer.unity.cn/projects/6572ea1bedbc2a001ef…

20240725java的Controller、DAO、DO、Mapper、Service层、反射、AOP注解等内容的学习

在Java开发中,‌controller、‌dao、‌do、‌mapper等概念通常与MVC(‌Model-View-Controller)‌架构和分层设计相关。‌这些概念各自承担着不同的职责,‌共同协作以构建和运行一个应用程序。‌以下是这些概念的解释:‌…

当全球银行系统“崩溃”时会发生什么?

有句名言:“当美国打喷嚏时,世界就会感冒……”换句话说,当人们对美国及其经济稳定性的信心下降时,其他经济体(以及黄金、白银和股票等资产)的价值往往会下降。 与任何其他资产类别一样,加密货…

黑马JavaWeb企业级开发(知识清单)03——HTML实现正文:排版(音视频、换行、段落)、布局标签(div、span)、盒子模型

文章目录 前言一、正文排版1. 视频标签: < video >2. 音频标签: < audio >3. 换行标签: < br >4. 段落标签 < p >5. vscode实现 二、布局1. 盒子模型2. 布局标签< div >和< span >3. VScode实现 三、源代码和运行结果总结 前言 本篇文章是…

leetcode3098. 求出所有子序列的能量和

官解 class Solution(object):# 定义常量mod int(1e9 7) # 模数&#xff0c;用于防止结果溢出inf float(inf) # 无穷大&#xff0c;用于初始化时的特殊值def sumOfPowers(self, nums, k):n len(nums) # 数组长度res 0 # 用于存储最终结果# 三维动态规划表&#xff0c;…

24.7.17数据结构|顺序表

目录 大O的工程意义&#xff1f; 线性表 引入&#xff1a; 主要掌握【代码实现】&#xff1a; 一、线性结构 1、逻辑描述 2、顺序表 1、如何定义结构 1&#xff09;静态顺序表 1&#xff09;动态顺序表 2、写代码 &#xff08;1&#xff09;【clion创建工程】 ​编…

【数据结构】详解二叉树及其操作

无论你觉得自己多么的了不起&#xff0c;也永远有人比你更强。&#x1f493;&#x1f493;&#x1f493; 目录 ✨说在前面 &#x1f34b;知识点一&#xff1a;二叉树的遍历 • &#x1f330;1.创建一棵二叉树 • &#x1f330;2.二叉树的遍历 •&#x1f525;前序遍历 •&a…

Apache DolphinScheduler 3.2.2 版本正式发布!

Apache DolphinScheduler 3.2.2 版本正式发布&#xff01; 近日&#xff0c;Apache DolphinScheduler 发布了 3.2.2 版本。此版本主要基于 3.2.1 版本进行了 bug 修复&#xff0c;新增若干特性&#xff0c;并进行了众多改进和 Bug 修复&#xff0c;以及文档修复等。 &#x1…

【前端 08】简单学习js字符串

JavaScript中的String对象详解 在JavaScript中&#xff0c;字符串&#xff08;String&#xff09;是一种非常基础且常用的数据类型&#xff0c;用于表示文本数据。虽然JavaScript中的字符串是原始数据类型&#xff0c;但它们的行为类似于对象&#xff0c;因为JavaScript为字符…

谷粒商城实战笔记-52~53-商品服务-API-三级分类-新增-修改

文章目录 一&#xff0c;52-商品服务-API-三级分类-新增-新增效果完成1&#xff0c;点击Append按钮&#xff0c;显示弹窗2&#xff0c;测试完整代码 二&#xff0c;53-商品服务-API-三级分类-修改-修改效果完成1&#xff0c;添加Edit按钮并绑定事件2&#xff0c;修改弹窗确定按…

vue3-print-nb实现打印pdf分页

安装插件 npm install vue3-print-nb --savevue3 引入 import print from vue3-print-nb // 打印插件 app.use(print)使用 这里使用的是对象配置方式 对象配置方式——在js中定义一个对象&#xff0c;对象中可配置打印区域相关属性&#xff0c;在需要打印的单据内容最外面的…

【Django】在vscode中新建Django应用并新增路由

文章目录 打开一个终端输入新建app命令在app下的views.py内写一个视图app路由引入该视图项目路由引入app路由项目(settings.py)引入app&#xff08;AntappConfig配置类&#xff09;运行项目 打开一个终端 输入新建app命令 python manage.py startapp antapp在app下的views.py内…

MySQL第一阶段:多表查询、事务

继续我的MySQL之旅&#xff0c;继续上篇的DDL、DML、DQL、以及一些约束&#xff0c;该到了多表查询和事务的学习总结&#xff0c;以及相关的案例实现&#xff0c;为未来的复习以及深入的理解做好知识储备。 目录 多表查询 连接查询 内连接 外连接 子查询 事务 事务简介…

为什么用LeSS?

实现适应性 LeSS是一个产品开发的组织系统&#xff0c;旨在最大化一个组织的适应性。关于适应性&#xff08;或者敏捷性&#xff0c;也就是敏捷开发的初衷&#xff09;我们是指优化&#xff1a; 以相对低的成本改变方向的能力&#xff0c;主要是基于通过频繁交付产生的探索。从…

pyuic5将ui文件转换为py文件报错:one input ui-file must be specified;no element found;

ERROR 1 文件命名不规范Solution 1:文件命名不能有空格 ERROR 2未选中ui文件 Solution 2:选中要转换成py 的文件

writing classes ... [xxx of xxxx] 执行时间太长

一、问题展示 二、解决方法 打开设置【File - Settings…】修改堆大小

使用vfbox网关实现modbus opc profinet iec61850等协议间的转换

在当今物联网&#xff08;IoT&#xff09;与工业自动化日益融合的时代背景下&#xff0c;协议转换网关作为连接不同设备与系统之间的桥梁&#xff0c;扮演着至关重要的角色。VFBox协议转换网关&#xff0c;作为这一领域内的佼佼者&#xff0c;以其高效、灵活、可靠的性能&#…

鸿蒙APP架构及开发入门

1.鸿蒙系统 1.1 什么是鸿蒙 鸿蒙是一款面向万物互联时代的、全新的分布式操作系统。 在传统的单设备系统能力基础上&#xff0c;鸿蒙提出了基于同一套系统能力、适配多种终端形态的分布式理念&#xff0c;能够支持手机、平板、智能穿戴、智慧屏、车机、PC、智能音箱、耳机、…

从代码层面熟悉UniAD,开始学习了解端到端整体架构

0. 简介 最近端到端已经是越来越火了&#xff0c;以UniAD为代表的很多工作不断地在不断刷新端到端的指标&#xff0c;比如最近SparseDrive又重新刷新了所有任务的指标。在端到端火热起来之前&#xff0c;成熟的模块化自动驾驶系统被分解为不同的独立任务&#xff0c;例如感知、…

【Django】网上蛋糕商城后台-商品管理

1.商品管理功能 当管理员点击商品管理时&#xff0c;发送服务器请求 path(admin/goods_list/, viewsAdmin.goods_list), # 处理商品列表请求 def goods_list(request):try:type request.GET["type"]except:type 0try:ym request.GET["ym"]except:ym …