12.梯度下降法的具体解析——举足轻重的模型优化算法

引言

梯度下降法(Gradient Descent)是一种广泛应用于机器学习领域的基本优化算法,它通过迭代地调整模型参数,最小化损失函数以求得到模型最优解。

通过阅读本篇博客,你可以:

1.知晓梯度下降法的具体流程

2.掌握不同梯度下降法的区别

一、梯度下降法的流程

梯度下降法的流程通常分为以下四个步骤。

1.初始化模型参数

初始化模型参数其实就是random随机一个初始的 \theta (一组 W_{0},...,W_{n})。这样我们就可以得到上图中的 Starting Point(开始点)

2.计算当下参数的梯度

计算模型参数的梯度,其实就是对于当前损失函数所在位置进行求偏导,公式:

gradient_{j} = \frac{\partial J(\theta)}{\partial \theta_{j}}

公式推导 J(\theta) 是损失函数,\theta_{j} 是样本中某个特征维度 x_{j} 对应的权值系数,也可以写成 W_{j} 。对于多元线性回归来说,损失函数 J(\theta) = \frac{1}{2}(h_{\theta}x - y)^{2} (推导过程在9.深入线性回归推导出MSE——不容小觑的线性回归算法-CSDN博客中),因为我们的MSE中 Xy 是已知的,\theta 是未知的,而 \theta 不是一个变量而是许多向量组成的矩阵,所以我们只能对含有一堆变量的函数MSE中的一个变量求导,即偏导,下面就是对 \theta_{j} 求偏导。

\frac{\partial J(\theta)}{\partial \theta_{j}} = \frac{\partial \frac{1}{2}(h_{\theta}x-y)^{2}}{\partial \theta_{j}}

由于链式求导法则,我们可以推出:

\Rightarrow \frac{\partial J(\theta)}{\partial \theta_{j}} = 2 \cdot \frac{1}{2}(h_{\theta}x-y) \cdot\frac{\partial (h_{\theta}x - y)}{\partial \theta_{j}}

在多元线性回归中,h_{\theta}x 就是 W^{T}X,也就是 \omega _{0}x_{0} + \omega_{1}x_{1}+...+\omega_{n}x_{n},我们通常把它写成\sum_{n}^{i = 0}\omega _{i}x_{i} ,所以继续推导公式:

\Rightarrow \frac{\partial J(\theta)}{\partial \theta_{j}} = (h_{\theta}x - y) \cdot \frac{\partial \sum_{n}^{i =0}(\theta_{i}x_{i}-y)}{\partial \theta_{j}}

由于我们是对 \theta_{j} 求偏导,那么和 \theta_{j} 无关的可以忽略不计,所以公式变为:

\Rightarrow \frac{\partial J(\theta)}{\partial \theta_{j}} = (h_{\theta}x - y) \cdot x_{j}

所以,我们可以得到结论:\theta_{j} 对应的梯度(gradient)与预测值 \hat{y} 和真实值 y 有关,同时还与每个特征维度 x_{j} 有关。如果我们分别对每个维度求偏导,即可得到所有维度对应的梯度值。

3.根据梯度和学习率更新参数

通过11.梯度下降法的思想——举足轻重的模型优化算法-CSDN博客的学习,我们已经知道了梯度下降法的公式:

W_{j}^{t+1} = W_{j}^{t} - \eta \cdot gradient_{j}

在获得了梯度之后,我们可以将公式表示为:

W_{j}^{t+1} = W_{j}^{t} - \eta \cdot (h_{\theta}x - y) \cdot x_{j}

通过这个公式我们就可以去更新参数逼近最优解。

4.判断是否收敛

在如何判断收敛问题上,我相信大多数的人都会认为直接判断梯度(gradient)是否为0。其实这样的方法是错误的,由于非凸损失函数的存在,gradient = 0 的情况可能是极大值!所以我们使用了另外一种方法,设置合理的阈值(Threshold)来界定函数是否收敛。即判断不等式:

Loss^{t} - Loss^{t+1} < Threshold

如果前一次的损失函数 Loss^{t} 减去这次迭代后的损失函数 Loss^{t+1} 小于我们设定的阈值Threshold ,那我们认为函数收敛,当前的参数就是我们寻求的最优解。反之,我们重复第二步与第三步,一直达到最优解为止。其实我们是在判断 Loss 的下降收益是否更合理,随着迭代次数的增多,Loss 减小的幅度不再变化就可以认为停止在最低点。

