【PyTorch实战演练】自调整学习率实例应用(附代码)

目录

0. 前言

1. 自调整学习率的常用方法

1.1  ExponentialLR 指数衰减方法

1.2 CosineAnnealingLR 余弦退火方法

1.3 ChainedScheduler 链式方法

2. 实例说明

3. 结果说明

3.1 余弦退火法训练过程

3.2 指数衰减法训练过程

3.3 恒定学习率训练过程

3.4 结果解读

4. 完整代码


0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文介绍深度学习训练中经常能使用到的实用技巧——自调整学习率,并基于PyTorch框架通过实例进行使用,最后对比同样条件下以自调整学习率和固定学习率在模型训练上的表现。

在深度学习模型训练过程中,经常会出现损失值不收敛的头疼情况,令人怀疑自己设计的网络模型存在不合理或者错误之处,但是导致这种情况的真正原因往往非常简单——就是学习率的设定不合理。

这就导致在深度学习模型训练的初期,可能需要“摸索”一个比较合适的学习率,在后期逐渐细化的过程可能还需要尝试其他的学习率,进行“分段训练”。学习率的这种“摸索”费时费力,因此自调整学习率的方法应运而生。

1. 自调整学习率的常用方法

自调整学习率是一种通用的方法,通过在训练期间动态地更新学习率以使模型更好地收敛,提高模型的精确度和稳定性。在PyTorch框架 torch.optim.lr_scheduler 中集成了大量的自调整学习率的方法:

lr_scheduler.LambdaLR

lr_scheduler.MultiplicativeLR

lr_scheduler.StepLR

lr_scheduler.MultiStepLR

lr_scheduler.ConstantLR

lr_scheduler.LinearLR

lr_scheduler.ExponentialLR

lr_scheduler.PolynomialLR

lr_scheduler.CosineAnnealingLR

lr_scheduler.ChainedScheduler

lr_scheduler.SequentialLR

lr_scheduler.ReduceLROnPlateau

lr_scheduler.CyclicLR

lr_scheduler.OneCycleLR

lr_scheduler.CosineAnnealingWarmRestarts

这里具体方法虽然很多,但是每种自动调整学习率的策略都比较简单,本文仅介绍以下两种常见的方法,其余可以查看PyTorch官网说明。

1.1  ExponentialLR 指数衰减方法

这种方法的学习率为:

lr = lr_{initial} \cdot \gamma^{epoch}

这里有两个参数:

  • lr_{initial}:初始学习率
  • \gamma:衰减指数,取值范围(0,1),一般来说取值要非常接近1(>0.95,甚至要>0.99),否则在动辄几百几千的迭代次数epoch条件下学习率很快会衰减到0

PyTorch调用代码为:

torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch)
  • optimizer:训练该模型的优化算法
  • gamma:上面的衰减指数\gamma
  • last_epoch:不用理会,默认-1即可
1.2 CosineAnnealingLR 余弦退火方法

首先我要吐槽下不知道是哪个大聪明给起的这个名字,我学过《机械加工原理与工艺》,但我不明白这个算法和退火有什么关系?名字非常唬人,方法并不复杂。

这个方法就是让学习率按余弦曲线进行震荡,其震荡范围为[0, lr_initial],震荡周期为T_max:

PyTorch调用代码为:

torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, last_epoch)

其参数同上说明不再赘述。

1.3 ChainedScheduler 链式方法

这个方法说白了就是多种自调节学习率方法进行组合使用,例如以下:

scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
scheduler2 = ExponentialLR(self.opt, gamma=0.9)
scheduler = ChainedScheduler([scheduler1, scheduler2])

最后再强调一下,无论使用哪种方法,在coding时别忘了每个epoch学习后加上 scheduler.step() ,这样才能更新学习率。

2. 实例说明

本文目的是验证自调节学习率的作用,选用一个简单的“平方网络”实例,即期望输出为输入值的平方。

训练输入数据x_train,输出数据y_train分别为:

x_train = torch.tensor([0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1],dtype=torch.float32).unsqueeze(-1)
y_train = torch.tensor([0.01,0.04,0.09,0.16,0.25,0.36,0.49,0.64,0.81,1],dtype=torch.float32).unsqueeze(-1)

验证输入数据x_val为:

x_val = torch.tensor([0.15,0.25,0.35,0.45,0.55,0.65,0.75,0.85,0.95],dtype=torch.float32).unsqueeze(-1)

网络模型使用5层全连接网络,对应这种简单问题足够用了:

