梯度下降算法优化—随机梯度下降、小批次、动量、Adagrad等方法pytorch实现

现有不足

现有调整网络的方法是借助成本函数的梯度下降方法,也就是给函数作切线,不断逼近最优点,即成本函数为零的点。
梯度下降的一般公式为:
梯度下降公式
即根据每个节点成本函数的梯度进行更新,使用该方法有一些问题:
1,计算量大,耗时长:我们的训练数据往往是成千上万的,每条都反向传播,计算梯度再调整参数,这等计算量就算是计算机也吃不消,更何况现在是大数据的时代,耗费的时间更是要呈指数上升。
2,易掉进局部最优的陷阱:根据梯度下降,我们找到的往往是一个极值点,而非最值点,如何找到方法跳出局部最优的陷阱而找到最优解也是目前的一个不足。

传统梯度下降使用pytorch实现的一般思路是:
1,获取数据
2,定义损失函数
3,定义优化器
4,计算损失,并反向传播计算梯度
5,更新模型参数

定义一个最简单的线性模型y=w*x+b,损失函数为预测值和实际的差,训练模型的具体代码如下:

import torch
import matplotlib.pyplot as plt
import timestart_time=time.time()
# 定义训练数据
X = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
Y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)# 初始化模型参数
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)# 定义损失函数
loss_fn = torch.nn.MSELoss()# 定义优化器
optimizer = torch.optim.SGD([w, b], lr=0.01)
epochs = 100# 批量梯度下降
for epoch in range(epochs):# 前向传播Y_pred = w * X + b# 计算损失loss = loss_fn(Y_pred, Y)# 反向传播loss.backward()# 更新参数optimizer.step()# 清空梯度optimizer.zero_grad()# 输出结果
print(f"w = {w.item()}, b = {b.item()}")# 时间统计
total_time=time.time()-start_time
print(f"time = {total_time}s")
# 绘制拟合直线
plt.scatter(X.numpy(), Y.numpy())
plt.plot(X.numpy(), (w * X + b).detach().numpy(), 'r')
plt.show()

输出结果如图:
批梯度下降输出结果

优化算法

上述问题的解决方案主要有两种:
1,通过数学或工程方法减少计算量。
2,优化路径或步长,使模型在获得最优解前尽量不走弯路(其实本质也是减少计算量)。

随机梯度下降—SGD

与传统批量梯度下降不同,随机梯度下降只选择一个节点的数据进行更新模型参数,该方法相比批量梯度下降不那么准确,但在循环过程中大致是朝着最优方向前进的,但相比批量梯度下降方法,大大提升了训练效率,算法是空间和时间的平衡,那我们现在就是准确性和时间的平衡。

在代码实现中,我们只需要在模型训练中选择一个数据参与更新参数过程即可,通过设置数据量,在模型训练中加载计算,区别的详细代码如下:

# 设置数据量为1
batch_size = 1
epochs = 100# 随机梯度下降
for epoch in range(epochs):# 创建DataLoader,选择数据集中的一个loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X, Y), batch_size=batch_size, shuffle=True)for x_batch, y_batch in loader:# 前向传播y_pred = w * x_batch + b# 计算损失loss = loss_fn(y_pred, y_batch)

该方法虽然减少了计算量,但因为其只选择一个数据作为更新依据,形成了一种对数据敏感的模型,在下降过程中易反复震荡,增加步数,但在震荡过程中造成其收敛方向的不确定性,一定程度上也能缓解算法陷入局部极值的陷阱。

小批次梯度下降—Mini-batch

该方法介于批量梯度下降和随机梯度下降之间,即每次选择数据集中的一部分作为更新模型的依据,该方法收敛速度快于批量梯度下降,收敛性波动小于随机梯度下降,代码实现中只需更改batch_size的大小即可。

动量梯度下降—MGD

批梯度下降每次参数的更新仅与上一次的梯度有关,而梯度往往是一种趋势,为了利用这种趋势加快步长(类似人在下坡时步子会变大),结合梯度的历史数据创造了动量梯度下降,实现在相关方向上的加速,并一定程度上抑制抖动的效果。

动量公式如下:
动量公式
参数更新公式
其中β代表历史数据的占比大小,在编码中我们可以通过设置momentum参数实现动量梯度下降,修改的代码如下:

momentum = 0.5
# 优化器中设置动量大小
optimizer = torch.optim.SGD([w, b], lr=0.01, momentum=momentum)

在该实验中我们可以推断,w越接近2,b越接近0说明训练效果越好,从该实验结果可以看出动量可以加速训练过程。
训练结果
且有了动量法,如果参数设置的够大,一定程度上我们也能跳过一些小的极值点,避免了陷入局部最小的陷阱中。

AdaGrad自适应梯度下降法