二、梯度下降法的分类

我们根据梯度下降法流程中求取梯度的步骤样本数量的不同,将梯度下降法分为三个基本的类别。它们每次学习(更新模型参数)使用的样本个数,每次更新使用不同的样本会导致每次学习的准确性和学习时间不同

1.全量梯度下降(Batch Gradient Descent)

全量梯度下降(Batch Gradient Descent)通过使用整个数据集在每次迭代中计算损失函数的梯度,以此更新模型参数(也称批量梯度下降)。由于我们使用整个数据集的样本,所以全量梯度下降的公式为:

W_{j}^{t+1} = W_{j}^{t} - \eta \cdot \sum_{m}^{i = 1}(h_{\theta}x_{i} - y_{i}) \cdot x_{j}

在全量梯度下降中,对于 \theta 的更新,所有的样本都有贡献,也就是参与调整 \theta 。所以从理论上来说一次更新的幅度是比较大的。

全量梯度下降法的优点在于收敛稳定,每次更新都朝着全局最优的方向移动。并且能够净化噪声,由于使用整个数据集计算梯度,随机噪声对更新的影响较小,使得损失函数的路径相对平滑。

缺点也是相当明显,当数据集非常大时,全量梯度下降法每个迭代计算数据集的梯度是非常耗时且占用内存的。所以不适合处理实时数据,比如在线学习和实时更新数据场景。

上图表示的梯度下降法中两个维度参数的关系,我们可以将圆圈看成一个碗的俯视图,碗底就是我们要找的最优解。我们不难发现,全量梯度下降法每次迭代都直接向碗底行进,目标明确。

2.随机梯度下降(Stochastic Gradient Descent)

随机梯度下降(Stochastic Gradient Descent)通过使用数据集中的一个随机样本在每次迭代中计算损失函数的梯度,以此更新模型参数。由于使用随机的一个样本,所以随机梯度下降的公式就是:

W_{j}^{t+1} = W_{j}^{t} - \eta \cdot (h_{\theta}x - y) \cdot x_{j}

随机梯度下降的优点在于计算速度快,由于每次迭代只对一个样本计算梯度,因此更新速度快,适合大规模数据集。它还拥有更强的泛化能力,由于引入了随机性,SGD能更好地跳出局部最优,避免过拟合(过拟合相关内容会在专栏后续文章中更新)。并且能够处理实时数据,可以在线学习,所以适用于动态更新的场景。

同样地,由于每次更新只基于一个样本,SGD的收敛并不稳定,梯度波动较大,会导致损失函数的收敛路径不平稳。并且由于随机性的存在,SGD通常需要更多的迭代次数才能收敛到最优解,即收敛速度变慢

从上图我们可以看出,相比较全量梯度下降,SGD需要迭代更多的次数才能找到最优解。

3.小批量梯度下降(Mini-batch Gradient Descent)

小批量梯度下降(Mini-batch Gradient Descent)通过使用数据集的一部分样本在每次迭代中计算损失函数的梯度,以此更新模型参数。由于使用了数据集的部分样本,所以小批量梯度下降的公式为:

W_{j}^{t+1} = W_{j}^{t} - \eta \cdot \sum_{batchsize}^{i = 1}(h_{\theta}x_{i} - y_{i}) \cdot x_{j}

小批量梯度下降综合了全量梯度下降与随机梯度下降,在更新速度与更新次数中取得一个平衡。其每次更新从数据集中随机选择 batchsize 个样本进行学习。相对于随机梯度下降法,小批量梯度下降法降低了收敛的波动性(降低了参数更新的方差),使得更新更加稳定。相对于全量梯度下降法,其提高了每次学习的速度。

小批量梯度下降的优点在于平衡了计算效率和收敛稳定性。并且不用担心内存瓶颈而使用向量化计算,还能利用GPU的并行计算能力提高计算速度。在每个小批量中,我们可以设置不同的学习率,提高模型的训练表现。

小批量梯度下降的缺点则在于样本的大小会影响训练效果,所以我们要人为地选择合适的样本大小。

从下图中我们就能看到随机梯度下降与小批量梯度下降的区别。

