神经网络的工程基础(二)——随机梯度下降法|文末送书

相关说明

这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。
本文涉及到的代码链接如下:regression2chatgpt/ch06_optimizer/stochastic_gradient_descent.ipynb

本文将讨论利用PyTorch实现随机梯度下降法的细节。

关于大语言模型的内容,推荐参考这个专栏。

内容大纲

  • 相关说明
  • 一、随机梯度下降法:更优化的算法
  • 二、算法细节
  • 三、代码实现
  • 四、粉丝福利

一、随机梯度下降法:更优化的算法

梯度下降法虽然在理论上很美好,但在实际应用中常常会碰到瓶颈。为了说明这个问题,令 L i L_i Li表示模型在 i i i点的损失,即 L i = ( y i − a x i − b ) 2 L_i = (y_i - ax_i - b)^2 Li=(yiaxib)2,对所有数据点的损失求和后,可以得到整体损失函数: L = 1 ⁄ n ∑ i L i L = 1⁄n ∑_iL_i L=1⁄niLi 。即模型的损失函数实际上是各个数据点损失的平均值,这一观点适用于大多数模型 1

计算整体损失函数 L L L的梯度可得, ∇ L = 1 ⁄ n ∑ i L i ∇L = 1⁄n \sum_i L_i L=1⁄niLi。也就是说,损失函数的梯度等于所有数据点处梯度的平均值。但是在实际应用中,通常会使用大型数据集计算所有数据点的梯度和,这需要相当长的时间。为了加速这个计算过程,可以考虑使用随机梯度下降法(Stochastic Gradient Descent,SGD)。

二、算法细节

随机梯度下降法的核心思想是:每次迭代时只随机选择小批量的数据点来计算梯度,然后用这个小批量数据点的梯度平均值来代替整体损失函数的梯度2

为了使算法的细节更加准确,引入一个超参数,称为批量大小(Batch Size),记作m。每次随机选取m个数据,记为 I 1 , I 2 , ⋯ , I m I_1,I_2,⋯,I_m I1,I2,,Im。使用这些数据点的梯度平均值来近似代替整体损失函数的梯度: ∇ L = 1 ⁄ n ∑ i ∇ L i ≈ 1 ⁄ m ∑ j = 1 m ∇ L I j ∇L = 1⁄n ∑_i∇L_i ≈ 1⁄m ∑_{j = 1}^m∇L_{I_j } L=1⁄niLi1⁄mj=1mLIj 。由此得到新的参数迭代公式:
a k + 1 = a k − η / m ∑ j = 1 m ∂ L I j / ∂ a b k + 1 = b k − η / m ∑ j = 1 m ∂ L I j / ∂ b (1) a_{k + 1} = a_k - η/m ∑_{j = 1}^m∂L_{I_j }/∂a \\ b_{k + 1} = b_k -η/m ∑_{j = 1}^m∂L_{I_j }/∂b \tag{1} ak+1=akη/mj=1mLIj/abk+1=bkη/mj=1mLIj/b(1)

在随机梯度下降法中,所有数据点都使用了一遍,称为模型训练了一轮。由此在实际应用中常使用另一个超参数——训练轮次(Epoch),表示所有数据将被用几遍,用于控制随机梯度下降法的循环次数。换句话说,就是公式(1)被迭代运算多少次。

在一些机器学习书籍和学术文献中,还对随机梯度下降法(当m=1时)和小批量梯度下降法(当m>1时)进行了进一步的区分。然而,这两种方法之间的区别并不大,其核心思想都是基于随机采样来近似计算梯度,从而高效地更新参数、优化模型。在实际应用中,会根据问题的性质和数据规模选择合适的批次大小,以获得最佳的训练效果。因此,本书将统一使用随机梯度下降法来代表这一类方法,以保持概念清晰和简洁。

与梯度下降法相比,随机梯度下降法更高效,这是因为小批量梯度计算比整体梯度计算快得多。尽管在随机梯度下降法中,采用小批量数据估计梯度可能会引入一些噪声,但实践证明这些噪声对整个优化过程有好处,有助于模型克服局部最优的“陷阱”,逐步逼近全局最优参数。

三、代码实现

随机梯度下降法的实现与梯度下降法类似,不同之处在于,每次计算梯度时需要“随机”选取一部分数据,具体的实现步骤可以参考程序清单1(完整代码)。

  1. 在程序清单1的第2行,引入一个名为batch_size的超参数,用于控制每个批次中的数据量大小。选择合适的batch_size对算法的运行效率和稳定性至关重要。如果参数设置过大,可能会导致算法运行效率下降;而过小的参数可能使算法变得过于随机,影响收敛的稳定性。选择合适的参数需要结合具体的模型和应用场景,结合相关领域的经验进行决策。
  2. 在程序清单1的第11—13行,展示了一种随机选取批次数据的实现方式。这也是随机梯度下降法与普通梯度下降法的主要区别之一。实现随机性的方式有很多种,比如引入随机数等。这里仅呈现一种经典方法:将数据按顺序划分成批次。
