深度学习中的优化方法(Momentum,AdaGrad,RMSProp,Adam)详解及调用

深度学习中常用的优化方法包括啦momentum(动量法),Adagrad(adaptive gradient自适应梯度法),RMSProp(root mean square propagation均方根传播算法),Adam(adaptive moment estimation自适应矩估计法)

指数加权平均算法

所谓指数加权平均算法是上述优化算法的基础,其作用是对历史数据和当前数据进行加权求和,具体公式如下

if t==0:

v_0 = x_0

if t > 0

v_t = \beta v_{t-1} + (1 - \beta) x_t

其中

  • v_t为时间步t的加权平均值
  • v_{t-1}为为时间步t-1的加权平均值
  • x_t为时间步t的观测值
  • \beta为平滑因子,0 < \beta < 1

可以较为明显地看出,所谓指数加权平均算法的关键就在于\beta的大小,其越大,当前时间步的值就会越偏向过去的值,换句话说,整体数值序列就会更加平滑

指数加权平均算法的可视化

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from numpy.array_api import linspacetorch.manual_seed(0)def exponential_wma(data,beta=0.9):list1 = []for n,i in enumerate(data):if n ==0:list1.append(i)else:list1.append(beta * list1[n-1] + (1-beta) * i)return list1if __name__ == '__main__':data = torch.randn(50)*10x = torch.linspace(1,50,50)fig = plt.figure()axes1 = plt.subplot(1,2,1)axes1.scatter(x,data)axes1.plot(x,data)axes1.set_xlabel('x')axes1.set_ylabel('y')axes1.set_title('original data')print(data)axes2 = plt.subplot(1,2,2)axes2.scatter(x,data,label='original_data')axes2.plot(x,exponential_wma(data),  label='ewma_curve')axes2.set_xlabel('x')axes2.set_ylabel('y')axes2.set_title('ewma')print(exponential_wma(data))axes2.legend()plt.subplots_adjust(wspace=0.4)fig.savefig('ewma.png')plt.show()# tensor([-11.2584, -11.5236,  -2.5058,  -4.3388,   8.4871,   6.9201,  -3.1601,
#         -21.1522,   3.2227, -12.6333,   3.4998,   3.0813,   1.1984,  12.3766,
#          11.1678,  -2.4728, -13.5265, -16.9593,   5.6665,   7.9351,   5.9884,
#         -15.5510,  -3.4136,  18.5301,   7.5019,  -5.8550,  -1.7340,   1.8348,
#          13.8937,  15.8633,   9.4630,  -8.4368,  -6.1358,   0.3159,  10.5536,
#           1.7784,  -2.3034,  -3.9175,   5.4329,  -3.9516,   2.0553,  -4.5033,
#          15.2098,  34.1050, -15.3118, -12.3414,  18.1973,  -5.5153, -13.2533,
#           1.8855])
# [tensor(-11.2584), tensor(-11.2849), tensor(-10.4070), tensor(-9.8002), tensor(-7.9715), tensor(-6.4823), tensor(-6.1501), tensor(-7.6503), tensor(-6.5630), tensor(-7.1700), tensor(-6.1030), tensor(-5.1846), tensor(-4.5463), tensor(-2.8540), tensor(-1.4518), tensor(-1.5539), tensor(-2.7512), tensor(-4.1720), tensor(-3.1881), tensor(-2.0758), tensor(-1.2694), tensor(-2.6976), tensor(-2.7692), tensor(-0.6392), tensor(0.1749), tensor(-0.4281), tensor(-0.5587), tensor(-0.3193), tensor(1.1020), tensor(2.5781), tensor(3.2666), tensor(2.0962), tensor(1.2730), tensor(1.1773), tensor(2.1150), tensor(2.0813), tensor(1.6428), tensor(1.0868), tensor(1.5214), tensor(0.9741), tensor(1.0822), tensor(0.5237), tensor(1.9923), tensor(5.2036), tensor(3.1520), tensor(1.6027), tensor(3.2621), tensor(2.3844), tensor(0.8206), tensor(0.9271)]

这里设置的平滑因子为0.9,通常情况下使用动量法的时候平滑因子也会设置为0.9

平滑因子越大,理论上曲线就会越平滑

