【深度学习笔记】优化算法——Adam算法

Adam算法

🏷sec_adam

本章我们已经学习了许多有效优化的技术。
在本节讨论之前,我们先详细回顾一下这些技术:

  • 在 :numref:sec_sgd中,我们学习了:随机梯度下降在解决优化问题时比梯度下降更有效。
  • 在 :numref:sec_minibatch_sgd中,我们学习了:在一个小批量中使用更大的观测值集,可以通过向量化提供额外效率。这是高效的多机、多GPU和整体并行处理的关键。
  • 在 :numref:sec_momentum中我们添加了一种机制,用于汇总过去梯度的历史以加速收敛。
  • 在 :numref:sec_adagrad中,我们通过对每个坐标缩放来实现高效计算的预处理器。
  • 在 :numref:sec_rmsprop中,我们通过学习率的调整来分离每个坐标的缩放。

Adam算法 :cite:Kingma.Ba.2014将所有这些技术汇总到一个高效的学习算法中。
不出预料,作为深度学习中使用的更强大和有效的优化算法之一,它非常受欢迎。
但是它并非没有问题,尤其是 :cite:Reddi.Kale.Kumar.2019表明,有时Adam算法可能由于方差控制不良而发散。
在完善工作中, :cite:Zaheer.Reddi.Sachan.ea.2018给Adam算法提供了一个称为Yogi的热补丁来解决这些问题。
下面我们了解一下Adam算法。

算法

Adam算法的关键组成部分之一是:它使用指数加权移动平均值来估算梯度的动量和二次矩,即它使用状态变量

v t ← β 1 v t − 1 + ( 1 − β 1 ) g t , s t ← β 2 s t − 1 + ( 1 − β 2 ) g t 2 . \begin{aligned} \mathbf{v}_t & \leftarrow \beta_1 \mathbf{v}_{t-1} + (1 - \beta_1) \mathbf{g}_t, \\ \mathbf{s}_t & \leftarrow \beta_2 \mathbf{s}_{t-1} + (1 - \beta_2) \mathbf{g}_t^2. \end{aligned} vtstβ1vt1+(1β1)gt,β2st1+(1β2)gt2.

这里 β 1 \beta_1 β1 β 2 \beta_2 β2是非负加权参数。
常将它们设置为 β 1 = 0.9 \beta_1 = 0.9 β1=0.9 β 2 = 0.999 \beta_2 = 0.999 β2=0.999
也就是说,方差估计的移动远远慢于动量估计的移动。
注意,如果我们初始化 v 0 = s 0 = 0 \mathbf{v}_0 = \mathbf{s}_0 = 0 v0=s0=0,就会获得一个相当大的初始偏差。
我们可以通过使用 ∑ i = 0 t β i = 1 − β t 1 − β \sum_{i=0}^t \beta^i = \frac{1 - \beta^t}{1 - \beta} i=0tβi=1β1βt来解决这个问题。
相应地,标准化状态变量由下式获得

v ^ t = v t 1 − β 1 t and  s ^ t = s t 1 − β 2 t . \hat{\mathbf{v}}_t = \frac{\mathbf{v}_t}{1 - \beta_1^t} \text{ and } \hat{\mathbf{s}}_t = \frac{\mathbf{s}_t}{1 - \beta_2^t}. v^t=1β1tvt and s^t=1β2tst.

有了正确的估计,我们现在可以写出更新方程。
首先,我们以非常类似于RMSProp算法的方式重新缩放梯度以获得

g t ′ = η v ^ t s ^ t + ϵ . \mathbf{g}_t' = \frac{\eta \hat{\mathbf{v}}_t}{\sqrt{\hat{\mathbf{s}}_t} + \epsilon}. gt=s^t +ϵηv^t.

与RMSProp不同,我们的更新使用动量 v ^ t \hat{\mathbf{v}}_t v^t而不是梯度本身。
此外,由于使用 1 s ^ t + ϵ \frac{1}{\sqrt{\hat{\mathbf{s}}_t} + \epsilon} s^t +ϵ1而不是 1 s ^ t + ϵ \frac{1}{\sqrt{\hat{\mathbf{s}}_t + \epsilon}} s^t+ϵ 1进行缩放,两者会略有差异。
前者在实践中效果略好一些,因此与RMSProp算法有所区分。
通常,我们选择 ϵ = 1 0 − 6 \epsilon = 10^{-6} ϵ=106,这是为了在数值稳定性和逼真度之间取得良好的平衡。

