【知识】PyTorch中不同优化器的特点和使用

转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn]

如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~

目录

1. SGD(随机梯度下降)

2. Adam(自适应矩估计)

3. AdamW

4. Adagrad

5. Adadelta

6. Adafactor

7. SparseAdam

8. Adamax

9. LBFGS

10. RMSprop

11. Rprop(弹性反向传播)

12. ASGD(平均随机梯度下降)

13. NAdam(Nesterov 加速自适应矩估计)

14. RAdam(修正 Adam)

15. Adafactor(自适应因子化梯度)

16. AMSGrad 

性能考虑

总结


torch.optim — PyTorch 2.6 documentation

1. SGD(随机梯度下降)

  • 用途:适用于小型到中型模型的基本优化。

  • 特点

    • 通过负梯度方向更新参数。

    • 可以包含动量(momentum)以加速学习并减少震荡。

    • 简单且广泛使用,但需要仔细调整学习率。

import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的模型
model = nn.Linear(10, 1)  # 一个线性模型
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)# 训练循环
for input, target in dataloader:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()

2. Adam(自适应矩估计)

  • 用途:深度学习模型,尤其是需要 L2 正则化时。

  • 特点

    • 根据一阶和二阶矩估计为每个参数计算自适应学习率。

    • 支持学习率衰减的无偏估计。

    • 通常在适当设置下比 SGD 收敛更快。

optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, amsgrad=False)

3. AdamW

  • 用途:迁移学习、视觉任务,以及权重衰减关键的场景。

  • 特点

    • 将权重衰减与梯度解耦,使其更有效。

    • 在某些场景下性能超过 Adam 和 SGD。

optimizer = optim.AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)

4. Adagrad

  • 用途:处理稀疏数据,例如自然语言处理或图像识别。

  • 特点

    • 累积之前的平方梯度以调整学习率。

    • 随着训练的进行,学习率单调递减,有助于收敛。

optimizer = optim.Adagrad(model.parameters(), lr=0.01, lr_decay=0.0, weight_decay=0.0, initial_accumulator_value=0.0, eps=1e-10)

5. Adadelta

  • 用途:文本数据处理和图像分类。

  • 特点

    • 通过使用窗口和解决 Adagrad 的学习率递减问题。

    • 维护平方梯度和平方参数更新的运行平均值。

optimizer = optim.Adadelta(model.parameters(), lr=1.0, rho=0.9, eps=1e-06, weight_decay=0.0)

6. Adafactor

  • 用途:大规模模型、大批量或长序列(例如深度学习在网页文本语料库上的应用)。

  • 特点

    • 通过使用近似值替换二阶矩来减少计算开销。

    • 专为非常大的模型设计,不会增加训练时间。

optimizer = optim.Adafactor(model.parameters(), lr=0.05, eps=(1e-30, 1e-3), clip_threshold=1.0, decay_rate=-0.8, beta1=None, weight_decay=0.0, scale_parameter=True, relative_step=True, warmup_init=False)

7. SparseAdam

  • 用途:具有稀疏梯度数据的模型,例如 NLP 中的嵌入层。

  • 特点

    • 优化稀疏张量更新;结合 SparseAdam 用于密集张量和 Adagrad 用于稀疏更新。

    • 专为具有许多零值的参数设计。

optimizer = optim.SparseAdam(model.sparse_parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08)

8. Adamax

  • 用途:类似于 Adam,但基于无穷范数,某些问题上更稳定。

  • 特点

    • 使用过去梯度的最大值而不是平均值。

optimizer = optim.Adamax(model.parameters(), lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0)

9. LBFGS

  • 用途:无约束优化问题、回归以及需要二阶信息的问题。

  • 特点

    • 使用梯度评估近似海森矩阵的拟牛顿方法。

    • 比 SGD 或 Adam 需要更多内存和计算资源。

optimizer = optim.LBFGS(model.parameters(), lr=1.0, max_iter=20, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn=None)# 使用 LBFGS 需要提供一个闭包(closure)来重新评价模型
def closure():optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()return lossoptimizer.step(closure)

10. RMSprop

  • 用途:卷积神经网络和递归神经网络。

  • 特点

    • 维护平方梯度的运行平均值,并对参数更新进行归一化。

    • 解决 Adagrad 学习率单调递减的问题。

optimizer = optim.RMSprop(model.parameters(), lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0.0, momentum=0.0, centered=False)

11. Rprop(弹性反向传播)

  • 用途:神经网络中梯度大小不重要的场景。

  • 特点

    • 仅使用梯度的符号来更新参数,根据梯度符号变化调整学习率。

optimizer = optim.Rprop(model.parameters(), lr=0.01, etas=(0.5, 1.2), step_sizes=(1e-06, 50.0))

12. ASGD(平均随机梯度下降)

  • 用途:促进某些模型的泛化。

  • 特点

    • 维护优化过程中遇到的参数的运行平均值。

