PyTorch 神经网络回归(Regression)任务:关系拟合与优化过程

PyTorch 神经网络回归(Regression)任务:关系拟合与优化过程

本教程介绍了如何使用 PyTorch 构建一个简单的神经网络来实现关系拟合,具体演示了从数据准备到模型训练和可视化的完整过程。首先,利用一维线性空间生成带噪声的数据集,接着定义了一个包含隐藏层和输出层的神经网络。通过使用均方误差损失函数和随机梯度下降优化器,逐步训练神经网络来拟合数据。为了便于理解和监控训练过程,我们使用 matplotlib 实现了动态更新的图形,展示了每次迭代后的预测结果与真实数据的对比。该教程不仅帮助读者理解神经网络的基本架构和训练流程,还展示了如何通过可视化手段更直观地观察模型的优化过程,提升了对模型调优的理解与应用能力。

文章目录

  • PyTorch 神经网络回归(Regression)任务:关系拟合与优化过程
      • 一 导入第三方库
      • 二 设置数据集
      • 三 编写神经网络
      • 四 训练神经网络
        • 可视化训练过程
      • 五 完整代码示例
      • 六 源码地址
      • 七 参考

预备课:PyTorch 激活函数详解:从原理到最佳实践

一 导入第三方库

import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
import os

二 设置数据集

# 生成一维的线性空间数据,并增加一维使其形状为 (100, 1)
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2 * torch.rand(x.size())  # 生成对应的 y 数据,加上噪声模拟真实情况

三 编写神经网络

class Net(torch.nn.Module):def __init__(self, n_feature, n_hidden, n_output):super(Net, self).__init__()self.hidden = torch.nn.Linear(n_feature, n_hidden)  # 定义隐藏层,输入维度为 n_feature,输出维度为 n_hiddenself.predict = torch.nn.Linear(n_hidden, n_output)  # 定义输出层,输入维度为 n_hidden,输出维度为 n_outputdef forward(self, x):x = F.relu(self.hidden(x))  # 使用 ReLU 激活函数处理隐藏层的输出x = self.predict(x)  # 计算最终输出return x

在此定义了神经网络的结构,其中隐藏层的输入维度为 n_feature,输出维度为 n_hidden,而输出层的输入维度为 n_hidden,输出维度为 n_output。下图展示了以 3 个神经元为例的网络结构,以帮助理解。
在这里插入图片描述

:如果对上述代码感到困惑,可以暂时将其视为固定写法,专注于理解其基本框架。

四 训练神经网络

# 初始化神经网络
net = Net(n_feature=1, n_hidden=10, n_output=1)  # 定义网络,输入输出各为 1,隐藏层有 10 个神经元
print(net)  # 打印网络结构# 定义优化器和损失函数
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)  # 使用随机梯度下降法优化网络参数,学习率为 0.2
loss_func = torch.nn.MSELoss()  # 定义均方误差损失函数plt.ion()  # 开启交互模式,允许动态更新图像for epoch in range(200):prediction = net(x)  # 前向传播,使用当前网络计算预测值loss = loss_func(prediction, y)  # 计算预测值与真实值之间的误差optimizer.zero_grad()  # 清空上一步的梯度信息loss.backward()  # 反向传播,计算梯度optimizer.step()  # 根据梯度更新网络参数if epoch % 5 == 0:  # 每 5 个周期更新一次图像plt.cla()  # 清除当前图像内容plt.scatter(x.data.numpy(), y.data.numpy(), label='True Data')plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2, label='Prediction')plt.text(0.5, 0, f'Loss={loss.item():.4f}', fontdict={'size': 20, 'color': 'red'})plt.legend()  # 添加图例# 保存当前图像# file_path = os.path.join(target_directory, f'epoch_{epoch}.png')# plt.savefig(file_path)# print(f"图像已保存: {file_path}")plt.pause(0.1)  # 暂停以更新图像plt.ioff()  # 关闭交互模式
plt.show()  # 显示最终图像
可视化训练过程

可视化神经网络训练(关系拟合)

:通过引入 matplotlib 实现训练过程的可视化,帮助直观地跟踪模型的学习进展。

五 完整代码示例

