【机器学习】线性回归模型

线性回归是机器学习中最基础的算法之一。它主要用于回归任务,即预测一个连续的数值输出。本文将从零开始,带领你构建线性回归模型,逐步推演损失函数、梯度下降、学习率等核心概念,并使用numpy实现。最后,我们会通过sklearn快速实现线性回归模型。

线性回归模型简介

线性回归模型的核心思想是用一个直线(或超平面)拟合一组数据,找到特征和目标变量之间的线性关系。其数学表达式为:

y = w ⋅ x + b y = w \cdot x + b y=wx+b
其中:

  • ( y ) 是预测值(输出),
  • ( w ) 是权重(或斜率),
  • ( x ) 是输入变量(特征),
  • ( b ) 是偏置(截距)。

目标是找到合适的 ( w ) 和 ( b ),使得模型的预测结果尽可能接近真实值。

损失函数

为了衡量模型的预测值和真实值之间的差距,我们使用损失函数。常见的损失函数是均方误差(MSE, Mean Squared Error),其公式如下:

M S E = 1 n ∑ i = 1 n ( y i − y i ^ ) 2 MSE = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y_i})^2 MSE=n1i=1n(yiyi^)2
其中:

  • ( y_i ) 是第 (i) 个样本的真实值,
  • ( \hat{y_i} ) 是模型的预测值,
  • ( n ) 是样本数。

损失函数越小,说明模型越准确。

梯度下降(Gradient Descent)

为了最小化损失函数,我们使用梯度下降算法。梯度下降的基本思想是从随机初始化的参数开始,逐步调整参数,使得损失函数逐渐变小,最终找到最优解。

梯度下降的更新规则
w = w − α ∂ L ∂ w b = b − α ∂ L ∂ b w = w - \alpha \frac{\partial L}{\partial w}\\ b = b - \alpha \frac{\partial L}{\partial b} w=wαwLb=bαbL
其中:

  • ( \alpha ) 是学习率(决定每次更新的步长),
  • ( \frac{\partial L}{\partial w} ) 是损失函数关于 ( w ) 的导数(梯度),
  • ( \frac{\partial L}{\partial b} ) 是损失函数关于 ( b ) 的导数。

学习率(Learning Rate)

学习率 ( \alpha ) 是梯度下降中的重要超参数。它决定了每次参数更新的步长。学习率过大,可能会错过最优解;学习率过小,训练过程会非常缓慢,甚至陷入局部最优解。

代码实现:从零开始构建线性回归模型

接下来,我们使用 numpy 从头实现一个线性回归模型。

数据准备

我们首先构造一组简单的线性数据,用来训练我们的模型。

import numpy as np
import matplotlib.pyplot as plt# 生成数据
np.random.seed(42)
X = 2 * np.random.rand(100, 1)  # 随机生成 100 个点,范围在 [0, 2]
y = 4 + 3 * X + np.random.randn(100, 1)  # y = 4 + 3x + 噪声# 可视化数据
plt.scatter(X, y)
plt.xlabel("X")
plt.ylabel("y")
plt.title("Generated Data")
plt.show()

损失函数实现

接下来,我们实现均方误差(MSE)损失函数。

def mse_loss(y_true, y_pred):return np.mean((y_true - y_pred) ** 2)

梯度计算

我们需要计算损失函数对 ( w ) 和 ( b ) 的偏导数:

def compute_gradients(X, y, w, b):n = len(y)y_pred = X.dot(w) + bdw = (2/n) * X.T.dot(y_pred - y)db = (2/n) * np.sum(y_pred - y)return dw, db

梯度下降算法

使用梯度下降算法更新参数 ( w ) 和 ( b ):

def gradient_descent(X, y, w, b, learning_rate, iterations):for i in range(iterations):dw, db = compute_gradients(X, y, w, b)w -= learning_rate * dwb -= learning_rate * dbif i % 100 == 0:y_pred = X.dot(w) + bloss = mse_loss(y, y_pred)print(f"Iteration {i}: Loss = {loss}")return w, b

模型训练

初始化参数并开始训练:

# 初始化参数
w = np.random.randn(1, 1)
b = np.random.randn(1)# 超参数设置
learning_rate = 0.1
iterations = 1000# 训练模型
w_trained, b_trained = gradient_descent(X, y, w, b, learning_rate, iterations)
print(f"Trained weights: {w_trained}, Trained bias: {b_trained}")

