深度学习中的学习率调度器(scheduler)分析并作图查看各方法差异

文章目录

    • 1. 指数衰减调度器(Exponential Decay Scheduler)
      • 工作原理
      • 适用场景
      • 实现示例
    • 2. 余弦退火调度器(Cosine Annealing Scheduler)
      • 工作原理
      • 适用场景
      • 实现示例
    • 3. 步长衰减调度器(Step Decay Scheduler)
      • 工作原理
      • 适用场景
      • 实现示例
    • 4. 多项式衰减与预热调度器(Polynomial Decay with Warm-up)
      • 工作原理
      • 适用场景
      • 实现示例
    • 5. 多步衰减调度器(MultiStep Decay Scheduler)
      • 工作原理
      • 适用场景
      • 实现示例
    • 总结
    • 参考资料

在深度学习模型训练过程中, 学习率调度器(Learning Rate Scheduler)是优化过程中不可或缺的重要组成部分。它们能够在训练的不同阶段自动调整学习率,从而提高模型的收敛速度和最终性能。选择合适的学习率调度器对于优化训练过程至关重要,不同的调度器适用于不同的训练需求和模型架构。本文将介绍几种常用的学习率调度器,并通过 PyTorch 提供的 torch.optim.lr_schedulertransformers 库中的调度器,展示具体的实现示例及其适用场景。可以通过 运行示例代码来作图查看学习率变化情况,能帮助大家更好的了解不同方法的区别。

1. 指数衰减调度器(Exponential Decay Scheduler)

请添加图片描述

工作原理

指数衰减调度器通过在每个训练步骤中以固定的速率减小学习率,从而逐步降低学习率。这种调度器适用于需要平稳且持续减小学习率的训练过程,有助于模型在训练后期稳定收敛。

适用场景

  • 稳定收敛:适用于希望学习率在整个训练过程中持续且缓慢降低,以避免训练后期的震荡。
  • 简单调整:当训练过程相对稳定,不需要复杂的学习率调整策略时,指数衰减是一个简单有效的选择。

实现示例

import matplotlib.pyplot as plt
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ExponentialLR# 定义优化器和参数
initial_lr = 5e-5
num_training_steps = 3000
decay_rate = 0.99
params = [torch.nn.Parameter(torch.randn(10, 10)) for _ in range(5)]  # 示例模型参数
optimizer = AdamW(params, lr=initial_lr)# 定义指数衰减调度器
scheduler = ExponentialLR(optimizer, gamma=decay_rate)# 模拟学习率变化
learning_rates = []
for step in range(num_training_steps):optimizer.step()scheduler.step()current_lr = optimizer.param_groups[0]['lr']learning_rates.append(current_lr)# 绘制学习率变化曲线
plt.figure(figsize=(12, 6))
plt.plot(learning_rates, label='Learning Rate')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('Exponential Decay Scheduler')
plt.legend()
plt.grid(True)
plt.show()

2. 余弦退火调度器(Cosine Annealing Scheduler)

请添加图片描述

工作原理

余弦退火调度器通过余弦函数调整学习率,使其在训练过程中呈现周期性变化。这种调度器特别适用于处理模型训练中的振荡现象,能够在训练末期提供较低的学习率以帮助模型更好地收敛。

适用场景

  • 避免局部最优:通过周期性调整学习率,可以帮助模型跳出局部最优解。
  • 动态调整:适用于需要在训练过程中动态调整学习率以应对不同训练阶段需求的场景。
  • 模型复杂度较高:对于复杂模型,如深层神经网络,余弦退火有助于更好地探索参数空间。

实现示例

import matplotlib.pyplot as plt
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR# 优化器和参数定义同上
initial_lr = 5e-5
num_training_steps = 3000
T_max = 1000  # 一个周期内的步数
params = [torch.nn.Parameter(torch.randn(10, 10)) for _ in range(5)]  # 示例模型参数
optimizer = AdamW(params, lr=initial_lr)# 定义余弦退火调度器
scheduler = CosineAnnealingLR(optimizer, T_max=T_max)# 模拟学习率变化
learning_rates = []
for step in range(num_training_steps):optimizer.step()scheduler.step()current_lr = optimizer.param_groups[0]['lr']learning_rates.append(current_lr)# 绘制学习率变化曲线
plt.figure(figsize=(12, 6))
plt.plot(learning_rates, label='Learning Rate')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('Cosine Annealing Scheduler')
plt.legend()
plt.grid(True)
plt.show()

3. 步长衰减调度器(Step Decay Scheduler)

在这里插入图片描述

工作原理

步长衰减调度器在训练过程中每隔一定的步数(step_size)后按指定的因子(gamma)降低学习率。这种调度器适用于需要在训练过程中分阶段减小学习率的场景,有助于模型在不同训练阶段进行有效的学习。