import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
import osclass Net(torch.nn.Module):def __init__(self, n_feature, n_hidden, n_output):super(Net, self).__init__()self.hidden = torch.nn.Linear(n_feature, n_hidden)  # 定义隐藏层,输入维度为 n_feature,输出维度为 n_hiddenself.predict = torch.nn.Linear(n_hidden, n_output)  # 定义输出层,输入维度为 n_hidden,输出维度为 n_outputdef forward(self, x):x = F.relu(self.hidden(x))  # 使用 ReLU 激活函数处理隐藏层的输出x = self.predict(x)  # 计算最终输出return xdef print_hi(name):print(f'Hi, {name}')# 创建保存图片的目录# target_directory = "/Users/your/Desktop/001"# if not os.path.exists(target_directory):#     os.makedirs(target_directory)# 创建数据集# 生成一维的线性空间数据,并增加一维使其形状为 (100, 1)x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)y = x.pow(2) + 0.2 * torch.rand(x.size())  # 生成对应的 y 数据,加上噪声模拟真实情况# 初始化神经网络net = Net(n_feature=1, n_hidden=10, n_output=1)  # 定义网络,输入输出各为 1,隐藏层有 10 个神经元print(net)  # 打印网络结构# 定义优化器和损失函数optimizer = torch.optim.SGD(net.parameters(), lr=0.2)  # 使用随机梯度下降法优化网络参数,学习率为 0.2loss_func = torch.nn.MSELoss()  # 定义均方误差损失函数plt.ion()  # 开启交互模式,允许动态更新图像for epoch in range(200):prediction = net(x)  # 前向传播,使用当前网络计算预测值loss = loss_func(prediction, y)  # 计算预测值与真实值之间的误差optimizer.zero_grad()  # 清空上一步的梯度信息loss.backward()  # 反向传播,计算梯度optimizer.step()  # 根据梯度更新网络参数if epoch % 5 == 0:  # 每 5 个周期更新一次图像plt.cla()  # 清除当前图像内容plt.scatter(x.data.numpy(), y.data.numpy(), label='True Data')plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2, label='Prediction')plt.text(0.5, 0, f'Loss={loss.item():.4f}', fontdict={'size': 20, 'color': 'red'})plt.legend()  # 添加图例# 保存当前图像# file_path = os.path.join(target_directory, f'epoch_{epoch}.png')# plt.savefig(file_path)# print(f"图像已保存: {file_path}")plt.pause(0.1)  # 暂停以更新图像plt.ioff()  # 关闭交互模式plt.show()  # 显示最终图像if __name__ == '__main__':print_hi('关系拟合')

复制粘贴并覆盖到你的 main.py 中运行,运行结果如下。

Hi, 关系拟合
Net((hidden): Linear(in_features=1, out_features=10, bias=True)(predict): Linear(in_features=10, out_features=1, bias=True)
)

六 源码地址

代码地址,GitHub 之 关系拟合 。

七 参考

[1] PyTorch 官方文档

[2] 莫烦 Python

[3] 可视化神经网络 TensorFlow Playground

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/494748.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【uni-app】2025最新uni-app一键登录保姆级教程(包含前后端获取手机号方法)(超强避坑指南)

前言: 最近在配置uni-app一键登录时遇到了不少坑,uni-app的配套文档较为混乱,并且有部分更新的内容也没有及时更改在文档上,导致部分开发者跟着uni-app配套文档踩坑!而目前市面上的文章质量也层次不齐,有的…

干货分享:ISO 20000认证的适用范围、认证资料清单、认证流程等问题详解

编辑:石芸姗 审核:贺兆普 在当今这个数字化时代,信息技术(IT)已成为企业运营与发展的核心驱动力。随着技术的不断进步和业务需求的日益复杂,企业对IT服务的质量、效率及安全性提出了更高要求。 信息技术服…

Element-plus表格使用总结

这里我使用的是Vue工程进行开发学习,安装需要通过包管理器进行下载 npm install element-plus --save 然后在main.js中配置文件即可使用,如果在引入index.css时没有提示,无需担心,直接写index.css即可导入样式。 Table表格 表格…

CNN和Transfomer介绍

文章目录 CNN和Transfomer介绍CNN和Transfomer的区别1. **基本概念**2. **数据处理方式**3. **模型结构差异**4. **应用场景区别** 自注意力机制1. **自注意力机制的概念**2. **自注意力机制的实现步骤**3. **自注意力机制的优势** Transformer结构组成1. **多头注意力层&#…

如何解决 ‘adb‘ 不是内部或外部命令,也不是可运行的程序或批处理文件的问题

在cmd中输入 adb ,显示 ‘adc‘ 不是内部或外部命令,也不是可运行的程序或批处理文件的问题 解决办法:在环境变量中添加adb所在的路径 1、找到 adb.exe 的所在的文件路径,一般在 Android 安装目录下 \sdk\platform-tools\adb.exe…

数据结构---------二叉树前序遍历中序遍历后序遍历

以下是用C语言实现二叉树的前序遍历、中序遍历和后序遍历的代码示例&#xff0c;包括递归和非递归&#xff08;借助栈实现&#xff09;两种方式&#xff1a; 1. 二叉树节点结构体定义 #include <stdio.h> #include <stdlib.h>// 二叉树节点结构体 typedef struct…

网络架构与IP技术:4K/IP演播室制作的关键支撑

随着科技的不断发展&#xff0c;广播电视行业也在不断迭代更新&#xff0c;其中4K/IP演播室技术的应用成了一个引人注目的焦点。4K超高清技术和IP网络技术的结合&#xff0c;不仅提升了节目制作的画质和效果&#xff0c;还为节目制作带来了更高的效率和灵活性。那么4K超高清技术…

MySQL 8.0:explain analyze 分析 SQL 执行过程

