【漫话机器学习系列】087.常见的神经网络最优化算法(Common Optimizers Of Neural Nets)

常见的神经网络优化算法

1. 引言

在深度学习中,优化算法(Optimizers)用于更新神经网络的权重,以最小化损失函数(Loss Function)。一个高效的优化算法可以加速训练过程,并提高模型的性能和稳定性。本文介绍几种常见的神经网络优化算法,包括随机梯度下降(SGD)、带动量的随机梯度下降(Momentum SGD)、均方根传播算法(RMSProp)以及自适应矩估计(Adam),并提供相应的代码示例。

2. 常见的优化算法

2.1 随机梯度下降(Stochastic Gradient Descent, SGD)

随机梯度下降(SGD)是最基本的优化算法,其更新规则如下:

其中:

  • w 代表网络参数(权重);
  • α 是学习率(Learning Rate),控制更新步长;
  • ∇L(w) 是损失函数相对于权重的梯度。

代码示例(使用 PyTorch 实现 SGD)

import torch
import torch.nn as nn
import torch.optim as optim# 定义简单的线性模型
model = nn.Linear(1, 1)  # 1 个输入特征,1 个输出特征
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降# 训练步骤
for epoch in range(100):optimizer.zero_grad()  # 清空梯度inputs = torch.tensor([[1.0]], requires_grad=True)targets = torch.tensor([[2.0]])outputs = model(inputs)loss = criterion(outputs, targets)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数if epoch % 10 == 0:print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')

运行结果

Epoch [0/100], Loss: 4.9142
Epoch [10/100], Loss: 2.1721
Epoch [20/100], Loss: 0.9601
Epoch [30/100], Loss: 0.4244
Epoch [40/100], Loss: 0.1876
Epoch [50/100], Loss: 0.0829
Epoch [60/100], Loss: 0.0366
Epoch [70/100], Loss: 0.0162
Epoch [80/100], Loss: 0.0072
Epoch [90/100], Loss: 0.0032


2.2 带动量的随机梯度下降(Momentum SGD)

带动量的 SGD 在 SGD 的基础上加入动量(Momentum),用于加速收敛并减少震荡:


其中:

  • 是累积的梯度,类似于物理中的动量;
  • β 是动量系数(通常取 0.9)。

代码示例(Momentum SGD)

import torch
import torch.nn as nn
import torch.optim as optimmodel = nn.Linear(1, 1)  # 1 个输入特征,1 个输出特征
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)for epoch in range(100):optimizer.zero_grad()inputs = torch.tensor([[1.0]], requires_grad=True)targets = torch.tensor([[2.0]])outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()if epoch % 10 == 0:print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')

运行结果 

Epoch [0/100], Loss: 3.0073
Epoch [10/100], Loss: 1.3292
Epoch [20/100], Loss: 0.5875
Epoch [30/100], Loss: 0.2597
Epoch [40/100], Loss: 0.1148
Epoch [50/100], Loss: 0.0507
Epoch [60/100], Loss: 0.0224
Epoch [70/100], Loss: 0.0099
Epoch [80/100], Loss: 0.0044
Epoch [90/100], Loss: 0.0019

优点:

  • 缓解了 SGD 震荡问题,提高收敛速度;
  • 在非凸优化问题中表现更好。

2.3 均方根传播算法(RMSProp)

RMSProp 通过自适应调整学习率来加速训练,并缓解震荡问题:


其中:

  • 是梯度平方的滑动平均;
  • β 是衰减系数(一般取 0.9);
  • ϵ 是一个很小的数,防止除零错误。

代码示例(RMSProp)

import torch
import torch.nn as nn
import torch.optim as optim# 定义简单的线性模型
model = nn.Linear(1, 1)  # 1 个输入特征,1 个输出特征
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.RMSprop(model.parameters(), lr=0.01, alpha=0.9)for epoch in range(100):optimizer.zero_grad()inputs = torch.tensor([[1.0]], requires_grad=True)targets = torch.tensor([[2.0]])outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()if epoch % 10 == 0:print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')

运行结果

Epoch [0/100], Loss: 1.1952
Epoch [10/100], Loss: 0.5887
Epoch [20/100], Loss: 0.3333
Epoch [30/100], Loss: 0.1731
Epoch [40/100], Loss: 0.0752
Epoch [50/100], Loss: 0.0239
Epoch [60/100], Loss: 0.0043
Epoch [70/100], Loss: 0.0003
Epoch [80/100], Loss: 0.0000
Epoch [90/100], Loss: 0.0000

优点:

  • 适用于非平稳目标函数;
  • 能有效处理不同特征尺度的问题;
  • 在 RNN(循环神经网络)等任务上表现较好。

2.4 自适应矩估计(Adam, Adaptive Moment Estimation)

Adam 结合了动量法(Momentum)和 RMSProp,同时考虑梯度的一阶矩(平均值)和二阶矩(方差):



其中:

  • ​ 是梯度的一阶矩估计;
  • ​ 是梯度的二阶矩估计;
  • ​ 分别控制一阶矩和二阶矩的指数衰减率(通常取 0.9 和 0.999)。