程序清单1 随机梯度下降法
 1 |  # 定义每批次用到的数据量2 |  batch_size = 203 |  # 定义模型4 |  model = Linear()5 |  # 确定最优化算法6 |  learning_rate = 0.17 |  optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)8 |  9 |  for t in range(20):
10 |      # 选取当前批次的数据,用于训练模型
11 |      ix = (t * batch_size) % len(x)
12 |      xx = x[ix: ix + batch_size]
13 |      yy = y[ix: ix + batch_size]
14 |      yy_pred = model(xx)
15 |      # 计算当前批次数据的损失
16 |      loss = (yy - yy_pred).pow(2).mean()
17 |      # 将上一次的梯度清零
18 |      optimizer.zero_grad()
19 |      # 计算损失函数的梯度
20 |      loss.backward()
21 |      # 迭代更新模型参数的估计值
22 |      optimizer.step()
23 |      # 注意!loss记录的是模型在当前批次数据上的损失,该数值的波动较大
24 |      print(f'Step {t + 1}, Loss: {loss: .2f}; Result: {model.string()}')

在随机梯度下降法的执行过程中,通常使用模型的整体损失作为指标来监测算法的运行情况。但要注意的是,程序清单1中第16行定义的loss表示模型在小批量数据上的损失,这个值仅依赖于少量数据,迭代过程中会表现出极大的不稳定性,因此并不适合作为评估算法运行情况的主要标志。

如果希望更准确地监测算法的运行情况,需要在更大的数据集上估计模型的整体损失,例如在全部训练数据上计算损失,如图1所示。这种评估方式更稳定,能够更全面地反映模型的训练进展。

图1

图1

四、粉丝福利

参与方式:关注博主、点赞、收藏、评论区评论“解构大语言模型”(切记要点赞+收藏,否则抽奖无效,每个人最多评论三次!)
本次送书数量不少于3本,【阅读量越多,送得越多】
活动结束后,会私信中奖粉丝,请各位注意查看私信哦~

活动截止时间:2024-05-24 24:00:00


  1. 对于解决回归问题的模型,这个结论显然成立。对于解决分类问题的模型(比如逻辑回归模型),只需对模型的似然函数做简单的数学变换(先求对数,再求相反数),就可以得到同样的结论。 ↩︎

  2. 这在数学上是完全合理的。从统计的角度来看,用所有数据点求平均值,并不比随机抽样的方法高明很多。与线性回归参数估计值类似,两个结果都是随机变量:它们都以真实梯度为期望,只是前者的置信区间更小。 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/338202.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QT——QSlider实现,QT滑动控件的使用

目录 简介滑动块调节两种方法滑动条触发信号量理想滑动块运用(参考) 简介 QT中滑动条的控件叫QSlider,继承自QAbstractSlider类。 主要用途是通过滑块的滑动的方式在一定范围内调节某个值。根据调节的后得到的结果去执行一些处理&#xff0c…

第十三章 进程与线程

第十三章 进程与线程 程序与进程的概念 程序: 英文单词为Program,是指一系列有序指令的集合,使用编程语言所编写,用于实现一定的功能。 进程: 进程则是指启动后的程序,系统会为进程分配内存空间。 函数式…

浅析R16移动性增强那些事儿(DAPS/CHO/MRO)

R16移动性增强相关技术总结 Dual Active Protocol Handover Dual Active Protocol Handover意为双激活协议栈切换,下文简称DAPS切换,DAPS切换的核心思想是切换过程中,在UE成功连接到目标基站前继续保持和源基站的连接和数据传输,…

贷款借钱平台 小额贷款系统开发小额贷款源码 贷款平台开发搭建

这款是贷款平台源码/卡卡贷源码/小贷源码/完美版 后台51800 密码51800 数据库替换application/database.php程序采用PHPMySQL,thinkphp框架代码开源,不加密后台效果:手机版效果 这款是贷款平台源码/卡卡贷源码/小贷源码/完美版 后台51800 密码…

【软考】2024年5月系统架构设计师考试感受

目录 一 考试时间 二 考试方式 三 考试批次安排 四 回忆版真题 五 考试感受 一 考试时间 2024年5月系统架构设计师考试时间如下: 5🈷️25日上午 8点30-12点30: 综合知识和案例分析💚 5🈷️25日下午 14点30-16点30: 论文…

《庆余年算法番外篇》:范闲通过最短路径算法在阻止黑骑截杀林相

