Lucas带你手撕机器学习——线性回归

什么是线性回归

线性回归是机器学习中的基础算法之一,用于预测一个连续的输出值。它假设输入特征与输出值之间的关系是线性关系,即目标变量是输入变量的线性组合。我们可以从代码实现的角度来学习线性回归,包括如何使用 Python 进行简单的线性回归模型构建、训练、和预测。

线性回归的直观理解

你可以把线性回归理解成“画一条线来预测未来”。假设你有一张散点图,每个点代表某个物品的重量和它的价格。你的目标是找到一条直线,能够尽可能准确地描述这些点之间的关系。

线性回归的工作原理

假设我们有一些数据点,每个点都有一个输入(如重量)和一个输出(如价格)。线性回归就是在这些点之间找到一条直线,使得这条线能够“最好”地描述这些数据点。

这条直线的公式是:

在这里插入图片描述

其中:

  • y:输出,即我们想要预测的值(例如,物品的价格)
  • x:输入特征(例如,物品的重量)
  • w:线的斜率,表示重量对价格的影响有多大
  • b:截距,表示当重量为 0 时,预测的价格是多少

线性回归的基本原理

线性回归的数学公式为:

在这里插入图片描述

其中:

  • y 是预测值(目标变量)
  • x1,x2,…,xn 是输入特征
  • w1,w2,…,wn 是特征对应的权重(回归系数)
  • b 是偏置项(截距)

如何找到“最好的”直线?

“最好的”直线是指那些经过这条直线的点尽可能接近数据点。为了衡量直线的好坏,我们需要一个方法来计算直线与数据点之间的差距。

误差的概念
  • 对于每个数据点,我们可以计算它的实际价格(真实值)和用这条直线预测出来的价格之间的差距,称为“误差”。
  • 比如说,某个物品的真实价格是 10 元,但通过直线预测出来的价格是 9 元,那么这个点的误差就是 10−9=1。
均方误差(Mean Squared Error,MSE)

为了让误差的计算更稳定,我们通常不直接使用误差,而是使用“均方误差”来衡量模型的好坏:

在这里插入图片描述

其中:

  • yi:第 i 个样本的真实值
  • yi^:第 i 个样本通过模型预测的值
  • N:样本数量

均方误差的作用就是将所有数据点的误差平方后取平均值,这样可以确保误差不会因为正负抵消。我们的目标是让这个均方误差尽可能小,意味着直线与数据点之间的差距最小。

训练模型

在实际训练过程中,我们会不断调整直线的斜率 w 和截距 b,直到找到使均方误差最小的那一组 w 和 b。这就意味着找到了“最好的”直线。

代码实现

使用 Scikit-Learn 实现****线性回归

我们可以使用 Scikit-Learn 库,它提供了非常简洁的接口来进行线性回归。下面是一个完整的示例代码:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error# 生成一些模拟数据
np.random.seed(42)
X = 2 * np.random.rand(100, 1)  # 输入特征,100 个样本,1 个特征
y = 4 + 3 * X + np.random.randn(100, 1)  # 线性关系 y = 4 + 3x + 噪声# 拆分数据为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 创建线性回归模型并进行训练
model = LinearRegression()
model.fit(X_train, y_train)# 输出模型的系数和截距
print(f'权重(w): {model.coef_[0][0]}')
print(f'截距(b): {model.intercept_[0]}')# 预测并计算均方误差
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print(f'测试集上的 MSE: {mse}')# 可视化结果
plt.scatter(X_test, y_test, color='blue', label='真实值')
plt.plot(X_test, y_pred, color='red', label='预测值', linewidth=2)
plt.xlabel('X')
plt.ylabel('y')
plt.legend()
plt.title('线性回归拟合结果')
plt.show()

在这里插入图片描述

  1. 代码解释
  • 生成模拟数据: 生成了一些随机数据点 X和 y,其中 y=4 + 3X + 噪声,这样我们就有一个线性关系的示例数据。
  • 数据集拆分: 使用 train_test_split 将数据集拆分成训练集和测试集,80% 用于训练,20% 用于测试。
  • 训练模型: 使用 LinearRegression 类创建模型,并用训练集数据拟合模型。
  • 预测和评估: 使用测试集进行预测,计算预测值与真实值之间的均方误差(MSE)。
  • 结果可视化: 将真实值和预测结果在图中可视化,可以清楚地看到线性回归的拟合效果。

