PyTorch使用Tricks:学习率衰减 !!

文章目录

前言

1、指数衰减

2、固定步长衰减

3、多步长衰减

4、余弦退火衰减

5、自适应学习率衰减

6、自定义函数实现学习率调整:不同层不同的学习率


前言

在训练神经网络时,如果学习率过大,优化算法可能会在最优解附近震荡而无法收敛;如果学习率过小,优化算法的收敛速度可能会非常慢。因此,一种常见的策略是在训练初期使用较大的学习率来快速接近最优解,然后逐渐减小学习率,使得优化算法可以更精细地调整模型参数,从而找到更好的最优解。

通常学习率衰减有以下的措施:

  • 指数衰减:学习率按照指数的形式衰减,每次乘以一个固定的衰减系数,可以使用 torch.optim.lr_scheduler.ExponentialLR 类来实现,需要指定优化器和衰减系数。
  • 固定步长衰减:学习率每隔一定步数(或者epoch)就减少为原来的一定比例,可以使用 torch.optim.lr_scheduler.StepLR 类来实现,需要指定优化器、步长和衰减比例。
  • 多步长衰减:学习率在指定的区间内保持不变,在区间的右侧值进行一次衰减,可以使用 torch.optim.lr_scheduler.MultiStepLR 类来实现,需要指定优化器、区间列表和衰减比例。
  • 余弦退火衰减:学习率按照余弦函数的周期和最值进行变化,可以使用 torch.optim.lr_scheduler.CosineAnnealingLR 类来实现,需要指定优化器、周期和最小值。
  • 自适应学习率衰减:这种策略会根据模型的训练进度自动调整学习率,可以使用 torch.optim.lr_scheduler.ReduceLROnPlateau 类来实现。例如,如果模型的验证误差停止下降,那么就减小学习率;如果模型的训练误差上升,那么就增大学习率。
  • 自适应函数实现学习率调整:不同层不同的学习率。

1、指数衰减

指数衰减是一种常用的学习率调整策略,其主要思想是在每个训练周期(epoch)结束时,将当前学习率乘以一个固定的衰减系数(gamma),从而实现学习率的指数衰减。这种策略可以帮助模型在训练初期快速收敛,同时在训练后期通过降低学习率来提高模型的泛化能力。

在PyTorch中,可以使用 torch.optim.lr_scheduler.ExponentialLR 类来实现指数衰减。该类的构造函数需要两个参数:一个优化器对象和一个衰减系数。在每个训练周期结束时,需要调用ExponentialLR 对象的 step() 方法来更新学习率。

以下是一个使用 ExponentialLR代码示例:

import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import ExponentialLR# 假设有一个模型参数
model_param = torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))# 使用SGD优化器,初始学习率设置为0.1
optimizer = SGD([model_param], lr=0.1)# 创建ExponentialLR对象,衰减系数设置为0.9
scheduler = ExponentialLR(optimizer, gamma=0.9)# 在每个训练周期结束时,调用step()方法来更新学习率
for epoch in range(100):# 这里省略了模型的训练代码# ...# 更新学习率scheduler.step()

在这个例子中,初始的学习率是0.1,每个训练周期结束时,学习率会乘以0.9,因此学习率会按照指数的形式衰减。

2、固定步长衰减

固定步长衰减是一种学习率调整策略,它的原理是每隔一定的迭代次数(或者epoch),就将学习率乘以一个固定的比例,从而使学习率逐渐减小。这样做的目的是在训练初期使用较大的学习率,加快收敛速度,而在训练后期使用较小的学习率,提高模型精度。

PyTorch提供了 torch.optim.lr_scheduler.StepLR 类来实现固定步长衰减,它的参数有:

  • optimizer:要进行学习率衰减的优化器,例如 torch.optim.SGD torch.optim.Adam等。
  • step_size:每隔多少隔迭代次数(或者epoch)进行一次学习率衰减,必须是正整数。
  • gamma:学习率衰减的乘法因子,必须是0到1之间的数,表示每次衰减为原来的 gamma倍。
  • last_epoch:最后一个epoch的索引,用于恢复训练的状态,默认为-1,表示从头开始训练。
  • verbose:是否打印学习率更新的信息,默认为False。

下面是一个使用  torch.optim.lr_scheduler.StepLR 类的具体例子,假设有一个简单的线性模型,使用 torch.optim.SGD 作为优化器,初始学习率为0.1,每隔5个epoch就将学习率乘以0.8,训练100个epoch:

