深入探讨梯度下降:优化机器学习的关键步骤(二)

文章目录

  • 🍀引言
  • 🍀eta参数的调节
  • 🍀sklearn中的梯度下降

🍀引言

承接上篇,这篇主要有两个重点,一个是eta参数的调解;一个是在sklearn中实现梯度下降

在梯度下降算法中,学习率(通常用符号η表示,也称为步长或学习速率)的选择非常重要,因为它直接影响了算法的性能和收敛速度。学习率控制了每次迭代中模型参数更新的幅度。以下是学习率(η)的重要性:

  • 收敛速度:学习率决定了模型在每次迭代中移动多远。如果学习率过大,模型可能会在参数空间中来回摇摆,导致不稳定的收敛或甚至发散。如果学习率过小,模型将收敛得很慢,需要更多的迭代次数才能达到最优解。因此,选择合适的学习率可以加速收敛速度。

  • 稳定性:过大的学习率可能会导致梯度下降算法不稳定,甚至无法收敛。过小的学习率可以使算法更加稳定,但可能需要更多的迭代次数才能达到最优解。因此,合适的学习率可以在稳定性和收敛速度之间取得平衡。

  • 避免局部最小值:选择不同的学习率可能会导致模型陷入不同的局部最小值。通过尝试不同的学习率,您可以更有可能找到全局最小值,而不是被困在局部最小值中。

  • 调优:学习率通常需要调优。您可以尝试不同的学习率值,并监视损失函数的收敛情况。通常,您可以使用学习率衰减策略,逐渐降低学习率以改善收敛性能。

  • 批量大小:学习率的选择也与批量大小有关。通常,小批量梯度下降(Mini-batch Gradient Descent)使用比大批量梯度下降更大的学习率,因为小批量可以提供更稳定的梯度估计。

总之,学习率是梯度下降算法中的关键超参数之一,它需要仔细选择和调整,以在训练过程中实现最佳性能和收敛性。不同的问题和数据集可能需要不同的学习率,因此在实践中,通常需要进行实验和调优来找到最佳的学习率值。


🍀eta参数的调节

在上代码前我们需要知道,如果eta的值过小会造成什么样的结果

在这里插入图片描述
反之如果过大呢

在这里插入图片描述
可见,eta过大过小都会影响效率,所以一个合适的eta对于寻找最优有着至关重要的作用


在上篇的学习中我们已经初步完成的代码,这篇我们将其封装一下
首先需要定义两个函数,一个用来返回thera的历史列表,一个则将其绘制出来

def gradient_descent(eta,initial_theta,epsilon = 1e-8):theta = initial_thetatheta_history = [initial_theta]def dj(theta): return 2*(theta-2.5) #  传入theta,求theta点对应的导数def j(theta):return (theta-2.5)**2-1  #  传入theta,获得目标函数的对应值while True:gradient = dj(theta)last_theta = thetatheta = theta-gradient*eta theta_history.append(theta)if np.abs(j(theta)-j(last_theta))<epsilon:breakreturn theta_historydef plot_gradient(theta_history):plt.plot(plt_x,plt_y)plt.plot(theta_history,[(i-2.5)**2-1 for i in theta_history],color='r',marker='+')plt.show()

其实就是上篇代码的整合罢了
之后我们需要进行简单的调参了,这里我们分别采用0.10.010.9,这三个参数进行调节

eta = 0.1
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述

eta = 0.01
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述

eta = 0.9
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述
这三张图与之前的提示很像吧,可见调参的重要性
如果我们将eta改为1.0呢,那么会发生什么

eta = 1.0
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述
那改为1.1呢

eta = 1.1
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述
我们从图可以清楚的看到,当eta为1.1的时候是嗷嗷增大的,这种情况我们需要采用异常处理来限制一下,避免报错,处理的方式是限制循环的最大值,且可以在expect中设置inf(正无穷)

def gradient_descent(eta,initial_theta,n_iters=1e3,epsilon = 1e-8):theta = initial_thetatheta_history = [initial_theta]i_iter = 1def dj(theta):  try:return 2*(theta-2.5) #  传入theta,求theta点对应的导数except:return float('inf')def j(theta):return (theta-2.5)**2-1  #  传入theta,获得目标函数的对应值while i_iter<=n_iters:gradient = dj(theta)last_theta = thetatheta = theta-gradient*eta theta_history.append(theta)if np.abs(j(theta)-j(last_theta))<epsilon:breaki_iter+=1return theta_historydef plot_gradient(theta_history):plt.plot(plt_x,plt_y)plt.plot(theta_history,[(i-2.5)**2-1 for i in theta_history],color='r',marker='+')plt.show()

