Python 梯度下降法(二):RMSProp Optimize

文章目录

  • Python 梯度下降法(二):RMSProp Optimize
    • 一、数学原理
      • 1.1 介绍
      • 1.2 公式
    • 二、代码实现
      • 2.1 函数代码
      • 2.2 总代码
    • 三、代码优化
      • 3.1 存在问题
      • 3.2 收敛判断
      • 3.3 函数代码
      • 3.4 总代码
    • 四、优缺点
      • 4.1 优点
      • 4.2 缺点
    • 五、相关链接

Python 梯度下降法(二):RMSProp Optimize

一、数学原理

1.1 介绍

RMSProp(Root Mean Square Propagation)是一种自适应学习率优化算法,广泛用于深度学习中的梯度下降优化。它通过调整每个参数的学习率来解决传统梯度下降法中学习率固定的问题,从而加速收敛并提高性能。

RMSProp 的核心思想是对每个参数的学习率进行自适应调整。它通过维护一个指数加权移动平均(Exponential Moving Average, EMA)的梯度平方值来调整学习率:

  • 对于梯度较大的参数,降低其学习率。
  • 对于梯度较小的参数,增加其学习率。
    这种方法可以有效缓解梯度下降中的震荡问题,尤其是在非凸优化问题中。

1.2 公式

符号说明:

θ : 需要优化的参数向量 J ( θ ) : 损失函数 g t : 在第 t 次迭代时损失函数关于 θ 的梯度, ∇ θ J ( θ t ) ρ : 衰减率,常用值为 0.9 η : 学习率,需要手动设置 ϵ : 一个及小的参数,无限趋近于零,避免不会出现零 ( 1 0 − 8 ) s t : 指数加权移动平均 \begin{array}{l} \theta&:需要优化的参数向量 \\ J(\theta)&: 损失函数 \\ g_{t}&:在第t次迭代时损失函数关于\theta的梯度,\nabla_{\theta}J(\theta_{t}) \\ \rho &: 衰减率,常用值为0.9\\ \eta&:学习率,需要手动设置 \\ \epsilon&: 一个及小的参数,无限趋近于零,避免不会出现零(10^{-8}) \\ s_{t}&:指数加权移动平均 \end{array} θJ(θ)gtρηϵst:需要优化的参数向量:损失函数:在第t次迭代时损失函数关于θ的梯度,θJ(θt):衰减率,常用值为0.9:学习率,需要手动设置:一个及小的参数,无限趋近于零,避免不会出现零(108):指数加权移动平均

  1. 初始化参数为: θ 0 , s 0 = 0 \theta_{0},s_{0}=0 θ0,s0=0

  2. 迭代更新,每次迭代 t t t中,更新指数加权移动平均:
    s t = ρ s t − 1 + ( 1 − ρ ) g t ⊙ g t s_{t}=\rho s_{t-1}+(1-\rho)g_{t}\odot g_{t} st=ρst1+(1ρ)gtgt s t s_{t} st可以理解为对梯度平方的一个平滑估计,它更关注近期的梯度信息, ρ \rho ρ控制了历史信息的衰减程度。

  3. 计算自适应学习率,公式为 η s t + ϵ \frac{\eta}{\sqrt{ s_{t}+\epsilon }} st+ϵ η,其中,分母 s t + ϵ \sqrt{ s_{t}+\epsilon} st+ϵ 起到了归一化梯度的作用,使得学习率可以更具梯度的尺寸进行自适应的调整。

  4. 更新参数: θ t + 1 = θ t − η s t + ϵ g t \theta_{t+1}=\theta_{t}- \frac{\eta}{\sqrt{ s_{t}+\epsilon }}g_{t} θt+1=θtst+ϵ ηgt

二、代码实现

2.1 函数代码

RMSProp优化算法实现:

# 定义RMSProp优化算法
def rmsprop_optimizer(X, y, eta, num_iter=1000, epsilon=1e-5, rho=0.9):"""X: 数据 x  mxny: 数据 y  nx1eta: 学习率  num_iter: 迭代次数epsilon: 无穷小rho: 衰减率"""m, n = X.shape theta, s = np.zeros((n, 1)), np.zeros((n, 1))  # 初始参数 nx1,以及指数加权移动平均 nx1loss_ = []  # 存储损失率的变化,便于绘图for _ in range(num_iter):# 计算预测值h = np.dot(X, theta)# 计算误差error = h - yloss_.append(np.mean(error**2) / 2)# 计算梯度gradient = (1/m) * np.dot(X.T, error)s = rho * s + (1 - rho) * np.pow(gradient, 2)  # 利用广播机制来进行运算每个维度上的平滑估计theta = theta - np.multiply(eta / np.sqrt(s + epsilon), gradient)  # Hadamar productreturn theta, loss_

