机器学习的起点:线性回归Linear Regression

机器学习的起点:线性回归Linear Regression

作为机器学习的起点,线性回归是理解算法逻辑的绝佳入口。我们从定义、评估方法、应用场景到局限性,用生活化的案例和数学直觉为你构建知识框架。

回归算法


一、线性回归的定义与核心原理

定义:线性回归是一种通过线性方程(如Y = AX + B)建立自变量(X)与因变量(Y)关系的统计方法。

  • 数学表达:简单线性回归的公式为
    $y = \beta_0 + \beta_1 x + \epsilon $
    其中:
    • $\beta_0 $是截距(当X=0时Y的起点)
    • $\beta_1 $是斜率(X每增加1单位,Y的平均变化量)
    • ϵ \epsilon ϵ是误差项(代表无法用线性关系解释的随机波动)

直观理解:想象在散点图上找一条“最佳直线”,让所有数据点到这条直线的垂直距离之和最小。这条直线代表了变量间的整体趋势,例如身高与体重的关系、学习时间与考试成绩的关联等。


二、如何评估模型训练效果?

线性回归的目标是找到最优的$\beta_0 和 和 \beta_1$,使得预测值与真实值的误差最小

常用的评估指标是平方残差和(Residual Sum of Squares, RSS)或均方误差(MSE):
RSS = ∑ i = 1 n ( y i − y ^ i ) 2 \text{RSS} = \sum_{i=1}^n (y_i - \hat{y}_i)^2 RSS=i=1n(yiy^i)2

MSE = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 \text{MSE} = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2 MSE=n1i=1n(yiy^i)2

为什么用平方?

  1. 避免正负误差相互抵消(如+3和-3的总误差为0,但平方和为18)。
  2. 对大误差更敏感,模型会更努力减少明显偏离的点(但这也意味着对异常值敏感)。

延伸指标

  • R²(决定系数):衡量模型解释数据变动的能力,范围0~1,越接近1拟合越好。
    例如,若R²=0.8,说明模型能解释80%的Y值变化。

三、线性回归的应用场景

线性回归的预测能力在多个领域大放异彩,以下是典型应用案例:

  1. 排队问题

    • 场景:预测新加入者的排队位置。
    • 逻辑:假设排队时间与人数呈线性关系,若当前排队人数为X,每人的平均处理时间为斜率β₁,则可预测新人的等待时间Y。
  2. 身高与体重的关系

    • 场景:已知某人身高(X),预测其体重(Y)。
    • 模型:通过大量样本数据拟合Y = β₀ + β₁X,例如β₁=0.7表示身高每增加1cm,体重平均增加0.7kg。
  3. 人口与GDP预测

    • 场景:分析某地区人口(X)与经济规模(Y)的关系。
    • 扩展:若数据包含多个变量(如教育水平、资源储量),可升级为多元线性回归:
      $ Y = \beta_0 + \beta_1X_1 + \beta_2X_2 + \cdots + \beta_nX_n $
      例如,GDP可能同时与人口、工业产值相关。

四、线性回归的局限性

尽管简单高效,线性回归并非万能,其核心问题包括:

  1. 无法捕捉非线性关系

    • 例子:若数据呈现抛物线分布(如温度与冰淇淋销量的关系),强行用直线拟合会导致预测偏差。
    • 解决方案:引入多项式回归(如\ $Y = \beta_0 + \beta_1X + \beta_2X^2 $)或使用决策树等非线性模型。
  2. 对异常值敏感

    • 例子:若数据中存在极端值(如身高2.5米的人),会显著拉偏拟合直线的斜率。
    • 解决方案:清洗数据或使用鲁棒回归(如Huber损失函数)。
  3. 多重共线性问题

    • 场景:在多元回归中,若自变量高度相关(如“房间数”和“房屋面积”),会导致模型参数不稳定。
    • 解决方案:剔除冗余变量或使用正则化技术(如岭回归)。

五、案例

通过 scikit-learn 练习线性回归是一个绝佳的实践方式!

以下是分步指南,涵盖数据生成、模型训练、评估、可视化以及应对局限性的解决方案:

1. 环境准备

确保安装以下库:

pip install numpy matplotlib scikit-learn

2. 基础线性回归(单变量)

