Python 梯度下降法(五):Adam Optimize

文章目录

  • Python 梯度下降法(五):Adam Optimize
    • 一、数学原理
      • 1.1 介绍
      • 1.2 符号说明
      • 1.3 实现流程
    • 二、代码实现
      • 2.1 函数代码
      • 2.2 总代码
      • 2.3 遇到的问题
      • 2.4 算法优化
    • 三、优缺点
      • 3.1 优点
      • 3.2 缺点
    • 四、相关链接

Python 梯度下降法(五):Adam Optimize

一、数学原理

1.1 介绍

Adam 算法结合了 Adagrad 和 RMSProp 算法的优点。Adagrad 算法会根据每个参数的历史梯度信息来调整学习率,对于出现频率较低的参数会给予较大的学习率,而对于出现频率较高的参数则给予较小的学习率。RMSProp 算法则是对 Adagrad 算法的改进,它通过使用移动平均的方式来计算梯度的平方,从而避免了 Adagrad 算法中学习率单调下降的问题。

1.2 符号说明

参数意义
g t = ∇ θ J ( θ t ) g_{t}=\nabla_{\theta}J(\theta_{t}) gt=θJ(θt) t t t时刻的梯度
m t m_{t} mt梯度的一阶矩(均值)
β 1 \beta_{1} β1一阶矩衰减率,一般取0.9
v t v_{t} vt梯度的二阶矩(未中心化的方差)
β 2 \beta_{2} β2二阶矩衰减率,一般取0.99
θ \theta θ线性拟合参数
η \eta η学习率
ϵ \epsilon ϵ无穷小量,一般取 1 0 − 8 10^{-8} 108

1.3 实现流程

  1. 初始化: θ \theta θ η \eta η m 0 ⃗ = 0 \vec{m_{0}}=0 m0 =0 v 0 ⃗ = 0 \vec{v_{0}}=0 v0 =0
  2. 计算梯度: g t = ∇ θ J ( θ t ) = 1 m X T L g_{t}=\nabla_{\theta}J(\theta_{t})=\frac{1}{m}X^{T}L gt=θJ(θt)=m1XTL
  3. 梯度的一阶矩估计(均值): m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_{t}=\beta_{1}m_{t-1}+(1-\beta_{1})g_{t} mt=β1mt1+(1β1)gt
  4. 梯度的二阶矩估计(未中心化的方差): v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_{t}=\beta_{2}v_{t-1}+(1-\beta_{2})g_{t}^{2} vt=β2vt1+(1β2)gt2
  5. 偏差修正: m t ^ = m t 1 − β 1 t 、 v t ^ = v t 1 − β 2 t \hat{m_{t}}=\frac{m_{t}}{1-\beta_{1}^{t}}、\hat{v_{t}}=\frac{v_{t}}{1-\beta_{2}^{t}} mt^=1β1tmtvt^=1β2tvt
  6. 更新参数: θ t = θ t − 1 − η m t ^ v t ^ + ϵ \theta_{t}=\theta_{t-1}-\frac{\eta \hat{m_{t}}}{\sqrt{ \hat{v_{t}} }+\epsilon} θt=θt1vt^ +ϵηmt^

二、代码实现

2.1 函数代码

# 定义 Adam 函数
def adam_optimizer(X, y, eta, num_iter=1000, beta1=0.8, beta2=0.8, epsilon=1e-8, threshold=1e-8):"""X: 数据 x  mxn,可以在传入数据之前进行数据的归一化y: 数据 y  mx1eta: 学习率num_iter: 迭代次数beta: 衰减率epsilon: 无穷小threshold: 阈值"""m, n = X.shapetheta, mt, vt, loss_ = np.random.randn(n, 1), np.zeros((n, 1)), np.zeros((n, 1)), []  # 初始化数据for iter in range(num_iter):h = X.dot(theta)err = h - yloss_.append(np.mean((err ** 2) / 2))g = (1 / m ) * X.T.dot(err)# 一阶矩估计mt = beta1 * mt + (1 - beta1) * g# 二阶矩估计vt = beta2 * vt + (1 - beta2) * g ** 2# 偏差修正mt_ = mt / (1 - pow(beta1, (iter + 1)))  # 得 + 1 不然在 iter = 0 时,分母为零vt_ = np.abs(vt / (1 - pow(beta2, (iter + 1))))# 更新参数theta = theta - (eta * mt_) / (np.sqrt(vt_) + epsilon)# 检查是否收敛if iter > 1 and abs(loss_[-1] - loss_[-2]) < threshold:print(f"Converged at iteration {iter + 1}")breakreturn theta.flatten(), loss_

