PyTorch 框架实现线性回归:从数据预处理到模型训练全流程

系列文章目录

01-PyTorch新手必看:张量是什么?5 分钟教你快速创建张量!
02-张量运算真简单!PyTorch 数值计算操作完全指南
03-Numpy 还是 PyTorch?张量与 Numpy 的神奇转换技巧
04-揭秘数据处理神器:PyTorch 张量拼接与拆分实用技巧
05-深度学习从索引开始:PyTorch 张量索引与切片最全解析
06-张量形状任意改!PyTorch reshape、transpose 操作超详细教程
07-深入解读 PyTorch 张量运算:6 大核心函数全面解析,代码示例一步到位!
08-自动微分到底有多强?PyTorch 自动求导机制深度解析
09-从零手写线性回归模型:PyTorch 实现深度学习入门教程
10-PyTorch 框架实现线性回归:从数据预处理到模型训练全流程


文章目录

  • 系列文章目录
  • 前言
  • 一、构建数据集
    • 1.1 示例代码
  • 二、数据加载器
    • 2.1 示例代码
  • 三、定义模型
    • 3.1 示例代码
  • 四、定义损失函数与优化器
    • 4.1 示例代码
  • 五、训练模型
    • 5.1 示例代码
  • 六、绘制结果
    • 6.1 示例代码
  • 七、主函数
    • 7.1 示例代码
    • 7.2 示例输出
  • 八、总结
    • 8.1 完整代码


前言

在之前的文章中,通过手动方式构建了一个简单的线性回归模型。然而,面对复杂的网络设计,手动实现不仅繁琐,还容易出错。为了提高效率和灵活性,可以利用 PyTorch 提供的组件来快速搭建模型。

本文将通过 PyTorch 实现线性回归,主要包括以下内容:

  • 使用 nn.MSELoss 代替自定义的平方损失函数
  • 使用 data.DataLoader 代替自定义的数据加载器
  • 使用 optim.SGD 代替自定义的优化器
  • 使用 nn.Linear 代替自定义的线性模型

一、构建数据集

首先,需要生成一组模拟的线性回归数据,并将其转换为 PyTorch 张量。使用 sklearn.datasets.make_regression 函数来创建数据。

1.1 示例代码