生成模拟数据
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score# 生成数据:y = 2x + 5 + 噪声
np.random.seed(42)
X = np.random.rand(100, 1) * 10  # 生成0~10之间的100个随机数
y = 2 * X + 5 + np.random.randn(100, 1) * 2  # 添加高斯噪声# 可视化数据
plt.scatter(X, y, alpha=0.7, label='真实数据')
plt.xlabel("X (自变量)")
plt.ylabel("y (因变量)")
plt.title("线性回归示例数据")
plt.show()
训练模型并预测
# 创建模型并训练
model = LinearRegression()
model.fit(X, y)# 查看参数
print(f"截距 (β0): {model.intercept_[0]:.2f}")
print(f"斜率 (β1): {model.coef_[0][0]:.2f}")# 预测新数据
X_new = np.array([[2.5], [7.0]])  # 预测X=2.5和X=7.0时的y值
y_pred = model.predict(X_new)
print(f"预测值: {y_pred.flatten()}")
评估与可视化
# 计算评估指标
y_pred_all = model.predict(X)
mse = mean_squared_error(y, y_pred_all)
r2 = r2_score(y, y_pred_all)
print(f"MSE: {mse:.2f}, R²: {r2:.2f}")# 可视化拟合直线
plt.scatter(X, y, alpha=0.7, label='真实数据')
plt.plot(X, y_pred_all, color='red', label='拟合直线')
plt.xlabel("X")
plt.ylabel("y")
plt.legend()
plt.title("线性回归拟合结果")
plt.show()
输出结果:
线性回归示例数据:

在这里插入图片描述

线性回归拟合结果

在这里插入图片描述

控制台输出结果

截距 (β0): 5.43
斜率 (β1): 1.91
预测值: [10.2003057 18.7865098]
MSE: 3.23, R²: 0.91

3. 多元线性回归(多变量)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split# 加载数据集(这里以加州房价为例)
data = fetch_california_housing()
X = data.data  # 多个特征(如收入、房龄等)
y = data.target# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 训练模型
model = LinearRegression()
model.fit(X_train, y_train)# 评估测试集
y_pred = model.predict(X_test)
print(f"MSE: {mean_squared_error(y_test, y_pred):.2f}")
print(f"R²: {r2_score(y_test, y_pred):.2f}")# 查看特征重要性(系数)
feature_names = data.feature_names
for name, coef in zip(feature_names, model.coef_):print(f"{name}: {coef:.2f}")

4. 应对线性回归的局限性

问题1:非线性关系 → 多项式回归
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号from sklearn.preprocessing import PolynomialFeatures# 生成非线性数据(抛物线)
np.random.seed(42)
X = np.linspace(-3, 3, 100).reshape(-1, 1)
y = 0.5 * X**2 + X + 2 + np.random.randn(100, 1) * 0.5# 将特征转换为多项式(如X²)
poly = PolynomialFeatures(degree=2, include_bias=False)
X_poly = poly.fit_transform(X)# 训练多项式回归模型
model = LinearRegression()
model.fit(X_poly, y)# 预测并可视化
X_test = np.linspace(-3, 3, 100).reshape(-1, 1)
X_test_poly = poly.transform(X_test)
y_pred = model.predict(X_test_poly)plt.scatter(X, y, alpha=0.7, label='真实数据')
plt.plot(X_test, y_pred, color='red', label='多项式回归')
plt.legend()
plt.show()
输出结果:

在这里插入图片描述

问题2:异常值 → 鲁棒回归(Huber损失)
# 异常值 → 鲁棒回归(Huber损失)
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号from sklearn.linear_model import HuberRegressor# 生成带异常值的数据
X = np.random.rand(50, 1) * 10
y = 3 * X + 5 + np.random.randn(50, 1) * 2
y[45:] += 30  # 添加异常值# 对比普通线性回归和Huber回归
model_linear = LinearRegression()
model_linear.fit(X, y)model_huber = HuberRegressor()
model_huber.fit(X, y.ravel())  # HuberRegressor需要y为一维数组# 可视化对比
X_test = np.linspace(0, 10, 100).reshape(-1, 1)
y_pred_linear = model_linear.predict(X_test)
y_pred_huber = model_huber.predict(X_test)plt.scatter(X, y, alpha=0.7, label='数据(含异常值)')
plt.plot(X_test, y_pred_linear, color='red', label='普通线性回归')
plt.plot(X_test, y_pred_huber, color='green', label='Huber回归')
plt.legend()
plt.show()
输出结果:

在这里插入图片描述

问题3:多重共线性 → 岭回归
from sklearn.linear_model import Ridge
from sklearn.preprocessing import StandardScaler# 生成高共线性数据(两个强相关特征)
np.random.seed(42)
X1 = np.random.rand(100, 1) * 10
X2 = X1 + np.random.randn(100, 1) * 0.1  # X2与X1高度相关
X = np.hstack([X1, X2])
y = 2 * X1 + 3 * X2 + 5 + np.random.randn(100, 1) * 2# 标准化数据(岭回归对尺度敏感)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)# 对比普通线性回归和岭回归
model_linear = LinearRegression()
model_linear.fit(X_scaled, y)model_ridge = Ridge(alpha=10)  # alpha是正则化强度
model_ridge.fit(X_scaled, y)print("普通线性回归系数:", model_linear.coef_)
print("岭回归系数:", model_ridge.coef_)
输出结果:

在这里插入图片描述
参考文献视频:点击跳转

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/25145.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

DeepSeek 提示词:常见指令类型

🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编…

查询NFT图片地址

前言 有人给我发了nft,但是没有图片,我就很纳闷为什么,所以想一探究竟 解决思路 先说下环境吧 Sepolia 测试网 metamask钱包 需要获取nft的合约地址和token id 钱包内 nft可以查得到 思路: 我的理解就是ERC721有标准的…

一个滑块可变色的Seekbar

因项目需要,做一个如下图的滑动条,要求如下: 1、滑块跟着进度条改变颜色 2、滑块有白色边和内部颜色组成 大体思路,就是背景需要UI按照需求提供,然后变色时,根据滑动回调动态设置对应的颜色。 直接上代码…

重大更新!锂电池剩余寿命预测新增 CALCE 数据集

往期精彩内容: 单步预测-风速预测模型代码全家桶-CSDN博客 半天入门!锂电池剩余寿命预测(Python)-CSDN博客 超强预测模型:二次分解-组合预测-CSDN博客 VMD CEEMDAN 二次分解,BiLSTM-Attention预测模型…

实时时钟(RTC)/日历芯片PCF8563的I2C读写驱动(2):功能介绍

0 参考资料 PCF8563数据手册(第 11 版——2015 年 10 月 26 日).pdf 1 功能介绍 1.1 实时时钟(RTC)/日历 (1)PCF8563支持实时时钟(RTC),提供时、分、秒信息。对应寄存器…

Xcode如何高效的一键重命名某个关键字

1.选中某个需要修改的关键字; 2.右击,选择Refactor->Rename… 然后就会出现如下界面: 此时就可以一键重命名了。 还可以设置快捷键。 1.打开Settings 2.找到Key Bindings 3.搜索rename 4.出现三个,点击一个地方设置后其…

Grok 3 的崛起:AI 的新时代

AI 领域再次震动,一款全新的深度思考大型语言模型正式亮相。它不仅碾压了现有的各项基准测试,还成功登顶 LM Marina 排行榜,夺得第一名。这款 AI 不是别人,正是埃隆马斯克那款“基于事实、敢言无忌”的 Grok 3——一个号称既极为聪…

ros安装rqt_joint_trajectory_controller

有时候&#xff0c;我们可以看到别人的代码里面有这个&#xff0c;但是这个是需要安装的。 <node name"gui_controller" pkg"rqt_joint_trajectory_controller" type"rqt_joint_trajectory_controller" />sudo apt-get install ros-noeti…

ARM Linux LCD上实时预览摄像头画面

文章目录 1、前言2、环境介绍3、步骤4、应用程序编写4.1、lcd初始化4.2、摄像头初始化4.3、jpeg解码4.4、开启摄像头4.5、完整的程序如下 5、测试5.1、编译应用程序5.2、运行应用程序 6、总结 1、前言 本次应用程序主要针对支持MJPEG格式输出的UVC摄像头。 2、环境介绍 rk35…