最后,我们简单更新:

x t ← x t − 1 − g t ′ . \mathbf{x}_t \leftarrow \mathbf{x}_{t-1} - \mathbf{g}_t'. xtxt1gt.

回顾Adam算法,它的设计灵感很清楚:
首先,动量和规模在状态变量中清晰可见,
它们相当独特的定义使我们移除偏项(这可以通过稍微不同的初始化和更新条件来修正)。
其次,RMSProp算法中两项的组合都非常简单。
最后,明确的学习率 η \eta η使我们能够控制步长来解决收敛问题。

实现

从头开始实现Adam算法并不难。
为方便起见,我们将时间步 t t t存储在hyperparams字典中。
除此之外,一切都很简单。

%matplotlib inline
import torch
from d2l import torch as d2ldef init_adam_states(feature_dim):v_w, v_b = torch.zeros((feature_dim, 1)), torch.zeros(1)s_w, s_b = torch.zeros((feature_dim, 1)), torch.zeros(1)return ((v_w, s_w), (v_b, s_b))def adam(params, states, hyperparams):beta1, beta2, eps = 0.9, 0.999, 1e-6for p, (v, s) in zip(params, states):with torch.no_grad():v[:] = beta1 * v + (1 - beta1) * p.grads[:] = beta2 * s + (1 - beta2) * torch.square(p.grad)v_bias_corr = v / (1 - beta1 ** hyperparams['t'])s_bias_corr = s / (1 - beta2 ** hyperparams['t'])p[:] -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr)+ eps)p.grad.data.zero_()hyperparams['t'] += 1

现在,我们用以上Adam算法来训练模型,这里我们使用 η = 0.01 \eta = 0.01 η=0.01的学习率。

data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(adam, init_adam_states(feature_dim),{'lr': 0.01, 't': 1}, data_iter, feature_dim);
loss: 0.244, 0.015 sec/epoch

在这里插入图片描述

此外,我们可以用深度学习框架自带算法应用Adam算法,这里我们只需要传递配置参数。

trainer = torch.optim.Adam
d2l.train_concise_ch11(trainer, {'lr': 0.01}, data_iter)
loss: 0.254, 0.015 sec/epoch

在这里插入图片描述

Yogi

Adam算法也存在一些问题:
即使在凸环境下,当 s t \mathbf{s}_t st的二次矩估计值爆炸时,它可能无法收敛。
:cite:Zaheer.Reddi.Sachan.ea.2018 s t \mathbf{s}_t st提出了的改进更新和参数初始化。
论文中建议我们重写Adam算法更新如下:

s t ← s t − 1 + ( 1 − β 2 ) ( g t 2 − s t − 1 ) . \mathbf{s}_t \leftarrow \mathbf{s}_{t-1} + (1 - \beta_2) \left(\mathbf{g}_t^2 - \mathbf{s}_{t-1}\right). stst1+(1β2)(gt2st1).

每当 g t 2 \mathbf{g}_t^2 gt2具有值很大的变量或更新很稀疏时, s t \mathbf{s}_t st可能会太快地“忘记”过去的值。
一个有效的解决方法是将 g t 2 − s t − 1 \mathbf{g}_t^2 - \mathbf{s}_{t-1} gt2st1替换为 g t 2 ⊙ s g n ( g t 2 − s t − 1 ) \mathbf{g}_t^2 \odot \mathop{\mathrm{sgn}}(\mathbf{g}_t^2 - \mathbf{s}_{t-1}) gt2sgn(gt2st1)
这就是Yogi更新,现在更新的规模不再取决于偏差的量。

s t ← s t − 1 + ( 1 − β 2 ) g t 2 ⊙ s g n ( g t 2 − s t − 1 ) . \mathbf{s}_t \leftarrow \mathbf{s}_{t-1} + (1 - \beta_2) \mathbf{g}_t^2 \odot \mathop{\mathrm{sgn}}(\mathbf{g}_t^2 - \mathbf{s}_{t-1}). stst1+(1β2)gt2sgn(gt2st1).

论文中,作者还进一步建议用更大的初始批量来初始化动量,而不仅仅是初始的逐点估计。