import torch
import matplotlib.pyplot as plt# 定义一个简单的线性模型
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.fc = torch.nn.Linear(1, 1)def forward(self, x):return self.fc(x)# 创建模型和优化器
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)# 创建固定步长衰减的学习率调度器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.8)# 记录学习率变化
lr_list = []# 模拟训练过程
for epoch in range(100):# 更新学习率scheduler.step()# 记录当前学习率lr_list.append(optimizer.param_groups[0]['lr'])# 绘制学习率曲线
plt.plot(range(100), lr_list, color='r')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.show()

学习率在每个5个epoch后都会下降为原来的0.8倍,直到最后接近0。

固定步长衰减指数衰减都是学习率衰减的策略,但它们在衰减的方式和速度上有所不同:

  • 固定步长衰减:在每隔固定的步数(或epoch)后,学习率会减少为原来的一定比例。这种策略的衰减速度是均匀的,不会随着训练的进行而改变。
  • 指数衰减:在每个训练周期(或epoch)结束时,学习率会乘以一个固定的衰减系数,从而实现学习率的指数衰减。这种策略的衰减速度是逐渐加快的,因为每次衰减都是基于当前的学习率进行的。

3、多步长衰减

多步长衰减是一种学习率调整策略,它在指定的训练周期(或epoch)达到预设的里程碑时,将学习率减少为原来的一定比例。这种策略可以在模型训练的关键阶段动态调整学习率。

在PyTorch中,可以使用 torch.optim.lr_scheduler.MultiStepLR 类来实现多步长衰减。以下是一个使用 MultiStepLR 的代码示例:

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR# 假设有一个简单的模型
model = torch.nn.Linear(10, 2)
optimizer = optim.SGD(model.parameters(), lr=0.1)# 创建 MultiStepLR 对象,设定在第 30 和 80 个 epoch 时学习率衰减为原来的 0.1 倍
scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)# 在每个训练周期结束时,调用 step() 方法来更新学习率
for epoch in range(100):# 这里省略了模型的训练代码# ...# 更新学习率scheduler.step()

在这个例子中,初始的学习率是0.1,当训练到第30个epoch时,学习率会变为0.01(即0.1*0.1),当训练到第80个epoch时,学习率会再次衰减为0.001(即0.01*0.1)。

4、余弦退火衰减

余弦退火衰减是一种学习率调整策略,它按照余弦函数的周期和最值来调整学习率。在PyTorch中,可以使用 torch.optim.lr_scheduler.CosineAnnealingLR 类来实现余弦退火衰减。

以下是一个使用 CosineAnnealingLR 的代码示例:

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
import matplotlib.pyplot as plt# 假设有一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.linear = nn.Linear(10, 2)model = SimpleModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)# 创建 CosineAnnealingLR 对象,周期设置为 10,最小学习率设置为 0.01
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=0.01)lr_list = []
# 在每个训练周期结束时,调用 step() 方法来更新学习率
for epoch in range(100):# 这里省略了模型的训练代码# ...# 更新学习率scheduler.step()lr_list.append(optimizer.param_groups[0]['lr'])# 绘制学习率变化曲线
plt.plot(range(100), lr_list)
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title("Learning rate schedule: CosineAnnealingLR")
plt.show()

在这个例子中,初始的学习率是0.1,学习率会按照余弦函数的形式在0.01到0.1之间变化,周期为10个epoch。

5、自适应学习率衰减

自适应学习率衰减是一种学习率调整策略,它会根据模型的训练进度自动调整学习率。例如,如果模型的验证误差停止下降,那么就减小学习率;如果模型的训练误差上升,那么就增大学习率。在PyTorch中,可以使用 torch.optim.lr_scheduler.ReduceLROnPlateau 类来实现自适应学习率衰减。

以下是一个使用 ReduceLROnPlateau 的代码示例:

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau# 假设有一个简单的模型
model = nn.Linear(10, 2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)# 创建 ReduceLROnPlateau 对象,当验证误差在 10 个 epoch 内没有下降时,将学习率减小为原来的 0.1 倍
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.1)# 模拟训练过程
for epoch in range(100):# 这里省略了模型的训练代码# ...# 假设有一个验证误差val_loss = ...# 在每个训练周期结束时,调用 step() 方法来更新学习率scheduler.step(val_loss)