2.2 总代码

import numpy as np
import matplotlib.pyplot as plt# 设置 matplotlib 支持中文
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
plt.rcParams['axes.unicode_minus'] = False# 定义 Adam 函数
def adam_optimizer(X, y, eta, num_iter=1000, beta1=0.8, beta2=0.8, epsilon=1e-8, threshold=1e-8):"""X: 数据 x  mxn,可以在传入数据之前进行数据的归一化y: 数据 y  mx1eta: 学习率num_iter: 迭代次数beta: 衰减率epsilon: 无穷小threshold: 阈值"""m, n = X.shapetheta, mt, vt, loss_ = np.random.randn(n, 1), np.zeros((n, 1)), np.zeros((n, 1)), []  # 初始化数据for iter in range(num_iter):h = X.dot(theta)err = h - yloss_.append(np.mean((err ** 2) / 2))g = (1 / m ) * X.T.dot(err)# 一阶矩估计mt = beta1 * mt + (1 - beta1) * g# 二阶矩估计vt = beta2 * vt + (1 - beta2) * g ** 2# 偏差修正mt_ = mt / (1 - pow(beta1, (iter + 1)))  # 得 + 1 不然在 iter = 0 时,分母为零vt_ = np.abs(vt / (1 - pow(beta2, (iter + 1))))# 更新参数theta = theta - (eta * mt_) / (np.sqrt(vt_) + epsilon)# 检查是否收敛if iter > 1 and abs(loss_[-1] - loss_[-2]) < threshold:print(f"Converged at iteration {iter + 1}")breakreturn theta.flatten(), loss_# 生成一些示例数据
np.random.seed(42)
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)
# 添加偏置项
X_b = np.c_[np.ones((100, 1)), X]# 超参数
eta = 0.01# 运行 Adam 优化器
theta, loss_ = adam_optimizer(X_b, y, eta)print("最优参数 theta:")
print(theta)# 绘制损失函数图像
plt.plot(range(len(loss_)), loss_, label="损失函数图像")
plt.title("损失函数图像")
plt.xlabel("迭代次数")
plt.ylabel("损失值")
plt.legend()  # 显示图例
plt.grid(True)  # 显示网格线
plt.show()

1738332803_d7btmbrnt5.png1738332802724.png

2.3 遇到的问题

当偏差修正为以下算法时,出现报错:

        # 偏差修正mt_ = mt / (1 - pow(beta1, (iter)))vt_ = np.abs(vt / (1 - pow(beta2, (iter))))

1738332890_bekam4jjvm.png1738332889461.png

进行检验时,我们发现:

1738333012_ujmv5g46fw.png1738333011494.png

mt_,vt_ \text{mt\_,vt\_} mt_,vt_为无穷量,因此考虑分母为零的情况,而当 iter = 0 \text{iter}=0 iter=0时, 1 − β iter = 0 1- \beta^{\text{iter}}=0 1βiter=0,故说明索引不能从0开始,而应该从1开始,因此引入 iter + 1 \text{iter}+1 iter+1,防止分母的无穷大引入。

2.4 算法优化

由于算法过程中,如果数据量太多会引起资源的严重浪费,因此我们引入小批量梯度下降法的类似方法,批量截取数据来进行拟合。

