深入浅出机器学习中的梯度下降算法

大家好,在机器学习中,梯度下降算法(Gradient Descent)是一个重要的概念。它是一种优化算法,用于最小化目标函数,通常是损失函数。梯度下降可以帮助找到一个模型最优的参数,使得模型的预测更加准确,本文将介绍梯度下降算法的原理、公式以及在Python中实现这一算法。

1. 梯度下降算法的理论基础

在数学中,梯度是一个向量,表示函数在某一点的变化率和方向。在多维空间中,梯度指向函数上升最快的方向。

图片

可以通过梯度来找到函数的最小值或最大值,对于损失函数关注的是最小值。

梯度下降的核心思想是通过不断调整参数,沿着损失函数的梯度方向移动,从而逐步逼近最小值。具体步骤如下:

(1) 初始化参数:随机选择参数的初始值。

(2) 计算梯度:计算损失函数对每个参数的梯度。

(3) 更新参数:根据梯度信息调整参数,更新规则为:

其中:\theta是要优化的参数;\alpha是学习率(step size),决定每次更新的幅度;\triangledown J(\theta )是损失函数关于参数的梯度。

(4) 重复步骤:重复计算梯度和更新参数,直到收敛(即损失函数的变化非常小)。

假设我们有一个简单的线性回归问题,目标是最小化均方误差(MSE)损失函数: 

其中\hat{Y}_{i} = \theta _{0} + \theta _{1}X_{i}是模型的预测值。为了使用梯度下降,我们需要计算损失函数关于参数的梯度: 

通过求导,可以得到梯度表达式,并利用它来更新参数。

2. Python 实现梯度下降算法

接下来将通过一个简单的线性回归示例来实现梯度下降算法,以下是实现代码:

import numpy as np
import matplotlib.pyplot as plt

生成一些随机数据来模拟房屋面积与房价之间的线性关系:

# 生成数据
np.random.seed(0)# 生成自变量 X(房屋面积),范围从50到200平方米
X = 50 + 150 * np.random.rand(100)  # 生成从50到200的100个点# 生成因变量 Y(房价),假设房价与房屋面积的关系
Y = 300000 + 2000 * X + np.random.randn(100) * 20000  # 线性关系加上噪声,价格范围在30万到50万之间# 绘制生成的散点图
plt.scatter(X, Y, color='blue', alpha=0.5)
plt.title('房屋面积与房价的关系')
plt.xlabel('房屋面积 (平方米)')
plt.ylabel('房价 (人民币)')
plt.grid()
plt.show()

图片

实现梯度下降算法的核心部分:

# 将数据标准化,帮助梯度下降更快收敛
X = (X - np.mean(X)) / np.std(X)
Y = (Y - np.mean(Y)) / np.std(Y)# 梯度下降参数
alpha = 0.01  # 学习率
num_iterations = 1000  # 迭代次数
m = len(Y)  # 样本数量# 初始化参数
theta_0 = 0  # 截距
theta_1 = 0  # 斜率# 存储损失值
losses = []# 梯度下降算法实现
for i in range(num_iterations):# 计算预测值Y_pred = theta_0 + theta_1 * X# 计算损失函数 (MSE)loss = (1/m) * np.sum((Y - Y_pred) ** 2)losses.append(loss)# 计算梯度gradient_0 = -(2/m) * np.sum(Y - Y_pred)  # 截距的梯度gradient_1 = -(2/m) * np.sum((Y - Y_pred) * X)  # 斜率的梯度# 更新参数theta_0 -= alpha * gradient_0theta_1 -= alpha * gradient_1print(f'截距 (θ0): {theta_0:.4f}, 斜率 (θ1): {theta_1:.4f}')

截距 (θ0): 0.0000, 斜率 (θ1): 0.9743

通过绘制损失函数随迭代次数变化的曲线,观察梯度下降的收敛过程。

# 绘制损失函数变化曲线
plt.figure()
plt.plot(range(num_iterations), losses, color='blue')
plt.title('损失函数随迭代次数的变化')
plt.xlabel('迭代次数')
plt.ylabel('损失值 (MSE)')
plt.grid()
plt.show()

最后,我们可以将训练好的回归线可视化,以观察模型的效果。​​​​​​​​​​​​​​

