从零开始学习线性回归:理论、实践与PyTorch实现

文章目录

  • 🥦介绍
  • 🥦基本知识
  • 🥦代码实现
  • 🥦完整代码
  • 🥦总结

🥦介绍

线性回归是统计学和机器学习中最简单而强大的算法之一,用于建模和预测连续性数值输出与输入特征之间的关系。本博客将深入探讨线性回归的理论基础、数学公式以及如何使用PyTorch实现一个简单的线性回归模型。

🥦基本知识

线性回归的数学基础
线性回归的核心思想是建立一个线性方程,它表示了自变量(输入特征)与因变量(输出)之间的关系。这个线性方程通常表示为:
y=b0+b1x1+b2x2+…+bpxp
y=b0​+b1​x1​+b2​x2​+…+bp​xp​

其中, y y y 是因变量, x 1 , x 2 , … , x p x_1, x_2, \ldots, x_p x1,x2,,xp 是自变量, b 0 , b 1 , b 2 , … , b p b_0, b_1, b_2, \ldots, b_p b0,b1,b2,,bp 是模型的参数, p p p 是特征的数量。我们的目标是找到最佳的参数值,以最小化模型的误差。

损失函数
为了找到最佳参数,我们需要定义一个损失函数来度量模型的性能。在线性回归中,最常用的损失函数是均方误差(MSE),它表示了模型预测值与实际值之间的平方差的平均值:
在这里插入图片描述
其中, n n n 是样本数量, y i y_i yi 是实际值, y ^ i \hat{y}_i y^i 是模型的预测值。

梯度下降优化
为了最小化损失函数,我们使用梯度下降算法。梯度下降通过计算损失函数相对于参数的梯度,并迭代地更新参数,以减小损失。更新规则如下:
在这里插入图片描述
其中, b j b_j bj 是第 j j j个参数, α \alpha α 是学习率, ∂ ∂ b j M S E \frac{\partial}{\partial b_j} MSE bjMSE 是损失函数对参数 b j b_j bj的偏导数。

🥦代码实现

如果你想知道实现线性回归的大体步骤,下图可以充分进行说明
在这里插入图片描述

  • 准备数据
  • 设计模型(计算) y ^ i \hat{y}_i y^i
  • 构造损失和优化器
  • 训练周期(前向,反向 ,更新)

本节还是以刘二大人的视频讲解为例,结尾会设置传送门

class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__() # 调用父类的构造函数self.linear = torch.nn.Linear(1, 1)  # 参数详情下图展示def forward(self, x):y_pred = self.linear(x)   # x代表输入样本的张量return y_pred
model = LinearModel()

所以模型类都要继承Module,此类主要包含两个函数一个是构造函数(初始化对象时调用),另一个是前向计算

好奇的小伙伴会思考为何没有反向(backward),这是因为Module会帮你进行,但是如果后期自己有更高效的方法可以自行设置。
在这里插入图片描述

  • 第一个参数 in_features:这是输入特征的数量。在这里,表示我们的模型只有一个输入特征。如果你有多个输入特征,你可以将这个参数设置为输入特征的数量。

  • 第二个参数 out_features:这是输出特征的数量。这表示我们的模型将生成一个输出。在线性回归中,通常只有一个输出,因为我们试图预测一个连续的数值。

  • 第三个参数:意思是要不要偏置量。默认true

通常情况下特征代表列,比如我们有一个n×2的y和一个n×3的x,那么我们需要一个3×2的权重,有的书中会在两边做转置,但无论咋样目的都是为了让这个矩阵乘法成立

criterion = torch.nn.MSELoss(size_average=False)  # 使用均方误差损失 
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 使用随机梯度下降优化器

在这里插入图片描述
在这里插入图片描述
model.parameters() 用于告诉优化器哪些参数需要在训练过程中进行更新,这包括模型的权重和偏置项等。在线性回归示例中,模型的参数包括权重和偏置项。

优化器的选择有许多大家可以都试试看看
在这里插入图片描述

之后就进行训练了

for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data) print(epoch, loss.item())optimizer.zero_grad()   # 归零loss.backward()  # 反向optimizer.step()  # 更新
print('w = ', model.linear.weight.item()) 
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]]) 
y_test = model(x_test)
print('y_pred = ', y_test.data)