注意:inf表示正无穷大


🍀sklearn中的梯度下降

这里我们还是以波士顿房价为例子
首先导入需要的库

from sklearn.datasets import load_boston
from sklearn.linear_model import SGDRegressor

之后取一部分的数据

boston = load_boston()
X = boston.data
y = boston.target
X = X[y<50]
y = y[y<50]

然后进行数据归一化

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(X,y)
std = StandardScaler()
std.fit(X_train)
X_train_std=std.transform(X_train)
X_test_std=std.transform(X_test)
sgd_reg = SGDRegressor()
sgd_reg.fit(X_train_std,y_train)

最后取得score

sgd_reg.score(X_test_std,y_test)

运行结果如下
在这里插入图片描述


请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/117838.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis之管道解读

目录 基本介绍 使用例子 管道对比 管道与原生批量命令对比 管道与事务对比 使用pipeline注意事项 基准测试 基本介绍 Redis是一种基于客户端-服务端模型以及请求/响应协议的TCP服务器。 这意味着请求通常按如下步骤处理&#xff1a; 客户端发送一个请求到服务器&am…

Redis功能实战篇之附近商户

在互联网的app当中&#xff0c;特别是像美团&#xff0c;饿了么等app。经常会看到附件美食或者商家&#xff0c; 当我们点击美食之后&#xff0c;会出现一系列的商家&#xff0c;商家中可以按照多种排序方式&#xff0c;我们此时关注的是距离&#xff0c;这个地方就需要使用到我…

已解决module ‘pip‘ has no attribute ‘pep425tags‘报错问题(如何正确查看pip版本、支持、32位、64位方法汇总)

本文摘要&#xff1a;本文已解决module ‘pip‘ has no attribute ‘pep425tags‘的相关报错问题&#xff0c;并总结提出了几种可用解决方案。同时结合人工智能GPT排除可能得隐患及错误。并且最后说明了如何正确查看pip版本、支持、32位、64位方法汇总 &#x1f60e; 作者介绍&…

【051】基于Vue、Springboot电商管理系统(含源码、详细论文、数据库)

基于Vue、Springboot、Mysql的前后端分离的电商管理系统&#xff0c;不仅功能完善&#xff0c;还有详细课设报告供查看&#xff0c;这不收藏起来&#xff0c;源码和论文获取见文末结尾 部分报告内容如下&#xff08;省略图片&#xff09; c 目录 1 引言 4 1.…

java-数组

数组静态初始化写法&#xff1a; //静态初始化数组 int[] age new int[] {7,18,19}; double[] scores new double[]{67.5,77.8,94.2,99};//静态初始化数组简化写法 int[] age1 {7,18,19}; double[] scores2 {67.5,77.8,94.2,99};数组在内存中定义方式&#xff1a; 1.在内…

paddle 1-高级

目录 为什么要精通深度学习的高级内容 高级内容包含哪些武器 1. 模型资源 2. 设计思想与二次研发 3. 工业部署 4. 飞桨全流程研发工具 5. 行业应用与项目案例 飞桨开源组件使用场景概览 框架和全流程工具 1. 模型训练组件 2. 模型部署组件 3. 其他全研发流程的辅助…

POSTGRESQL WAL 日志问题合集之WAL 如何解析

开头还是介绍一下群&#xff0c;如果感兴趣polardb ,mongodb ,mysql ,postgresql ,redis 等有问题&#xff0c;有需求都可以加群群内有各大数据库行业大咖&#xff0c;CTO&#xff0c;可以解决你的问题。加群请加 liuaustin3微信号 &#xff0c;在新加的朋友会分到3群 &#xf…

C语言入门 Day_12 一维数组

目录 前言 1.创建一维数组 2.使用一维数组 3.易错点 4.思维导图 前言 存储一个数据的时候我们可以使用变量&#xff0c; 比如这里我们定义一个记录语文考试分数的变量chinese_score&#xff0c;并给它赋值一个浮点数&#xff08;float&#xff09;。 float chinese_scoe…

【计算机网络】序列化与反序列化

文章目录 1. 如何处理结构化数据&#xff1f;序列化 与 反序列化 2. 实现网络版计算器1. Tcp 套接字的封装——sock.hpp创建套接字——Socket绑定——Bind将套接字设置为监听状态——Listen获取连接——Accept发起连接——Connect 2. 服务器的实现 ——TcpServer.hpp初始化启动…

Nuxt 菜鸟入门学习笔记四:静态资源