模型可视化

训练结束后,我们可以将拟合直线与原始数据进行对比:

# 绘制拟合直线
plt.scatter(X, y)
plt.plot(X, X.dot(w_trained) + b_trained, color='red')
plt.xlabel("X")
plt.ylabel("y")
plt.title("Linear Regression Fit")
plt.show()

使用 sklearn 实现线性回归

最后,我们使用 sklearn 库快速实现同样的线性回归模型。

from sklearn.linear_model import LinearRegression# 训练模型
lin_reg = LinearRegression()
lin_reg.fit(X, y)# 输出权重和偏置
print(f"Sklearn Trained weights: {lin_reg.coef_}, Sklearn Trained bias: {lin_reg.intercept_}")# 绘制拟合直线
plt.scatter(X, y)
plt.plot(X, lin_reg.predict(X), color='red')
plt.xlabel("X")
plt.ylabel("y")
plt.title("Linear Regression with Sklearn")
plt.show()

总结

在本教程中,我们通过 numpy 实现了线性回归模型,深入理解了损失函数、梯度下降和学习率等概念。最后,我们通过 sklearn 验证了结果。希望这篇文章能帮助你打下机器学习的基础,深入理解线性回归背后的原理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/460906.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Cyber​​Panel filemanager/upload 远程命令执行漏洞复现

0x01 产品简介 CyberPanel是一个开源的Web控制面板,它提供了一个用户友好的界面,用于管理网站、电子邮件、数据库、FTP账户等。CyberPanel旨在简化网站管理任务,使非技术用户也能轻松管理自己的在线资源。 0x02 漏洞概述 该漏洞源于filemanager/upload接口未做身份验证和…

(C#面向初学者的 .NET 的生成 AI) 第 2 部分-什么是 AI 和 ML?

从本部分开始Luis Quintanilla介绍AI和机器学习,需要学习的一些东西是什么是AI和ML?作为一名.net开发人员如何学习使用AI和ML。 1、首先什么是AI 和 ML? 你可以把它看作是基本相同事物的不同层次。 在顶层的是AI(人工智能&#xf…

Swarm-LIO: Decentralized Swarm LiDAR-inertial Odometry论文翻译

文章目录 前言一、介绍二、相关工作三、方法A. 问题表述B. 框架概述C. 群体系统的初始化D. 去中心化激光雷达-惯性状态估计 四. 实验A. 室内飞行B. 退化环境飞行C. 去中心化部署 五. 结论和未来工作 前言 原文:原文 准确的自我状态和相对状态估计是完成群体任务的关…

wsl 使用docker 部署oracle11g数据库

wsl 使用docker 部署oracle11g数据库 1. 下载oracle11g sudo docker pull registry.cn-hangzhou.aliyuncs.com/helowin/oracle_11g2. 运行oracle11g容器(docker-compose) services:oracle-1.0:container_name: oracle11gimage: oracle11g:1.0restart:…

IDEA集成JProfiler

目录 下载并安装JProfiler下载安装包管理员身份运行配置许可证邮箱复制注册码配置永久许可证选择IDE集成 在IDEA中下载并安装JProfiler插件启动并使用JProfiler进行性能分析启动Java应用程序:自动运行JProfiler 知识扩充功能 下载并安装JProfiler 下载安装包 官网…

Tomcat 和 Docker部署Java项目的区别

在 Java 项目部署中,Tomcat 和 Docker 是两种常见的选择。虽然它们都可以用来运行 Java 应用,但它们在定位、部署方式、依赖环境、资源隔离、扩展性和适用场景等方面有显著区别。 1. 功能定位 1.1 Tomcat Apache Tomcat 是一种轻量级的 Java 应用服务器…

AI-基本概念-多层感知器模型/CNN/RNN/自注意力模型

1 需求 神经网络 …… 深度学习 …… 深度学习包含哪些神经网络: 全连接神经网络卷积神经网络循环神经网络基于注意力机制的神经网络 2 接口 3 CNN 在这个示例中: 首先定义了一个简单的卷积神经网络SimpleCNN,它包含两个卷积层、两个池…

Leaflet查询矢量瓦片偏移的问题

1、问题现象 使用Leaflet绘制工具查询出来的结果有偏移 2、问题排查 1)Leaflet中latLngToContainerPoint和latLngToLayerPoint的区别 2)使用Leaflet查询需要使用像素坐标 3)经排查发现,container获取的坐标是地图容器坐标&…