optimizer = optim.ASGD(model.parameters(), lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0.0)

13. NAdam(Nesterov 加速自适应矩估计)

  • 用途:结合 Nesterov 动量和 Adam。

  • 特点

    • 结合 Nesterov 加速梯度(NAG)以提供更稳定的收敛。

optimizer = optim.NAdam(model.parameters(), lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, momentum_decay=0.004)

14. RAdam(修正 Adam)

  • 用途:需要自适应学习率但希望减少方差的场景。

  • 特点

    • 根据梯度方差动态调整学习率。

optimizer = optim.RAdam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0)

15. Adafactor(自适应因子化梯度)

  • 用途:大规模模型的内存高效优化。

  • 特点

    • 通过将大梯度分解为小成分来减少内存使用。

optimizer = optim.Adafactor(model.parameters(), lr=0.3, eps=(1e-30, 1e-3), clip_threshold=1.0, decay_rate=-0.8, beta1=None, weight_decay=0.0, scale_parameter=True, relative_step=True, warmup_init=False)

16. AMSGrad 

  • 用途: Adam 优化器的一种改进版本,旨在解决 Adam 在某些情况下可能不收敛的问题。它通过保留梯度的历史信息来防止学习率过早下降,从而提高训练的稳定性和收敛性。
  • 特点
    • 自适应学习率:AMSGrad 自适应地调整学习率,以便更好地训练神经网络。

    • 防止震荡:它可以防止 Adam 算法中的震荡现象,从而提高训练效果。

    • 改进收敛性:通过优化二阶动量,避免了 Adam 算法可能遭遇的收敛问题,特别适合长时间训练或解决深层网络难题。

# 初始化 AMSGrad 优化器,通过amsgrad参数设置
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=True)

性能考虑

  • 这些优化器的性能可能因硬件和问题的性质而异。PyTorch 将优化器分为以下几类:

    • For-loop:基本实现,但由于内核调用较慢。

    • Foreach:使用多张量操作以加快处理速度。

    • Fused:将步骤合并为单个内核以实现最大速度。

总结

  • 选择优化器取决于问题的复杂性、数据的稀疏性和硬件的可用性。像 Adam 或 AdamW 这样的自适应算法因其通用有效性而被广泛使用,而像 SGD 这样的简单方法在适当调整超参数时是最优的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/23499.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Confluence知识库管理系统安装步骤(Windows版本)

我们介绍的是安装7.15.1以下版本的安装方式,8.0以上的安装方式暂不支持。 如果你要安装8.0以上的版本,请参考本文末尾的附录中提供的相关网址。 首先我们安装之前需要准备安装所需文件以上文件可以在这里下载:【https://download.csdn.net/download/Elegant_Kevin/90412040】…

Uniapp 开发中遇到的坑与注意事项:全面指南

文章目录 1. 引言Uniapp 简介开发中的常见问题本文的目标与结构 2. 环境配置与项目初始化环境配置问题解决方案 项目初始化注意事项解决方案 常见错误与解决方案 3. 页面与组件开发页面生命周期注意事项示例代码 组件通信与复用注意事项示例代码 样式与布局问题注意事项示例代码…

学习笔记--电磁兼容性EMC

一、基本概念 电磁兼容性(Electromagnetic Compatibility,EMC)是电子电气设备在特定电磁环境中正常工作的能力,同时不会对其他设备产生不可接受的电磁干扰。其核心目标是确保设备在共享的电磁环境中既能抵抗干扰,又能避…

Unity百游修炼(2)——Brick_Breaker详细制作全流程

一、项目简介 Brick Breaker 是一款经典的打砖块游戏,本次案例将使用 Unity 引擎来实现该游戏的核心功能。 游戏画面如下: Brick_ breaker 二、项目结构概览和前期准备 (1)在 Unity 项目视图中,我们可以看到几个重要…

Java基础常见的面试题(易错!!)

面试题一:为什么 Java 不支持多继承 Java 不支持多继承主要是为避免 “菱形继承问题”(又称 “钻石问题”),即一个子类从多个父类继承到同名方法或属性时,编译器无法确定该调用哪个父类的成员。同时,多继承…

算法题(77):数组中的第k个最大元素

审题: 需要我们在时间复杂度O(n)的前提下找到数组中第k个最大元素 思路: 方法一:建堆实现 首先写一个dowmset函数,实现对第i个索引位置的向下调整。然后创建build函数,利用dowmset实现向下调整建堆,再根据k…

PCIe学习笔记1:PCIe体系架构——PCIe简介

目录 一、PCIe简介 1.1 串行传输 1.1.1 相对于并行传输的优化 1.1.2 带宽计算 1.1.3 差分信号传输 1.1.4 基于数据包的传输协议 1.2 PCIe的系统拓扑结构 1.2.1 根组件(Root Complex,RC) 1.2.2 上行端口与下行端口 1.2.3 交换机与桥 …

一天记20个忘10个之4:man