# 定义 Adam 函数
def adam_optimizer(X, y, eta, num_iter=1000, batch_size=32, beta1=0.8, beta2=0.8, epsilon=1e-8, threshold=1e-8):"""X: 数据 x  mxn,可以在传入数据之前进行数据的归一化y: 数据 y  mx1eta: 学习率num_iter: 迭代次数batch_size: 小批量分支法的批量数beta: 衰减率epsilon: 无穷小threshold: 阈值"""m, n = X.shapetheta, mt, vt, loss_ = np.random.randn(n, 1), np.zeros((n, 1)), np.zeros((n, 1)), []  # 初始化数据num_batchs = m // batch_sizefor _ in range(num_iter):range_shuffle = np.random.permutation(m)X_shuffled = X[range_shuffle]y_shuffled = y[range_shuffle]loss_temp = []for iter in range(num_batchs):start_index = batch_size * iterend_index = start_index + batch_sizexi = X_shuffled[start_index:end_index]yi = y_shuffled[start_index:end_index]h = xi.dot(theta)err = h - yiloss_temp.append(np.mean((err ** 2) / 2))g = (1 / m ) * xi.T.dot(err)# 一阶矩估计mt = beta1 * mt + (1 - beta1) * g# 二阶矩估计vt = beta2 * vt + (1 - beta2) * g ** 2# 偏差修正mt_ = mt / (1 - pow(beta1, (iter + 1)))vt_ = np.abs(vt / (1 - pow(beta2, (iter + 1))))# 更新参数theta = theta - (eta * mt_) / (np.sqrt(vt_) + epsilon)loss_.append(np.mean(loss_temp))# 检查是否收敛if _ > 1 and abs(loss_[-1] - loss_[-2]) < threshold:print(f"Converged at iteration {iter + 1}")breakreturn theta.flatten(), loss_

1738333762_rdxih0p4h8.png1738333761148.png

使用小批量进行Adam优化,可以大大节省系统的资源。

三、优缺点

3.1 优点

对不同参数调整学习率:Adam 能够为模型的每个参数自适应地调整学习率。它会根据参数的梯度历史信息,对出现频率较低的参数给予较大的学习率,对出现频率较高的参数给予较小的学习率。这使得模型在训练过程中能够更好地处理不同尺度和变化频率的参数,加速收敛过程。

无需手动精细调整:在很多情况下,Adam 算法提供的默认超参数就能取得不错的效果,不需要像传统优化算法那样进行大量的手动调参,节省了时间和精力。

低内存需求:Adam 只需要存储梯度的一阶矩估计(均值)和二阶矩估计(未中心化的方差),不需要像一些二阶优化方法那样存储复杂的海森矩阵(Hessian matrix),因此内存占用相对较小,适合处理大规模数据集和深度神经网络。

快速收敛:通过结合梯度的一阶矩和二阶矩信息,Adam 能够更准确地估计梯度的方向和大小,从而在大多数情况下比传统的随机梯度下降(SGD)算法更快地收敛到最优解。

利用稀疏信息:在处理稀疏数据(如自然语言处理中的词向量)时,Adam 能够根据数据的稀疏性调整学习率。对于那些很少出现的特征,算法会给予较大的学习率,使得模型能够更有效地学习这些特征,避免因数据稀疏而导致的学习困难

偏差修正机制:Adam 算法引入了偏差修正机制,用于修正一阶矩和二阶矩估计在训练初期的偏差。这使得算法在训练的早期阶段更加稳定,能够避免因初始估计不准确而导致的训练波动或不收敛问题。

3.2 缺点

自适应特性的局限性:虽然 Adam 能够自适应地调整学习率,但在某些情况下,这种自适应特性可能会导致算法陷入局部最优解。由于学习率会随着训练过程自动调整,可能会在接近局部最优解时过早地降低学习率,使得算法难以跳出局部最优区域,从而无法找到全局最优解。

需要一定的调参经验:尽管 Adam 提供了默认的超参数,但在某些复杂的任务或数据集上,这些默认参数可能不是最优的。例如, β \beta β ϵ \epsilon ϵ的取值会影响算法的性能,如果选择不当,可能会导致收敛速度变慢、模型性能下降等问题。因此,在实际应用中,可能仍然需要进行一定的超参数调优。

过度适应训练数据:由于 Adam 算法在训练过程中过于关注梯度的历史信息和自适应调整学习率,可能会导致模型过度适应训练数据,从而降低模型的泛化能力。在某些情况下,使用 Adam 训练的模型在测试集上的表现可能不如使用其他优化算法训练的模型。

四、相关链接