Vue生成名片二维码带logo并支持下载

一、需求 生成一张名片,名片上有用户信息以及二维码,名片支持下载功能(背景样式可更换,忽略本文章样图样式)。 二、参考文章 这不是我自己找官网自己摸索出来的,是借鉴各位前辈的,学以致用&am…

如何利用网站进行仿牌运营?

对于很多人来说,仿牌网站的运营是一项充满挑战的任务,很多初学者对如何开始感到困惑,甚至不清楚仿牌网站的运作机制。此外,搜索引擎对新网站的审核期也使得许多站长倍感压力。那么,如何才能在这一过程中有效地进行SEO优…

数字IC开发:布局布线

数字IC开发:布局布线 前端经过DFT,综合后输出网表文件给后端,由后端通过布局布线,将网表转换为GDSII文件;网表文件只包含单元器件及其连接等信息,GDS文件则包含其物理位置,具体的走线&#xff1…

传智杯 第六届-复赛-C

题目描述: 小红有一个数组,她每次可以选择数组的一个元素 xxx ,将这个元素分成两个元素 aaa 和 bbb ,使得 abxabxabx。请问小红最少需要操作多少次才可以使得数组的所有元素都相等。 输入描述: 第一行输入一个整数 n(1≤n≤10^5)…

华为配置 之 GVRP协议

目录 简介: 配置GVRP: 总结: 简介: GVRP(GARP VLAN Registration Protocol),称为VLAN注册协议,是用来维护交换机中的VLAN动态注册信息,并传播该信息到其他交换机中&…

外包干了7天,技术明显退步。。。。。

先说一下自己的情况,本科生,22年通过校招进入南京某软件公司,干了接近2年的功能测试,今年年初,感觉自己不能够在这样下去了,长时间呆在一个舒适的环境会让一个人堕落!而我已经在一个企业干了2年的功能测试&…

openGauss开源数据库实战十

文章目录 任务十 openGauss逻辑结构:数据库管理任务目标实施步骤一、登录到openGauss二、创建数据库三、查看数据库集群中有哪些数据库四、查看数据库默认表空间的信息五、查看数据库下有哪些模式六、查看数据库下有哪些表七、修改数据库的默认表空间八、重命名数据库九、删除数…

H3C OSPF配置

OSPF配置实验 实验拓扑图 实验需求 1.配置IP地址 2.分区域配置OSPF&#xff0c;实现全网互通 3.为了路由结构稳定&#xff0c;要求路由器使用环回口作为Router-id&#xff0c;ABR的环回口宣告进骨干区域 实验配置 1.配置IP地址 R1&#xff1a; <H3C>system-view …

飞桨首创 FlashMask :加速大模型灵活注意力掩码计算,长序列训练的利器

在 Transformer 类大模型训练任务中&#xff0c;注意力掩码&#xff08;Attention Mask&#xff09;一方面带来了大量的冗余计算&#xff0c;另一方面因其 O ( N 2 ) O(N^2) O(N2)巨大的存储占用导致难以实现长序列场景的高效训练&#xff08;其中 N N N为序列长度&#xff09;…

乘云而上,OceanBase再越山峰

一座山峰都是一个挑战&#xff0c;每一次攀登都是一次超越。 商业数据库时代&#xff0c;面对国外数据库巨头这座大山&#xff0c;实现市场突破一直都是中国数据库产业多年夙愿&#xff0c;而OceanBase在金融核心系统等领域的攻坚克难&#xff0c;为产业突破交出一副令人信服的…

为什么要使用Golang以及如何入门

什么是golang&#xff1f; Go是一种开放源代码的编程语言&#xff0c;于2009年首次发布&#xff0c;由Google的Rob Pike&#xff0c;Robert Griesemer和Ken Thompson开发。基于C的语法&#xff0c;它进行了一些更改和改进&#xff0c;以安全地管理内存使用&#xff0c;管理对象…

《文心一言插件设计与开发》赛题三等奖方案 | NoteTable

一年一度的 CCF大数据与计算智能大赛&#xff08;简称2024 CCF BDCI大赛&#xff09;又开始啦~~ 程序员们可冲一波嗷~ 大赛地址&#xff1a;http://go.datafountain.cn/6506 现在我们再次释放往届获奖方案&#xff0c; 为新一届大赛的同学们提供一些方案和灵感参考~ 大家借鉴借…