梯度下降(批量梯度下降、随机梯度下降、小批量梯度下降)

在上一篇中我们推导了损失函数 J ( θ ) = 1 2 m ∑ i = 1 m ( y i − h θ ( x i ) ) 2 J(\theta) = \frac{1}{2m} \sum_{i=1}^{m} (y^{i} - h_{\theta}(x^{i}))^2 J(θ)=2m1i=1m(yihθ(xi))2的由来,结尾讲到最小化这个损失函数来找到最优的参数 θ \theta θ,通常是使用梯度下降实现的。

梯度下降广泛用于机器学习和统计建模中的参数估计,特别是在训练线性回归模型时。它的目标是最小化一个损失函数(目标函数),这个函数量化了模型预测和真实数据之间的误差。梯度下降通过迭代地调整模型的参数来减少成本函数的值。

梯度下降的过程类似于一个“盲人下山”的过程:

起始点:就像盲人随机地在山上的某个位置开始,梯度下降算法通常从一个随机的参数值开始。这个起始点可能远离最小值,也就是最低点。

目标值(误差最低值):算法的目标是找到成本函数的最小值,这就好比盲人想要下山到达山谷底部,那里是海拔最低的地方。

获取移动方向:盲人通过脚下的坡度来判断下山的方向,而梯度下降算法通过计算成本函数的梯度来确定下降的方向。梯度告诉我们成本函数上升最快的方向,我们需要往相反方向移动来降低成本函数的值。

控制移动距离(步长):盲人下山时的每一步都不会太大,以免跌倒;同理,梯度下降算法中的学习率决定了每一步下降的距离。学习率太大可能会越过最小值,太小则下降过程会非常缓慢。

递归移动(迭代更新):盲人会一步一步地移动,每走一步都基于当前位置的坡度来决定下一步的方向。梯度下降算法通过迭代地更新参数,每次迭代都基于当前参数的梯度来更新,直到找到最小值或者达到其他停止条件。
梯度下降
在这个下山(梯度下降)概念中,又细分出几种算法,其中如上述的普通梯度下降被命名为:批量梯度下降,除了批量梯度下降外还有随机梯度下降和小批量梯度下降。

批量梯度下降:

批量梯度下降(Batch Gradient Descent)和上述的方法一样,小步多次逐步找到最终的目标值,在每次迭代中使用全部的训练数据来计算损失函数的梯度。因为要用到全部训练数据,所以内存占用高、性能差、速度慢、准确度高。
批量梯度下降的详细步骤:

  1. 初始化参数:
    在开始迭代之前,首先随机选择一组参数 θ \theta θ或者从一个零向量开始。
  2. 计算梯度:
    在每次迭代中,先计算损失函数对于每个参数 θ j \theta_j θj的梯度。这涉及到对整个训练集的计算,如下所示: ∇ θ J ( θ ) = − α 1 m ∑ i = 1 m ( y i − h θ ( x i ) ) x i \nabla_{\theta} J(\theta) = -\alpha \frac{1}{m} \sum_{i=1}^{m} (y^{i}-h_{\theta}(x^{i}) )x^{i} θJ(θ)=αm1i=1m(yihθ(xi))xi

其中:

  • ∇ θ J ( θ ) \nabla_{\theta} J(\theta) θJ(θ)表示损失函数 J ( θ ) J(\theta) J(θ)关于参数向量 θ \theta θ的梯度
  • m m m是训练样本的数量
  • x i x^{i} xi是第 i i i个训练样本的特征
  • y i y^{i} yi是对应的目标值
  • y i − h θ ( x i ) y^{i} - h_{\theta}(x^{i}) yihθ(xi)是预测误差,也就是模型对第 i i i个训练样本的预测值 h θ ∗ ( x i ) h_{\theta}* (x^{i}) hθ(xi)与实际值 y i y^{i} yi之间的差异
  • x i x^{i} xi(第 i i i的特征向量)与预测误差相乘,表示梯度是如何随着特征 x i x^{i} xi的变化而变化
  • α \alpha α表示步长
  1. 更新参数:
    计算出梯度后,更新参数 θ : = θ − α ∇ θ J ( θ ) \theta := \theta - \alpha \nabla_{\theta} J(\theta) θ:=θαθJ(θ)
    α 是学习率,决定了在参数空间中移动的步长。
    迭代直至收敛:

重复步骤2和步骤3直到损失函数的值不再显著变化,或者达到一定的迭代次数。
所以相对于 θ j \theta_j θj的下一个位置 θ j ′ \theta_j^{\prime} θj就可以表示为 θ j \theta_j θj减去 ( − 1 m ∑ i = 1 m ( y i − h θ ( x i ) ) x i ) (-\frac{1}{m} \sum_{i=1}^{m} (y^{i}-h_{\theta}(x^{i}) )x^{i}) (m1i=1m(yihθ(xi))xi),即:

θ j ′ = θ j + α 1 m ∑ i = 1 m ( y i − h θ ( x i ) ) x j i \theta_j^{\prime} = \theta_j + \alpha \frac{1}{m} \sum_{i=1}^{m} \left(y^{i} - h_{\theta}(x^{i}) \right) x_j^{i} θj=θj+αm1i=1m(yihθ(xi))xji

随机梯度下降:

随机梯度下降是批量梯度下降的一个优化版本,每次只找一个样本,迭代速度快,但不一定每次都朝着收敛方向。对于每个参数 θ j ′ \theta_j^{\prime} θj,更新规则如下:

θ j ′ = θ j + ( y i − h θ ( x i ) ) x j i \theta_j^{\prime}= \theta_j + \left({y^{i} - h_ \theta}(x^{i}) \right) x_j^{i} θj=θj+(yihθ(xi))xji

小批量梯度下降:

小批量梯度下降在每次迭代中使用一个小批量(batch)的样本来计算梯度和更新参数。相当于取少量的数据牺牲一些每次移动的准确性,从而极大提高运算速度,因此也是最常用的梯度下降方法。对于每个参数 θ j ′ \theta_j^{\prime} θj,更新规则如下:
θ j ′ = θ j − α 1 B ∑ i = k k + B − 1 ( h θ ( x i ) − y i ) x j i \theta_j^{\prime}= \theta_j - \alpha \frac{1}{B} \sum_{i=k}^{k+B-1} \left( h_{\theta}(x^{i}) - y^{i} \right) x_j^{i} θj=θjαB1i=kk+B1(hθ(xi)yi)xji

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/211277.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Mysq8l在Centos上安装后忘记root密码如何重新设置

场景 Mysql8在Windows上离线安装时忘记root密码: Mysql8在Windows上离线安装时忘记root密码-CSDN博客 如果是在Windows上忘记密码可以参考上面。 如果在Centos中安装mysql可以参考下面。 CentOS7中安装Mysql8并配置远程连接和修改密码等: CentOS7中…

element中el-table表头通过header-row-style设置样式

文章目录 一、知识点二、设置全部表头2.1、方式一2.2、方式二 三、设置某个表头四、最后 一、知识点 有些时候需要给element-ui表头设置不同样式,比如居中、背景色、字体大小等等,这时就可以用到本文要说的属性header-row-style。官网说明如下所示&…

前后端分离vue+Nodejs社区志愿者招募管理系统

1、首页 1)滑动的社区照片册 使用轮播图,对社区的活动纪念与实时事件宣传。 每个图片附有文字链接,点击跳转对应社区要闻具体页。 2)社区公告栏 日常的社区公告以及系统说明在此区域中进行说明与展示。 2、志愿活动 1)志愿活动发布 想发布需要登录 2)志愿…

Linux基础项目开发1:量产工具——输入系统(三)

前言: 前面我们已经实现了显示系统,现在我们来实现输入系统,与显示系统类似,下面让我们一起来对输入系统进行学习搭建吧 目录 一、数据结构抽象 1. 数据本身 2. 设备本身: 3. input_manager.h 二、触摸屏编程 t…

Presto基础学习--学习笔记

1,Presto背景 2011年,FaceBook的数据仓库存储在少量大型hadoop/hdfs集群,在这之前,FaceBook的科学家和分析师一直靠hive进行数据分析,但hive使用MR作为底层计算框架,是专为批处理设计的,但是随…

亿胜盈科ATR2037 无限射频前端低噪声放大器

亿胜盈科ATR2037 是一款应用于无线通信射频前端,工作频段为 0.7 到 6GHz 的超低噪声放大器。 ATR2037 低噪声放大器采用先进的 GaAs pHEMT 工艺设计和制作,ATR2037 低噪声放大器在整个工作频段内可以获得非常好的射频性能超低噪声系数。 亿胜盈科ATR203…

abapgit 安装及使用