是德科技keysight N5173B信号发生器,是一款经济高效的仪器

是德科技keysight N5173B信号发生器安捷伦N5173B信号源 是德N5173B微波模拟信号发生器&#xff0c;拥有 9 kHz 至 40 GHz 的频率覆盖范围&#xff0c;N5173B为宽带滤波器、放大器、接收机等器件的参数测试提供了必要的信号&#xff0c;是一款经济高效的仪器。 N5173B特点&…

【Redis】在Java中以及Spring环境下操作Redis

Java环境下&#xff1a; 1.创建maven 项目 2.导入依赖 <!-- redis --><dependency><groupId>redis.clients</groupId><artifactId>jedis</artifactId><version>4.3.2</version></dependency> 此处使用的是Jedis&…

registry 容器镜像测试

registry 封装容器部署环境测试 封装打包镜像 dockerfile # 阶段 1&#xff1a;构建阶段&#xff08;使用多阶段构建以减少最终镜像大小&#xff09; FROM golang:1.22-alpine AS builder # 安装构建所需工具 RUN #apk add --no-cache git # 设置工作目录 WORKDIR /app # 将…

Python视频网站(Django框架)

有需要请加文章底部Q哦 可远程调试 Python视频网站(Django框架) 一 介绍 此Python视频网站基于Django框架开发&#xff0c;数据库mysql&#xff0c;前端jquery.js。系统角色分为用户和管理员。 技术栈:Python3(Django框架)MySQLjquery.jsPyCharmnavicat 二 功能 用户 1 注册…

多元数据直观表示(R语言)

一、实验目的&#xff1a; 通过上机试验&#xff0c;掌握R语言实施数据预处理及简单统计分析中的一些基本运算技巧与分析方法&#xff0c;进一步加深对R语言简单统计分析与图形展示的理解。 二、实验内容&#xff1a; bank.csv文件中数据来自1969-1971年美国一家银行的474名职…

在MacOS上打造本地部署的大模型知识库(一)

一、在MacOS上安装Ollama docker run -d -p 3000:8080 --add-hosthost.docker.internal:host-gateway -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main 最后停掉Docker的ollama&#xff0c;就能在webui中加载llama模…

Fiddler在Windows下抓包Https

文章目录 1.Fiddler Classic 配置2.配置浏览器代理自动代理手动配置浏览器代理 3.抓取移动端 HTTPS 流量&#xff08;可选&#xff09;解决抓取 HTTPS 失败问题1.Fiddler证书过期了 默认情况下&#xff0c;Fiddler 无法直接解密 HTTPS 流量。需要开启 HTTPS 解密&#xff1a; 1…

常用的AI文本大语言模型汇总

AI文本【大语言模型】 1、文心一言https://yiyan.baidu.com/ 2、海螺问问https://hailuoai.com/ 3、通义千问https://tongyi.aliyun.com/qianwen/ 4、KimiChat https://kimi.moonshot.cn/ 5、ChatGPThttps://chatgpt.com/ 6、魔塔GPT https://www.modelscope.cn/studios/iic…

(python)Arrow库使时间处理变得更简单

前言 Arrow库并不是简单的二次开发,而是在datetime的基础上进行了扩展和增强。它通过提供更简洁的API、强大的时区支持、丰富的格式化和解析功能以及人性化的显示,填补了datetime在某些功能上的空白。如果你需要更高效、更人性化的日期时间处理方式,Arrow库是一个不错的选择…

游戏引擎学习第127天

仓库:https://gitee.com/mrxiao_com/2d_game_3 为本周设定阶段 我们目前的渲染器已经实现了令人惊讶的优化&#xff0c;经过过去两周的优化工作后&#xff0c;渲染器在1920x1080分辨率下稳定地运行在60帧每秒。这个结果是意料之外的&#xff0c;因为我们没有预计会达到这样的…

leetcode 73. 矩阵置零

题目如下 数据范围 如果一个点m(i,j) 0其中i j都大于0那么按照题目要求对应的m[0][j] m[i][0]都要赋值为0. 所以我们可以令第一行和第一列作为标记是否对应的列和行需要置为0. 又因为我们没法判断第一行和第一列所以需要额外两个变量标记第一列和第二列。 这样就可以满足题…