总结

本篇博客讲解了梯度下降法的流程和大致的分类。希望可以对大家起到作用,谢谢。


关注我,内容持续更新(后续内容在作者专栏《从零基础到AI算法工程师》)!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/437986.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据仓库简介(一)

数据仓库概述 1. 什么是数据仓库&#xff1f; 数据仓库&#xff08;Data Warehouse&#xff0c;简称 DW&#xff09;是由 Bill Inmon 于 1990 年提出的一种用于数据分析和挖掘的系统。它的主要目标是通过分析和挖掘数据&#xff0c;为不同层级的决策提供支持&#xff0c;构成…

云服务架构与华为云架构

目录 1.云服务架构是什么&#xff1f; 1.1 云服务模型 1.2 云部署模型 1.3 云服务架构的组件 1.4 云服务架构模式 1.5 关键设计考虑 1.6 优势 1.7 常见的云服务架构实践 2.华为云架构 2.1 华为云服务模型 2.2 华为云部署模型 2.3 华为云服务架构的核心组件 2.4 华…

【C++】STL标准模板库容器set

&#x1f984;个人主页:修修修也 &#x1f38f;所属专栏:C ⚙️操作环境:Visual Studio 2022 目录 &#x1f4cc;关联式容器set(集合)简介 &#x1f4cc;set(集合)的使用 &#x1f38f;set(集合)的模板参数列表 &#x1f38f;set(集合)的构造函数 &#x1f38f;set(集合)的迭代…

翔云 OCR:发票识别与验真

在数字化时代&#xff0c;高效处理大量文档和数据成为企业和个人的迫切需求。翔云 OCR 作为一款强大的光学字符识别工具&#xff0c;在发票识别及验真方面表现出色&#xff0c;为我们带来了极大的便利。 一、翔云 OCR 简介 翔云 OCR 是一款基于先进的人工智能技术开发的文字识别…

搭建k8s集群服务(kubeadm方式)

准备工作 操作系统版本&#xff1a;CentOS Linux release 7.9.2009 (Core) 虚拟机硬件配置&#xff1a;2核8G内存&#xff08;最低2G&#xff09;&#xff0c;硬盘最低25G&#xff1b; linux内核版本&#xff08;3.10版本尝试失败&#xff09;&#xff1a;5.4.268-1.el7.elr…

基于Java+VUE+echarts大数据智能道路交通信息统计分析管理系统

大数据智能交通管理系统是一种基于Web的系统架构&#xff0c;通过浏览器/服务器&#xff08;B/S&#xff09;模式实现对城市交通数据的高效管理和智能化处理。该系统旨在通过集成各类交通数据&#xff0c;包括但不限于车辆信息、行驶记录、违章情况等&#xff0c;来提升城市管理…

【Python】AudioLazy:基于 Python 的数字信号处理库详解

AudioLazy 是一个用于 Python 的开源数字信号处理&#xff08;DSP&#xff09;库&#xff0c;设计目的是简化信号处理任务并提供更直观的操作方式。它不仅支持基础的滤波、频谱分析等功能&#xff0c;还包含了滤波器、信号生成、线性预测编码&#xff08;LPC&#xff09;等高级…

两个向量所在平面的法线,外积,叉积,行列式

偶尔在一个数学题里面看到求两向量所在平面的法线&#xff0c;常规方法可以通过法线与两向量垂直这一特点&#xff0c;列两个方程求解&#xff1b;另外一种方法可以通过求解两个向量的叉积&#xff0c;用矩阵行列式 (determinant) 的方式&#xff0c;之前还没见过&#xff0c;在…

【计算机网络】传输层UDP和TCP协议

目录 再谈端口号端口号范围划分认识知名端口号查看知名端口号两个问题 UDP协议UDP特点UDP的缓冲区基于UDP的应用层协议 TCP协议TCP协议格式确认应答机制超时重传机制连接管理机制&#xff08;三次握手与四次挥手&#xff09;理解TIME_WAIT状态理解CLOSE_WAIT状态滑动窗口快重传…

【C++】迭代器失效问题解析