适用场景

  • 分阶段训练:适用于需要在训练的特定阶段进行学习率调整的任务,如先快速学习再细致调整。
  • 明确的训练阶段:当训练过程可以划分为多个明确的阶段,每个阶段需要不同学习率时,步长衰减是理想选择。
  • 资源受限的训练:在有限的训练资源下,通过分阶段调整学习率可以更有效地利用计算资源。

实现示例

import matplotlib.pyplot as plt
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR# 调度器参数
initial_lr = 5e-5
num_training_steps = 3000
step_size = 500  # 每隔 step_size 个 step,学习率衰减一次
gamma = 0.1      # 衰减因子
params = [torch.nn.Parameter(torch.randn(10, 10)) for _ in range(5)]  # 示例模型参数
optimizer = AdamW(params, lr=initial_lr)# 定义步长衰减调度器
scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)# 模拟学习率变化
learning_rates = []
for step in range(num_training_steps):optimizer.step()scheduler.step()current_lr = optimizer.param_groups[0]['lr']learning_rates.append(current_lr)# 绘制学习率变化曲线
plt.figure(figsize=(12, 6))
plt.plot(learning_rates, label='Learning Rate')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('Step Decay Scheduler')
plt.legend()
plt.grid(True)
plt.show()

4. 多项式衰减与预热调度器(Polynomial Decay with Warm-up)

在这里插入图片描述

工作原理

多项式衰减与预热调度器结合了学习率预热和多项式衰减的优势。训练初期通过预热阶段逐步增加学习率,随后按照多项式函数逐步降低学习率。这种调度器适用于如 BERT 等复杂模型的训练,有助于在训练初期稳定模型参数并在后期促进收敛。

适用场景

  • 复杂模型训练:适用于需要在训练初期进行稳定性的复杂模型,如 Transformer、BERT 等。
  • 防止初期震荡:通过预热阶段逐步增加学习率,可以防止训练初期由于学习率过高导致的梯度震荡。
  • 需要精细控制:适用于需要对学习率进行精细控制,以实现最佳收敛效果的任务。

实现示例

import matplotlib.pyplot as plt
import torch
from torch.optim import AdamW
from transformers import get_polynomial_decay_schedule_with_warmup# 调度器参数
initial_lr = 5e-5
warmup_steps = 100
num_training_steps = 3000
lr_end = 1e-7  # 最低学习率
power = 2.0    # 多项式衰减的幂次
params = [torch.nn.Parameter(torch.randn(10, 10)) for _ in range(5)]  # 示例模型参数
optimizer = AdamW(params, lr=initial_lr)# 定义多项式衰减与预热调度器
scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps, lr_end=lr_end, power=power
)  # 二次衰减# 模拟学习率变化
learning_rates = []
for step in range(num_training_steps):optimizer.step()scheduler.step()current_lr = optimizer.param_groups[0]['lr']learning_rates.append(current_lr)# 绘制学习率变化曲线
plt.figure(figsize=(12, 6))
plt.plot(learning_rates, label='Learning Rate')
plt.axvline(x=warmup_steps, color='r', linestyle='--', label='End of Warm-up')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('Polynomial Decay Scheduler with Warm-up')
plt.legend()
plt.grid(True)
plt.show()

5. 多步衰减调度器(MultiStep Decay Scheduler)

在这里插入图片描述

工作原理

多步衰减调度器在预设的多个步数(milestones)时刻按指定的因子(gamma)降低学习率。这种调度器允许在训练过程中在多个关键点调整学习率,适用于需要在多个阶段显著改变学习率的训练任务。

适用场景

  • 多阶段训练:适用于训练过程中有多个关键阶段,每个阶段需要不同学习率的任务。
  • 灵活调整:当训练过程不规则或需要根据训练进展手动调整学习率时,多步衰减提供了灵活性。
  • 特定任务需求:适用于一些特定任务或模型架构,需要在特定步数后调整学习率以优化性能。

实现示例

import matplotlib.pyplot as plt
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import MultiStepLR# 调度器参数
initial_lr = 5e-5
num_training_steps = 3000
milestones = [1000, 2000]  # 指定的步数
gamma = 0.1  # 衰减因子
params = [torch.nn.Parameter(torch.randn(10, 10)) for _ in range(5)]  # 示例模型参数
optimizer = AdamW(params, lr=initial_lr)# 定义多步衰减调度器
scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=gamma)# 模拟学习率变化
learning_rates = []
for step in range(num_training_steps):optimizer.step()scheduler.step()current_lr = optimizer.param_groups[0]['lr']learning_rates.append(current_lr)# 绘制学习率变化曲线
plt.figure(figsize=(12, 6))
plt.plot(learning_rates, label='Learning Rate')
for i, milestone in enumerate(milestones):if i == 0:plt.axvline(x=milestone, color='r', linestyle='--', label=f'Milestone at Step {milestone}')else:plt.axvline(x=milestone, color='r', linestyle='--')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('MultiStep Decay Scheduler')
plt.legend()
plt.grid(True)
plt.show()