abapgit 需求 SA[ BASIS 版本 702 及以上 版本查看路径如下: 安装步骤如下: 1. 下载abapgit 独立版本 程序 链接如下:raw.githubusercontent.com/abapGit/build/main/zabapgit_standalone.prog.abap 2.安装开发版本 2.1 在线安装 前置条…

揭秘:软件测试中Web请求的完整流程!

在软件开发的过程中,测试是一个至关重要的环节。而在现代互联网应用中,Web请求是很常见的一个测试需求。本文将介绍Web请求的完整测试流程,帮助读者更好地理解软件测试的关键步骤。 一、测试准备阶段 在进行Web请求测试之前,测试团…

Could not resolve all files for configuration ‘:app:debugCompileClasspath‘.

修改前 修改后 maven {url https://developer.huawei.com/repo/}

Scrum敏捷开发流程及支撑工具

Scrum是一种敏捷开发框架,用于管理复杂的项目。以下这些步骤构成了Scrum敏捷开发流程的核心。通过不断迭代、灵活应对变化和持续反馈,Scrum框架帮助团队快速交付高质量的产品。 以下是Scrum敏捷开发流程的基本步骤: 产品Backlog创建&#xf…

idea通过remote远程调试云服务器

引用了第三方的包,调试是看不到运行流程,于是想到了idea的remote方法 -agentlib:jdwptransportdt_socket,servery,suspendn,address9002 写一个.sh文件并启动 nohup java -jar -agentlib:jdwptransportdt_socket,servery,suspendn,address9002 ./demo.j…

Nacos多数据源插件

Nacos从2.2.0版本开始,可通过SPI机制注入多数据源实现插件,并在引入对应数据源实现后,便可在Nacos启动时通过读取application.properties配置文件中spring.datasource.platform配置项选择加载对应多数据源插件.本文档详细介绍一个多数据源插件如何实现以及如何使其生效。 注意:…

MYSQL练题笔记-高级查询和连接-连续出现的数字

一、题目相关内容 1)相关的表和题目 2)帮助理解题目的示例,提供返回结果的格式 二、自己初步的理解 其实这一部分的题目很简单,但是没啥思路啊,怎么想都想不通,还是看题解吧,中等题就是中等题…

【虚拟机】Docker基础 【二】【数据卷和挂载本地目录】

2.2.数据卷 容器是隔离环境,容器内程序的文件、配置、运行时产生的容器都在容器内部,我们要读写容器内的文件非常不方便。大家思考几个问题: 如果要升级MySQL版本,需要销毁旧容器,那么数据岂不是跟着被销毁了&#x…

机器学习---线性回归算法

1、什么是回归? 从大量的函数结果和自变量反推回函数表达式的过程就是回归。线性回归是利用数理统计中回归分析来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法。 2、一元线性回归 3、多元线性回归 如果回归分析中包括两个或两个以上的自变量&a…

详解前后端交互时PO,DTO,VO模型类的应用场景

前后端交互时的数据传输模型 前后端交互流程 前后端交互的流程: 前端与后端开发人员之间主要依据接口进行开发 前端通过Http协议请求后端服务提供的接口后端服务的控制层Controller接收前端的请求Contorller层调用Service层进行业务处理Service层调用Dao持久层对数据持久化 …

Android : AndroidStudio开发工具优化

1.开启 gradle 单独的守护进程 Windows: 进入目录 C:\Users\Administrator\.gradle 创建文件: gradle.properties # Project-wide Gradle settings. # IDE (e.g. Android Studio) users: # Settings specified in this file will override any Gradle s…

MySQL安全最佳实践指南(2024版)

MySQL以其可靠性和效率在各种可用的数据库系统中脱颖而出。然而,与任何保存有价值数据的技术一样,MySQL数据库也是网络罪犯有利可图的目标。 这使得MySQL的安全性不再仅是一种选择,而是一种必要。这份全面的指南将深入研究保护MySQL数据库的…

集成开发环境 PyCharm 的安装【侯小啾python基础领航计划 系列(二)】

集成开发环境PyCharm的安装【侯小啾python基础领航计划 系列(二)】 大家好,我是博主侯小啾, 🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔…

AI - FlowField(流场寻路)

FlowField流场寻路,利用网格存储每个点对目标点的推力,网格上的单位根据对于推力进行移动。用于大量单位进行寻路对于同一目的地的寻路,常用于rts游戏等。 对应一张网格地图(图中黑块是不可行走区域) 生成热度图 计算所有网格对于目标点(…