# 可视化回归线
plt.figure()
plt.scatter(X, Y, color='blue', alpha=0.5)
plt.plot(X, theta_0 + theta_1 * X, color='red', linewidth=2)
plt.title('梯度下降后的线性回归拟合')
plt.xlabel('房屋面积 (标准化)')
plt.ylabel('房价 (标准化)')
plt.grid()plt.tight_layout()  # 调整子图间距
plt.show()

图片

梯度下降算法在许多机器学习算法中得到了广泛应用,比如线性回归、逻辑回归、神经网络等,可以用于分类问题,通过优化对数损失函数,也可以用于深度学习,反向传播算法依赖于梯度下降来更新权重。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/482099.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

树莓派5+文心一言 -> 智能音箱

一、简介 效果:运行起来后,可以连续对话 硬件:树莓派5、麦克风、音箱,成本500-1000 软件:snowboy作为唤醒词、百度语音作为语音识别、brain作为指令匹配、百度文心一言作为对话模块、微软的edge-tts语音合成... 二…

Springboot——SseEmitter流式输出

文章目录 前言SseEmitter 简介测试demo注意点异常一 ResponseBodyEmitter is already set complete 前言 最近做AI类的开发,看到各大AI模型的输出方式都是采取的一种EventStream的方式实现。 不是通常的等接口处理完成后,一次性返回。 而是片段式的处理…

5G学习笔记之随机接入

目录 1. 概述 2. MSG1 2.1 选择SSB 2.2 选择Preamble Index 2.3 选择发送Preamble的时频资源 2.4 确定RA-RNTI 2.5 确定发送功率 3. MSG2 4. MSG3 5. MSG4 6. 其它 6.1 切换中的随机接入 6.2 SI请求的随机接入 6.3 通过PDCCH order重新建立同步 1. 概述 随机接入…

【Linux-多线程】重谈地址空间+内存管理方式

一、背景知识 a.重谈地址空间 我们之前已经说过,CPU内部见的地址,以及我们打印出来的地址都是虚拟地址;物理内存加载到CPU,CPU内执行进程创建内核数据结构,页表等,通过页表映射到物理磁盘上;也…

Spark Optimization —— Reducing Shuffle

Spark Optimization : Reducing Shuffle “Shuffling is the only thing which Nature cannot undo.” — Arthur Eddington Shuffle Shuffle Shuffle I used to see people playing cards and using the word “Shuffle” even before I knew how to play it. Shuffling in c…

Elasticsearch——Java API 操作

Elasticsearch 软件是由Java语言开发的,所以也可以通过JavaAPI的方式对 Elasticsearch服务进行访问。 创建 Maven 项目 我们在 IDEA 开发工具中创建 Maven 项目(模块也可)ES。并修改pom文件&#xff0c;增加Maven依赖关系。 #直接复制在pom文件的<dependencies></de…

量化的8位LLM训练和推理使用bitsandbytes在AMD GPUs上

Quantized 8-bit LLM training and inference using bitsandbytes on AMD GPUs — ROCm Blogs 在这篇博客文章中&#xff0c;我们将介绍bitsandbytes的8位表示方式。正如你将看到的&#xff0c;bitsandbytes的8位表示方式显著地减少了微调和推理大语言模型&#xff08;LLMs&…

自回归(Autoregressive)模型概述

自回归&#xff08;Autoregressive&#xff09;模型概述 自回归&#xff08;Autoregressive&#xff0c;简称AR&#xff09;模型是一类基于“历史数据”来预测未来数据的模型。其核心思想是模型的输出不仅依赖于当前输入&#xff0c;还依赖于先前的输出。自回归模型通常用于时…

Win11电脑亮度无法调节以及夜间模式点击没有用失效解决方法

一、问题 最近&#xff0c;突然感觉屏幕亮度十分刺眼&#xff0c;想调整为夜间模式&#xff0c;发现点了夜间模式根本没用&#xff0c;亮度也是变成了灰色。 明明前几天还能调节的&#xff0c;这实在是太难受了&#xff01; 二、原因 这是远程控制软件向日葵的问题 在向日葵…

Linux笔记---进程:进程终止