Momentum动量法

Momentum动量法的原理就是在梯度下降的时候使用指数加权平均法计算下降的梯度

动量更新方法如下

v_{t+1} = \beta v_t + (1 - \beta) \nabla_{\theta} J(\theta_t)

其中

  • v_{t+1}是动量项(累积的梯度),用于更新参数。
  • \beta 是动量系数(通常 0 < \beta < 1 ,例如  0.9 ),它决定了之前梯度对当前更新的影响程度。
  • v_t 是前一步的动量值。
  • \nabla_{\theta} J(\theta_t) 是当前时间步  t  计算得到的梯度。
  • \theta_t 是当前模型参数。

参数更新方法如下

\theta_{t+1} = \theta_t - \eta v_{t+1}

其中

  • \theta_{t+1}是更新后的参数。
  • \theta_t 是当前的参数。
  • \eta 是学习率。
  • v_{t+1} 是当前的动量项(累积梯度)。

动量法的调用

动量法的调用一般集成在SGD优化器中,通过设置SGD优化器中的momentum参数来配置,momentum参数的值就是动量系数(平滑因子)

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.optim as optimmodel = nn.Linear(5, 1)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print(model)
print(optimizer)# Linear(in_features=5, out_features=1, bias=True)
# SGD (
# Parameter Group 0
#     dampening: 0
#     differentiable: False
#     foreach: None
#     fused: None
#     lr: 0.001
#     maximize: False
#     momentum: 0.9
#     nesterov: False
#     weight_decay: 0
# )

AdaGrad自适应梯度

AdaGrad(自适应梯度法)的作用是随着训练的进行,对学习率进行逐步衰减

累计梯度平方和

G_t = G_{t-1} + g_t^2

其中

  • G_t 是参数梯度平方的累积和(是一个对角矩阵,表示每个参数的梯度平方和)。
  • g_t  是当前时间步  t  的梯度
  •  G_{t-1}  是前面所有时间步的梯度平方和

参数更新方式

\theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{G_t} + \epsilon} \cdot g_t

其中

  • \theta_t 是当前的参数
  • \eta  是全局的学习率
  • G_t 是梯度平方的累积和
  • \epsilon  是一个小常数,用于防止除以零,通常取值为  10^{-8}

自适应梯度法的调用

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(10, 100)self.fc2 = nn.Linear(100, 5)def forward(self, x):x = F.relu(self.fc1(x))x = self.fc2(x)return xif __name__ == '__main__':my_net = SimpleNet()optimizer = optim.Adagrad(my_net.parameters(), lr=0.01)print(my_net)print(optimizer)# SimpleNet(
#   (fc1): Linear(in_features=10, out_features=100, bias=True)
#   (fc2): Linear(in_features=100, out_features=5, bias=True)
# )
# Adagrad (
# Parameter Group 0
#     differentiable: False
#     eps: 1e-10
#     foreach: None
#     fused: None
#     initial_accumulator_value: 0
#     lr: 0.01
#     lr_decay: 0
#     maximize: False
#     weight_decay: 0
# )

RMSProp均方根传播法

RMSProp均方根传播法是在Adagrad的基础上做了优化,由于Adagrad中的累计平方和,会导致学习率快速下降导致模型收敛变慢

所以RMSProp对累计平方和进行了优化,转为了加权平均算法

梯度平方的指数加权平均算法

E[g^2]t = \beta E[g^2]{t-1} + (1 - \beta) g_t^2

其中

  • E[g^2]_t  是梯度平方的指数加权移动平均值(即该参数梯度的平滑历史平方值)
  • \beta  是衰减系数(通常取值接近 1,例如 0.9 或 0.99)
  • g_t  是当前时刻  t  的梯度
  • E[g^2]_{t-1}  是前一时刻的梯度平方的指数加权移动平均

参数更新

\theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{E[g^2]_t + \epsilon}} \cdot g_t

其中

  • \theta_t  是当前的参数
  • \eta 是当前的学习率
  • E[g^2]_t 是梯度平方的指数加权移动平均
  • \epsilon  是一个小常数(通常取  10^{-8} ),用于防止除以零的情况