2.2 总代码

import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 定义RMSProp优化算法
def rmsprop_optimizer(X, y, eta, num_iter=1000, epsilon=1e-5, rho=0.9):"""X: 数据 x  mxny: 数据 y  nx1eta: 学习率  num_iter: 迭代次数epsilon: 无穷小rho: 衰减率"""m, n = X.shape theta, s = np.zeros((n, 1)), np.zeros((n, 1))  # 初始参数 nx1,以及指数加权移动平均 nx1loss_ = []  # 存储损失率的变化,便于绘图for _ in range(num_iter):# 计算预测值h = np.dot(X, theta)# 计算误差error = h - yloss_.append(np.mean(error**2) / 2)# 计算梯度gradient = (1/m) * np.dot(X.T, error)s = rho * s + (1 - rho) * np.pow(gradient, 2)  # 利用广播机制来进行运算每个维度上的平滑估计theta = theta - np.multiply(eta / np.sqrt(s + epsilon), gradient)  # Hadamar productreturn theta, loss_# 生成一些示例数据
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)
# 添加偏置项
X_b = np.c_[np.ones((100, 1)), X]# 设置超参数
eta = 0.1# RMSProp优化算法
theta, loss_ = rmsprop_optimizer(X_b, y, eta)print("最优参数 theta:")
print(theta)
plt.plot(range(len(loss_)), loss_, label="损失函数图像")
plt.title("损失函数图像")
plt.xlabel("迭代次数")
plt.ylabel("损失值")
plt.show()

1738222496_l1dhye0o4j.png1738222495967.png

可以发现,其对于损失值的下降性能也较好,损失率也较为稳定。

三、代码优化

3.1 存在问题

  • 未使用小批量数据:该代码在每次迭代时使用了全部的训练数据 Xy 来计算梯度,这相当于批量梯度下降的方式。在处理大规模数据集时,这种方式可能会导致计算效率低下,并且可能会陷入局部最优解。可以参考之前小批量梯度下降的代码,引入小批量数据的处理,以提高算法的效率和泛化能力。
  • 缺乏数据预处理:在实际应用中,输入数据 X 可能需要进行预处理,例如归一化或标准化,以确保不同特征具有相似的尺度,从而加快算法的收敛速度。(这里不进行解决,参考特征缩放:数据归一化-CSDN博客)
  • 缺乏收敛判断:代码只是简单地进行了固定次数的迭代,没有设置收敛条件。在实际应用中,可以添加收敛判断,例如当损失值的变化小于某个阈值时提前停止迭代,以节省计算资源。

这里引入Mini-batch Gradient Descent,以及收敛判断,减少计算资源

3.2 收敛判断

# 收敛判断,设定阈值,进行收敛判断
# 满足条件即停止,减少系统资源的使用
if len(loss_) > 1 and abs(loss_[-1] - loss_[-2]) < threshold:print(f"Converged at iteration {_ + 1}")break  # 注意,这里不能使用return

3.3 函数代码

# 定义RMSProp优化算法
def rmsprop_optimizer(X, y, eta, batch_size=32, num_iter=1000, epsilon=1e-5, rho=0.9, threshold=1e-3):"""X: 数据 x  mxn,可以在传入数据之前进行数据的归一化y: 数据 y  nx1eta: 学习率  batch_size: 批量数据的大小num_iter: 迭代次数epsilon: 无穷小rho: 衰减率threshold: 收敛阈值"""m, n = X.shape theta, s = np.zeros((n, 1)), np.zeros((n, 1))  # 初始参数 nx1,以及指数加权移动平均 nx1loss_ = []  # 存储损失率的变化,便于绘图num_batchs = m // batch_sizefor _ in range(num_iter):# 打乱数据集shuffled_indices = np.random.permutation(m)X_shuffled = X[shuffled_indices]y_shuffled = y[shuffled_indices]loss_temp = []  # 存储每次小批量样本生成的值for batch in range(num_batchs):# 选取小批量样本start_index = batch * batch_sizeend_index = start_index + batch_sizexi = X_shuffled[start_index:end_index]yi = y_shuffled[start_index:end_index]# 计算预测值h = np.dot(xi, theta)# 计算误差error = h - yiloss_temp.append(np.mean(error**2) / 2)# 计算梯度gradient = (1/m) * np.dot(xi.T, error)s = rho * s + (1 - rho) * np.pow(gradient, 2)  # 利用广播机制来进行运算每个维度上的平滑估计theta = theta - np.multiply(eta / np.sqrt(s + epsilon), gradient)  # Hadamar productloss_.append(np.mean(loss_temp))# 收敛判断if len(loss_) > 1 and abs(loss_[-1] - loss_[-2]) < threshold:print(f"Converged at iteration {_ + 1}")breakreturn theta, loss_