文章目录 public 目录assets 目录全局样式导入 Nuxt 官网地址&#xff1a; https://nuxt.com/ Nuxt 使用以下两个目录来处理 CSS、fonts 和图片等静态资源&#xff1a; public 目录 public 目录用作静态资产的公共服务器&#xff0c;可通过应用程序定义的 URL 公开获取。 换…

隧道结构健康监测系统,保障隧道稳定安全运行

隧道是地下隐蔽工程&#xff0c;会受到潜在、无法预知的地质因素影响&#xff0c;早期修建的隧道经常出现隧道拱顶开裂、地表沉降、隧道渗漏水、围岩变形、附近建筑物倾斜等隧道的健康问题变得日益突出&#xff0c;作为城市生命线不可或缺的一部分&#xff0c;为了确保隧道工程…

Go 官方标准编译器中所做的优化

本文是对#102 Go 官方标准编译器中实现的优化集锦汇总[1] 内容的记录与总结. 优化1-4: 字符串和字节切片之间的转化 1.紧跟range关键字的 从字符串到字节切片的转换&#xff1b; package mainimport ( "fmt" "strings" "testing")var cs10086 s…

图像翻拍检测——反射分量分离的特征融合

随着计算机技术的迅速发展&#xff0c;需要建立人与信息一一对应的安保认证技术&#xff0c;通过建立完整的映射网络体系&#xff0c;从而确保每个人的人身、财产、隐私等的安全.与指纹、基因等人体生物特征识别系统相比&#xff0c;人脸识别系统更加友好&#xff0c;不需要人的…

unity面试题(性能优化篇)

CPU 预处理、缓存数据 注释空的unity函数 运算cpu->gpu 减少昂贵计算(开方) 限制帧数 加载(预加载、分帧加载、异步加载、对象池) 慎用可空类型比较 避免频繁计算(分帧、隔帧) 算法优化 变体收集预热 使用clear操作代替容器的new操作 unity spine使用二进制格式…

Data Rescue Professional for Mac:专业的数据恢复工具

在数字化时代&#xff0c;我们的生活和工作离不开电脑和存储设备。但是&#xff0c;意外情况时常发生&#xff0c;例如误删除文件、格式化硬盘、病毒攻击等&#xff0c;这些都可能导致重要的数据丢失。面对数据丢失&#xff0c;我们迫切需要一款可靠的数据恢复工具。今天&#…

CentOS 8 安装 Code Igniter 4

在安装好LNMP运行环境基础上&#xff0c;将codeigniter4文件夹移动到/var/nginx/html根目录下&#xff0c;浏览器地址栏输入IP/codeigniter/pulbic 一直提示&#xff1a; Cache unable to write to "/var/nginx/html/codeigniter/writable/cache/". 找了好久&…

文献阅读:Semantic Communications for Speech Signals

目录 论文简介动机&#xff1a;为什么作者想要解决这个问题&#xff1f;贡献&#xff1a;作者在这篇论文中完成了什么工作(创新点)&#xff1f;规划&#xff1a;他们如何完成工作&#xff1f;自己的看法(作者如何得到的创新思路) 论文简介 作者 Zhenzi Weng Zhijin Qin Geoffre…

【MySQL】用户管理

之前我们一直都使用root身份来对mysql进行操作&#xff0c;但这样存在安全隐患。这时&#xff0c;就需要使用MySQL的用户管理 目录 一、用户 1.1 用户信息 1.2 添加用户 1.3 删除用户 1.4 修改用户密码 二、用户权限 2.1 赋予授权 2.2 回收权限 一、用户 1.1 用户信息…

Kubernetes技术--使用kubeadm快速部署一个K8s集群

这里我们配置一个单master集群。(一个Master节点,多个Node节点) 1.硬件环境准备 一台或多台机器,操作系统 CentOS7.x-86_x64。这里我们使用安装了CentOS7的三台虚拟机 硬件配置:2GB或更多RAM,2个CPU或更多CPU,硬盘30GB或更多 2.主机名称和IP地址规划 3. 初始化准备工作…

Scala的函数式编程与高阶函数,匿名函数,偏函数,函数的闭包、柯里化,抽象控制,懒加载等

Scala的函数式编程 函数式编程 解决问题时&#xff0c;将问题分解成一个一个的步骤&#xff0c;将每个步骤进行封装&#xff08;函数&#xff09;&#xff0c;通过调用这些封装好的步骤&#xff0c;解决问题。 例如&#xff1a;请求->用户名、密码->连接 JDBC->读取…