import torch
from sklearn.datasets import make_regression# 构建数据集
def create_dataset():x, y, coef = make_regression(n_samples=150,       # 样本数量n_features=1,        # 特征数量noise=15,            # 噪声大小coef=True,           # 返回系数bias=10.0,           # 偏置项random_state=42      # 随机种子)# 转换为 PyTorch 张量x = torch.tensor(x, dtype=torch.float32)y = torch.tensor(y, dtype=torch.float32)return x, y, coef

在上面的代码中,生成了一个 150 个样本、带有噪声的数据集,并将其转换为 PyTorch 支持的张量格式,便于后续训练使用。


二、数据加载器

为了更方便地处理数据批量,使用 PyTorch 的 DataLoader 来加载数据集。

2.1 示例代码

from torch.utils.data import TensorDataset, DataLoader# 构建数据加载器
def create_dataloader(x, y):dataset = TensorDataset(x, y)  # 创建数据集对象dataloader = DataLoader(dataset, batch_size=20, shuffle=True)  # 批量大小为20return dataloader

DataLoader 可以轻松实现数据的分批次处理,同时支持打乱数据顺序以提高模型的泛化能力。


三、定义模型

PyTorch 提供了 nn.Linear 作为线性模型的实现,只需定义输入和输出的特征数量。

线性模型的核心公式为:
y ^ = x ⋅ w + b \hat{y} = x \cdot w + b y^=xw+b
其中,w 是权重,b 是偏置,均为模型的可学习参数。

3.1 示例代码

import torch.nn as nn# 构建线性模型
def create_model():model = nn.Linear(in_features=1, out_features=1)  # 输入和输出特征均为1return model

四、定义损失函数与优化器

使用 nn.MSELoss 作为损失函数,并使用 optim.SGD 优化器来更新模型参数。

4.1 示例代码

import torch.optim as optim# 定义损失函数和优化器
def create_loss_and_optimizer(model):criterion = nn.MSELoss()  # 均方误差损失optimizer = optim.SGD(model.parameters(), lr=0.01)  # 学习率为0.01return criterion, optimizer

五、训练模型

将数据加载器、模型、损失函数和优化器整合到训练循环中,通过梯度下降来优化模型。

5.1 示例代码

# 训练模型
def train_model(dataloader, model, criterion, optimizer, epochs=100):for epoch in range(epochs):for batch_x, batch_y in dataloader:# 前向传播:计算预测值y_pred = model(batch_x)# 计算损失loss = criterion(y_pred, batch_y.unsqueeze(1))# 清空梯度optimizer.zero_grad()# 反向传播:计算梯度loss.backward()# 更新参数optimizer.step()# 打印每10个epoch的损失if (epoch + 1) % 10 == 0:print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")

在训练过程中,每次从数据加载器中取出一个批次的数据进行训练,并更新模型参数。


六、绘制结果

训练完成后,可视化模型的拟合效果,与真实数据进行对比。

6.1 示例代码

# 可视化模型拟合结果
def plot_results(x, y, model, coef, bias):# 设置中文字体plt.rcParams['font.sans-serif'] = ['SimHei']  # 使用黑体plt.rcParams['axes.unicode_minus'] = False   # 正常显示负号# 绘制散点图plt.scatter(x.numpy(), y.numpy(), label="数据点", alpha=0.7)# 绘制模型预测的直线x_range = torch.linspace(x.min(), x.max(), 100).reshape(-1, 1)y_pred = model(x_range).detach().numpy()plt.plot(x_range.numpy(), y_pred, label="拟合直线", color="r")# 绘制真实的直线coef = torch.tensor(coef, dtype=torch.float32)bias = torch.tensor(bias, dtype=torch.float32)y_true = coef * x_range + biasplt.plot(x_range.numpy(), y_true.numpy(), label="真实直线", color="g", linestyle="--")# 添加标题和标签plt.title("线性回归拟合结果")plt.xlabel("自变量 X")plt.ylabel("因变量 Y")# 显示图例和网格plt.legend()plt.grid(True)# 显示绘图plt.show()

七、主函数

7.1 示例代码

if __name__ == "__main__":# 构建数据x, y, coef = create_dataset()dataloader = create_dataloader(x, y)model = create_model()criterion, optimizer = create_loss_and_optimizer(model)# 训练模型train_model(dataloader, model, criterion, optimizer, epochs=100)# 绘制结果plot_results(x, y, model, coef, bias=10.0)

7.2 示例输出

在这里插入图片描述
运行结果将显示数据点与拟合直线,直线与真实线性关系高度吻合,说明模型训练效果良好。


八、总结

本文通过使用 PyTorch 实现线性回归模型,完成了以下内容:

  • 构建数据集并设计数据加载器;
  • 使用 nn.Linear 定义线性假设函数;
  • 使用 nn.MSELoss 设计均方误差损失函数;
  • 使用 optim.SGD 实现随机梯度下降优化方法;
  • 训练模型并可视化拟合结果。

8.1 完整代码

import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt# 构建数据集
def create_dataset():x, y, coef = make_regression(n_samples=150,       # 样本数量n_features=1,        # 特征数量noise=15,            # 噪声大小coef=True,           # 返回系数bias=10.0,           # 偏置项random_state=42      # 随机种子)# 转换为 PyTorch 张量x = torch.tensor(x, dtype=torch.float32)y = torch.tensor(y, dtype=torch.float32)return x, y, coef# 构建数据加载器
def create_dataloader(x, y):dataset = TensorDataset(x, y)  # 创建数据集对象dataloader = DataLoader(dataset, batch_size=20, shuffle=True)  # 批量大小为20return dataloader# 构建线性模型
def create_model():model = nn.Linear(in_features=1, out_features=1)  # 输入和输出特征均为1return model# 定义损失函数和优化器
def create_loss_and_optimizer(model):criterion = nn.MSELoss()  # 均方误差损失optimizer = optim.SGD(model.parameters(), lr=0.01)  # 学习率为0.01return criterion, optimizer# 训练模型
def train_model(dataloader, model, criterion, optimizer, epochs=100):for epoch in range(epochs):for batch_x, batch_y in dataloader:# 前向传播:计算预测值y_pred = model(batch_x)# 计算损失loss = criterion(y_pred, batch_y.unsqueeze(1))# 清空梯度optimizer.zero_grad()# 反向传播:计算梯度loss.backward()# 更新参数optimizer.step()# 打印每10个epoch的损失if (epoch + 1) % 10 == 0:print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")# 可视化模型拟合结果
def plot_results(x, y, model, coef, bias):# 设置中文字体plt.rcParams['font.sans-serif'] = ['SimHei']  # 使用黑体plt.rcParams['axes.unicode_minus'] = False   # 正常显示负号# 绘制散点图plt.scatter(x.numpy(), y.numpy(), label="数据点", alpha=0.7)# 绘制模型预测的直线x_range = torch.linspace(x.min(), x.max(), 100).reshape(-1, 1)y_pred = model(x_range).detach().numpy()plt.plot(x_range.numpy(), y_pred, label="拟合直线", color="r")# 绘制真实的直线coef = torch.tensor(coef, dtype=torch.float32)bias = torch.tensor(bias, dtype=torch.float32)y_true = coef * x_range + biasplt.plot(x_range.numpy(), y_true.numpy(), label="真实直线", color="g", linestyle="--")# 添加标题和标签plt.title("线性回归拟合结果")plt.xlabel("自变量 X")plt.ylabel("因变量 Y")# 显示图例和网格plt.legend()plt.grid(True)# 显示绘图plt.show()# 主函数
if __name__ == "__main__":# 构建数据x, y, coef = create_dataset()dataloader = create_dataloader(x, y)model = create_model()criterion, optimizer = create_loss_and_optimizer(model)# 训练模型train_model(dataloader, model, criterion, optimizer, epochs=100)# 绘制结果plot_results(x, y, model, coef, bias=10.0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/503547.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Elasticsearch:优化的标量量化 - 更好的二进制量化

作者:来自 Elastic Benjamin Trent 在这里,我们解释了 Elasticsearch 中的优化标量量化以及如何使用它来改进更好的二进制量化 (Better Binary Quantization - BBQ)。 我们的全新改进版二进制量化 (Better Binary Quantization - BBQ) 索引现在变得更强大…

科普CMOS传感器的工作原理及特点

在当今数字化成像的时代,图像传感器无疑是幕后的关键 “功臣”,它宛如一位神奇的 “光影魔法师”,通过光电效应这一奇妙的物理现象,将光子巧妙地转换成电荷,为图像的诞生奠定基础。而在众多类型的图像传感器中&#xf…

IDEA中Maven依赖包导入失败报红的潜在原因

在上网试了别人的八个问题总结之后依然没有解决&#xff1a; IDEA中Maven依赖包导入失败报红问题总结最有效8种解决方案_idea导入依赖还是报红-CSDN博客https://blog.csdn.net/qq_43705131/article/details/106165960 江郎才尽之后突然想到一个原因&#xff1a;<dep…

Java100道面试题

1.JVM内存结构 1. 方法区&#xff08;Method Area&#xff09; 方法区是JVM内存结构的一部分&#xff0c;用于存放类的相关信息&#xff0c;包括&#xff1a; 类的结构&#xff08;字段、方法、常量池等&#xff09;。字段和方法的描述&#xff0c;如名称、类型、访问修饰符…

虚表 —— 隐藏行(简单版)

因为隐藏行改变了listview内部行号处理机制&#xff0c;需要处理大量细节&#xff0c;如listview内部用于传递行号的各种消息、通知等、封装的各种读取行号的函数等。 所以在工作量很大&#xff0c;一处纰漏可能导致重大bug的情况下&#xff0c;仅对隐藏行功能进行了简单封装&…

UDP -- 简易聊天室

目录 gitee&#xff08;内有详细代码&#xff09; 图解 MessageRoute.hpp UdpClient.hpp UdpServer.hpp Main.hpp 运行结果&#xff08;本地通信&#xff09; 如何分开对话显示&#xff1f; gitee&#xff08;内有详细代码&#xff09; chat_room zihuixie/Linux_Lear…

python制作翻译软件

本文复刻此教程&#xff1a;制作属于自己的翻译软件-很简单【Python】_哔哩哔哩_bilibili 一、明确需求&#xff08;以搜狗翻译为例&#xff09; &#xff08;1&#xff09;网址&#xff1a;https://fanyi.sogou.com/text &#xff08;2&#xff09; 数据&#xff1a;翻译内容…

uni-app 资源引用(绝对路径和相对路径)方法汇总

文章目录 一、前言&#x1f343;二、绝对路径和相对路径2.1 绝对路径2.2 相对路径 三、引用组件四、引用js4.1 js 文件引入4.2 NPM支持 五、引用css六、引用json6.1 json文件引入 七、引用静态资源7.1 模板内引入静态资源7.2 css 引入静态资源7.3 js/uts 引入静态资源7.4 静态资…

在 ASP.NET CORE 中上传、下载文件

创建 Web API 来提供跨客户端和服务器的文件上传和下载是常有的事。本文将介绍如何通过 ASP.NET CORE 来实现。 首先在 Visual Studio 中创建空的 Web API 项目&#xff0c;然后选择目标框架 .Net Core 3.1。 创建名为 FileController 的控制器&#xff0c;提供操作文件的接口…

基于 GEE Sentinel-1 数据集提取水体

目录 1 水体提取原理 2 完整代码 3 运行结果 1 水体提取原理 水体提取是地理信息和遥感技术的关键应用之一&#xff0c;对于多个领域都具有重要的应用价值。它有助于更好地管理水资源&#xff0c;保护环境&#xff0c;减少灾害风险&#xff0c;促进可持续发展&#xff0c;以…

微信小程序获取图片使用session(上篇)

概述&#xff1a; 我们开发微信小程序&#xff0c;从后台获取图片现实的时候&#xff0c;通常采用http get的方式&#xff0c;例如以下代码 <image class"user_logo" src"{{logoUrl}}"></image>变量logoUrl为ur图片l的请求地址 但是对于很多…

新年感悟:2025年1月7日高铁随想

2025年1月7日&#xff0c;乘坐在从珠海去广州南的C7676高铁上&#xff0c;突然悟明白两个事情。 首先&#xff0c;不管学习任何东西&#xff0c;总结是一个非常关键的经验。以前&#xff0c;总是幻想着能找到一本书&#xff0c;或者一个特别优秀的老师&#xff0c;仅仅通过看看…

centOS7

特殊权限 set_uid 赋予所有者身份 chmod us 文件 set_gid 赋予所有组身份 chmod gs 文件/目录 sticky_bit 防火墙 firewall-cmd 开启端口 firewall-cmd --zonepublic --add-port8080/tcp --permanent 重启防火墙 systemctl restart firewalld 查看开启的所有端口 fi…

Hbuilder ios 离线打包sdk版本4.36,HbuilderX 4.36生成打包资源 问题记录

1、打包文档地址https://nativesupport.dcloud.net.cn/AppDocs/usesdk/ios.html#%E9%85%8D%E7%BD%AE%E5%BA%94%E7%94%A8%E7%89%88%E6%9C%AC%E5%8F%B7 2、配置应用图标 如果没有appicon文件&#xff0c;此时找到 Assets.xcassets 或者 Images.xcassets(看你sdk引入的启动文件中…

HCIA-Access V2.5_8_2_EPON基本架构和关键参数

EPON数据利用方式 EPON和GPON同样只有一根光纤&#xff0c;所以为了避免双向发送数据出现冲突&#xff0c;我们同样采用WDM技术&#xff0c;那么主要利用两个波长&#xff0c;一个是1490纳米的波长&#xff0c;一个是1310纳米的波长&#xff0c;下行OLT给ONU发送数据的时候&…

新一代智能工控系统网络安全合规解决方案

01.新一代智能工控系统概述 新一代智能工控系统是工业自动化的核心&#xff0c;它通过集成人工智能、工业大模型、物联网、5G等技术&#xff0c;实现生产过程的智能化管理和控制。这些系统具备实时监控、自动化优化、灵活调整等特点&#xff0c;能够提升生产效率、保证产品质量…

前端使用Get传递数组形式的数据

前端使用Get传递数组形式的数据 前端后端接收 不能直接使用 JSON.stringify()传输参数&#xff0c;或者直接用json数据传输&#xff0c;后端均会应为包含了非法的符号 [与 ]而报错。 前端 主要在于对Array形式的数据进行转换&#xff0c;拼接成字符串&#xff0c;采用join方…

Centos 下安装 GitLab16.2.1

参考 https://blog.csdn.net/weixin_46059351/article/details/140649426 https://blog.csdn.net/qq_46028493/article/details/144993598 Centos 安装 GitLab 修改 yum 的配置 首先查看目前配置的 yum&#xff1a; cat /etc/yum.repos.d/CentOS-Base.repo应该是这个样子…

uniapp 微信小程序 自定义日历组件

效果图 功能&#xff1a;可以记录当天是否有某些任务或者某些记录 具体使用&#xff1a; 子组件代码 <template><view class"Accumulate"><view class"bx"><view class"bxx"><view class"plank"><…

刚体变换矩阵的逆

刚体运动中的变换矩阵为&#xff1a; 求得变换矩阵的逆矩阵为&#xff1a; opencv应用 cv::Mat R; cv::Mat t;R.t(), -R.t()*t