3.4 总代码

import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 定义RMSProp优化算法
def rmsprop_optimizer(X, y, eta, batch_size=32, num_iter=1000, epsilon=1e-5, rho=0.9, threshold=1e-3):"""X: 数据 x  mxn,可以在传入数据之前进行数据的归一化y: 数据 y  nx1eta: 学习率  batch_size: 批量数据的大小num_iter: 迭代次数epsilon: 无穷小rho: 衰减率threshold: 收敛阈值"""m, n = X.shape theta, s = np.zeros((n, 1)), np.zeros((n, 1))  # 初始参数 nx1,以及指数加权移动平均 nx1loss_ = []  # 存储损失率的变化,便于绘图num_batchs = m // batch_sizefor _ in range(num_iter):# 打乱数据集shuffled_indices = np.random.permutation(m)X_shuffled = X[shuffled_indices]y_shuffled = y[shuffled_indices]loss_temp = []  # 存储每次小批量样本生成的值for batch in range(num_batchs):# 选取小批量样本start_index = batch * batch_sizeend_index = start_index + batch_sizexi = X_shuffled[start_index:end_index]yi = y_shuffled[start_index:end_index]# 计算预测值h = np.dot(xi, theta)# 计算误差error = h - yiloss_temp.append(np.mean(error**2) / 2)# 计算梯度gradient = (1/m) * np.dot(xi.T, error)s = rho * s + (1 - rho) * np.pow(gradient, 2)  # 利用广播机制来进行运算每个维度上的平滑估计theta = theta - np.multiply(eta / np.sqrt(s + epsilon), gradient)  # Hadamar productloss_.append(np.mean(loss_temp))  # 使用平均值作为参考# 收敛判断if len(loss_) > 1 and abs(loss_[-1] - loss_[-2]) < threshold:print(f"Converged at iteration {_ + 1}")breakreturn theta, loss_# 生成一些示例数据
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)
# 添加偏置项
X_b = np.c_[np.ones((100, 1)), X]# 设置超参数
eta = 0.1# RMSProp优化算法
theta, loss_ = rmsprop_optimizer(X_b, y, eta)print("最优参数 theta:")
print(theta)
plt.plot(range(len(loss_)), loss_, label="损失函数图像")
plt.title("损失函数图像")
plt.xlabel("迭代次数")
plt.ylabel("损失值")
plt.show()

1738223819_x4hrij0zgr.png1738223818170.png

四、优缺点

4.1 优点

  • 自适应学习率:RMSProp 能够根据参数的梯度变化情况自适应地调整学习率。对于梯度较大的参数,学习率会自动减小;对于梯度较小的参数,学习率会相对增大。这使得算法在处理不同尺度的梯度时更加稳定,有助于加快收敛速度。
  • 缓解 Adagrad 学习率衰减过快问题:与 Adagrad 算法不同,RMSProp 使用指数加权移动平均来计算梯度平方的累积值,避免了 Adagrad 中学习率单调递减且后期学习率过小的问题,使得算法在训练后期仍然能够继续更新参数。

4.2 缺点

  • 对超参数敏感:RMSProp 的性能依赖于超参数 η \eta η ρ \rho ρ的选择。如果超参数设置不当,可能会导致算法收敛速度慢或者无法收敛到最优解。
  • 可能陷入局部最优:和其他基于梯度的优化算法一样,RMSProp 仍然有可能陷入局部最优解,尤其是在损失函数具有复杂的地形时。

