机器学习第4天:模型优化方法—梯度下降

文章目录

前言

梯度下降原理简述

介绍

可能的问题

批量梯度下降

随机梯度下降

基本算法

存在的问题

退火算法

代码演示

小批量梯度下降


前言

若没有机器学习基础,建议先阅读同一系列以下文章

机器学习第1天:概念与体系漫游-CSDN博客

机器学习第2天:训练数据的获取与处理-CSDN博客

机器学习第3天:线性回归-CSDN博客

梯度下降原理简述

介绍

在一个多元函数中,某点的梯度方向代表函数增加最快的方向,梯度下降的原理就是,找到损失函数下降最快的方向(与梯度方向相反),然后往这个方向走,最后达到损失函数的最小值,如下图,从高的红色点到达了低的蓝色点,梯度下降就是这样一个过程

我们可以得到一个参数更新公式,把参数设为a, 梯度设为grad,那么

a=a-n*grad

为什么要有个n呢,因为梯度方向只能表示当前位置函数增加最快的方向,对于复杂的函数,当我们改变位置的时候,这个梯度可能一直在变化,所以n代表我们每走一步的距离,我们慢慢的走下去,然后每走一步再找一次方向,这样就能走到最小的位置了

可能的问题

由于算法是往最低的地方走,当走到函数局部最小值的时候,周围都比较高,那么就可能困在这里了,无法达到全局最小值别急,当然,当损失函数是一个凸函数的时候,是没有局部最小值的,只有全局最小值,例如MSE就是一个凸函数


批量梯度下降

批量梯度下降指的是用整个向量经过矩阵运算来计算梯度,容易知道,这样算法会很慢,当训练集很大时,可能要花费很多时间

我们将每个训练集实例比作一个高山上的一个点,批量梯度下降就是算出这些点的整体趋势,然后向下运动,我们将在下面看到一种不同的想法


随机梯度下降

基本算法

与批量梯度下降不同的是,随机梯度下降每次随机选择一个实例来计算梯度,这样大大减小了运行时间,并且,随机梯度下降可以摆脱局部最小值的问题,因为随机挑选,那么即使有一部分在局部最小值中,还有一部分的方向选择可以将困住的部分解救出。


存在的问题

训练集向量中的每一个实例对应于山上的某个点,随机梯度下降就是以某个点来抉择整体的下降趋势,可以预料到,下降的过程将不会那么顺利,但趋势是对的,可以看涨图来理解随机梯度下降与批量梯度下降的区别


退火算法

当随机梯度下降算法的参数越接近最小值的时候,因为随机性,可能永远到达不了最小值,会在这周围运动

这里我们可以减小步长n,让每次变化的幅度小一点,这样我们就能更加靠近全局最小值了,这就被称作退火算法

代码演示(随机梯度下降与退火算法)

import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import SGDRegressornp.random.seed(42)x = np.random.rand(100, 1)
y = 2 * x + np.random.rand(100, 1)model = SGDRegressor(max_iter=1000, tol=1e-3, penalty=None, eta0=0.1)
model.fit(x, y)
pre_y = model.predict(x)plt.scatter(x, y)
plt.plot(x, pre_y, "r-")
plt.show()

max_iter为下降批次,tol为损失函数阈值,penalty为不使用正则化(可自行搜索),eta0为最初的步长(之后会慢慢减小),整体意思就是当模型训练1000次或损失函数比0.001小时停止训练 

可以看到拟合效果也很好


小批量梯度下降

有了上面两种梯度下降的定义,小批量梯度下降应该也好理解了,它兼容二者的优点与缺点

训练快,容易到最小值,但是可能难以辨别局部最小值

当你使用GPU的时候,定义处理批次与GPU相同可以充分利用硬件资源,提高效率

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/196584.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

802.11-2020协议学习__专题__TxTime-Calculation__HR/DSSS

802.11-2020协议学习__专题__TxTime-Calculation__HR/DSSS 16.2.2 PPDU format16.2.2.1 General16.2.2.2 Long PPDU format16.2.2.3 Short PPDU format 16.3.4 HR/DSSS TXTIME calculation PREV: TBD NEXT: TBD 16.2.2 PPDU format 16.2.2.1 General 定…

五分钟,Docker安装kafka 3.5,kafka-map图形化管理工具

首先确保已经安装docker,如果是windows安装docker,可参考 wsl2安装docker 1、安装zk docker run -d --restartalways -e ALLOW_ANONYMOUS_LOGINyes --log-driver json-file --log-opt max-size100m --log-opt max-file2 --name zookeeper -p 2181:218…

各类软件docker安装

docker Docker 要求 CentOS 系统的内核版本高于 3.10 ,通过 uname -r 命令查看你当前的内核版本: uname -r 3.10.0-1062.1.2.el7.x86_64 安装 Docker: 安装 Docker:yum -y install dockerkafka和zookeeper docker pull wurstmei…

【RH850芯片】RH850U2A芯片平台Spinlock的底层实现

目录 前言 正文 1.RH850U2A上的原子操作 1.1 Link 1.2 Link generation 1.3 Success in storing 1.4 Failure in storing 1.5 Condition for successful storing 1.6 Loss of the link 1.7 示例代码 2.Spinlock代码分析 2.1 尝试获取Spinlock 2.2 释放Spinlock …

Vue前端添加水印功能

文章目录 概要技术细节附上几张调整的结果图 概要 前端Vue在页面添加水印,且不影响页面其他功能使用,初级代码水准即可使用,且有防人修改或者删除功能! 提示:适用于Vue,组件已经封装开箱即用,有…