6、自定义函数实现学习率调整:不同层不同的学习率

可以通过为优化器提供一个参数组列表来实现对不同层使用不同的学习率。每个参数组是一个字典,其中包含一组参数和这组参数的学习率。

以下是一个具体的例子,假设有一个包含两个线性层的模型,想要对第一层使用学习率0.01,对第二层使用学习率0.001:

import torch
import torch.nn as nn
import torch.optim as optim# 定义一个包含两个线性层的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.layer1 = nn.Linear(10, 2)self.layer2 = nn.Linear(2, 10)model = SimpleModel()# 创建一个参数组列表,每个参数组包含一组参数和这组参数的学习率
params = [{'params': model.layer1.parameters(), 'lr': 0.01},{'params': model.layer2.parameters(), 'lr': 0.001}]# 创建优化器
optimizer = optim.SGD(params)# 现在,当调用 optimizer.step() 时,第一层的参数会使用学习率 0.01 进行更新,第二层的参数会使用学习率 0.001 进行更新

在这个例子中,首先定义了一个包含两个线性层的模型。然后,创建一个参数组列表,每个参数组都包含一组参数和这组参数的学习率。最后创建了一个优化器SGD,将这个参数组列表传递给优化器。这样,当调用 optimizer.step() 时,第一层的参数会使用学习率0.01进行更新,第二层的参数会使用学习率0.001进行更新。

参考:深度图学习与大模型LLM

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/260953.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LDRA Testbed软件静态分析_软件质量度量

系列文章目录 LDRA Testbed软件静态分析_操作指南 LDRA Testbed软件静态分析_自动提取静态分析数据生成文档 LDRA Testbed软件静态分析_Jenkins持续集成_(1)自动进行静态分析的环境搭建 LDRA Testbed软件静态分析_Jenkins持续集成_(2)配置邮件自动发送静态分析结果 LDRA Testb…

Qt C++春晚刘谦魔术约瑟夫环问题的模拟程序

什么是约瑟夫环问题? 约瑟夫问题是个有名的问题:N个人围成一圈,从第一个开始报数,第M个将被杀掉,最后剩下一个,其余人都将被杀掉。例如N6,M5,被杀掉的顺序是:5&#xff…

离线升级esp32开发板升级包esp32-2.0.14(最新版已经3.0alpha了)

1.Arduino IDE 2.3.2最新 2024.2.20升级安装:https://www.arduino.cc/en/software 2.开发板地址 地址(esp8266,esp32) http://arduino.esp8266.com/stable/package_esp8266com_index.json,https://raw.githubusercontent.com/espressif/arduino-esp32…

09MARL深度强化学习policy gradient

文章目录 前言1、Multi-Agent Policy Gradient Theorem2、Centralised State-Value Critics2、Centralised Action-Value Critics 前言 Independent learning算法当中每个智能体看待其他智能体为环境的一部分,加剧了环境非平稳性,而CTDE框架的算法能够降…

HTTP的详细介绍

目录 一、HTTP 相关概念 二、HTTP请求访问的完整过程 1、 建立连接 2、 接收请求 3、 处理请求 3.1 常见的HTTP方法 3.2 GET和POST比较 4、访问资源 5、构建响应报文 6、发送响应报文 7、记录日志 三、HTTP安装组成 1、常见http 服务器程序 2、apache介绍和特点 …

工具分享:在线键盘测试工具

在数字化时代,键盘作为我们与计算机交互的重要媒介之一,其性能和稳定性直接影响到我们的工作效率和使用体验。为了确保键盘的每个按键都能正常工作,并帮助用户检测潜在的延迟、连点等问题,一款优质的在线键盘测试工具显得尤为重要…

QT设置窗口随窗体变化(窗口文本框随窗体的伸缩)

目录 1.建立新窗口2.最终效果 1.建立新窗口 1)在窗体中创建一个 textBrowser,记录坐标及宽高 X-100 Y-130 宽-571 高-281,窗体宽高800*600; 2)在.h头文件中插入void resizeEvent(QResizeEvent *event) override;函数 …

如何系统地自学Python?

如何系统地自学Python? 如何系统地自学Python?1.了解编程基础2.学习Python基础语法3.学习Python库和框架4.练习编写代码5.参与开源项目6.加入Python社区7.利用资源学习8.制定学习计划9.持之以恒总结 如何系统地自学Python? 作为一个Python语…