def yogi(params, states, hyperparams):beta1, beta2, eps = 0.9, 0.999, 1e-3for p, (v, s) in zip(params, states):with torch.no_grad():v[:] = beta1 * v + (1 - beta1) * p.grads[:] = s + (1 - beta2) * torch.sign(torch.square(p.grad) - s) * torch.square(p.grad)v_bias_corr = v / (1 - beta1 ** hyperparams['t'])s_bias_corr = s / (1 - beta2 ** hyperparams['t'])p[:] -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr)+ eps)p.grad.data.zero_()hyperparams['t'] += 1data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(yogi, init_adam_states(feature_dim),{'lr': 0.01, 't': 1}, data_iter, feature_dim);
loss: 0.245, 0.015 sec/epoch

在这里插入图片描述

小结

  • Adam算法将许多优化算法的功能结合到了相当强大的更新规则中。
  • Adam算法在RMSProp算法基础上创建的,还在小批量的随机梯度上使用EWMA。
  • 在估计动量和二次矩时,Adam算法使用偏差校正来调整缓慢的启动速度。
  • 对于具有显著差异的梯度,我们可能会遇到收敛性问题。我们可以通过使用更大的小批量或者切换到改进的估计值 s t \mathbf{s}_t st来修正它们。Yogi提供了这样的替代方案。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/273495.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux-socket套接字

前言 在当今数字化时代,网络通信作为连接世界的桥梁,成为计算机科学领域中至关重要的一部分。理解网络编程是每一位程序员必备的技能之一,而掌握套接字编程则是深入了解网络通信的关键。本博客将深入讨论套接字编程中的基本概念、常见API以及…

GitOps实践之Argo CD (2)

argocd 【-1】argocd可以解决什么问题? helm 部署是手动的?依赖流水线。而有时候仅仅更新一个小东西,流水线跑好久,CD真的不应该和CI耦合。不同环境的helm配置不同,手动修改问题多,可以用git管理起来,例如分不同环境用目录区分。argocd创建应用可以不通环境部署到不同集…

Langchain-Chatchat本地搭建ChatGLM3模型和提取PDF内容

文章目录 1、软件要求2、安装CUDA2.1、安装gcc2.2、安装CUDA 3、安装Anaconda33.1、下载Anaconda33.2、创建python虚拟环境 4、部署系统4.1、下载源码4.2、安装依赖4.3、下载模型4.4、初始化配置和知识库4.4.1、初始化配置4.4.2、初始化知识库 4.5、运行4.6、运行4.6.1、启动4.…

C语言编译成库文件的要求

keil编译成库文件 在Keil中,将C语言源文件编译成库文件通常需要进行以下步骤: 创建一个新的Keil项目,并将所需的C语言源文件添加到该项目中。 在项目设置中配置编译选项,确保生成的目标文件符合库文件的标准格式。 编译项目&…

基于PHP的餐厅管理系统APP设计与实现

目 录 摘 要 I Abstract II 引 言 1 1 相关技术 3 1.1 MVC 3 1.2 ThinkPHP 3 1.3 MySQL数据库 3 1.4 uni-app 4 1.5 本章小结 4 2 系统分析 5 2.1 功能需求 5 2.2 用例分析 7 2.3 非功能需求 8 2.4 本章小结 8 3 系统设计 9 3.1 系统总体设计 9 3.2 系统详细设计 10 3.3 本章小…

基于Java+springboot+VUE+redis实现的前后端分类版网上商城项目

基于Java springbootVUEredis实现的前后端分类版网上商城项目 博主介绍:多年java开发经验,专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》 欢迎点赞 收藏 ⭐留言…

Ajax、Axios、Vue、Element与其案例

目录 一.Ajax 二.Axios 三.Vue 四.Element 五.增删改查案例 一.依赖:数据库,mybatis,servlet,json-对象转换器 二.资源:elementvueaxios 三.pojo 四.mapper.xml与mapper接口 五.service 六.servlet 七.html页…

css flex 布局换行

默认使用display: flex;是不换行的,只需要加上flex-wrap: wrap;就行了,效果图 .app-center {display: flex;flex-wrap: wrap;justify-content:flex-start; } 通过上面我们发现虽然时间换行了,但是每行的边距不一样 加上这个就行了&#xff…

微信小程序-分包