🥦完整代码

x_data = torch.Tensor([[1.0], [2.0], [3.0]]) 
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__() self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x) return y_pred
model = LinearModel()
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data) print(epoch, loss.item())optimizer.zero_grad() loss.backward()optimizer.step()
print('w = ', model.linear.weight.item()) 
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]]) 
y_test = model(x_test)
print('y_pred = ', y_test.data)
predicted = model(x_data).detach().numpy()
plt.scatter(x_data, y_data, label='Original data')
plt.plot(x_data, predicted, label='Fitted line', color='r')
plt.legend()
plt.show()

运行结果如下
在这里插入图片描述

在这里插入图片描述

🥦总结

在本篇博客中,我们使用PyTorch实现了一个简单的线性回归模型,并使用随机生成的数据对其进行了训练和可视化。线性回归是一个入门级的机器学习模型,但它为理解模型训练和预测的基本概念提供了一个很好的起点。

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/150207.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python爬虫(二十二)_selenium案例:模拟登陆豆瓣

本篇博客主要用于介绍如何使用seleniumphantomJS模拟登陆豆瓣,没有考虑验证码的问题,更多内容,请参考:Python学习指南 #-*- coding:utf-8 -*-from selenium import webdriver from selenium.webdriver.common.keys import Keysimp…

IDEA 使用

目录 Git.gitignore 不上传取消idea自动 add file to git撤销commit的内容本地已经有一个开发完成的项目,这个时候想要上传到仓库中 Git .gitignore 不上传 在项目根目录下创建 .gitignore 文件夹,并添加内容: .gitignore取消idea自动 add…

Leetcode901-股票价格跨度

一、前言 本题基于leetcode901股票价格趋势这道题,说一下通过java解决的一些方法。并且解释一下笔者写这道题之前的想法和一些自己遇到的错误。需要注意的是,该题最多调用 next 方法 10^4 次,一般出现该提示说明需要注意时间复杂度。 二、解决思路 ①…

ArcGIS Engine:视图菜单的创建和鹰眼图的实现

目录 01 创建项目 1.1 通过ArcGIS-ExtendingArcObjects创建窗体应用 1.2 通过C#-Windows窗体应用创建窗体应用 1.2.1 创建基础项目 1.2.2 搭建界面 02 创建视图菜单 03 鹰眼图的实现 3.1 OnMapReplaced事件的触发 3.2 OnExtentUpdated事件的触发 04 稍作演示 01 创建项目…

【交互式阈值二进制图像】采用彩色或单色图像通过交互/手动方式阈值单色图像或彩色图像的单个色带研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

外部中断的基本操作

题目背景 定义一个 Working() 函数,使L1指示灯不断闪烁。将P32引脚定义成外部中断功能,按下S5按键就会产生外部中断触发信号,在中断响应函数中,点亮L8指示灯,延时较长一段时间后熄灭,该功能用两种方法实现…

selenium +IntelliJ+firefox/chrome 环境全套搭配

1第一步:下载IntelliJ idea 代码编辑器 2第二步:下载浏览器Chrome 3第三步:下载JDK 4第四步:配置环境变量(1JAVA_HOME 2 path) 5第五步:下载Maven 6第六步:配置环境变量&#x…

Scala第十六章节

Scala第十六章节 scala总目录 文档资料下载 章节目标 掌握泛型方法, 类, 特质的用法了解泛型上下界相关内容了解协变, 逆变, 非变的用法掌握列表去重排序案例 1. 泛型 泛型的意思是泛指某种具体的数据类型, 在Scala中, 泛型用[数据类型]表示. 在实际开发中, 泛型一般是结合…

CTFHUB SSRF

目录 web351 ​编辑 web352 web353 web354 sudo.cc 代表 127 web355 host长度 web356 web357 DNS 重定向 web358 bypass web359 mysql ssrf web360 web351 POST查看 flag.php即可 web352 <?php error_reporting(0); highlight_file(__FILE__); $url$_…

Java基础(二)

1. 面向对象基础 1.1 面向对象和面向过程的区别 面向过程把解决问题的过程拆成一个个方法&#xff0c;通过一个个方法的执行解决问题。面向对象会先抽象出对象&#xff0c;然后用对象执行方法的方式解决问题。 面向对象开发的方式更容易维护和迭代升级、易复用、易扩展。 1…

