深入探讨梯度下降:优化机器学习的关键步骤(一)

文章目录

  • 🍀引言
  • 🍀什么是梯度下降?
  • 🍀损失函数
  • 🍀梯度(gradient)
  • 🍀梯度下降的工作原理
  • 🍀梯度下降的变种
    • 🍀随机梯度下降(SGD)
    • 🍀批量梯度下降(BGD)
    • 🍀小批量梯度下降(Mini-Batch GD)
  • 🍀如何选择学习率?
  • 🍀梯度下降的相关数学公式
  • 🍀梯度下降的实现(代码)
  • 🍀总结

🍀引言

在机器学习领域,梯度下降是一种核心的优化算法,它被广泛应用于训练神经网络、线性回归和其他机器学习模型中。本文将深入探讨梯度下降的工作原理,并且进行简单的代码实现


🍀什么是梯度下降?

梯度下降是一种迭代优化算法,旨在寻找函数的局部最小值(或最大值)以最小化(或最大化)一个损失函数。在机器学习中,我们通常使用梯度下降来最小化模型的损失函数,以便训练模型的参数。
这里顺便提一嘴,与梯度下降齐名的梯度上升算法目的是使效用函数最大。


🍀损失函数

在使用梯度下降之前,我们首先需要定义一个损失函数。损失函数是一个用于衡量模型预测值与实际观测值之间差异的函数。通常,我们使用均方误差(MSE)作为回归问题的损失函数,使用交叉熵作为分类问题的损失函数。


🍀梯度(gradient)

梯度是损失函数相对于模型参数的偏导数。它告诉我们如果稍微调整模型参数,损失函数会如何变化。梯度下降算法利用梯度的信息来不断调整参数,以减小损失函数的值。

🍀梯度下降的工作原理

梯度下降的核心思想是沿着损失函数的负梯度方向调整参数,直到达到损失函数的局部最小值。具体来说,梯度下降的步骤如下:

  • 初始化模型参数:首先,随机初始化模型参数或使用某种启发式方法。

  • 计算损失和梯度:使用当前模型参数计算损失函数的值,并计算损失函数相对于参数的梯度。

  • 参数更新:根据梯度的方向和学习率(learning rate)本文我称其为eta,更新模型参数。学习率是一个控制步长大小的超参数,它决定了每次迭代中参数更新的大小。

  • 重复迭代:重复步骤2和3,直到损失函数的值收敛到一个稳定的值,或达到预定的迭代次数。

🍀梯度下降的变种

在梯度下降的基础上,发展出了多种变种算法,以应对不同的问题和挑战。其中一些常见的包括

🍀随机梯度下降(SGD)

随机梯度下降每次只使用一个随机样本来估计梯度,从而加速收敛速度。它特别适用于大规模数据集和在线学习。

🍀批量梯度下降(BGD)

批量梯度下降在每次迭代中使用整个训练数据集来计算梯度。尽管计算开销较大,但通常能够更稳定地收敛到全局最小值。

🍀小批量梯度下降(Mini-Batch GD)

小批量梯度下降综合了SGD和BGD的优点,它使用一个小批量样本来估计梯度,平衡了计算效率和收敛性能。

🍀如何选择学习率?

学习率是梯度下降的关键超参数之一。选择合适的学习率可以加速收敛,但过大的学习率可能导致不稳定的训练过程。通常,我们可以采用以下方法选择学习率:

  • 网格搜索:尝试不同的学习率值,通过验证集的性能来选择最佳值。

  • 学习率衰减:开始时使用较大的学习率,随着训练的进行逐渐减小学习率。

  • 自适应学习率:使用自适应学习率算法,如Adam、Adagrad或RMSprop,它们可以自动调整学习率以适应梯度的变化。

🍀梯度下降的相关数学公式

本人数学不好,这里有说的不清楚的地方还请见谅,谢谢佬~
首先我们通过图像认识一下损失函数
在这里插入图片描述
这里的步长指的是,可能有些人会好奇为啥有一个负号呢?因为对称轴左侧的导数都是负值,这里加一个负号不就正了嘛
在这里插入图片描述

具体推导过程请查看相关佬的文章(哭~)

🍀梯度下降的实现(代码)

首先我们导入我们需要的库

import numpy as np
import matplotlib.pyplot as plt

之后我们需要举一个例子,这里我们采用numpy里面的一个分割函数linspace,同时我们举一个函数的例子

plt_x = np.linspace(-1,6,141)
plt_y = (plt_x-2.5)**2-1

之后我们使用show进行展示一下图像

plt.plot(plt_x,ply_y)
plt.show()

运行结果如下
在这里插入图片描述

上图看起来就是一个普通的曲线,方便我们进行理解