PyTorch 实现线性回归

为了更好地理解线性回归的原理,我们也可以使用 PyTorch 从头实现一个简单的线性回归模型:

import torch
import torch.nn as nn
import torch.optim as optim# 生成模拟数据
torch.manual_seed(42)
X = torch.randn(100, 1) * 2
y = 4 + 3 * X + torch.randn(100, 1)# 定义线性模型
class LinearRegressionModel(nn.Module):def __init__(self):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(1, 1)  # 输入 1 维,输出 1 维def forward(self, x):return self.linear(x)# 创建模型、损失函数和优化器
model = LinearRegressionModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
epochs = 1000
for epoch in range(epochs):model.train()optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item()}')# 输出训练好的模型参数
[w, b] = model.parameters()
print(f'权重(w): {w.item()}')
print(f'截距(b): {b.item()}')

代码解释

  • 定义模型: 使用 nn.Module 定义了一个简单的线性模型,只包含一个线性层。
  • 定义损失函数和优化器: 选择均方误差作为损失函数(nn.MSELoss()),使用随机梯度下降(optim.SGD)优化模型。
  • 模型训练: 通过前向传播计算损失,通过反向传播计算梯度并更新模型参数。

总结

以上两种方法分别使用 Scikit-Learn 和 PyTorch 实现了线性回归模型。Scikit-Learn 的方式适合快速建模和测试,而 PyTorch 版本则更灵活,更适合理解深度学习模型的训练过程。掌握这些方法后,可以将它们应用于更复杂的模型和任务中。

感谢阅读!!我是正在澳洲深造的Lucas!!
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/453244.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

git的安装以及入门使用

文章目录 git的安装以及入门使用什么是git?git安装git官网 git初始化配置使用方式初始化配置: git的安装以及入门使用 什么是git? Git 是一个免费开源的分布式版本控制系统,使用特殊的仓库数据库记录文件变化。它记录每个文件的…

WebGl 使用uniform变量动态修改点的颜色

在WebGL中,uniform变量用于在顶点着色器和片元着色器之间传递全局状态信息,这些信息在渲染过程中不会随着顶点的变化而变化。uniform变量可以用来设置变换矩阵、光照参数、材料属性等。由于它们在整个渲染过程中共享,因此可以被所有使用该着色…

嵌入式linux系统中多路复用和信号驱动实现

大家好,今天主要给大家分享一下,如何使用linux系统中的多路复用和信号驱动的功能实现。 第一:linux多路复用基本特点 当应用程序同时处理多路数据的输入或输出时,若采用非阻塞模式,将达不到预期的效果 如果采用非阻塞模式,对多个输入进行轮询可以实现,但CPU的消耗非常大…

【设计模式系列】装饰器模式

目录 一、什么是装饰器模式 二、装饰器模式中的角色 三、装饰器模式的典型应用场景 四、装饰器模式在BufferedReader中的应用 一、什么是装饰器模式 装饰器模式是一种结构型设计模式,用于在不修改对象自身的基础上,通过创建一个或多个装饰类来给对象…

黑马 | Reids | 基础篇

黑马reids基础篇 文章目录 黑马reids基础篇一.初始Redis1.1SQL 和 NoSql的区别1.1.1结构化和非结构化1.1.2关联和非关联1.1.3查询方式1.1.4 事务1.1.5总结 1.2 认识Redis1.3 Redis安装启动默认启动:后台启动:开机自启 1.4 Redis客户端1.4.1.Redis命令行客…

windows安装mysql,跳过自定义的密码验证

1、mysql版本8 2、配置系统环境变量 3、新建my.ini文件在mysql目录下,需要指定data目录 [mysqld] # 设置3306端口 port3306# 自定义设置mysql的安装目录,即解压mysql压缩包的目录 basedirD:\hjl\app\mysql\mysql-8.0.33-winx64# 自定义设置mysql数据…

24/10/14 算法笔记 循环神经网络RNN