该方法针对学习力α,实现步长的自动调整,其学习率的更新公式为:自适应梯度下降
其中w表示特征梯度,epsilon表示很小的一个数,用于防止分母为零,即通过梯度平方和来调整学习率,使用该方法可以实现梯度大的位置减小学习速率,梯度小的位置增大学习率,而加了求和公式,使得其在稀疏特征中表现更好。

编码中,我们可以使用optimizer = torch.optim.Adagrad([w, b], lr=0.1)将优化器设置为Adagrad实现自适应梯度下降。

但该方法由于其积累平方和,必然导致后期学习率变小,可能造成难以收敛的后果。

RMSProp

为了解决AdamGrad算法后期学习率太小的问题,RMSProp 通过引入一个衰减系数来解决这个问题,使得历史信息能够指数级衰减,将算法A中梯度的求和改为衰减系数
该方法在编码中可使用optimizer = torch.optim.RMSprop([w, b], lr=0.01,weight_decay=0.9)设置RMSProp优化器并指定衰减系数。

Adam

该算法结合了动量法和RMSProp的思想,梯度上结合动量,而学习率结合RMSProp,具体公式如下:
在这里插入图片描述
最终的更新公式为:
在这里插入图片描述
该方法同时考虑了基于自身变化趋势的迭代更新,和学习率根据梯度的变化β1β2分别控制这两步的权重,项目中可以通过optimizer = torch.optim.Adam([w, b], lr=0.01)设置优化器实现,该算法在当前阶段可以说基本完美,只是执行过程中一定程度依赖于衰减系数,并且因为存储多步计算,空间上占用可能稍多,但可以说根本不是问题。

总结

本文介绍了几种梯度下降优化算法,包括其原理和代码实现,但并不一定说最好的算法就一定产生最好的结果,每种算法都有其适应的领域,在实际中还是要具体问题具体分析。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/450677.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

探索OpenCV的人脸检测:用Haar特征分类器识别图片中的人脸

目录 简介 OpenCV和Haar特征分类器 实现人脸检测 1. 导入所需库 2. 加载图片和Haar特征分类器 3. 检测人脸 4. 标注人脸 5. 显示 6、结果展示 结论 简介 在计算机视觉和图像处理领域,人脸识别是一项重要的技术。它不仅应用于安全监控、人机交互&#xff0…

10秒钟用Midjourney画出国风味的变形金刚

上魔咒 Optimus Prime comes from the movie Transformers, Chinese style, Wu ShanMing, Ink Painting Halo Dyeing, Conceptual of the Digita Art, MasterComposition, Romantic Ancient Style, Inspired by traditional patterns and symbols, Minimalism, do not con…

day01 -- MybatisPlus

1. MybatisPlus简介 有基础的同学可结合资源中的代码一起看 MyBatis 的增强工具,在 MyBatis 的基础上只做增强不做改变,为简化开发、提高效率而生 特性 通用的 CRUD 操作:内置通用 Mapper、通用 Service,仅仅通过少量配置即可实…

私有化部署大模型最佳解决方案 Ollama (8B)模型

私有化部署大模型Ollama 为什么需要私有化部署大模型一、Ollama本地部署Llama3大模型二、Langchain4j调用Ollama本地部署模型API三、Ollama本地部署nomic向量模型四、Spring AI调用Ollama本地部署模型API 为什么需要私有化部署大模型 企业考虑成本和数据隐私问题,会…

021_Thermal_Transient_in_Matlab统一偏微分框架之热传导问题

Matlab求解有限元专题系列 固体热传导方程 固体热传导的方程为: ρ C p ( ∂ T ∂ t u t r a n s ⋅ ∇ T ) ∇ ⋅ ( q q r ) − α T d S d t Q \rho C_p \left( \frac{\partial T}{\partial t} \mathbf{u}_{\mathtt{trans}} \cdot \nabla T \right) \nab…

BM算法(手算版)

BM 算法 BM 算法是一种字符串匹配的算法。 与 KMP 相比&#xff0c;BM 算法不扫描全部输入字符&#xff0c;平均匹配时间 c・n, 常量 c <1 (随机或真实文本), 但最坏情况是 O (n・m). 可以将 BM 算法的最坏情况改进到 O (n)&#xff1a;通过记录文本后缀中最…

计算机系统简介

一、计算机的软硬件概念 1.硬件&#xff1a;计算机的实体&#xff0c;如主机、外设、硬盘、显卡等。 2.软件&#xff1a;由具有各类特殊功能的信息&#xff08;程序&#xff09;组成。 系统软件&#xff1a;用来管理整个计算机系统&#xff0c;如语言处理程序、操作系统、服…

群晖前面加了雷池社区版,安装失败,然后无法识别出用户真实访问IP

有nas的相信对公网都不模式&#xff0c;在现在基础上传带宽能有100兆的时代&#xff0c;有公网代表着家里有一个小服务器&#xff0c;像百度网盘&#xff0c;优酷这种在线服务都能部署为私有化服务。但现在运营商几乎不可能提供公网ip&#xff0c;要么自己买个云服务器做内网穿…