分包 1.什么是分包 分包指的是把一个完整的小程序项目,按照需求划分为不同的子包,在构建时打包成不同的分包,用户在使用时按需进行加载。 2.分包的好处 对小程序进行分包的好处主要有以下两点: 可以优化小程序首次启动的下载时间…

二维码图案样式怎么改?二维码改样式的简单方法

怎么修改二维码图案的样式呢?一般情况下生成的二维码图案大多是黑白的普通样式,那么很多人会为了提高展现效果或者增加辨识度,需要修改二维码的图案样式、添加logo、文字等其他内容,那么面对这样的需求该如何解决呢?下…

docker学习(十四)docker搭建私服

docker私服搭建,配置域名访问,设置访问密码 启动registry docker run -d \-p 5000:5000 \-v /opt/data/registry:/var/lib/registry \registrydocker pull hello-world docker tag hello-world 127.0.0.1:5000/hello-world docker push 127.0.0.1:5000…

SQL中如何添加数据

SQL中如何添加数据 一、SQL中如何添加数据(方法汇总)二、SQL中如何添加数据(方法详细解说)1. 使用SQL脚本(推荐)1.1 在表中插入1.1.1 **第一种形式**1.1.2 **第二种形式**SQL INSERT INTO 语法示例SQL INSE…

Keepalived+LVS构建高可用集群

目录 一、Keepalive基础介绍 1. Keepalive与VRRP 2. VRRP相关技术 3. 工作原理 4. 模块 5. 架构 6. 安装 7. Keepalived 相关文件 7.1 配置组成 7.2 全局配置 7.3 VRRP实例配置(lvs调度器) 7.4 虚拟服务器与真实服务器配置 二、Keepalived…

IDEA管理Git + Gitee 常用操作

文章目录 IDEA管理Git Gitee 常用操作1.Gitee创建代码仓库1.创建仓库1.点击新建仓库2.完成仓库信息填写3.创建成功4.管理菜单可以修改这个项目的设置 2.设置SSH公钥免密登录基本介绍1.找到.ssh目录2.执行指令 ssh-keygen3.将公钥信息添加到码云账户1.点击设置2.ssh公钥3.复制.…

软件开发服务合同套用模板

一、合作方式 二、合同标的 三、开发进度及软件成果交付 四、开发费用 五、付款结算方式 六、知识产权条款 七、双方的权利和义务 八、验收 九、售后服务支持 十、培训 十一、保密责任 十二、不可抗力 十三、争议的解决 十四、其它事项 软件开发全套资料获取下载…

python之海龟绘图

海龟绘图(turtle)是一个Python内置的绘图库,也被称为“Turtle Graphics”或简称“Turtles”。它采用了一种有趣的绘图方式,模拟一只小海龟在屏幕上爬行,而小海龟爬行的路径就形成了绘制的图形。这种绘图方式最初源自20…

elasticsearch篇:RestClient操作

1. RestClient ES官方提供了各种不同语言的客户端,用来操作ES。这些客户端的本质就是组装DSL语句,通过http请求发送给ES。官方文档地址:Elasticsearch Clients | Elastic 其中的Java Rest Client又包括两种: Java Low Level Res…

Microsoft SQL Server 编写汉字转拼音函数

目录 应用场景 举例 函数实现 小结 应用场景 在搜索应用中,我们一般会提供一个搜索框,输入关健字,点击查询按钮以获取结果数据。大部分情况我们会提供模糊查询的形式以在一个或多个字段进行搜索以获取结果。这样可以简化用户的操作&…

jvm八股

文章目录 运行时数据区域Java堆对象创建对象的内存布局对象的访问定位句柄直接指针 GC判断对象是否已死引用计数算法可达性分析算法 引用的类别垃圾收集算法分代收集理论标记清除算法标记复制算法标记整理算法 实现细节并发的可达性分析 垃圾收集器serial收集器ParNew收集器Par…

【PyTorch实战演练】深入剖析MTCNN(多任务级联卷积神经网络)并使用30行代码实现人脸识别

文章目录 0. 前言1. 级联神经网络介绍2. MTCNN介绍2.1 MTCNN提出背景2.2 MTCNN结构 3. MTCNN PyTorch实战3.1 facenet_pytorch库中的MTCNN3.2 识别图像数据3.3 人脸识别3.4 关键点定位 0. 前言 按照国际惯例,首先声明:本文只是我自己学习的理解&#xff…