注意:在多步衰减调度器的绘图代码中,plt.axvline 函数仅在第一个里程碑处添加标签,后续的里程碑标签设置为 None'_nolegend_',以避免图例中出现重复的标签。

总结

以上示例代码展示了不同学习率调度器的实现方式以及学习率随训练步骤变化的过程。选择合适的调度器可以根据具体任务和模型的需求来优化训练效果。以下是各类调度器的快速参考:

  • 指数衰减调度器(Exponential Decay Scheduler):适用于希望学习率持续且缓慢降低,稳定收敛的训练过程。
  • 余弦退火调度器(Cosine Annealing Scheduler):适用于需要动态调整学习率以避免局部最优,尤其适合复杂模型。
  • 步长衰减调度器(Step Decay Scheduler):适用于分阶段训练,明确划分训练阶段的任务。
  • 多项式衰减与预热调度器(Polynomial Decay with Warm-up):适用于复杂模型训练,防止初期震荡并促进后期收敛。
  • 多步衰减调度器(MultiStep Decay Scheduler):适用于多阶段训练,需要在多个关键点调整学习率的任务。

在实际应用中,可以根据模型的复杂度、数据集的特性以及训练的阶段性需求,灵活选择和调整学习率调度策略,以实现最佳的训练效果。

参考资料

  • PyTorch 官方文档 - Learning Rate Scheduler
  • Transformers 库 - 调度器

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/1846.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

WEB攻防-通用漏洞_XSS跨站_权限维持_捆绑钓鱼_浏览器漏洞