通过github创建自己网页链接的方法

文章目录 要使用GitHub创建静态网页链接&#xff0c;可以按照以下详细步骤进行操作&#xff1a;一、准备阶段二、创建仓库并配置三、准备并上传静态网站文件四、配置GitHub Pages五、访问和更新你的静态网页 要使用GitHub创建静态网页链接&#xff0c;可以按照以下详细步骤进行…

uniapp微信小程序调用百度OCR

uniapp编写微信小程序调用百度OCR 公司有一个识别行驶证需求&#xff0c;调用百度ocr识别 使用了image-tools这个插件&#xff0c;因为百度ocr接口用图片的base64 这里只是简单演示&#xff0c;accesstoken获取接口还是要放在服务器端&#xff0c;不然就暴露了自己的百度项目k…

Xshell使用密钥远程登录Ubuntu 22.04报错:所选的用户密钥未在远程主机上注册。请再试一次

报错截图如下&#xff1a; 问题原因&#xff1a; Ubuntu 22.04 不支持 Xshell使用的私钥。 查看系统支持的私钥&#xff1a;sudo sshd -T | egrep "pubkey" ~$ sudo sshd -T | egrep "pubkey" pubkeyauthentication yes pubkeyacceptedalgorithms ssh-ed…

基于SpringBoot+Vue的旅游服务平台【提供源码+答辩PPT+参考文档+项目部署】

&#x1f4a5; ① 前言&#xff1a;这两年毕业设计和毕业答辩的要求和难度不断提升&#xff0c;传统的JavaWeb项目缺少创新和亮点&#xff0c;往往达不到毕业答辩的要求&#xff01; ❗② 如何解决这类问题&#xff1f; 让我们能够顺利通过毕业&#xff0c;我也一直在不断思考、…

ROS 的 urdf 中 link 和 joint 的子标签中 origin 的含义

主要参考文章——主要文章&#xff0c;官方关于urdf的介绍和官方文档的翻译解析 link标签里面的origin含义 link标签里面有三个主要的子标签&#xff0c;分别是visual——连杆的外观和坐标系&#xff0c;collisoin——连杆的碰撞属性和inertial——连杆的惯性设置 首先&…

【AIGC】AI如何匹配RAG知识库: Embedding实践,语义搜索

引言 RAG作为减少模型幻觉和让模型分析、回答私域相关知识最简单高效的方式&#xff0c;我们除了使用之外可以尝试了解其是如何实现的。在实现RAG的过程中Embedding是非常重要的手段。本文将带你简单地了解AI工具都是如何通过Embedding去完成语义分析匹配的。 Embedding技术简…

低空经济发展迅猛,无人机设计制造技术详解

低空经济的迅猛发展&#xff0c;为无人机设计制造技术带来了新的机遇和挑战。无人机作为低空经济中的重要组成部分&#xff0c;其设计制造技术直接关系到无人机的性能、安全性和应用场景的拓展。以下是对无人机设计制造技术的详细解析&#xff1a; 一、无人机设计技术 1. 气动…

【HTML + CSS 魔法秀】打造惊艳 3D 旋转卡片

HTML结构 box 类是整个组件的容器。item-wrap 类是每个旋转卡片的包装器&#xff0c;每个都有一个内联样式–i&#xff0c;用于控制动画的延迟。item类是实际的卡片内容&#xff0c;包含一个图片。 <template><div class"box"><div class"item…

STM32L010F4 最小系统设计

画一个 STM32L010F4 的测试板子...... by 矜辰所致前言 最近需要用到一个新的 MCU&#xff1a; STM32L010F4 &#xff0c;上次测试的 VL53L0X 需要移植到这个芯片上&#xff0c;网上一搜 STM32L010F4&#xff0c;都是介绍资料&#xff0c;没有最小系统&#xff0c;使用说明等。…

计算生物学与生物信息学漫谈-1-测序一路走来

最近工作中&#xff0c;反思自己计算生物学基础非常薄弱&#xff0c;然而作为一门非常新兴的交叉学科&#xff0c;涉及计算机、物理、生物、数学等多多学科&#xff0c;国内并没有这样完善的教程&#xff0c;因此想要自己做一个教程&#xff0c;使用费曼学习法学习&#xff0c;…

探讨淘宝商品 API 接口:运用及收益

在当今电子商务蓬勃发展的时代&#xff0c;淘宝作为全球领先的电商平台&#xff0c;拥有海量的商品资源和庞大的用户群体。而淘宝商品 API 接口的出现&#xff0c;为开发者和企业提供了一种强大的工具&#xff0c;能够实现对淘宝商品数据的高效获取和利用。本文将深入探讨淘宝商…

C语言 | Leetcode C语言题解之第492题构造矩形

题目&#xff1a; 题解&#xff1a; class Solution { public:vector<int> constructRectangle(int area) {int w sqrt(1.0 * area);while (area % w) {--w;}return {area / w, w};} };