class Linear(torch.nn.Module):def __init__(self):super().__init__()self.layers = torch.nn.Sequential(torch.nn.Linear(in_features=1, out_features=3),torch.nn.Sigmoid(),torch.nn.Linear(in_features=3,out_features=5),torch.nn.Sigmoid(),torch.nn.Linear(in_features=5, out_features=10),torch.nn.Sigmoid(),torch.nn.Linear(in_features=10,out_features=5),torch.nn.Sigmoid(),torch.nn.Linear(in_features=5, out_features=1),torch.nn.ReLU(),)

优化器选用Adam,训练组及验证组的损失函数都选用MSE均方差。

3. 结果说明

对于本文对比的三种方法,其训练参数设定如下表:

学习率调整方法训练参数设定
指数衰减迭代次数epoch=2000,初始学习率initial_lr=0.0005,衰减指数gamma=0.99995
余弦退火迭代次数epoch=2000,初始学习率initial_lr=0.001(因为学习率均值为初始值的一半,公平起见给余弦退火方法初始学习率*2),震荡周期T_max=200
恒定学习率学习率lr=0.0005

各种方法的训练过程如下图:

3.1 余弦退火法训练过程

3.2 指数衰减法训练过程

3.3 恒定学习率训练过程

3.4 结果解读
  1. 指数衰减方法的gamma真的要选择非常接近1才行!理由我上面也解释了,下面是gamma=0.999时的loss下降情况,可见其下降速度之慢;

  2. 本实例中,使用余弦退火方法训练的模型在验证组表现最好,验证组损失值val=0.0027,其次是指数衰减法(val=0.0037),最差是恒定学习率(val=0.0053),但这个结果仅限本实例中,并不是说哪种方法就比哪种方法好,在不同模型上可能要因地制宜选择合适的方法
  3. 目前这些所谓的“自调节学习率”我认为仅能算是“半自动调节”,因为仍有超参数需要事前设定(例如衰减指数gamma)。而选择哪种方法最好,以及这种方法的参数如何设定呢?可能还是需要做一定的“摸索”,但是相比恒定学习率肯定会节省很多时间!

4. 完整代码

import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
import argparsetorch.manual_seed(25)x_train = torch.tensor([0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1],dtype=torch.float32).unsqueeze(-1)
y_train = torch.tensor([0.01,0.04,0.09,0.16,0.25,0.36,0.49,0.64,0.81,1],dtype=torch.float32).unsqueeze(-1)
x_val = torch.tensor([0.15,0.25,0.35,0.45,0.55,0.65,0.75,0.85,0.95],dtype=torch.float32).unsqueeze(-1)class Linear(torch.nn.Module):def __init__(self):super().__init__()self.layers = torch.nn.Sequential(torch.nn.Linear(in_features=1, out_features=3),torch.nn.Sigmoid(),torch.nn.Linear(in_features=3,out_features=5),torch.nn.Sigmoid(),torch.nn.Linear(in_features=5, out_features=10),torch.nn.Sigmoid(),torch.nn.Linear(in_features=10,out_features=5),torch.nn.Sigmoid(),torch.nn.Linear(in_features=5, out_features=1),torch.nn.ReLU(),)def forward(self,x):return self.layers(x)criterion = torch.nn.MSELoss()def train_with_CosinAnneal(epoch, initial_lr, T_max):linear1 = Linear()opt = torch.optim.Adam(linear1.parameters(), lr=initial_lr)scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=opt, T_max=T_max, last_epoch=-1)last_epoch_loss = 0for i in tqdm(range(epoch)):opt.zero_grad()total_loss = 0for j in range(len(x_train)):output = linear1(x_train[j])loss = criterion(output,y_train[j])total_loss = total_loss + loss.detach()loss.backward()opt.step()if i == epoch-1:last_epoch_loss = total_lossscheduler.step()cur_lr = scheduler.get_lr()  #查看学习率变化情况# plt.scatter(i, cur_lr, s=2, c='g')  plt.scatter(i, total_loss,s=2,c='r')val = 0for x in x_val:val = val + ((linear1(x) - x * x) / x * x)**2plt.title('CosinAnneal---val=%f---last_epoch_loss=%f'%(val,last_epoch_loss))plt.xlabel('epoch')plt.ylabel('loss')plt.show()def train_with_ExponentialLR(epoch, initial_lr, gamma):linear2 = Linear()opt = torch.optim.Adam(linear2.parameters(), lr=initial_lr)scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=opt, gamma=gamma, last_epoch=-1)last_epoch_loss = 0for i in tqdm(range(epoch)):opt.zero_grad()total_loss = 0for j in range(len(x_train)):output = linear2(x_train[j])loss = criterion(output,y_train[j])total_loss = total_loss + loss.detach()loss.backward()opt.step()if i == epoch-1:last_epoch_loss = total_lossscheduler.step()cur_lr = scheduler.get_lr()# plt.scatter(i, cur_lr, s=2, c='g')plt.scatter(i, total_loss,s=2,c='g')val = 0for x in x_val:val = val + ((linear2(x) - x * x) / x * x)**2plt.title('ExponentialLR---val=%f---last_epoch_loss=%f'%(val,last_epoch_loss))plt.xlabel('epoch')plt.ylabel('loss')plt.show()def train_without_lr_scheduler(epoch, initial_lr):linear3 = Linear()opt = torch.optim.Adam(linear3.parameters(), lr=initial_lr)last_epoch_loss = 0for i in tqdm(range(epoch)):opt.zero_grad()total_loss = 0for j in range(len(x_train)):output = linear3(x_train[j])loss = criterion(output, y_train[j])total_loss = total_loss + loss.detach()loss.backward()opt.step()if i == epoch-1:last_epoch_loss = total_lossplt.scatter(i, total_loss, s=2, c='b')val = 0for x in x_val:val = val + ((linear3(x) - x * x) / x * x)**2plt.title('without_lr_scheduler---val=%f---last_epoch_loss=%f'%(val,last_epoch_loss))plt.xlabel('epoch')plt.ylabel('loss')plt.show()if __name__ == '__main__' :epoch = 2000initial_lr = 0.0005gamma = 0.99995T_max = 200Cosine_Anneal_args = [epoch,initial_lr*2,T_max]train_with_CosinAnneal(*Cosine_Anneal_args)# ExponentialLR_args = [epoch, initial_lr, gamma]# train_with_ExponentialLR(*ExponentialLR_args)## without_scheduler_args = [epoch, initial_lr]# train_without_lr_scheduler(*without_scheduler_args)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/167413.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华为OD机试 - 代表团坐车 - 动态规划(Java 2023 B卷 200分)