RNN: 一种专门用于处理序列数据的神经网络,它能够捕捉时间序列中的动态特征。RNN的核心特点是其循环连接,这允许网络在不同时间步之间传递信息,从而实现对序列数据的记忆和处理能力。 应用的场景: 自然语言处理(NLP&…

[241021] X-CMD 内测版 v0.4.12 新功能: starship ohmyposh ping tping docker ascii

目录 X-CMD 发布内测版 v0.4.12📃Changelog🎨 starship🎨 ohmyposh🎨 theme🌐 ping🌐 tping🐋 docker💻 mac - 集成 MacOS 实用功能🔄 ascii🦖 deno&#x1f…

探索秘境:如何使用智能体插件打造专属的小众旅游助手『小众旅游探险家』

文章目录 摘要引言智能体介绍和亮点展示介绍亮点展示 已发布智能体运行效果智能体创意想法创意想法创意实现路径拆解 如何制作智能体可能会遇到的几个问题快速调优指南总结未来展望 摘要 本文将详细介绍如何使用智能体平台开发一款名为“小众旅游探险家”的旅游智能体。通过这…

怎么设置打别人电话显示自己公司名称?

在日常生活中,想必许多人都曾接到过显示公司名称的来电。相较于常规的电话号码,这类带有企业信息的来电无疑更具可信度,让人更愿意接听。在这个骚扰电话和推销电话泛滥、信任缺失的现代社会,这些能够自证身份的电话号码就像是一张…

职场经验:如何封装自动化测试框架?

封装自动化测试框架,测试人员不用关注框架的底层实现,根据指定的规则进行测试用例的创建、执行即可,这样就降低了自动化测试门槛,能解放出更多的人力去做更深入的测试工作。 本篇文章就来介绍下,如何封装自动化测试框…

filebeat接入nginx和mysql获取日志

下载nginx (1) 直接下载 yum install nginx -y(2)查看状态启动 systemctl start nginx systemctl status nginx(3)配置文件检查 nginx -t(4)端口检查 netstat -tulpn | grep :80&am…

Linux系统:配置Apache支持CGI(Ubuntu)

配置Apache支持CGI 根据以下步骤配置,实现Apache支持CGI 安装Apache: 可参照文章: Ubuntu安装Apache教程。执行以下命令,修改Apache2配置文件000-default.conf: sudo vim /etc/apache2/sites-enabled/000-default.con…

相同的树算法

给你两棵二叉树的根节点 p 和 q ,编写一个函数来检验这两棵树是否相同。 如果两个树在结构上相同,并且节点具有相同的值,则认为它们是相同的。 示例 1: 输入:p [1,2,3], q [1,2,3] 输出:true示例 2&…

屏幕画面卡住不动声音正常怎么办?电脑屏幕卡住不动解决方法

在数字时代,电脑作为我们日常生活与工作中不可或缺的伙伴,偶尔也会遇到一些小状况。其中,“屏幕画面卡住不动,但是声音依然正常”的情况就是一种常见的问题。本文将探讨这一现象的原因,并提供几种可能的解决方案&#…

Pyqt5设计打开电脑摄像头+可选择哪个摄像头(如有多个)

目录 专栏导读库的安装代码介绍完整代码总结 专栏导读 🌸 欢迎来到Python办公自动化专栏—Python处理办公问题,解放您的双手 🏳️‍🌈 博客主页:请点击——> 一晌小贪欢的博客主页求关注 👍 该系列文…

注册安全分析报告:北外网校

前言 由于网站注册入口容易被黑客攻击,存在如下安全问题: 暴力破解密码,造成用户信息泄露短信盗刷的安全问题,影响业务及导致用户投诉带来经济损失,尤其是后付费客户,风险巨大,造成亏损无底洞…

安装Maven配置以及构建Maven项目(2023idea)

一、下载Maven绿色软件 地址:http://maven.apache.org/download.cgi 尽量不要选择最高版本的安装,高版本意味着高风险的不兼容问题,选择低版本后续问题就少。你也可以选择尝试。 压缩后: 打开后: 在该目录下新建mvn-…

手机思维导图怎么制作?5个软件教你在手机上绘制思维导图

手机思维导图怎么制作?5个软件教你在手机上绘制思维导图 在手机上制作思维导图不仅可以帮助你快速理清思路,还可以随时随地进行创作和调整。以下是5款适合手机上绘制思维导图的软件,它们功能强大、操作简单,帮助你轻松上手。 迅…

2024年游戏买量还有空间吗?

本人从事游戏行业多年,一直做游戏分发的工作,但近年来随着我国大经济背景的整体向下,不仅仅影响了实体企业,游戏行业买量也明显受到影响。 2024年,游戏买量市场呈现出以下几个主要特点: 小游戏买量爆发&am…