Python 梯度下降法合集:

  • Python 梯度下降法(一):Gradient Descent-CSDN博客
  • Python 梯度下降法(二):RMSProp Optimize-CSDN博客
  • Python 梯度下降法(三):Adagrad Optimize-CSDN博客
  • Python 梯度下降法(四):Adadelta Optimize-CSDN博客
  • Python 梯度下降法(五):Adam Optimize-CSDN博客
  • Python 梯度下降法(六):Nadam Optimize-CSDN博客
  • Python 梯度下降法(七):Summary-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/11136.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【上篇】-分两篇步骤介绍-如何用topview生成和自定义数字人-关于AI的使用和应用-如何生成数字人-优雅草卓伊凡-如何生成AI数字人

【上篇】-分两篇步骤介绍-如何用topview生成和自定义数字人-关于AI的使用和应用-如何生成数字人-优雅草卓伊凡-如何生成AI数字人 背景 AI数字人有很多应用目前&#xff0c;本文做如何生成数字人&#xff0c;因为后续就连我们公司自己也会有很多关于AI数字人的使用&#xff0c…

MapReduce简单应用(一)——WordCount

目录 1. 执行过程1.1 分割1.2 Map1.3 Combine1.4 Reduce 2. 代码和结果2.1 pom.xml中依赖配置2.2 工具类util2.3 WordCount2.4 结果 参考 1. 执行过程 假设WordCount的两个输入文本text1.txt和text2.txt如下。 Hello World Bye WorldHello Hadoop Bye Hadoop1.1 分割 将每个文…

tensorboard的基本使用及案例

TensorBoard 是一个可视化工具&#xff0c;用于展示机器学习模型的训练过程和结果。以下是 TensorBoard 的基本使用方法及一些案例。 基本使用 安装 安装 TensorBoard&#xff1a; pip install tensorboard 如果使用 PyTorch&#xff0c;还需要安装 torch 和 torchvision&…

【ArcGIS遇上Python】批量提取多波段影像至单个波段

本案例基于ArcGIS python,将landsat影像的7个波段影像数据,批量提取至单个波段。 相关阅读:【ArcGIS微课1000例】0141:提取多波段影像中的单个波段 文章目录 一、数据准备二、效果比对二、python批处理1. 编写python代码2. 运行代码一、数据准备 实验数据及完整的python位…

吴恩达深度学习——超参数调试

内容来自https://www.bilibili.com/video/BV1FT4y1E74V&#xff0c;仅为本人学习所用。 文章目录 超参数调试调试选择范围 Batch归一化公式整合 Softmax 超参数调试 调试 目前学习的一些超参数有学习率 α \alpha α&#xff08;最重要&#xff09;、动量梯度下降法 β \bet…

Alibaba开发规范_编程规约之命名风格

文章目录 命名风格的基本原则1. 命名不能以下划线或美元符号开始或结束2. 严禁使用拼音与英文混合或直接使用中文3. 类名使用 UpperCamelCase 风格&#xff0c;但以下情形例外&#xff1a;DO / BO / DTO / VO / AO / PO / UID 等4. 方法名、参数名、成员变量、局部变量使用 low…

从0开始,来看看怎么去linux排查Java程序故障

一&#xff0c;前提准备 最基本前提&#xff1a;你需要有liunx环境&#xff0c;如果没有请参考其它文献在自己得到local建立一个虚拟机去进行测试。 有了虚拟机之后&#xff0c;你还需要安装jdk和配置环境变量 1. 安装JDK&#xff08;以OpenJDK 17为例&#xff09; 下载JDK…

智能园区管理系统助力企业安全与效率双提升的成功案例分析

内容概要 在当今迅速发展的商业环境中&#xff0c;企业面临着资产管理、风险控制和运营效率提高等多重挑战。为了应对这些挑战&#xff0c;智能园区管理系统应运而生&#xff0c;为企业提供了全新的解决方案。例如&#xff0c;快鲸智慧园区&#xff08;楼宇&#xff09;管理系…

nacos 配置管理、 配置热更新、 动态路由

文章目录 配置管理引入jar包添加 bootstrap.yaml 文件配置在application.yaml 中添加自定义信息nacos 配置信息 配置热更新采用第一种配置根据服务名确定配置文件根据后缀确定配置文件 动态路由DynamicRouteLoaderNacosConfigManagerRouteDefinitionWriter 路由配置 配置管理 …