目录 专栏导读一、题目描述二、输入描述三、输出描述四、解题思路五、Java算法源码六、效果展示1、输入2、输出3、说明 华为OD机试 2023B卷题库疯狂收录中,刷题点这里 专栏导读 本专栏收录于《华为OD机试(JAVA)真题(A卷B卷&#…

人工智能(5):深度学习简介

1 深度学习 —— 神经网络简介 深度学习(Deep Learning)(也称为深度结构学习【Deep Structured Learning】、层次学习【Hierarchical Learning】或者是深度机器学习【Deep Machine Learning】)是一类算法集合,是机器学…

Jenkins+Ant+Jmeter接口自动化集成测试

一、Jenkins安装配置 1、安装配置JDK1.6环境变量; 2、下载jenkins.war,放入C:\jenkins目录下,目录位置随意; Jenkins启动方法: cmd进入Jenkins目录下,执行java -jar jenkins.war 浏览器输入:l…

html5语义化标签

目录 前言 什么是语义化标签 常见的语义化标签 语义化的好处 前言 HTML5 的设计目的是为了在移动设备上支持多媒体。之前网页如果想嵌入视频音频,需要用到 flash ,但是苹果设备是不支持 flash 的,所以为了改变这一现状,html5 …

解决 Element-ui中 表格(Table)使用 v-if 条件控制列显隐时数据展示错乱的问题

本文 Element-ui 版本 2.x 问题 在 el-table-column 上需根据不同 v-if 条件来控制列显隐时&#xff0c;就会出现列数据展示错乱的情况&#xff08;要么 A 列的数据显示在 B 列上&#xff0c;或者后端返回有数据的但是显示的为空&#xff09;&#xff0c;如下所示。 <tem…

使用screen实现服务器代码一直运行

1.安装screen sudo apt install screen 2.创建一个screen&#xff08;创建一个名为chatglm的新的链接&#xff0c;用来一直运行 screen -S chatglm 3.查看进程列表 screen -ls 创建之后&#xff0c;就可以在当前窗口利用cd命令进入要执行的项目中&#xff0c;开始执行&#xf…

pow函数

pow函数 pow的翻译是指数表达式 第一个参数为底数&#xff0c;第二个参数为指数 返回值为&#xff1a; 头文件为include <math.h> #include <stdio.h> #include <math.h>int main() {int ret (int)pow(10, 2);printf("%d\n", ret);return 0; }…

分享一下我家网络机柜,家庭网络设备推荐

家里网络机柜搞了几天终于搞好了&#xff0c;非专业的&#xff0c;走线有点乱&#xff0c;勿喷。 从上到下的设备分别是&#xff1a; 无线路由器&#xff08;当ap用&#xff09;:TL-XDR6088 插排&#xff1a;德木pdu机柜插排 硬盘录像机&#xff1a;TL-NVR6108-L8P 第二排左边…

云计算认证有哪些?认证考了有什么用?

云计算作为一项快速发展的技术&#xff0c;对人才的需求持续增长。无论是男生还是女生&#xff0c;只要具备相关的技能和知识&#xff0c;都可以在云计算领域找到就业机会。 目前入行云计算最好最便捷的方式就是考证&#xff0c;拿到一个云计算相关的证书&#xff0c;就能开启…

2023最新UI酒桌喝酒游戏小程序源码 娱乐小程序源码 带流量主

2023最新UI酒桌喝酒游戏小程序源码 娱乐小程序源码 带流量主 修改增加了广告位&#xff0c;根据文档直接替换&#xff0c;原版本没有广告位 直接上传源码到开发者端即可 通过后改广告代码&#xff0c;然后关闭广告展示提交&#xff0c;通过后打开即可 无广告引流 流量主版…

No171.精选前端面试题,享受每天的挑战和学习

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云课上架的前后端实战课程《Vue.js 和 Egg.js 开发企业级健康管理项目》、《带你从入…

SD NAND对比TF卡优势(以CSNP4GCR01-AMW为例)

最近做的一个项目&#xff0c; 需要加大容量存储&#xff0c;这让我想到之前在做ARM的开发板使用的TF卡方案&#xff0c;但是TF卡需要携带卡槽的&#xff0c;但是有限的PCB板布局已经放不下卡槽的位置。 这个时候就需要那种能够不用卡槽&#xff0c;直接贴在板子上面&#xff0…

1、VMware虚拟机及网络配置

一、VMware虚拟网络编辑器 1、选择NAT模式并配置子网 2、进入NAT设置&#xff0c;配置网关 3、宿主机网络适配器设置 二、创建虚拟机 在这里插入图片描述 三、开启虚拟机&#xff0c;安装操作系统 在该网段内配置静态ip&#xff0c;指定网关为前面NAT配置的网关地址…

DDD与微服务的千丝万缕

一、软件设计发展过程二、什么是DDD&#xff1f;2.1 战略设计2.2 战术设计2.3 名词扫盲1. 领域和子域2. 核心域、通用域和支撑域3. 通用语言4. 限界上下文5. 实体和值对象6. 聚合和聚合根 2.4 事件风暴2.5 领域事件 三、DDD与微服务3.1 DDD与微服务的关系3.2 基于DDD进行微服务…

《动手学深度学习 Pytorch版》 8.7 通过时间反向传播

8.7.1 循环神经网络的梯度分析 本节主要探讨梯度相关问题&#xff0c;因此对模型及其表达式进行了简化&#xff0c;进行如下表示&#xff1a; h t f ( x t , h t − 1 , w h ) o t g ( h t , w o ) \begin{align} h_t&f(x_t,h_{t-1},w_h)\\ o_t&g(h_t,w_o) \end{ali…

【马蹄集】—— 概率论专题:第二类斯特林数

概率论专题&#xff1a;第二类斯特林数 目录 MT2224 矩阵乘法MT2231 越狱MT2232 找朋友MT2233 盒子与球MT2234 点餐 MT2224 矩阵乘法 难度&#xff1a;黄金    时间限制&#xff1a;5秒    占用内存&#xff1a;128M 题目描述 输入两个矩阵&#xff0c;第一个矩阵尺寸为 l…

pytorch 入门 (三)案例一:mnist手写数字识别

本文为&#x1f517;小白入门Pytorch内部限免文章 &#x1f368; 本文为&#x1f517;小白入门Pytorch中的学习记录博客&#x1f366; 参考文章&#xff1a;【小白入门Pytorch】mnist手写数字识别&#x1f356; 原作者&#xff1a;K同学啊 目录 一、 前期准备1. 设置GPU2. 导入…

蓝桥每日一题(day 4: 蓝桥592.门牌制作)--模拟--easy

#include <iostream> using namespace std; int main() {int res 0;for(int i 1; i < 2021; i ){int b i;while(b){if (b % 10 2) res ;b / 10;}}cout << res; return 0; }

最强英文开源模型LLaMA架构探秘,从原理到源码

导读&#xff1a; LLaMA 65B是由Meta AI&#xff08;原Facebook AI&#xff09;发布并宣布开源的真正意义上的千亿级别大语言模型&#xff0c;发布之初&#xff08;2023年2月24日&#xff09;曾引起不小的轰动。LLaMA的横空出世&#xff0c;更像是模型大战中一个搅局者。虽然它…