同时RMSProp也支持动量法用于记录历史梯度,但是与SGD中国的动量法有所不同

动量项 v_t 的更新

v_t = \beta v_{t-1} + \frac{\eta}{\sqrt{E[g^2]_t + \epsilon}} g_t

其中

  • v_t 是当前动量项
  • v_{t-1}  是前一时刻的动量项
  • \beta  是动量系数,通常接近 1(例如 0.9 或 0.99),表示过去动量的影响
  • \eta  是学习率
  • g_t 是当前的梯度
  • E[g^2]_t  是梯度平方的指数加权移动平均,用来动态调整每个参数的学习率

这里动量项的加入会使模型加快熟练

均方根传播法的调用

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(10, 100)self.fc2 = nn.Linear(100, 5)def forward(self, x):x = F.relu(self.fc1(x))x = self.fc2(x)return xif __name__ == '__main__':my_net = SimpleNet()optimizer = optim.RMSprop(my_net.parameters(), lr=0.01, alpha=0.99,momentum=0.9)print(my_net)print(optimizer)# SimpleNet(
#   (fc1): Linear(in_features=10, out_features=100, bias=True)
#   (fc2): Linear(in_features=100, out_features=5, bias=True)
# )
# RMSprop (
# Parameter Group 0
#     alpha: 0.99
#     capturable: False
#     centered: False
#     differentiable: False
#     eps: 1e-08
#     foreach: None
#     lr: 0.01
#     maximize: False
#     momentum: 0.9
#     weight_decay: 0
# )

Adam自适应矩估计法

Adam自适应矩估计法,与RMSProp不同,其完整融合了Momentum和AdaGrad方法

一阶矩估计(动量)

  • m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t
  • m_t 是梯度的一阶矩的指数加权移动平均值(即动量)
  • \beta_1  是控制动量的衰减系数,通常取  0.9 
  • g_t  是当前时刻的梯度

二阶矩估计(梯度平方的指数加权移动平均)

  • v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2
  • v_t 是梯度平方的指数加权移动平均值(相当于 RMSProp 中的累积梯度平方)
  • \beta_2  是控制二阶矩的衰减系数,通常取  0.999 

一阶矩和二阶矩的偏差校正

由于一阶矩和二阶矩在初始化的时候,不具备前一时刻的动量,所以由于(1-衰减系数)的存在,刚开始的梯度下降幅度偏小,所以这里使用了偏差校正去放大一开始的下降幅度

  • \hat{m_t} = \frac{m_t}{1 - \beta_1^t}
  • \hat{v_t} = \frac{v_t}{1 - \beta_2^t}
  • \hat{m_t}  是一阶矩的偏差校正值
  • \hat{v_t}  是二阶矩的偏差校正值

参数更新

  • \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v_t}} + \epsilon} \hat{m_t}
  • \theta_t  是当前的参数
  • \eta  是学习率
  • \epsilon  是一个小常数,通常取  10^{-8} ,用于防止除以零的情况
  • \hat{m_t}  是偏差校正后的一阶矩
  • \hat{v_t}  是偏差校正后的二阶矩

参数设置说明

  • \beta_1  和  \beta_2 :分别控制一阶矩和二阶矩的衰减速率。通常推荐值为 \beta_1  = 0.9  和 \beta_2  = 0.999 
  • \epsilon :用于防止除以零的问题,保证数值稳定性,通常设置为  10^{-8} 
  • 学习率  \eta :这是 Adam 的全局学习率,通常设置为  0.001 
  • \beta_1^t  是  \beta_1  的  t  次方,随着迭代步数  t  的增加,\beta_1^t  趋近于 0,从而消除偏差

自适应矩估计法的调用

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(10, 100)self.fc2 = nn.Linear(100, 5)def forward(self, x):x = F.relu(self.fc1(x))x = self.fc2(x)return xif __name__ == '__main__':my_net = SimpleNet()optimizer = optim.Adam(my_net.parameters(), lr=0.01, betas=(0.9, 0.999), eps=1e-8)print(my_net)print(optimizer)# SimpleNet(
#   (fc1): Linear(in_features=10, out_features=100, bias=True)
#   (fc2): Linear(in_features=100, out_features=5, bias=True)
# )
# Adam (
# Parameter Group 0
#     amsgrad: False
#     betas: (0.9, 0.999)
#     capturable: False
#     differentiable: False
#     eps: 1e-08
#     foreach: None
#     fused: None
#     lr: 0.01
#     maximize: False
#     weight_decay: 0
# )