剧情背景 在《庆余年 2》22集中,林相跟大宝交代完为人处世的人生哲理之后,就要跟大宝告别了 在《庆余年 2》23集中,林相在告老还乡的路上与婉儿和大宝告别后 范闲也在与婉儿的对话中知道黑骑调动是绝密,并把最近一次告老还乡梅…

民国漫画杂志《时代漫画》第39期.PDF

时代漫画39.PDF: https://url03.ctfile.com/f/1779803-1248636473-6bd732?p9586 (访问密码: 9586) 《时代漫画》的杂志在1934年诞生了,截止1937年6月战争来临被迫停刊共发行了39期。 ps: 资源来源网络!

Base64码转换

title: Base64码转换 date: 2024-06-01 20:30:28 tags: vue3 后端图片前端显示乱码 现象 后端传来一个图片,前端能够接收,但是console.log()后发现图片变成了乱码,但是检查后台又发现能够正常的收到了这张图片。 处理方法 笔者有尝试将图…

Leetcode3165. 不包含相邻元素的子序列的最大和(Go中的线段树分治包含多类数据使用maintain进行维护)

题目截图 题目分析 不能取相邻的,就是打家劫舍 然后更改某一个值就是单点更新 更新后,需要更新区间的值 需要注意的是,使用分治时需要考虑到一头一尾的问题,所以有4种情况(选or不选在两个位置) 这四种情况…

利用 Scapy 库编写 Teardrop 攻击脚本

一、介绍 Teardrop攻击是一种历史上比较著名的拒绝服务(Denial of Service, DoS)攻击,主要利用了IP数据包分片和重组过程中的漏洞来攻击目标系统。以下是对Teardrop攻击的详细介绍: 1.1 攻击原理 IP协议允许数据包在传输过程中…

jpom ruoyi 发布后端

添加ssh 添加标签 添加仓库 添加构建 构建 命令 APP_NAMEenterprise IMAGE_NAMEenterprise:latest APP_PORT8080 RUN_ENVjenkins cd ruoyi-admin docker stop $APP_NAME || true docker rm $APP_NAME || true docker rmi $IMAGE_NAME || true docker build -f Dockerfil…

System-Verilog 实现DE2-115倒车雷达模拟

System-Verilog 实现DE2-115倒车雷达模拟 引言: 随着科技的不断进步,汽车安全技术也日益成为人们关注的焦点。在众多汽车安全辅助系统中,倒车雷达以其实用性和高效性脱颖而出,成为现代汽车不可或缺的一部分。倒车雷达系统利用超声…

Django ORM魔法:用Python代码召唤数据库之灵!

探索Django ORM的神奇世界,学习如何用Python代码代替复杂的SQL语句,召唤数据库之灵,让数据管理变得轻松又有趣。从基础概念到高级技巧,阿佑带你一步步成为Django ORM的魔法师,让你的应用开发速度飞起来! 文…

【Java】面向对象的三大特征:封装、继承、多态

封装 什么叫封装? 在我们写代码的时候经常会涉及两种角色: 类的实现者 和 类的调用者。 封装的本质就是让类的调用者不必太多的了解类的实现者是如何实现类的, 只要知道如何使用类就行了,这样就降低了类使用者的学习和使用成本&a…

Windows环境安装redis

1、下载redis https://github.com/tporadowski/redis/releases 2、解压 .zip 3、更改文件名 更改文件名称为:redis 4、将本地解压后的redis,作为本地服务器下的应用服务 从redis文件路径下,执行cmd .\redis-server --service-install re…

LeetCode - 贪心(Greedy)算法集合(Python)[分配问题|区间问题]

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/139242199 贪心算法,是在每一步选择中,都采取当前状态下,最好或最优(即最有利)的选择&…

基于SSM框架的手机商城项目

后端: 订单管理 客户管理: 商品管理 类目管理 前端: 首页:

Python 学习笔记【1】

此笔记仅适用于有任一编程语言基础,且对面向对象有一定了解者观看 文章目录 数据类型字面量数字类型数据容器字符串列表元组 type()方法数据类型强转 注释单行注释多行注释 输出基本输出连续输出,中间用“,”分隔更复杂的输出格式 变量定义del方法 标识符…

基础—SQL—DQL(数据查询语言)排序查询

一、引言 排序查询这里面涉及的关键字:ORDER BY。在我们日常的开发中,这个是很常见的,比如打开一个网购的商城,这里面可以找到一个销量的排序、综合的排序、价格的排序(升序、降序)等等。接下来就学习这一部…

8-Django项目--登录及权限

目录 templates/login/login.html templates/login/404.html views/login.py utils/pwd_data.py auth.py settings.py 登录及权限 登录 views.py 中间件 auth.py templates/login/login.html {% load static %} <!DOCTYPE html> <html lang"en"&g…