【MySQL】如何处理DB读写分离数据不一致问题?

文章内容 1、前言读写库数据不一致问题我们如何解决?方案一:利用数据库自身特性方案二:不解决方案三:客户端保存法方案四:缓存标记法方案五:本地缓存标记 那DB读写分离情况下,如何解决缓存和数据…

h5网页和 Android APP联调,webview嵌入网页,网页中window.open打开新页面,网页只在webview中打开,没有重开一个app窗口

我是h5网页开发,客户app通过webview嵌入我的页面 点击标题window.open跳转到长图页面,客户的需求是在app里新开一个窗口展示长图页面,window.open打开,ios端是符合客户需求的,但是在安卓端他会在当前webview打开 这…

Find My资讯|苹果Vision Pro无法通过Find My进行远程定位和发声

苹果 Vision Pro 头显现在已经正式开售,不过根据该公司日前发布的支持文件,这款头显目前缺乏一系列关键查找功能,用户无法在 iCloud 网站或Find My应用中获悉头显的位置,也无法让这款头显远程播放声音。 不过支持文件同时提到 V…

3、windows环境下vscode开发c/c++环境配置(二)

前言:上一篇文章写了windows环境下,配置vscode的c/c开发环境,这一篇讲vscode开发c/c的配置文件,包括c_cpp_propertues.json,task.json及launch.json。 一、总体流程 通过c/c插件我们就可以来编写c/c程序了&#xff0c…

新版AI系统ChatGPT源码支持GPT-4/支持AI绘画去授权

源码获取方式 搜一搜:万能工具箱合集 点击资源库直接进去获取源码即可 如果没看到就是待更新,会陆续更新上 新版AI系统ChatGPT网站源码支持GPT-4/支持AI绘画/Prompt应用/MJ绘画源码/PCH5端/免授权,支持关联上下文,意间绘画模型…

村级数据下载

简介 我开发了一个网站,是一个提供2010年-2023年的中国行政区划关系的查询网站,github开源地址为:https://github.com/ruiduobao/gaode_MAP_CUN,五级行政区划的上下级关系来源于统计局发布的各个年份对应的统计用区划代码和城乡划…

Android 基础技术——Framework

笔者希望做一个系列,整理 Android 基础技术,本章是关于 Framework 简述 Android 系统启动流程 当按电源键触发开机,首先会从 ROM 中预定义的地方加载引导程序 BootLoader 到 RAM 中,并执行 BootLoader 程序启动 Linux Kernel&…

已解决Application run failed org.springframework.beans.factory.BeanNot

问题原因:SpringBoot的版本与mybiats-puls版本不对应且,spring自带的mybiats与mybiats-puls版本不对应 这里我用的是3.2.2版本的SpringBoot,之前mybiats-puls版本是3.5.3.1有所不同。 问题:版本对不上 解决办法:完整…

NoSQL 数据库管理工具,搭载强大支持:Redis、Memcached、SSDB、LevelDB、RocksDB,为您的数据存储提供无与伦比的灵活性与性能!

NoSQL 数据库管理工具,搭载强大支持:Redis、Memcached、SSDB、LevelDB、RocksDB,为您的数据存储提供无与伦比的灵活性与性能! 【官网地址】:http://www.redisant.cn/nosql 介绍 直观的用户界面 从单一应用程序中同…

计算机视觉基础:【矩阵】矩阵选取子集

OpenCV的基础是处理图像,而图像的基础是矩阵。 因此,如何使用好矩阵是非常关键的。 下面我们通过一个具体的实例来展示如何通过Python和OpenCV对矩阵进行操作,从而更好地实现对图像的处理。 示例 示例:选取矩阵中指定的行和列的…

我为什么不喜欢关电脑?

程序员为什么不喜欢关电脑? 你是否注意到,程序员们似乎从不关电脑?别以为他们是电脑上瘾,实则是有他们自己的原因!让我们一起揭秘背后的原因,看看程序员们真正的“英雄”本色! 一、上大学时。 …

C++:C++入门基础

创作不易,感谢三连 !! 一、什么是C C语言是结构化和模块化的语言,适合处理较小规模的程序。对于复杂的问题,规模较大的程序,需要高度的抽象和建模时,C语言则不合适。为了解决软件危机&#xff…