介绍 MySQL 8.0.16 引入一个实验特性&#xff1a;explain formattree &#xff0c;树状的输出执行过程&#xff0c;以及预估成本和预估返 回行数。在 MySQL 8.0.18 又引入了 EXPLAIN ANALYZE&#xff0c;在 formattree 基础上&#xff0c;使用时&#xff0c;会执行 SQL &#…

观察者模式(sigslot in C++)

大家&#xff0c;我是东风&#xff0c;今天抽点时间整理一下我很久前关注的一个不错的库&#xff0c;可以支持我们在使用标准C的时候使用信号槽机制进行观察者模式设计&#xff0c;sigslot 官网&#xff1a; http://sigslot.sourceforge.net/ 本文较为详尽探讨了一种观察者模…

【已解决】黑马点评项目Redis版本替换过程中误删数据库后前端显示出现的问题

为了实现基于Redis的Stream结构作为消息队列&#xff0c;实现异步秒杀下单的功能&#xff0c;换Redis版本 Redis版本太旧了&#xff0c;所以从3.2.1换成了5.0.14 此时犯了一个大忌&#xff0c;因为新的Redis打开后&#xff0c;没有缓存&#xff0c;不知道出了什么问题&#xf…

基于Spring Boot的九州美食城商户一体化系统

一、系统背景与目标 随着美食城行业的快速发展&#xff0c;传统的管理方式已经难以满足日益增长的管理需求和用户体验要求。因此&#xff0c;九州美食城商户一体化系统应运而生&#xff0c;旨在通过信息化、智能化的管理方式&#xff0c;实现美食城的商户管理、菜品管理、订单…

springboot vue 会员营销系统

springboot vue 会员营销系统介绍 演示地址&#xff1a; 开源版本&#xff1a;http://8.146.211.120:8083/ 完整版本&#xff1a;http://8.146.211.120:8086/ 移动端 http://8.146.211.120:8087/ 简介 欢迎使用springboot vue会员营销系统。本项目包含会员储值卡、套餐卡、计…

HarmonyOS NEXT 技术实践-基于意图框架服务实现智能分发

在智能设备的交互中&#xff0c;如何准确理解并及时响应用户需求&#xff0c;成为提升用户体验的关键。HarmonyOS Next 的意图框架服务&#xff08;Intents Kit&#xff09;为这一目标提供了强大的技术支持。本文将通过一个项目实现的示例&#xff0c;展示如何使用意图框架服务…

sfnt-pingpong -测试网络性能和延迟的工具

sfnt-pingpong 是一个用于测试网络性能和延迟的工具&#xff0c;通常用于测量不同网络环境下的数据包传输性能、吞吐量、延迟等指标。 它通常是基于某种网络协议&#xff08;如 TCP&#xff09;执行“ping-pong”式的测试&#xff0c;即客户端和服务器之间相互发送数据包&…

前端下载文件的几种方式使用Blob下载文件

前端下载文件的几种方式 使用Blob下载文件 在前端下载文件是个很通用的需求&#xff0c;一般后端会提供下载的方式有两种&#xff1a; 1.直接返回文件的网络地址&#xff08;一般用在静态文件上&#xff0c;比如图片以及各种音视频资源等&#xff09; 2.返回文件流&#xff08;…

智能座舱进阶-应用框架层-Jetpack主要组件

Jetpack的分类 1. DataBinding&#xff1a;以声明方式将可观察数据绑定到界面元素&#xff0c;通常和ViewModel配合使用。 2. Lifecycle&#xff1a;用于管理Activity和Fragment的生命周期&#xff0c;可帮助开发者生成更易于维护的轻量级代码。 3. LiveData: 在底层数据库更…

知乎 PB 级别 TiDB 数据库集群管控实践

以下文章来源于知乎技术专栏 &#xff0c;作者代晓磊 导读 在现代企业中&#xff0c;数据库的运维管理至关重要&#xff0c;特别是面对分布式数据库的复杂性和大规模集群的挑战。作为一款兼容 MySQL 协议的分布式关系型数据库&#xff0c;TiDB 在高可用、高扩展性和强一致性方…

SpringBoot 自动装配原理及源码解析

目录 一、引言 二、什么是 Spring Boot 的自动装配 三、自动装配的核心注解解析 3.1 SpringBootApplication 注解 &#xff08;1&#xff09;SpringBootConfiguration&#xff1a; &#xff08;2&#xff09;EnableAutoConfiguration&#xff1a; &#xff08;3&#xf…

C++中的字符串实现

短字符串优化(SSO) 实现1 实现2 写时复制 #define _CRT_SECURE_NO_WARNINGS #include<iostream> #include<cstdio> #include<cstring> #include<cstring> using std::cout; using std::endl;// 引用计数存放的位置 // 1. 存放在栈上 --- 不行 // 2. 存…

Linux 基本使用和程序部署

1. Linux 环境搭建 1.1 环境搭建方式 主要有 4 种&#xff1a; 直接安装在物理机上。但是Linux桌面使用起来非常不友好&#xff0c;所以不建议。[不推荐]。使用虚拟机软件&#xff0c;将Linux搭建在虚拟机上。但是由于当前的虚拟机软件(如VMWare之类的)存在一些bug&#xff…