代码示例(Adam)

import torch
import torch.nn as nn
import torch.optim as optim# 定义简单的线性模型
model = nn.Linear(1, 1)  # 1 个输入特征,1 个输出特征
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.Adam(model.parameters(), lr=0.01)for epoch in range(100):optimizer.zero_grad()inputs = torch.tensor([[1.0]], requires_grad=True)targets = torch.tensor([[2.0]])outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()if epoch % 10 == 0:print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')

输出结果 

Epoch [0/100], Loss: 3.6065
Epoch [10/100], Loss: 2.8894
Epoch [20/100], Loss: 2.2642
Epoch [30/100], Loss: 1.7359
Epoch [40/100], Loss: 1.3021
Epoch [50/100], Loss: 0.9555
Epoch [60/100], Loss: 0.6855
Epoch [70/100], Loss: 0.4805
Epoch [80/100], Loss: 0.3287
Epoch [90/100], Loss: 0.2192

优点:

  • 结合 Momentum 和 RMSProp 的优势;
  • 适用于大规模数据集和高维参数优化;
  • 具有自适应学习率,适用于不同类型的问题。

3. 选择合适的优化算法

优化算法特点适用场景
SGD计算简单,但容易震荡适用于大规模数据,适合凸优化问题
Momentum SGD增加动量,减少震荡,加速收敛适用于复杂深度神经网络
RMSProp自适应调整学习率,适用于非平稳问题适用于 RNN、强化学习等
Adam结合 Momentum 和 RMSProp,自适应学习率适用于大多数深度学习任务

4. 结论

在神经网络训练过程中,优化算法的选择对最终的模型性能有重要影响。SGD 是最基础的优化方法,而带动量的 SGD 在收敛速度和稳定性上有所提升。RMSProp 适用于非平稳目标函数,而 Adam 结合了 Momentum 和 RMSProp 的优势,成为当前最流行的优化算法之一。

不同任务可能需要不同的优化算法,通常的建议是:

  • 对于简单的凸优化问题,可以使用 SGD。
  • 对于深度神经网络,可以使用 Momentum SGD 或 Adam。
  • 对于 RNN 和强化学习问题,RMSProp 是一个不错的选择。

合理选择优化算法可以显著提升模型训练的效率和效果!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/15228.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

4G核心网的演变与创新:从传统到虚拟化的跨越

4G核心网 随着移动通信技术的不断发展,4G核心网已经经历了从传统的硬件密集型架构到现代化、虚拟化网络架构的重大转型。这一演变不仅提升了网络的灵活性和可扩展性,也为未来的5G、物联网(LOT)和边缘计算等技术的发展奠定了基础。…

拉格朗日插值法的matlab实现

一、基本原理 比如有如下这些点 x1x2x3x4y1y2y3y4 那么在拉个朗日原理中可以把过这些点的曲线表示为: 其g(x)y叫做一个插值基函数(开关),当xx1时,g1(x)1,而当xx2,x3,x4时,g1(x)都为0&#xf…

使用WebStorm开发Vue3项目

记录一下使用WebStorm开发Vu3项目时的配置 现在WebStorm可以个人免费使用啦!?? 基本配置 打包工具:Vite 前端框架:ElementPlus 开发语言:Vue3、TypeScript、Sass 代码检查:ESLint、Prettier IDE:WebSt…

35~37.ppt

目录 35.张秘书-《会计行业中长期人才发展规划》 题目​ 解析 36.颐和园公园(25张PPT) 题目​ 解析 37.颐和园公园(22张PPT) 题目 解析 35.张秘书-《会计行业中长期人才发展规划》 题目 解析 插入自定义的幻灯片:新建幻灯片→重用…

day44 QT核心机制

头文件&#xff1a; #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include<QLabel> //标签类头文件 #include<QPushButton> //按钮类头文件 #include<QLineEdit> //行编辑器类头文件QT_BEGIN_NAMESPACE namespace Ui { class Widget; } …

kafka服务端之副本

文章目录 概述副本剖析失效副本ISR的伸缩LWLEO与HW的关联LeaderEpoch的介入数据丢失的问题数据不一致问题Leader Epoch数据丢失数据不一致 kafka为何不支持读写分离 日志同步机制可靠性分析 概述 Kafka中采用了多副本的机制&#xff0c;这是大多数分布式系统中惯用的手法&…

[笔记] 汇编杂记(持续更新)

文章目录 前言举例解释函数的序言函数的调用栈数据的传递 总结 前言 举例解释 // Type your code here, or load an example. int square(int num) {return num * num; }int sub(int num1, int num2) {return num1 - num2; }int add(int num1, int num2) {return num1 num2;…

mysql8.0使用MHA实现高可用

一、环境配置 本实验环境共有四个节点&#xff0c; 其角色分配如下&#xff08;实验机器均为centos 7.x &#xff09; 机器名称IP配置服务角色备注manager192.168.8.145manager控制器用于监控管理master192.168.8.143数据库主服务器开启bin-log relay-log 关闭relay_logslave…

