【机器学习】多元线性回归

在实际应用中,许多问题都包含多个特征(输入变量),而不仅仅是单个输入变量。多元线性回归是线性回归的扩展,它能够处理多个输入特征并建立它们与目标变量的线性关系。本教程将系统性推演多元线性回归,包括向量化处理、特征放缩、梯度下降的收敛性和学习率选择等,并使用numpy实现。最后,我们会通过sklearn快速实现多元线性回归模型。

多元线性回归模型简介

多元线性回归的模型公式为:
y = X ⋅ w + b y = X \cdot w + b y=Xw+b
其中:

  • ( y ) 是预测值(输出),
  • ( X ) 是输入特征矩阵(每一行表示一个样本,每一列表示一个特征),
  • ( w ) 是权重向量,
  • ( b ) 是偏置项(通常看作一个常数项)。

模型的目标是找到最优的权重向量 ( w ) 和偏置 ( b ),使得预测值与真实值的差异最小。

向量化处理

在机器学习中,向量化是一种通过矩阵运算来加速模型训练的方式。我们将模型的多个样本和特征表示为矩阵形式,这样能够利用线性代数库(如 numpy)中的优化操作来加速计算。

多元线性回归的预测可以用向量化表示为:

Y pred = X ⋅ w + b Y_{\text{pred}} = X \cdot w + b Ypred=Xw+b
其中:

  • ( X ) 是 ( n \times m ) 的矩阵,表示 ( n ) 个样本的 ( m ) 个特征,
  • ( w ) 是 ( m \times 1 ) 的权重向量,
  • ( b ) 是常数偏置项。

损失函数

我们依然使用**均方误差(MSE, Mean Squared Error)**作为损失函数,用来衡量模型预测值与真实值之间的差异。其公式为:

M S E = 1 n ∑ i = 1 n ( y i − y i ^ ) 2 MSE = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y_i})^2 MSE=n1i=1n(yiyi^)2

特征放缩(Feature Scaling)

在多元线性回归中,特征的尺度对模型训练的影响较大。如果某些特征的值范围过大,会导致它们在梯度下降中主导权重更新,导致收敛速度变慢甚至无法收敛。因此,通常会对特征进行标准化或归一化。

  1. 归一化(Normalization):将特征缩放到[0, 1]区间,公式如下:

X ′ = X − X min X max − X min X' = \frac{X - X_{\text{min}}}{X_{\text{max}} - X_{\text{min}}} X=XmaxXminXXmin

  1. 标准化(Standardization):将特征的均值归零,标准差为1,公式如下:

X ′ = X − μ σ X' = \frac{X - \mu}{\sigma} X=σXμ

梯度下降的收敛性

在使用梯度下降优化模型参数时,梯度下降的收敛性取决于学习率的选择和损失函数的性质。如果学习率过大,梯度下降可能会在更新中超出最优值;学习率过小,收敛速度会非常慢。通常我们需要通过实验选择合适的学习率。

代码实现:多元线性回归模型

接下来,我们使用numpy从头实现多元线性回归模型。

数据准备

我们生成一个包含多个特征的数据集。

import numpy as np
import matplotlib.pyplot as plt# 生成多元数据
np.random.seed(42)
X = 2 * np.random.rand(100, 3)  # 生成 100 个样本,3 个特征
true_w = np.array([3, 4, 5])
y = X.dot(true_w) + 6 + np.random.randn(100)  # y = 3x1 + 4x2 + 5x3 + 6 + 噪声# 查看数据维度
print("X shape:", X.shape)
print("y shape:", y.shape)

损失函数

实现均方误差损失函数:

def mse_loss(y_true, y_pred):return np.mean((y_true - y_pred) ** 2)

梯度计算

实现梯度计算函数:

def compute_gradients(X, y, w, b):n = len(y)y_pred = X.dot(w) + bdw = (2/n) * X.T.dot(y_pred - y)db = (2/n) * np.sum(y_pred - y)return dw, db

梯度下降

我们定义梯度下降函数,更新权重和偏置:

def gradient_descent(X, y, w, b, learning_rate, iterations):for i in range(iterations):dw, db = compute_gradients(X, y, w, b)w -= learning_rate * dwb -= learning_rate * dbif i % 100 == 0:y_pred = X.dot(w) + bloss = mse_loss(y, y_pred)print(f"Iteration {i}: Loss = {loss}")return w, b

特征放缩

我们可以通过 StandardScaler 对特征进行标准化。

def standardize(X):mean = np.mean(X, axis=0)std = np.std(X, axis=0)X_scaled = (X - mean) / stdreturn X_scaled

模型训练

初始化参数并训练模型:

# 初始化参数
w = np.random.randn(3)
b = np.random.randn(1)# 特征标准化
X_scaled = standardize(X)# 超参数设置
learning_rate = 0.01
iterations = 1000# 训练模型
w_trained, b_trained = gradient_descent(X_scaled, y, w, b, learning_rate, iterations)
print(f"Trained weights: {w_trained}, Trained bias: {b_trained}")

可视化模型

对于多元回归,权重无法直接用图像展示,但可以展示损失值的收敛曲线:

# 绘制损失曲线
losses = []
for i in range(1000):dw, db = compute_gradients(X_scaled, y, w, b)w -= learning_rate * dwb -= learning_rate * dby_pred = X_scaled.dot(w) + bloss = mse_loss(y, y_pred)losses.append(loss)plt.plot(losses)
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.title("Loss Curve")
plt.show()

使用 sklearn 实现多元线性回归

最后,我们使用sklearn快速实现多元线性回归。

from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler# 特征标准化
scaler = StandardScaler()
X_scaled_sklearn = scaler.fit_transform(X)# 训练模型
lin_reg = LinearRegression()
lin_reg.fit(X_scaled_sklearn, y)# 输出权重和偏置
print(f"Sklearn Trained weights: {lin_reg.coef_}, Sklearn Trained bias: {lin_reg.intercept_}")

总结