目录 XSS的分类 XSS跨站-后台植入Cookie&表单劫持 【例1】:利用beef或xss平台实时监控Cookie等凭据实现权限维持 【例2】:XSS-Flash钓鱼配合MSF捆绑上线 【例3】:XSS-浏览器网马配合MSF访问上线 XSS的分类 反射型(非持久…

【AIGC-ChatGPT进阶提示词指令】智慧母婴:打造基于成长树的儿童发展引导系统

第一次进入全站综合热榜,有点紧张 好了,开始今天的内容,今天的内容是基于育儿的系统 今天继续回馈大家,最近都是可以在自媒体上使用的提示词。提示词在最下方 引言 在人工智能时代,如何将传统育儿智慧与现代教育理念有…

使用葡萄城+vue实现Excel

最终实现效果如下 包含增加复选框 设置公式 设置背景颜色等,代码实在太多 有需要可留言 第一步:创建表头 请使用官网提供的网址:在线 Excel 编辑器 | SpreadJS 在线表格编辑器 1.点击下方号,创建一个新的sheet页 默认新创建的she…

【Qt】01-了解QT

踏入QT的殿堂之路 前言一、创建工程文件1.1 步骤介绍1.2 编译介绍方法1、方法2、编译成功 二、了解框架2.1 main.cpp2.2 .Pro文件2.2.1 注释需要打井号。2.2.2 F1带你进入帮助模式2.2.3 build文件 2.3 构造函数 三、编写工程3.1 main代码3.2 结果展示 四、指定父对象4.1 main代…

《异步编程之美》— 全栈修仙《Java 8 CompletableFuture 对比 ES6 Promise 以及Spring @Async》

哈喽,大家好!在平常开发过程中会遇到许多意想不到的坑,本篇文章就记录在开发过程中遇到一些常见的问题,看了许多博主的异步编程,我只能说一言难尽。本文详细的讲解了异步编程之美,是不可多得的好文&#xf…

unity——Preject3——面板基类

目录 1.Canvas Group Canvas Group 的功能 Canvas Group 的常见用途 如何使用 Canvas Group 2.代码 3.代码分析 类分析:BasePanel 功能 作用 实际应用 代码解析:hideCallBack?.Invoke(); 语法知识点 作用 虚函数(virtual)和抽象类(abstract)的作用与区别 …

Windows service运行Django项目

系统:Windows Service 软件:nssm,nginx 配置Django项目 1、把Django项目的静态文件整理到staticfiles文件夹中 注:settings中的设置 STATIC_URL /static/ STATIC_ROOT os.path.join(BASE_DIR, staticfiles/) STATICFILES_DI…

SQL面试题1:连续登陆问题

引言 场景介绍: 许多互联网平台为了提高用户的参与度和忠诚度,会推出各种连续登录奖励机制。例如,游戏平台会给连续登录的玩家发放游戏道具、金币等奖励;学习类 APP 会为连续登录学习的用户提供积分,积分可兑换课程或…

【大数据】机器学习-----线性模型

一、线性模型基本形式 线性模型旨在通过线性组合输入特征来预测输出。其一般形式为: 其中: x ( x 1 , x 2 , ⋯ , x d ) \mathbf{x}(x_1,x_2,\cdots,x_d) x(x1​,x2​,⋯,xd​) 是输入特征向量,包含 d d d 个特征。 w ( w 1 , w 2 , ⋯ ,…

OpenCV基础:矩阵的创建、检索与赋值

本文主要是介绍如何使用numpy进行矩阵的创建,以及从矩阵中读取数据,修改矩阵数据。 创建矩阵 import numpy as npa np.array([1,2,3]) b np.array([[1,2,3],[4,5,6]]) #print(a) #print(b)# 创建全0数组 eros矩阵 c np.zeros((8,8), np.uint8) #prin…

(蓝桥杯)二维数组前缀和典型例题——子矩阵求和

题目描述 小 A 同学有着很强的计算能力,张老师为了检验小 AA同学的计算能力,写了一个 n 行 m 列的矩阵数列。 张老师问了小 A 同学 k 个问题,每个问题会先告知小 A 同学 4 个数 x1,y1,x2,y2画出一个子矩阵,张老师请小 A同学计算出…

Node.js - HTTP

1. HTTP请求 HTTP(Hypertext Transfer Protocol,超文本传输协议)是客户端和服务器之间通信的基础协议。HTTP 请求是由客户端(通常是浏览器、手机应用或其他网络工具)发送给服务器的消息,用来请求资源或执行…

[读书日志]8051软核处理器设计实战(基于FPGA)第七篇:8051软核处理器的测试(verilog+C)

6. 8051软核处理器的验证和使用 为了充分测试8051的性能,我们需要测试每一条指令。在HELLO文件夹中存放了整个测试的C语言工程文件。主函数存放在指令被分为五大类,和上面一样。 打开后是这样的文件结构。HELLO.c是主文件,这是里面的代码&am…

深入浅出 Android AES 加密解密:从理论到实战

深入浅出 Android AES 加密解密:从理论到实战 在现代移动应用中,数据安全是不可忽视的一环。无论是用户隐私保护,还是敏感信息的存储与传输,加密技术都扮演着重要角色。本文将以 AES(Advanced Encryption Standard&am…

IDEA编译器集成Maven环境以及项目的创建(2)

选择:“File” ---> "Othoer Setting" --> "Settings for New Projects..." --->搜索“Maven” 新建项目 利用maven命令去编译这个项目 利用maven去打包

Open FPV VTX开源之默认MAVLink设置

Open FPV VTX开源之默认MAVLink设置 1. 源由2. 准备3. 连接4. 安装5. 配置6. 测试6.1 启动wfb-ng服务6.2 启动wfb-ng监测6.3 启动QGroundControl6.4 观察测试结果 7. 总结8. 参考资料9. 补充9.1 telemetry_tx异常9.2 DEBUG串口部分乱码9.3 PixelPilot软件问题 1. 源由 飞控图传…

gesp(C++五级)(4)洛谷:B3872:[GESP202309 五级] 巧夺大奖

gesp(C五级)(4)洛谷:B3872:[GESP202309 五级] 巧夺大奖 题目描述 小明参加了一个巧夺大奖的游戏节目。主持人宣布了游戏规则: 游戏分为 n n n 个时间段,参加者每个时间段可以选择一个小游戏。 游戏中共有…

像JSONDecodeError: Extra data: line 2 column 1 (char 134)这样的问题怎么解决

问题介绍 今天处理返回的 JSON 的时候,出现了下面这样的问题: 处理这种问题的时候,首先你要看一下当前的字符串格式是啥样的,比如我查看后发现是下面这样的: 会发现这个字符串中间没有逗号,也就是此时的J…

道旅科技借助云消息队列 Kafka 版加速旅游大数据创新发展

作者:寒空、横槊、娜米、公仪 道旅科技:科技驱动,引领全球旅游分销服务 道旅科技 (https://www.didatravel.com/home) 成立于 2012 年,总部位于中国深圳,是一家以科技驱动的全球酒店资源批发商…

导出文件,能够导出但是文件打不开

背景: 在项目开发中,对于列表的查询,而后会有导出功能,这里导出的是一个excell表格。实现了两种,1.导出的文件,命名是前端传输过去的;2.导出的文件,命名是根据后端返回的文件名获取的…