<论文>DeepSeek-R1:通过强化学习激励大语言模型的推理能力(深度思考)

一、摘要 本文跟大家来一起阅读DeepSeek团队发表于2025年1月的一篇论文《DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning | Papers With Code》&#xff0c;新鲜的DeepSeek-R1推理模型&#xff0c;作者规模属实庞大。如果你正在使用Deep…

【Android开发AI实战】选择目标跟踪基于opencv实现——运动跟踪

文章目录 【Android 开发 AI 实战】选择目标跟踪基于 opencv 实现 —— 运动跟踪一、引言二、Android 开发与 AI 的融合趋势三、OpenCV 简介四、运动跟踪原理&#xff08;一&#xff09;光流法&#xff08;二&#xff09;卡尔曼滤波&#xff08;三&#xff09;粒子滤波 五、基于…

第1章 特征工程

原文&#xff1a;第1章 特征工程 俗话说&#xff0c;“巧妇难为无米之炊”。在机器学习中&#xff0c;数据和特征便是“米”&#xff0c;模型和算法则是“巧妇”。没有充足的数据、合适的特征&#xff0c;再强大的模型结构也无法得到满意的输出。正如一句业界经典的话所说&…

idea 如何使用deepseek 保姆级教程

1.安装idea插件codegpt 2.注册deepseek并生成apikey deepseek 开发平台&#xff1a; DeepSeek​​​​​​​ 3.在idea进行codegpt配置 打开idea的File->Settings->Tools->CodeGPT->Providers->Custom OpenAI Chat Completions的URL填写 https://api.deepseek…

多光谱成像技术在华为Mate70系列的应用

华为Mate70系列搭载了光谱技术的产物——红枫原色摄像头&#xff0c;这是一款150万像素的多光谱摄像头。 相较于普通摄像头&#xff0c;它具有以下优势&#xff1a; 色彩还原度高&#xff1a;色彩还原准确度提升约 120%&#xff0c;能捕捉更多光谱信息&#xff0c;使拍摄照片色…

10vue3实战-----实现登录的基本功能

10vue3实战-----实现登录的基本功能 1.基本页面的搭建2.账号登录的验证规则配置3.点击登录按钮4.表单的校验5.账号的登录逻辑和登录状态保存6.定义IAccount对象类型 1.基本页面的搭建 大概需要搭建成这样子的页面: 具体的搭建界面就不多讲。各个项目都有自己的登录界面&#…

vue学习5

1.自定义创建项目 2.ESlint代码规范 正规的团队需要统一的编码风格 JavaScript Standard Style 规范说明&#xff1a;https://standardjs.com/rules-zhcn.html 规则中的一部分&#xff1a; (1)字符串使用单引号 ‘aabc’ (2)无分号 const name ‘zs’ (3)关键字后加空格 if(n…

QTreeView和QTableView单元格添加超链接

QTreeView和QTableView单元格添加超链接的方法类似,本文仅以QTreeView为例。 在QTableView仿Excel表头排序和筛选中已经实现了超链接的添加,但是需要借助delegate,这里介绍一种更简单的方式,无需借助delegate。 一.效果 二.实现 QHTreeView.h #ifndef QHTREEVIEW_H #def…

Qt监控设备离线检测/实时监测设备上下线/显示不同的状态图标/海康大华宇视华为监控系统

一、前言说明 监控系统中一般有很多设备&#xff0c;有些用户希望知道每个设备是否已经上线&#xff0c;最好有不同的状态图标提示&#xff0c;海康的做法是对设备节点的图标和颜色变暗处理&#xff0c;离线的话就变暗&#xff0c;有可能是加了透明度&#xff0c;而大华的处理…

IDEA+DeepSeek让Java开发起飞

1.获取DeepSeek秘钥 登录DeepSeek官网 : https://www.deepseek.com/ 进入API开放平台&#xff0c;第一次需要注册一个账号 进去之后需要创建一个API KEY&#xff0c;然后把APIkey记录保存下来 接着我们获取DeepSeek的API对话接口地址&#xff0c;点击左边的&#xff1a;接口…

docker学习笔记

1.docker与虚拟机技术的不同 传统虚拟机&#xff1a;虚拟出一条硬件&#xff0c;运行一个完整的操作系统&#xff0c;然后在这个系统上安装和运行软件。容器内的应用直接运行在&#xff0c;宿主机的内容&#xff0c;容器是没有自己的内核的&#xff0c;也没有虚拟我们的硬件每…

Linux之kernel(4)netlink通信

Linux内核(04)之netlink通信 Author: Once Day Date: 2023年1月3日 一位热衷于Linux学习和开发的菜鸟&#xff0c;试图谱写一场冒险之旅&#xff0c;也许终点只是一场白日梦… 漫漫长路&#xff0c;有人对你微笑过嘛… 全系列文章可查看专栏: Linux内核知识_Once-Day的博客-…