接下来我们需要两个函数,一个为了返回导数,一个为了返回对应的y值

def dj(thera):return 2*(thera-2.5) # 求导
def j(thera)return (thera-2.5)**2-1  # 求对应的值

接下来是梯度下降的关键位置了,这里我们需要初始化两个参数以及一个范围参数,同时设置一个while循环,将前一个thera保存在last_thera中,后一个thera是前一个thera和步长的差值,这里的步长就是梯度个参数eta的乘积,最后使用if函数来终结循环,最终我们将最小值点的值、导数、以及自变量打印出来

eta = 0.1
theta =0.0
epsilon = 1e-8
while True:gradient = dj(theta)last_theta = thetatheta = theta-gradient*eta if np.abs(j(theta)-j(last_theta))<epsilon:breakprint(theta)
print(dj(theta))
print(j(theta))

运行结果如下
在这里插入图片描述
这里我们也可以使用列表来看看到底进行了多少次thera的循环

eta = 0.1
theta =0.0
epsilon = 1e-8
theta_history = [theta]
while True:gradient = dj(theta)last_theta = thetatheta = theta-gradient*eta theta_history.append(theta)if np.abs(j(theta)-j(last_theta))<epsilon:breakprint(theta)
print(dj(theta))
print(j(theta))len(theta_history)

运行结果如下

在这里插入图片描述
还可以绘制图像进行直观查看

plt.plot(plt_x,plt_y)
plt.plot(theta_history,[(i-2.5)**2-1 for i in theta_history],color='r',marker='*')
plt.show()

运行结果如下
在这里插入图片描述
这样的话就很直观了吧~

🍀总结

本节只介绍梯度下降的简单实现,下节继续学习此法中eta参数的调节

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/122232.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UML基础

统一建模语言&#xff08;UML是 Unified Modeling Language的缩写, 是用来对软件系统进行可视化建模的一种语言。UML为面向对象开发系统的产品 进行说明、可视化、和编制文档的一种标准语言。 共有9种图 UML中的图其实不止九种 (相同的图还可能会有不同的名称), 这里的九种图是…

SSM(Spring-Mybatis-SpringMVC)

文章目录 1. 介绍1.1 概念介绍 2 SSM整合框架3. SSM功能模块开发4 测试4.1 业务层接口测试4.2 表现层接口测试 5.优化 -表现层数据封装6.异常处理 1. 介绍 1.1 概念介绍 SSM项目是指基于SpringSpringMVCMyBatis框架搭建的Java Web项目。 Spring是负责管理和组织项目的IOC容器和…

selenium 动态爬取页面使用教程以及使用案例

Selenium 介绍 概述 Selenium是一款功能强大的自动化Web浏览器交互工具。它可以模拟真实用户在网页上的操作&#xff0c;例如点击、滚动、输入等等。Selenium可以爬取其他库难以爬取的网站&#xff0c;特别是那些需要登录或使用JavaScript的网站。Selenium可以自动地从Web页面…

[羊城杯 2020] easyphp

打开题目&#xff0c;源代码 <?php$files scandir(./); foreach($files as $file) {if(is_file($file)){if ($file ! "index.php") {unlink($file);}}}if(!isset($_GET[content]) || !isset($_GET[filename])) {highlight_file(__FILE__);die();}$content $_GE…

【广州华锐互动】AR技术在配电系统运维中的应用

随着科技的不断发展&#xff0c;AR(增强现实)技术逐渐走进了我们的生活。在电力行业&#xff0c;AR技术的应用也为巡检工作带来了许多新突破&#xff0c;提高了巡检效率和安全性。本文将从以下几个方面探讨AR配电系统运维系统的新突破。 首先&#xff0c;AR技术可以实现虚拟巡检…

opencv鼠标事件函数setMouseCallback()详解

文章目录 opencv鼠标事件函数setMouseCallback()详解1、鼠标事件函数&#xff1a;&#xff08;1&#xff09;鼠标事件函数原型&#xff1a;setMouseCallback()&#xff0c;此函数会在调用之后不断查询回调函数onMouse()&#xff0c;直到窗口销毁&#xff08;2&#xff09;回调函…

golang指针的学习笔记

package main // 声音文件所在的包&#xff0c;每个go文件必须有归属的包 import ("fmt" )// 引入程序中需要用的包&#xff0c;为了使用包下的函数&#xff0c;比如&#xff1a;Printin// 字符类型使用 func main(){ // 基本数据类型&#xff0c;变量存的就是值&am…

面向对象的软件测试案例 | Date.increment方法的测试