五、相关链接

Python 梯度下降法合集:

  • Python 梯度下降法(一):Gradient Descent-CSDN博客
  • Python 梯度下降法(二):RMSProp Optimize-CSDN博客
  • Python 梯度下降法(三):Adagrad Optimize-CSDN博客
  • Python 梯度下降法(四):Adadelta Optimize-CSDN博客
  • Python 梯度下降法(五):Adam Optimize-CSDN博客
  • Python 梯度下降法(六):Nadam Optimize-CSDN博客
  • Python 梯度下降法(七):Summary-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/10873.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【2025年更新】1000个大数据/人工智能毕设选题推荐

文章目录 前言大数据/人工智能毕设选题&#xff1a;后记 前言 正值毕业季我看到很多同学都在为自己的毕业设计发愁 Maynor在网上搜集了1000个大数据的毕设选题&#xff0c;希望对大家有帮助&#xff5e; 适合大数据毕业设计的项目&#xff0c;完全可以作为本科生当前较新的毕…

three.js+WebGL踩坑经验合集(6.2):负缩放,负定矩阵和行列式的关系(3D版本)

本篇将紧接上篇的2D版本对3D版的负缩放矩阵进行解读。 (6.1):负缩放&#xff0c;负定矩阵和行列式的关系&#xff08;2D版本&#xff09; 既然three.js对3D版的负缩放也使用行列式进行判断&#xff0c;那么&#xff0c;2D版的结论用到3D上其实是没毛病的&#xff0c;THREE.Li…

反向代理模块jmh

1 概念 1.1 反向代理概念 反向代理是指以代理服务器来接收客户端的请求&#xff0c;然后将请求转发给内部网络上的服务器&#xff0c;将从服务器上得到的结果返回给客户端&#xff0c;此时代理服务器对外表现为一个反向代理服务器。 对于客户端来说&#xff0c;反向代理就相当…

软件工程经济学-日常作业+大作业

目录 一、作业1 作业内容 解答 二、作业2 作业内容 解答 三、作业3 作业内容 解答 四、大作业 作业内容 解答 1.建立层次结构模型 (1)目标层 (2)准则层 (3)方案层 2.构造判断矩阵 (1)准则层判断矩阵 (2)方案层判断矩阵 3.层次单排序及其一致性检验 代码 …

【回溯】目标和 字母大小全排列

文章目录 494. 目标和解题思路&#xff1a;回溯784. 字母大小写全排列解题思路&#xff1a;回溯 494. 目标和 494. 目标和 给你一个非负整数数组 nums 和一个整数 target 。 向数组中的每个整数前添加 或 - &#xff0c;然后串联起所有整数&#xff0c;可以构造一个 表达式…

告别复杂,拥抱简洁:用plusDays(7)代替plus(7, ChronoUnit.DAYS)

前言 你知道吗?有时候代码里的一些小细节看起来很简单,却可能成为你调试时的大麻烦。在 Java 中,我们用 LocalDateTime 进行日期和时间的操作时,发现一个小小的替代方法可以让代码更简洁,功能更强大。这不,今天我们就来探讨如何用 LocalDateTime.now().plusDays(7) 替代…

《苍穹外卖》项目学习记录-Day10订单状态定时处理

利用Cron表达式生成器生成Cron表达式 1.处理超时订单 查询订单表把超时的订单查询出来&#xff0c;也就是订单的状态为待付款&#xff0c;下单的时间已经超过了15分钟。 //select * from orders where status ? and order_time < (当前时间 - 15分钟) 遍历集合把数据库…

【深度分析】微软全球裁员计划不影响印度地区,将继续增加当地就业机会

当微软的裁员刀锋掠过全球办公室时&#xff0c;班加罗尔的键盘声却愈发密集——这场资本迁徙背后&#xff0c;藏着数字殖民时代最锋利的生存法则。 表面是跨国公司的区域战略调整&#xff0c;实则是全球人才市场的地壳运动。微软一边在硅谷裁撤年薪20万美金的高级工程师&#x…

Linux中 端口被占用如何解决

lsof命令查找 查找被占用端口 lsof -i :端口号 #示例 lsof -i :8080 lsof -i :3306 netstat命令查找 查找被占用端口 netstat -tuln | grep 端口号 #示例 netstat -tuln | grep 3306 netstat -tuln | grep 6379 ss命令查找 查找被占用端口 ss -tunlp | grep 端口号 #示例…