OpenHarmony应用开发入门教程(一、开篇)

前言 华为正式宣布2024年发布的华为鸿蒙OS Next版将不再兼容安卓系统。这一重大改变,预示着华为鸿蒙OS即将进入一个全新的阶段。 都说科技无国界,这是骗人的鬼话。谷歌的安卓12.0系统早已发布,但是自从受到美影响,谷歌就拒绝再向…

CAD长方形纤维插件2D

插件介绍 CAD长方形纤维插件2D版本可用于在AutoCAD软件内生成随机分布的长方形纤维图形,生成的dwg格式模型可用于模拟二维随机分布的纤维复合材料、随机初始裂缝等,同时模型可导入COMSOL、Abaqus、ANSYS、Fluent等有限元软件内进行仿真分析计算。 插件…

【算法萌新闯力扣】:找到所有数组中消失对数字

力扣热题:找到所有数组中消失对数字 开篇 这两天刚交了蓝桥杯的报名费,刷题的积极性高涨。算上打卡题,今天刷了10道算法题了,题目都比较简单,挑选了一道还不错的题目与大家分享。 题目链接:448.找到所有数组中消失对…

(二)Pytorch快速搭建神经网络模型实现气温预测回归(代码+详细注解)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、数据集二、导入数据以及展示部分1.导入数据集以及对数据集进行处理2.展示数据(看看就好) 三(1)、搭建网络进…

ubuntu 20.04安装 Anaconda教程

在安装Anaconda之前需要先安装ros(防止跟conda冲突,先装ros)。提前安装好cuda 和cudnn。 本博客参考:ubuntu20.04配置ros noetic和cuda,cudnn,anaconda,pytorch深度学习的环境 安装完conda后,输入: pyth…

Flink(六)【DataFrame 转换算子(下)】

前言 今天学习剩下的转换算子:分区、分流、合流。 每天出来自学是一件孤独又充实的事情,希望多年以后回望自己的大学生活,不会因为自己的懒惰与懈怠而悔恨。 回答之所以起到了作用,原因是他们自己很努力。 …

人工智能基础_机器学习036_多项式回归升维实战3_使用线性回归模型_对天猫双十一销量数据进行预测_拟合---人工智能工作笔记0076

首先我们拿到双十一从2009年到2018年的数据 可以看到上面是代码,我们自己去写一下 首先导包,和准备数据 from sklearn.linear_model import SGDRegressor import numpy as np import matplotlib.pyplot as plt X=np.arange(2009.2020)#左闭右开,2009到2019 获取从2009到202…

MIKE水动力笔记20_由dfs2网格文件提取dfs1断面序列文件

本文目录 前言Step 1 MIKE Zero工具箱Step 2 提取dfs1 前言 在MIKE中,dfs2是一个一个小格格的网格面的时间序列文件,dfs1是一条由多个点组成的线的时间序列文件。 如下两图: 本博文内容主要讲如何从dfs2网格文件中提取dfs1断面序列文件。 …

CI/CD -gitlab

目录 一、常用命令 二、部署 一、常用命令 官网:https://about.gitlab.com/install/ gitlab-ctl start # 启动所有 gitlab 组件 gitlab-ctl stop # 停止所有 gitlab 组件 gitlab-ctl restart # 重启所有 gitlab 组件 gitlab-ctl statu…

linux进程间通信之信号

摘要 本文旨在研究Linux进程间通信的机制之一:信号。信号是由操作系统来处理的,说明信号的处理在内核态。信号不一定会立即被处理,此时会储存在信 号的信号表中。最后,我们会对这种通信方式的优缺点进行全面的分析,并给…

C++ opencv基本用法【学习笔记(九)】

这篇博客为修改过后的转载,因为没有转载链接,所以选了原创 文章目录 一、vs code 结合Cmake debug1.1 配置tasks.json1.2 配置launch.json 二、图片、视频、摄像头读取显示2.1 读取图片并显示2.2 读取视频文件并显示2.3 读取摄像头并写入文件 三、图片基…

设计模式-行为型模式-责任链模式

一、什么是责任链模式 责任链模式是一种设计模式。在责任链模式里,很多对象由每一个对象对其下家的引用而连接起来形成一条链。请求在这个链上传递,直到链上的某一个对象决定处理此请求。发出这个请求的客户端并不知道链上的哪一个对象最终处理这个请求&…

Spring Security OAuth2.0 实现分布式系统的认证和授权

Spring Security OAuth2.0 实现分布式系统的认证和授权 1. 基本概念1.1 什么是认证?1.2 什么是会话?1.2.1 基于 session 的认证方式1.2.2 基于 token 的认证方式 1.3 什么是授权?1.3.1 授权的数据模型 1.4 RBAC 介绍 2. Spring Security2.1 S…

JPA整合Sqlite解决Dialect报错问题, 最新版Hibernate6

前言 我个人项目中,不想使用太重的数据库,而内嵌数据库中SQLite又是最受欢迎的, 因此决定采用这个数据库。 可是JPA并不支持Sqlite,这篇文章就是记录如何解决这个问题的。 原因 JPA屏蔽了底层的各个数据库差异, 但是…

【2023春李宏毅机器学习】生成式学习的两种策略

文章目录 1 各个击破2 一步到位3 两种策略的对比 生成式学习的两种策略:各个击破、一步到位 对于文本生成:把每一个生成的元素称为token,中文当中token指的是字,英文中的token指的是word piece。比如对于unbreakable,他…