数据防泄密软件排行榜(企业电脑防泄密软件哪一款好用,有哪些推荐)

在当今信息化社会&#xff0c;数据已经成为了企业的重要资产。然而&#xff0c;数据的安全问题也日益突出&#xff0c;尤其是数据的泄露&#xff0c;不仅会导致企业的商业秘密被竞争对手获取&#xff0c;还可能引发一系列的法律问题。因此&#xff0c;数据防泄密软件的重要性不…

it端到端运营监控

公司的运维监控已成为确保业务顺利运行的关键。特别是对于IT部门&#xff0c;端到端运维监控不仅可以帮助企业及时发现和解决问题&#xff0c;还可以提高业务效率&#xff0c;优化客户体验。端到端运维监控的概念、重要性及其实施方法。 端到端操作监控的概念 端到端操作监控&…

NPDP35岁考还有意义吗? NPDP证书认可度如何?

一句话说的好&#xff0c;“活到老&#xff0c;学到老”&#xff0c;只要想学&#xff0c;想干事&#xff0c;什么时候都不晚&#xff0c;何况35岁正是一个人的转折阶段。当然&#xff0c;考取NPDP证书是否真的对自身有意义&#xff0c;这取决于你个人情况和职业发展目标。产品…

京东数据分析平台:2023年8月京东奶粉行业品牌销售排行榜

鲸参谋监测的京东平台8月份奶粉市场销售数据已出炉&#xff01; 鲸参谋数据显示&#xff0c;8月份京东平台上奶粉的销售量将近700万件&#xff0c;环比增长约15%&#xff0c;同比则下滑约19%&#xff1b;销售额将近23亿元&#xff0c;环比增长约4%&#xff0c;同比则下滑约3%。…

QMC5883L-磁力计椭球拟合校准

1.概述 磁力计椭球拟合校准是一种将磁力计测量数据校准到真实磁场的技术。这种技术通常使用椭球模型来拟合磁力计的测量结果&#xff0c;然后通过最小二乘法来找到拟合参数的最优解。 2.总体思想 磁力计椭球拟合校准的思想包括以下几个步骤&#xff1a; 1.数据预处理&#x…

万兆光模块的价格相比千兆光模块贵多少?

万兆光模块和千兆光模块是应用非常广泛的两款产品。万兆光模块与千兆光模块相比&#xff0c;主要优势在于速率更高、带宽更大。传输速率为万兆的光模块&#xff0c;理论上可以实现每秒传输10G的数据&#xff0c;是传输速率约为千兆的光模块的10倍&#xff0c;可以在同等时间内传…

HTML 笔记:初识 HTML(HTML文本标签、文本列表、嵌入图片、背景色、网页链接)

1 何为HTML 用来描述网页的一种语言超文本标记语言(Hyper Text Markup Language)不是一种编程语言&#xff0c;而是一种标记语言 (markup language) 2 HTML标签 HTML 标签是由尖括号包围的关键词&#xff0c;比如 <html> 作用是为了“标记”页面中的内容&#xff0c;使…

TCP VS UDP

程序员写网络程序&#xff0c;主要编写的应用层代码&#xff01; 真正要发这个数据&#xff0c;需要上层协议调用下层协议&#xff0c;应用层要调用传输层&#xff0c;则传输层给应用层提供一组api&#xff0c;统称为&#xff1a;soket api 基于UDP的api 基于TCP的api 这两个协…

深圳市重点实验室申报要求-华夏泰科

深圳市重点实验室&#xff0c;以开展基础研究、应用基础研究、前沿技术研究&#xff0c;培养人才、支撑产业和社会发展为目标而建立。它为研究人员提供了一个独特的平台&#xff0c;提供了一个展示他们创新性研究的舞台。本文将深入探讨如何申报深圳市重点实验室&#xff0c;为…

实战纪实 | 某米企业src未授权访问

公众号&#xff1a;掌控安全EDU 分享更多技术文章&#xff0c;欢迎关注一起探讨学习 某米企业src漏洞挖掘 这一挖就挖到了一个未授权操作漏洞&#xff0c;写个文章记录下~~ 通过信息收集&#xff0c;发现这么一个资产。 访问 http://xxx.com 如下图所示 1.点击头像-点击授权登…