在本教程中,我们深入推演了多元线性回归的基本原理,从向量化、特征放缩、梯度下降收敛性到学习率选择,并使用numpy实现了完整的多元线性回归模型。通过sklearn的实现,我们验证了结果并加速了训练流程。希望这篇教程能帮助你进一步理解多元线性回归模型的核心概念。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/458594.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[项目详解][boost搜索引擎#2] 建立index | 安装分词工具cppjieba | 实现倒排索引

目录 编写建立索引的模块 Index 1. 设计节点 2.基本结构 3.(难点) 构建索引 1. 构建正排索引(BuildForwardIndex) 2.❗构建倒排索引 3.1 cppjieba分词工具的安装和使用 3.2 引入cppjieba到项目中 倒排索引代码 本篇文章,我们将继续项…

C++《vector的模拟实现》

在之前《vector》章节当中我们学习了STL当中的vector基本的使用方法,了解了vector当中各个函数该如何使用,在学习当中我们发现了vector许多函数的使用是和我们之前学习过的string类的,但同时也发现vector当中一些函数以及接口是和string不同的…

在Postgresql中对空间数据进行表分区的实践

在数据库管理中,合理地对数据进行分区可以提高查询性能和数据管理效率。 本文将详细介绍在Postgresql中对空间数据进行表分区的实践过程。 测试计算机容量有限,测试最大数据量为1,000,000条。 关键字: Postgresql PostGIS 表分区 空间数据 测试计算…

Easy Excel合并单元格情况简单导入导出

需求 实现报表数据的导入导出&#xff0c;表格中部分数据是系统生成&#xff0c;部分数据是甲方填写&#xff0c;录入系统。 批号唯一 Maven <dependency><groupId>com.alibaba</groupId><artifactId>easyexcel</artifactId><version>3.…

【modbus协议】libmodbus库移植基于linux平台

文章目录 下载库函数源码编译路径添加libmodbus 源码分析核心数据结构常用接口函数 开发 TCP Server 端开发TCP Client 端 下载库函数源码 编译路径添加 libmodbus 源码分析 核心数据结构 modbus_t结构体&#xff1a; 这是 libmodbus 的核心数据结构&#xff0c;代表一个 Mod…

机房巡检机器人有哪些功能和作用

随着数据量的爆炸式增长和业务的不断拓展&#xff0c;数据中心面临诸多挑战。一方面&#xff0c;设备数量庞大且复杂&#xff0c;数据中心内服务器、存储设备、网络设备等遍布&#xff0c;这些设备需时刻保持良好运行状态&#xff0c;因为任何一个环节出现问题都可能带来严重后…

从0到1学习node.js(express模块)

文章目录 Express框架1、初体验express2、什么是路由3、路由的使用3、获取请求参数4、电商项目商品详情场景配置路由占位符规则5、小练习&#xff0c;根据id参数返回对应歌手信息6、express和原生http模块设置响应体的一些方法7、其他响应设置8、express中间件8.1、什么是中间件…

如何搭建直播美颜SDK平台的最佳实践?美颜API的实现与集成详解

本篇文章&#xff0c;将从技术实现、平台搭建、API集成以及性能优化四个方面&#xff0c;为开发者详解如何搭建一个直播美颜SDK平台。 一、直播美颜SDK平台的技术架构 一般的美颜效果包括磨皮、亮肤、瘦脸、大眼等&#xff0c;这些效果的实现需要依赖图像增强和滤镜算法。核心…

【51单片机】第一个小程序 —— 点亮LED灯

学习使用的开发板&#xff1a;STC89C52RC/LE52RC 编程软件&#xff1a;Keil5 烧录软件&#xff1a;stc-isp 开发板实图&#xff1a; 文章目录 单片机介绍LED灯介绍练习创建第一个项目点亮LED灯LED周期闪烁 单片机介绍 单片机&#xff0c;英文Micro Controller Unit&#xff0…

创建ODBC数据源SQLConfigDataSource函数的用法

网络上没有这个函数能实际落地的用法说明&#xff0c;我实践后整理一下&#xff1a; 1.头文件与额外依赖库&#xff1a; #include <odbcinst.h> #pragma comment(lib, "legacy_stdio_definitions.lib") 2.调用函数&#xff1a; if (!SQLConfigDataSourceW(…

阿里云镜像源无法访问?使用 DaoCloud 镜像源加速 Docker 下载(Linux 和 Windows 配置指南)

&#x1f680; 作者主页&#xff1a; 有来技术 &#x1f525; 开源项目&#xff1a; youlai-mall &#x1f343; vue3-element-admin &#x1f343; youlai-boot &#x1f343; vue-uniapp-template &#x1f33a; 仓库主页&#xff1a; GitCode&#x1f4ab; Gitee &#x1f…

java :String 类

在我们之前的讲解中我们已经了解了很多的Java知识&#xff0c;这节我们讲Java中字符如何定义以及关于String如何使用还有常见的string函数。 【本节目标】 1. 认识 String 类 2. 了解 String 类的基本用法 3. 熟练掌握 String 类的常见操作 4. 认识字符串常量池 5. 认识 …

江协科技STM32学习- P21 ADC模数转换器

&#x1f680;write in front&#x1f680; &#x1f50e;大家好&#xff0c;我是黄桃罐头&#xff0c;希望你看完之后&#xff0c;能对你有所帮助&#xff0c;不足请指正&#xff01;共同学习交流 &#x1f381;欢迎各位→点赞&#x1f44d; 收藏⭐️ 留言&#x1f4dd;​…

基于SpringCloud的WMS管理系统源码

商品管理&#xff1a;商品类型&#xff0c;规格&#xff0c;详情等设置。 采购管理&#xff1a;采购单录入。 销售管理&#xff1a;销售单录入。 库存管理&#xff1a;库存查询、库存日志 采用前后端分离的模式&#xff0c;微服务版本前端 后端采用Spring Boot、Spring Cl…

python实现放烟花效果庆祝元旦

马上就要2025年元旦啦&#xff0c;提前祝大家新年快乐 完整代码下载地址&#xff1a;https://download.csdn.net/download/ture_mydream/89926458

vLLM推理部署Qwen2.5

vLLM vLLM 是一个用于大模型推理的高效框架。它旨在提供高性能、低延迟的推理服务&#xff0c;并支持多种硬件加速器&#xff0c;如 GPU 和 CPU。 vLLM 适用于大批量Prompt输入&#xff0c;并对推理速度要求高的场景&#xff0c;吞吐量比HuggingFace Transformers高10多倍。 …

手指关节分割系统:视觉算法突破

手指关节分割系统源码&#xff06;数据集分享 [yolov8-seg-C2f-RFAConv&#xff06;yolov8-seg-fasternet-bifpn等50全套改进创新点发刊_一键训练教程_Web前端展示] 1.研究背景与意义 项目参考ILSVRC ImageNet Large Scale Visual Recognition Challenge 项目来源AAAI Glob…

灵动AI:艺术与科技的融合

灵动AI视频官网地址&#xff1a;https://aigc.genceai.com/ 灵动AI 科技与艺术的完美融合之作。它代表着当下最前沿的影像技术&#xff0c;为我们带来前所未有的视觉盛宴。 AI 视频以强大的人工智能算法为基石&#xff0c;能够自动分析和理解各种场景与主题。无论是壮丽的自然…

网络学习/复习2套接字

LinuxCode/code26 zc/C语言程序学习 - 码云 - 开源中国

c语言中整数在内存中的存储

整数的二进制表示有三种&#xff1a;原码&#xff0c;反码&#xff0c;补码 有符号的整数&#xff0c;三种表示方法均有符号位和数值位两部分&#xff0c;符号位都是用‘0’表示“正&#xff0c;用1表示‘负’ 最高位的以为被当作符号位&#xff0c;剩余的都是数值位。 整数…