面向对象技术产生了更好的系统结构&#xff0c;更规范的编码风格&#xff0c;它极大地优化了数据使用的安全性&#xff0c;提高了程序代码的可重用性&#xff0c;使得一些人就此认为面向对象技术开发出的程序无须进行测试。应该看到&#xff0c;尽管面向对象技术的基本思想保证…

【前端】场景题:如何在ul标签中插入多个节点 使用文档片段

直接插入的问题&#xff1a;会回流多次。每插入一次li就会回流一次&#xff0c;消耗性能。 这里可以使用文档片段来解决这个问题。 // 创建文档片段 let node document.createDocumentFragment()DocumentFragment节点存在于内存中&#xff0c;并不在DOM中&#xff0c;所以将子…

Chrome 和 Edge 上出现“status_breakpoint”错误解决办法

文章目录 STATUS_BREAKPOINTSTATUS_BREAKPOINT报错解决办法Chrome浏览器 Status_breakpoint 错误修复- 将 Chrome 浏览器更新到最新版本- 卸载不再使用的扩展程序和应用程序- 安装计算机上可用的任何更新&#xff0c;尤其是 Windows 10- 重启你的电脑。 Edge浏览器 Status_brea…

flutter架构全面解析

Flutter 是一个跨平台的 UI 工具集&#xff0c;它的设计初衷&#xff0c;就是允许在各种操作系统上复用同样的代码&#xff0c;例如 iOS 和 Android&#xff0c;同时让应用程序可以直接与底层平台服务进行交互。如此设计是为了让开发者能够在不同的平台上&#xff0c;都能交付拥…

分类任务评价指标

分类任务评价指标 分类任务中&#xff0c;有以下几个常用指标&#xff1a; 混淆矩阵准确率&#xff08;Accuracy&#xff09;精确率&#xff08;查准率&#xff0c;Precision&#xff09;召回率&#xff08;查全率&#xff0c;Recall&#xff09;F-scorePR曲线ROC曲线 1. 混…

浅谈Mysql读写分离的坑以及应对的方案 | 京东云技术团队

一、主从架构 为什么我们要进行读写分离&#xff1f;个人觉得还是业务发展到一定的规模&#xff0c;驱动技术架构的改革&#xff0c;读写分离可以减轻单台服务器的压力&#xff0c;将读请求和写请求分流到不同的服务器&#xff0c;分摊单台服务的负载&#xff0c;提高可用性&a…

C#模拟PLC设备运行

涉及&#xff1a;控件数据绑定&#xff0c;动画效果 using System; using System.Windows.Forms;namespace PLCUI {public partial class MainForm : Form{ public MainForm(){InitializeComponent();}private void MainForm_Load(object sender, EventArgs e){// 方式2&#x…

现货黄金走势图中的止盈点

对平仓时机的把握能力&#xff0c;是衡量现货黄金投资者水平的重要标志&#xff0c;止盈点设置得是否合理&#xff0c;在行情兑现的时候能否及时地离场&#xff0c;是事关投资者账户浮盈最终能否落袋为安的“头等大事”&#xff0c;要在现货黄金走势图中把握止盈点&#xff0c;…

四旋翼飞行器基本模型(MatlabSimulink)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

C++11新特性① | C++11 常用关键字实战详解

目录 1、引言 2、C11 新增关键字详解 2.1、auto 2.2、override 2.3、final 2.4、nullptr 2.5、使用delete阻止拷贝类对象 2.6、decltype 2.7、noexcept 2.8、constexpr 2.9、static_assert VC常用功能开发汇总&#xff08;专栏文章列表&#xff0c;欢迎订阅&#xf…

微服务介绍

在认识微服务之前&#xff0c;需要先了解一下与微服务对应的单体式&#xff08;Monolithic&#xff09;式架构。在Monolithic架构中&#xff0c;系统通常采用分层架构模式&#xff0c; 按技术维度对系统进行划分&#xff0c;比如持久化层、业务逻辑层、表示层。 Monolithic架构…

PYTHON知识点学习-字典

&#x1f308;write in front&#x1f308; &#x1f9f8;大家好&#xff0c;我是Aileen&#x1f9f8;.希望你看完之后&#xff0c;能对你有所帮助&#xff0c;不足请指正&#xff01;共同学习交流. &#x1f194;本文由 Aileen_0v0&#x1f9f8; 原创 CSDN首发&#x1f412; 如…

Si24R2F+畜牧 耳标测体温开发资料

Si24R2F是针对IOT应用领域推出的新款超低功耗2.4G内置NVM单发射芯片。广泛应用于2.4G有源活体动物耳标&#xff0c;带实时测温计步功能。相较于Si24R2E&#xff0c;Si24R2F增加了温度监控、自动唤醒间隔功能&#xff1b;发射功率由7dBm增加到12dBm&#xff0c;距离更远&#xf…