qt-Quick3D笔记之官方例程Runtimeloader Example运行笔记

qt-Quick3D笔记之官方例程Runtimeloader Example运行笔记 文章目录 qt-Quick3D笔记之官方例程Runtimeloader Example运行笔记1.例程运行效果2.例程缩略图3.项目文件列表4.main.qml5.main.cpp6.CMakeLists.txt 1.例程运行效果 运行该项目需要自己准备一个模型文件 2.例程缩略图…

高性能消息队列Disruptor

定义一个事件模型 之后创建一个java类来使用这个数据模型。 /* <h1>事件模型工程类&#xff0c;用于生产事件消息</h1> */ no usages public class EventMessageFactory implements EventFactory<EventMessage> { Overridepublic EventMessage newInstance(…

Spring Boot项目如何使用MyBatis实现分页查询

写在前面&#xff1a;大家好&#xff01;我是晴空๓。如果博客中有不足或者的错误的地方欢迎在评论区或者私信我指正&#xff0c;感谢大家的不吝赐教。我的唯一博客更新地址是&#xff1a;https://ac-fun.blog.csdn.net/。非常感谢大家的支持。一起加油&#xff0c;冲鸭&#x…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.27 线性代数王国:矩阵分解实战指南

1.27 线性代数王国&#xff1a;矩阵分解实战指南 #mermaid-svg-JWrp2JAP9qkdS2A7 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-JWrp2JAP9qkdS2A7 .error-icon{fill:#552222;}#mermaid-svg-JWrp2JAP9qkdS2A7 .erro…

EasyExcel使用详解

文章目录 EasyExcel使用详解一、引言二、环境准备与基础配置1、添加依赖2、定义实体类 三、Excel 读取详解1、基础读取2、自定义监听器3、多 Sheet 处理 四、Excel 写入详解1、基础写入2、动态列与复杂表头3、样式与模板填充 五、总结 EasyExcel使用详解 一、引言 EasyExcel 是…

FIDL:Flutter与原生通讯的新姿势,不局限于基础数据类型

void initUser(User user); } 2、执行命令./gradlew assembleDebug&#xff0c;生成IUserServiceStub类和fidl.json文件 3、打开通道&#xff0c;向Flutter公开方法 FidlChannel.openChannel(getFlutterEngine().getDartExecutor(), new IUserServiceStub() { Override void…

DIFY源码解析

偶然发现Github上某位大佬开源的DIFY源码注释和解析&#xff0c;目前还处于陆续不断更新地更新过程中&#xff0c;为大佬的专业和开源贡献精神点赞。先收藏链接&#xff0c;后续慢慢学习。 相关链接如下&#xff1a; DIFY源码解析

87.(3)攻防世界 web simple_php

之前做过&#xff0c;回顾 12&#xff0c;攻防世界simple_php-CSDN博客 进入靶场 <?php // 显示当前 PHP 文件的源代码&#xff0c;方便调试或查看代码结构 // __FILE__ 是 PHP 的一个魔术常量&#xff0c;代表当前文件的完整路径和文件名 show_source(__FILE__);// 包含…

x86-64数据传输指令

关于汇编语言一些基础概念的更详细的介绍&#xff0c;可移步MIPS指令集&#xff08;一&#xff09;基本操作_mips指令 sw-CSDN博客 该指令集中一个字2字节。 该架构有16个64位寄存器&#xff0c;名字都以%r开头&#xff0c;每个寄存器的最低位字节&#xff0c;低1~2位字节&…

网络工程师 (8)存储管理

一、页式存储基本原理 &#xff08;一&#xff09;内存划分 页式存储首先将内存物理空间划分成大小相等的存储块&#xff0c;这些块通常被称为“页帧”或“物理页”。每个页帧的大小是固定的&#xff0c;例如常见的页帧大小有4KB、8KB等&#xff0c;这个大小由操作系统决定。同…

全程Kali linux---CTFshow misc入门(25-37)

第二十五题&#xff1a; 提示&#xff1a;flag在图片下面。 直接检查CRC&#xff0c;检测到错误&#xff0c;就直接暴力破解。 暴力破解CRC的python代码。 import binascii import struct def brute_force_ihdr_crc(filename): # 读取文件二进制数据 with open(filen…