补充

在以上四个优化器中,都支持配置weight_decay参数,其为正则化系数,添加后可以对模型添加L2正则

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/440657.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

word无法复制粘贴

word无法复制粘贴 使用word时复制粘贴报错 如下&#xff1a; 报错&#xff1a;运行时错误‘53’&#xff0c;文件未找到&#xff1a;MathPage.WLL 这是mathtype导致的。 解决方法 1&#xff09;在mathtype下载目录下找到"\MathType\MathPage\64"下的"mathpa…

Python并发编程挑战与解决方案

Python并发编程挑战与解决方案 并发编程是现代软件开发中的一项核心能力&#xff0c;它允许多个任务同时运行&#xff0c;提高程序的性能和响应速度。Python因其易用性和灵活性而广受欢迎&#xff0c;但其全局解释器锁&#xff08;GIL&#xff09;以及其他特性给并发编程带来了…

python数据分析与可视化工具介绍-matplotlib库

众所周知&#xff0c;python的数据分析库主要是numpy&#xff0c;pandas&#xff0c;和matplotlib&#xff0c;而前面两个主要是数据处理工具库&#xff0c;最后一个才是真正的作图展示工具库。本节来学习一下matploatlib工具库的使用。 Matplotlib常用绘图函数 pyplot简介 m…

Oracle中TRUNC()函数详解

文章目录 前言一、TRUNC函数的语法二、主要用途三、测试用例总结 前言 在Oracle中&#xff0c;TRUNC函数用于截取或截断日期、时间或数值表达式的部分。它返回一个日期、时间或数值的截断版本&#xff0c;根据提供的格式进行截取。 一、TRUNC函数的语法 TRUNC(date) TRUNC(d…

字符编码发展史5 — UTF-16和UTF-32

上一篇《字符编码发展史4 — Unicode与UTF-8》我们讲解了Unicode字符集与UTF-8编码。本篇我们将继续讲解字符编码的第三个发展阶段中的UTF-16和UTF-32。 2.3. 第三个阶段 国际化 2.3.2. Unicode的编码方式 2.3.2.2. UTF-16 UTF-16也是一种变长编码&#xff0c;对于一个Unic…

1、如何查看电脑已经连接上的wifi的密码?

在电脑桌面右下角的如下位置&#xff1a;双击打开查看当前连接上的wifi的名字&#xff1a;ZTE-kfdGYX-5G 按一下键盘上的win R 键, 输入【cmd】 然后&#xff0c;按一下【回车】。 输入netsh wlan show profile ”wifi名称” keyclear : 输入完成后&#xff0c;按一下回车&…

浏览器前端向后端提供服务

WEB后端向浏览器前端提供服务是最常见的场景&#xff0c;前端向后端的接口发起GET或者POST请求&#xff0c;后端收到请求后执行服务器端任务进行处理&#xff0c;完成后向前端发送响应。 那浏览器前端向后端提供服务是什么鬼&#xff1f; 说来话长&#xff0c;长话短说。我在人…

AFSim仿真系统 --- 系统简解_06 平台及平台类型

平台及平台类型 在AFSIM模拟中&#xff0c;当在被模拟的场景中定义平台时&#xff0c;创建仿真实体&#xff08;如车辆和结构&#xff09;。 AFSIM是一个用于创建仿真的对象框架&#xff0c;而平台封装了对象的原则身份或定义。 平台可以拥有系统&#xff08;或平台部分&#x…

自然语言处理-语言转换

文章目录 一、语言模型二、统计语言模型1.含义与方法2.存在的问题 三、神经语言模型1.含义与方法2.one-hot编码3.词嵌入-word2vec4.模型的训练过程 四、总结 自然语言处理&#xff08;NLP&#xff09;中的语言转换方法主要涉及将一种形式的语言数据转换为另一种形式&#xff0c…

[Cocoa]_[初级]_[使用NSNotificationCenter作为目标观察者实现时需要注意的事项]