✨ Blog’s 主页: 白乐天_ξ( ✿&#xff1e;◡❛) &#x1f308; 个人Motto&#xff1a;他强任他强&#xff0c;清风拂山冈&#xff01; &#x1f525; 所属专栏&#xff1a;C深入学习笔记 &#x1f4ab; 欢迎来到我的学习笔记&#xff01; 一、迭代器失效的概念 迭代器的作用…

【PyTorch】生成对抗网络

生成对抗网络是什么 概念 Generative Adversarial Nets&#xff0c;简称GAN GAN&#xff1a;生成对抗网络 —— 一种可以生成特定分布数据的模型 《Generative Adversarial Nets》 Ian J Goodfellow-2014 GAN网络结构 Recent Progress on Generative Adversarial Networks …

Python | Leetcode Python题解之第450题删除二叉搜索树中的节点

题目&#xff1a; 题解&#xff1a; class Solution:def deleteNode(self, root: Optional[TreeNode], key: int) -> Optional[TreeNode]:cur, curParent root, Nonewhile cur and cur.val ! key:curParent curcur cur.left if cur.val > key else cur.rightif cur i…

解决Excel时出现“被保护单元格不支持此功能“的解决办法,详细喂饭级教程

今天有个朋友发过来一个excel文件&#xff0c;本来想修改表格的内容&#xff0c;但是提示&#xff0c;被保护单元格不支持此功能&#xff0c;对于这个问题&#xff0c;找到一个解决方法&#xff0c;现记录下来&#xff0c;分享给有需要的朋友。 表格文件名为aaa.xls,以WPS为例。…

什么是转义字符

1.什么是转义字符 转义字符是一组特殊的字符&#xff0c;转义字符顾名思义就是&#xff1a;转变原来的意思。 比如&#xff1a;我们有一组字符&#xff0c;其中的n能完整的打印出来&#xff0c;如下&#xff1a; #include <stdio.h> int main() { printf("asnfd&…

Typora解决图片复制到其他博客平台,解决图片显示转存失败(CSDN除外)

目录 一、Typora这个Markdown编辑器的确好用1.1 安装 二、 问题“图片转存失败”2.1 问题具体显示如下&#xff1a;2.2 问题分析&#xff1a;其实就是图片在typora里面是使用的本地路径&#xff0c;因此不显示&#xff0c; 三、解决方案3.1打开Typora&#xff0c;按下述图片显示…

【Verilog学习日常】—牛客网刷题—Verilog企业真题—VL74

异步复位同步释放 描述 题目描述&#xff1a; 请使用异步复位同步释放来将输入数据a存储到寄存器中&#xff0c;并画图说明异步复位同步释放的机制原理 信号示意图&#xff1a; clk为时钟 rst_n为低电平复位 d信号输入 dout信号输出 波形示意图&#xff1a; 输入描…

网络原理-数据链路层

在这一层中和程序员距离比较遥远&#xff0c;除非是做交换机开发&#xff0c;否则不需要了解数据链路层 由AI可知&#xff1a; 数据链路层&#xff08;Data Link Layer&#xff09;是OSI&#xff08;Open Systems Interconnection&#xff09;七层网络模型中的第二层&#xff0…

Elasticsearch 开放推理 API 增加了对 Google AI Studio 的支持

作者&#xff1a;来自 Elastic Jeff Vestal 我们很高兴地宣布 Elasticsearch 的开放推理 API 支持 Gemini 开发者 API。使用 Google AI Studio 时&#xff0c;开发者现在可以与 Elasticsearch 索引中的数据进行聊天、运行实验并使用 Google Cloud 的模型&#xff08;例如 Gemin…

用网络分析仪测试功分器驻波的5个步骤

在射频系统中&#xff0c;功分器的驻波比直接关系到信号的稳定性和传输效率。本文将带您深入了解驻波比的测试方法和影响其结果的因素。 一、功分器驻波比 驻波(Voltage Standing Wave Ratio)&#xff0c;简称SWR或VSWR&#xff0c;是指频率相同、传输方向相反的两种波&#xf…

TCN模型实现电力数据预测

关于深度实战社区 我们是一个深度学习领域的独立工作室。团队成员有&#xff1a;中科大硕士、纽约大学硕士、浙江大学硕士、华东理工博士等&#xff0c;曾在腾讯、百度、德勤等担任算法工程师/产品经理。全网20多万粉丝&#xff0c;拥有2篇国家级人工智能发明专利。 社区特色&a…