据说,给你一个支点,你就能撬起地球。 那好,今天,我给你一个 man,如果你能完成记20个忘10个的任务,你就真的很 man 了。 零、热身 young manold manmedical man 一、man之复合词 1.1 man复合词 chairm…

SpringBoot之自定义简单的注解和AOP

1.引入依赖 <!-- AOP依赖--> <dependency><groupId>org.aspectj</groupId><artifactId>aspectjweaver</artifactId><version>1.9.8</version> </dependency>2.自定义一个注解 package com.example.springbootdemo3.an…

利用开源小智AI制作桌宠机器狗

本文主要介绍如何利用开源小智AI制作桌宠机器狗 1 源码下载 首先下载小智源码,下载地址, 下载源码后,使用vsCode打开,需要在vscode上安装esp-idf,安装方式请自己解决 2 源码修改 2.1添加机器狗控制代码 在目录main/iot/things下添加dog.cc文件,内容如下; #include…

深入理解IP子网掩码子网划分{作用} 以及 不同网段之间的ping的原理 以及子网掩码的区域划分

目录 子网掩码详解 子网掩码定义 子网掩码进一步解释 子网掩码的作用 计算总结表 子网掩码计算 子网掩码对应IP数量计算 判断IP是否在同一网段 1. 计算步骤 2. 示例 3. 关键点 总结 不同网段通信原理与Ping流程 1. 同网段通信 2. 跨网段通信 网段计算示例 3. P…

利用python和gpt写一个conda环境可视化管理工具

最近在学习python&#xff0c;由于不同的版本之间的差距较大&#xff0c;如果是用环境变量来配置python的话&#xff0c;会需要来回改&#xff0c;于是请教得知可以用conda来管理&#xff0c;但是conda在管理的时候老是要输入命令&#xff0c;感觉也很烦&#xff0c;于是让gpt帮…

Linux内核,slub分配流程

我们根据上面的流程图&#xff0c;依次看下slub是如何分配的 首先从kmem_cache_cpu中分配&#xff0c;如果没有则从kmem_cache_cpu的partial链表分配&#xff0c;如果还没有则从kmem_cache_node中分配&#xff0c;如果kmem_cache_node中也没有&#xff0c;则需要向伙伴系统申请…

使用Windbg调试目标进程排查C++软件异常的一般步骤与要点分享

目录 1、概述 2、将Windbg附加到已经启动起来的目标进程上&#xff0c;或者用Windbg启动目标程序 2.1、将Windbg附加到已经启动起来的目标进程上 2.2、用Windbg启动目标程序 2.3、Windbg关联到目标进程上会中断下来&#xff0c;输入g命令将该中断跳过去 3、分析实例说明 …

51单片机测试题AI作答测试(DeepSeek Kimi)

单片机测试题 DeepSeek Kimi 单项选择题 &#xff08;10道&#xff09; 6题8题判断有误 6题判断有误 智谱清言6题靠谱&#xff0c;但仔细斟酌&#xff0c;题目出的貌似有问题&#xff0c;详见 下方。 填空题 &#xff08;9道&#xff09; 脉宽调制&#xff08;Pulse …

模版语法vscode

这里注意&#xff1a;<template></template>里面只能写一个根标签&#xff0c;其他在嵌套&#xff1a; <script > export default {data(){return{tthtml:"<a hrefhttps://itbaizhan.com>百战程序员</a>"}} } </script><tem…

洛谷B3637 最长上升子序

B3637 最长上升子序列 - 洛谷 代码区&#xff1a; #include<bits/stdc.h>using namespace std;int main(){int n;cin >> n;int arry[n],dp[n];for(int i0;i<n;i){cin >>arry[i];dp[i]1;}/*在 i 之前可能存在多个 j 满足 arry[j] < arry[i]&#xff0c…

kotlin 知识点 七 泛型的高级特性

对泛型进行实化 泛型实化这个功能对于绝大多数Java 程序员来讲是非常陌生的&#xff0c;因为Java 中完全没有这个概 念。而如果我们想要深刻地理解泛型实化&#xff0c;就要先解释一下Java 的泛型擦除机制才行。 在JDK 1.5之前&#xff0c;Java 是没有泛型功能的&#xff0c;…

Day 49 卡玛笔记

这是基于代码随想录的每日打卡 1143. 最长公共子序列 给定两个字符串 text1 和 text2&#xff0c;返回这两个字符串的最长 公共子序列 的长度。如果不存在 公共子序列 &#xff0c;返回 0 。 一个字符串的 子序列 是指这样一个新的字符串&#xff1a;它是由原字符串在不改变…

重新求职刷题DAY18

1.513. 找树左下角的值 给定一个二叉树的 根节点 root&#xff0c;请找出该二叉树的 最底层 最左边 节点的值。 假设二叉树中至少有一个节点。 示例 1: 外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 输入: root [2,1,3] 输出: 1思路&#xff1a; 这…