场景 在开发Cocoa程序时&#xff0c;由于界面是用Objective-C写的。无法使用C的目标观察者[1]类。如果是使用第二种方案2[2],那么也需要增加一个代理类。那么有没有更省事的办法&#xff1f; 说明 开发界面的时候&#xff0c;经常是需要在子界面里传递数据给主界面&#xff0…

Windows 搭建 Gitea

一、准备工作 1. 安装 Git&#xff1a;Gitea 依赖 Git 进行代码管理&#xff0c;所以首先需要确保系统中安装了 Git。 下载地址&#xff1a;https://git-scm.com/downloads/win 2. 安装数据库&#xff08;可选&#xff09; 默认情况下&#xff0c;Gitea 使用 SQLite 作为内…

Nginx的基础讲解之重写conf文件

一、Nginx 1、什么是nginx&#xff1f; Nginx&#xff08;engine x&#xff09;是一个高性能的HTTP和反向代理web服务器&#xff0c;同时也提供了IMAP/POP3/SMTP服务。 2、用于什么场景 Nginx适用于各种规模的网站和应用程序&#xff0c;特别是需要高并发处理和负载均衡的场…

微信步数C++

题目&#xff1a; 样例解释&#xff1a; 【样例 #1 解释】 从 (1,1) 出发将走 2 步&#xff0c;从 (1,2) 出发将走 4 步&#xff0c;从 (1,3) 出发将走 4 步。 从 (2,1) 出发将走 2 步&#xff0c;从 (2,2) 出发将走 3 步&#xff0c;从 (2,3) 出发将走 3 步。 从 (3,1) 出发将…

AI 激活新势能,中小企业全媒体营销绽放无限可能

什么是全媒体营销&#xff1a; 全媒体营销是一种利用多种媒介渠道进行品牌、产品或服务推广的营销策略。它结合了传统媒体&#xff08;如电视、广播、报纸、杂志&#xff09;和新媒体&#xff08;如互联网、社交媒体、移动应用等&#xff09;的优势&#xff0c;以实现信息的广…

力扣之1322.广告效果

题目&#xff1a; sql建表语句&#xff1a; Create table If Not Exists Ads (ad_id int,user_id int,action ENUM (Clicked, Viewed, Ignored) ); Truncate table Ads; insert into Ads (ad_id, user_id, action) values (1, 1, Clicked); insert into Ads (ad_id, use…

【重学 MySQL】五十八、文本字符串(包括 enum set)类型

【重学 MySQL】五十八、文本字符串&#xff08;包括 enum set&#xff09;类型 CHAR 和 VARCHARTEXT 系列ENUMSET示例注意事项 在 MySQL 中&#xff0c;文本字符串类型用于存储字符数据。这些类型包括 CHAR、VARCHAR、TEXT 系列&#xff08;如 TINYTEXT、TEXT、MEDIUMTEXT 和 L…

基于SSM的仿win10界面的酒店管理系统

基于SSM的仿win10界面的酒店管理系统 运行环境: jdk1.8 eclipse tomcat7 mysql5.7 项目技术: jspssm&#xff08;springspringmvcmybatis&#xff09;mysql 项目功能模块&#xff1a;基础功能、房间类型、楼层信息、附属功能

AtCoder ABC373 A-D题解

ABC372 的题解没写是因为 D 是单调栈我不会(⊙︿⊙) 比赛链接:ABC373 总结&#xff1a;wssb。听说 E 很水&#xff1f;有时间我看看。 Problem A: Code #include <bits/stdc.h> using namespace std; int mian(){int ans0;for(int i1;i<12;i){string S;cin>&g…

[Offsec Lab] ICMP Monitorr-RCE+hping3权限提升

信息收集 IP AddressOpening Ports192.168.52.218TCP:22,80 $ nmap -p- 192.168.52.218 --min-rate 1000 -sC -sV -Pn PORT STATE SERVICE VERSION 22/tcp open ssh OpenSSH 7.9p1 Debian 10deb10u2 (protocol 2.0) | ssh-hostkey: | 2048 de:b5:23:89:bb:9f:d4:1…

表面缺陷检测系统源码分享

表面缺陷检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vision …