Linux-CentOS的yum源

1、什么是yum yum是CentOS的软件仓库管理工具。 2、yum的仓库 2.1、yum的远程仓库源 2.1.1、国内仓库 国内较知名的网络源(aliyun源&#xff0c;163源&#xff0c;sohu源&#xff0c;知名大学开源镜像等) 阿里源:https://opsx.alibaba.com/mirror 网易源:http://mirrors.1…

16.[前端开发]Day16-HTML+CSS阶段练习(网易云音乐五)

完整代码 网易云-main-left-rank&#xff08;排行榜&#xff09; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name&q…

【ts + java】古玩系统开发总结

src别名的配置 开发中文件和文件的关系会比较复杂&#xff0c;我们需要给src文件夹一个别名吧 vite.config.js import { defineConfig } from vite import vue from vitejs/plugin-vue import path from path// https://vitejs.dev/config/ export default defineConfig({pl…

使用Pygame制作“俄罗斯方块”游戏

1. 前言 俄罗斯方块&#xff08;Tetris&#xff09; 是一款由方块下落、行消除等核心规则构成的经典益智游戏&#xff1a; 每次从屏幕顶部出现一个随机的方块&#xff08;由若干小方格组成&#xff09;&#xff0c;玩家可以左右移动或旋转该方块&#xff0c;让它合适地堆叠在…

小程序设计和开发:什么是竞品分析,如何进行竞品分析

一、竞品分析的定义 竞品分析是指对竞争对手的产品进行深入研究和比较&#xff0c;以了解市场动态、发现自身产品的优势和不足&#xff0c;并为产品的设计、开发和营销策略提供参考依据。在小程序设计和开发中&#xff0c;竞品分析可以帮助开发者了解同类型小程序的功能、用户体…

Vue简介

目录 Vue是什么&#xff1f;为什么要使用Vue&#xff1f;Vue的三种加载方式拓展&#xff1a;什么是渐进式框架&#xff1f; Vue是什么&#xff1f; Vue是一套用于构建用户界面的渐进式 JavaScript (主张最少)框架 &#xff0c;开发者只需关注视图层。另一方面&#xff0c;当与…

Linux多路转接poll

Linux多路转接poll 1. poll() poll() 结构包含了要监视的 event 和发生的 event &#xff0c;接口使用比 select() 更方便。且 poll 并没有最大数量限制&#xff08;但是数量过大后性能也是会下降&#xff09;。 2. poll() 的工作原理 poll() 不再需要像 select() 那样自行…

C++【深入底层,手撕vector】

vector是向量的意思&#xff0c;看了vector的底层实现之后&#xff0c;能够很明确的认识到它其实就是我们经常使用的顺序表。在我们的认知中&#xff0c;顺序表会有一个数组、数据的size以及容量的大小。vector作为一个向量容器&#xff0c;它可以存放任意类型的数据。所以在实…

基于FPGA的BT656编解码

概述 BT656全称为“ITU-R BT.656-4”或简称“BT656”,是一种用于数字视频传输的接口标准。它规定了数字视频信号的编码方式、传输格式以及接口电气特性。在物理层面上,BT656接口通常包含10根线(在某些应用中可能略有不同,但标准配置为10根)。这些线分别用于传输视频数据、…

关于系统重构实践的一些思考与总结

文章目录 一、前言二、系统重构的范式1.明确目标和背景2.兼容屏蔽对上层的影响3.设计灰度迁移方案3.1 灰度策略3.2 灰度过程设计3.2.1 case1 业务逻辑变更3.2.2 case2 底层数据变更&#xff08;数据平滑迁移&#xff09;3.2.3 case3 在途新旧流程兼容3.2.4 case4 接口变更3.2.5…

Microsoft Power BI:融合 AI 的文本分析

Microsoft Power BI 是微软推出的一款功能强大的商业智能工具&#xff0c;旨在帮助用户从各种数据源中提取、分析和可视化数据&#xff0c;以支持业务决策和洞察。以下是关于 Power BI 的深度介绍&#xff1a; 1. 核心功能与特点 Power BI 提供了全面的数据分析和可视化功能&…