1. 进程终止概念与分类 进程终止是指一个正在运行的进程结束其执行的操作。以下是一些常见的导致进程终止的情况&#xff1a; 一、正常终止 完成任务当进程完成了它被设计要执行的任务后&#xff0c;就会正常终止。收到特定信号在操作系统中&#xff0c;进程可能会收到来自操作…

【工具推荐】dnsx——一个快速、多用途的 DNS 查询工具

basic/基本使用方式 echo baidu.com | dnsx -recon # 查询域名所有记录echo baidu.com | dnsx -a -resp # 查询域名的a记录echo baidu.com | dnsx -txt -resp # 查询域名的TXT记录echo ip | dnsx -ptr -resp # ip反查域名 A记录查询 TXT记录查询 ip反查域名 help/帮助信息 输…

【树莓派5】移动热点获取树莓派IP并初次登录SSH

本篇文章包含的内容 1 打开系统热点2 烧录系统设置3 配置 MobaXterm4 初次启动树莓派配置选项4.1 换源4.2 更新软件包4.3 安装vim编辑器4.4 更改CPU FAN温度转速 Windows版本&#xff1a;Windows11 24H2树莓派&#xff1a;树莓派5&#xff0c;Raspberry Pi 5SSH软件&#xff1a…

【Git系列】Git 提交历史分析:深入理解`git log`命令

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

第144场双周赛:移除石头游戏、两个字符串得切换距离、零数组变换 Ⅲ、最多可收集的水果数目

Q1、[简单] 移除石头游戏 1、题目描述 Alice 和 Bob 在玩一个游戏&#xff0c;他们俩轮流从一堆石头中移除石头&#xff0c;Alice 先进行操作。 Alice 在第一次操作中移除 恰好 10 个石头。接下来的每次操作中&#xff0c;每位玩家移除的石头数 恰好 为另一位玩家上一次操作…

Python parsel库学习总结

parsel库是Python中用于解析HTML文件的库&#xff0c;其能通过CSS选择器、xpath、正则表达式来定位html中的元素。 通过css选择器定位元素 from parsel import Selectorhtml """ <html><head><a class"option1">这是一个伪html片…

【HarmonyOS学习日志(11)】计算机网络之概念,组成和功能

文章目录 计算机网络概念计算机网络&#xff0c;互连网与互联网的区别计算机网络互连网互联网&#xff08;因特网&#xff0c;Internet&#xff09; 计算机网络的组成和功能计算机网络的组成从组成部分看从工作方式看从逻辑功能看 计算机网络的功能数据通信资源共享分布式处理提…

Vue3 开源UI 框架推荐 (大全)

一 、前言 &#x1f4a5;这篇文章主要推荐了支持 Vue3 的开源 UI 框架&#xff0c;包括 web 端和移动端的多个框架&#xff0c;如 Element-Plus、Ant Design Vue 等 web 端框架&#xff0c;以及 Vant、NutUI 等移动端框架&#xff0c;并分别介绍了它们的特性和资源地址。&#…

视觉语言动作模型VLA的持续升级:从π0之参考基线Octo到OpenVLA、TinyVLA、DeeR-VLA、3D-VLA

第一部分 VLA模型π0之参考基线Octo 1.1 Octo的提出背景与其整体架构 1.1.1 Octo的提出背景与相关工作 许多研究使用从机器人收集的大量轨迹数据集来训练策略 从早期使用自主数据收集来扩展策略训练的工作[71,48,41,19-Robonet,27,30]到最近探索将现代基于transformer的策略…

k8s--pod创建、销毁流程

文章目录 一、pod创建流程二、pod销毁流程 一、pod创建流程 1、用户通过kubectl或其他api客户端提交pod spec给apiserver,然后会进行认证、鉴权、变更、校验等一系列过程2、apiserver将pod对象的相关信息最终存入etcd中,待写入操作执行完成,apiserver会返回确认信息给客户端3、…

相同的二叉树

给你两棵二叉树的根节点 p 和 q &#xff0c;编写一个函数来检验这两棵树是否相同。 如果两个树在结构上相同&#xff0c;并且节点具有相同的值&#xff0c;则认为它们是相同的。 示例 1&#xff1a; 输入&#xff1a;p [1,2,3], q [1,2,3] 输出&#xff1a;true示例 2&…