【pytorch-04】:线性回归案例(手动构建)

文章目录

  • 1 构建数据集
  • 2 构建假设函数
  • 3 损失函数
  • 4 优化方法
  • 5 训练函数
  • 6.总结

1 构建数据集

为什么构建数据加载器?
在进行训练的时候都是采用的不是全部的数据,而是采用一个batch_size的数据进行训练,每次向模型当中送入batch_size数据,所以需要一个能够按照batch_size生成数据的数据加载器

import torch
from sklearn.datasets import make_regression # 构造数据加载器
import matplotlib.pyplot as plt
import random# 1. 构建数据集def create_dataset():# n_samples - 100 样本数量# n_feature - 设置特征个数# noise - 设置噪声,可以进行调整,出现波动# coef = True - 需要权重# bias = 14.5 - 偏置# random_state = 0 - 能够复现数据x, y, coef = make_regression(n_samples=100,n_features=1,noise=10,coef=True,bias=14.5,random_state=0)# 将构建数据转换为张量类型x = torch.tensor(x)y = torch.tensor(y)return x, y, coef# 2. 构建数据加载器, 按照一定的数据量产生数据
def data_loader(x, y, batch_size):# 2.1 计算下样本的数量data_len = len(y)# 2.2 构建数据索引data_index = list(range(data_len))# 2.3 数据集打乱random.shuffle(data_index)# 2.4 计算总的batch数量batch_number = data_len // batch_size# 遍历一个batchfor idx in range(batch_number):start = idx * batch_sizeend = start + batch_sizebatch_train_x = x[start: end]batch_train_y = y[start: end]yield batch_train_x, batch_train_ydef test01():x, y = create_dataset()plt.scatter(x, y)plt.show()for x, y in data_loader(x, y, batch_size=10):print(y)

2 构建假设函数

# 假设函数 y = wx + b
# 只有一个特征值,所以w初始值为一个标量,b的初始值也是一个标量
w = torch.tensor(0.1, requires_grad=True, dtype=torch.float64)
b = torch.tensor(0.0, requires_grad=True, dtype=torch.float64)def linear_regression(x):return w * x + b

3 损失函数

# 损失函数,采用MAS作为损失函数
def square_loss(y_pred, y_true):return (y_pred - y_true) ** 2

4 优化方法

# 优化方法 - 采用梯度下降法,进行权重参数的更新
def sgd(lr=1e-2):# 此处除以 16 使用的是批次样本的平均梯度值w.data = w.data - lr * w.grad.data / 16b.data = b.data - lr * b.grad.data / 16

5 训练函数

import torch
from sklearn.datasets import make_regression # 构造数据集
import matplotlib.pyplot as plt
import random# 1. 构建数据集def create_dataset():# n_samples - 100 样本数量# n_feature - 设置特征个数# noise - 设置噪声,可以进行调整,出现波动# coef = True - 需要权重# bias = 14.5 - 偏置# random_state = 0 - 能够复现数据x, y, coef = make_regression(n_samples=100,n_features=1,noise=10,coef=True,bias=14.5,random_state=0)# 将构建数据转换为张量类型x = torch.tensor(x)y = torch.tensor(y)return x, y, coef# 2. 构建数据加载器, 按照一定的数据量产生数据
def data_loader(x, y, batch_size):# 2.1 计算下样本的数量data_len = len(y)# 2.2 构建数据索引data_index = list(range(data_len))# 2.3 数据集打乱random.shuffle(data_index)# 2.4 计算总的batch数量batch_number = data_len // batch_size# 遍历一个batchfor idx in range(batch_number):start = idx * batch_sizeend = start + batch_sizebatch_train_x = x[start: end]batch_train_y = y[start: end]yield batch_train_x, batch_train_ydef test01():x, y = create_dataset()plt.scatter(x, y)plt.show()for x, y in data_loader(x, y, batch_size=10):print(y)# 假设函数 y = wx + b
# 只有一个特征值,所以w初始值为一个标量,b的初始值也是一个标量
w = torch.tensor(0.1, requires_grad=True, dtype=torch.float64)
b = torch.tensor(0.0, requires_grad=True, dtype=torch.float64)def linear_regression(x):return w * x + b# 损失函数,采用MAS作为损失函数
def square_loss(y_pred, y_true):return (y_pred - y_true) ** 2# 优化方法 - 采用梯度下降法,进行权重参数的更新
def sgd(lr=1e-2):# 此处除以 16 使用的是批次样本的平均梯度值w.data = w.data - lr * w.grad.data / 16b.data = b.data - lr * b.grad.data / 16def train():# 1.加载数据集x, y, coef = create_dataset()# 2.定义训练参数# epoch - 所有样本在模型中训练一遍称为一个epochepochs = 100learning_rate = 1e-2# 3.存储训练信息epoch_loss = [] # 记录每一个epoch的损失total_loss = 0.0train_samples = 0 # 统计训练的样本的数量for _ in range(epochs):# 4. 进行训练# 获取到数据,这里对于data_loader进行遍历,会一直将所有样本都取出# for循环执行一次就是一个epochfor train_x, train_y in data_loader(x, y, batch_size=16):# 1. 将训练样本送入模型进行预测y_pred = linear_regression(train_x)  # shape[16,1]# 2. 计算预测值和真实值的平方损失# reshape(-1,1) 将形状进行多行一列,让y_pred和train_y形状一样# print(len(y_pred),y_pred.size()) # 16 torch.Size([16, 1])# print(len(train_y),train_y.size()) # 16 torch.Size([16])loss = square_loss(y_pred, train_y.reshape(-1, 1)).sum()total_loss += loss.item()train_samples += len(train_y)# 3. 梯度清零# 对w和b进行求导,所以需要对于w和b进行梯度清零if w.grad is not None:w.grad.data.zero_()if b.grad is not None:b.grad.data.zero_()# 4. 自动微分loss.backward()# 5. 参数更新sgd(learning_rate)print('loss: %.10f' % (total_loss / train_samples))# 记录每一个 epoch 的平均损失epoch_loss.append(total_loss / train_samples)# 绘制数据集散点图plt.rcParams['font.sans-serif'] = ['SimHei']plt.scatter(x, y)# 绘制拟合的直线x = torch.linspace(x.min(), x.max(), 1000)y1 = torch.tensor([v * w + b for v in x])y2 = torch.tensor([v * coef + 14.5 for v in x])plt.plot(x, y1, label='训练')plt.plot(x, y2, label='真实')plt.grid()plt.legend()plt.show()# 打印损失变化曲线plt.plot(range(epochs), epoch_loss)plt.grid()plt.title('损失变化曲线')plt.show()if __name__ == '__main__':train()

在这里插入图片描述
在这里插入图片描述

6.总结

  • 采用sklearn中的make_regression()可以构造回归数据

指定:
n_samples - 样本个数
n_features - 特征个数
noise - 噪声
coef - 斜率(生成数据参考)
bias - 偏置
random_state - 指定随机种子,让实验重现

  • 采用data_loader(数据加载器),每次拿总样本的一部分进行训练

数据加载器需要知道:
x - 特征,可以计算得到特征的个数和特征的索引
y - 目标值
batch_size - 每次产生多少个,
batch数量计算:用总的样本个数//batch_size = batch
data_loader采用生成器,对batch数量进行循环,由多少个batch就循环多少次,每次产生一个batch_size的数据
data_loader是一个生成器,本质上也是一个迭代器,所以对data_loader进行一次完整的for循环,就会得到全部的数据

  • 假设函数

首先有假设函数的形式 - y = wx + b
x 如果是一个向量,有多个因变量
y 也是一个向量
将y w x b 都转换为tensor的形式,进行计算
y 就是函数的返回值

  • 训练函数

1.先加载数据集

2.获取到数据之后定义训练参数

  • epochs = 数据集中训练样本的个数,将数据集中的样本都训练一遍,称作一个epoch
  • 每个epoch的损失都要进行收集,所以定义一个epoch_loss列表
  • 进行梯度迭代的时候指定超参数学习率

3.存储信息定义参数

  • loss_epoch
  • total_loss
  • train_samples

4.分多个epoch,多个batch进行训练

  • 获取数据
  • 训练模型,得到y_pred
  • 计算损失函数
  • 进行梯度计算,反向传播
  • 进行参数更新

5.绘图展示

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/476097.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

实验室管理效率提升:Spring Boot技术的力量

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统,它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等,非常…

STM32H7开发笔记(2)——H7外设之多路定时器中断

STM32H7开发笔记(2)——H7外设之多路定时器中断 文章目录 STM32H7开发笔记(2)——H7外设之多路定时器中断0.引言1.CubeMX配置2.软件编写 0.引言 本文PC端采用Win11STM32CubeMX4.1.0.0Keil5.24.2的配置,硬件使用STM32H…

springboot基于微信小程序的旧衣回收系统的设计与实现

摘 要 微信小程序的旧衣回收系统是一种专为环保生活设计的应用软件。这款小程序的主要功能包括:系统首页、个人中心、用户管理、回收人员管理、旧衣服分类管理、旧衣信息管理、回收预约管理、回收派单管理、回收订单管理、积分商品管理、积分兑换管理、管理员管理、…

路由缓存后跳转到新路由时,上一路由中的tip信息框不销毁问题解决

上一路由tip信息框不销毁问题解决 路由缓存篇问题描述及截图解决思路关键代码 路由缓存篇 传送门 问题描述及截图 路由缓存后跳转新路由时,上一个路由的tip信息框没销毁。 解决思路 在全局路由守卫中获取DOM元素,通过css去控制 关键代码 修改文…

40分钟学 Go 语言高并发:并发下载器开发实战教程

并发下载器开发实战教程 一、系统设计概述 1.1 功能需求表 功能模块描述技术要点分片下载将大文件分成多个小块并发下载goroutine池、分片算法断点续传支持下载中断后继续下载文件指针定位、临时文件管理进度显示实时显示下载进度和速度进度计算、速度统计错误处理处理下载过…

【前端】JavaScript中的indexOf()方法详解:基础概念与背后的应用思路

博客主页: [小ᶻZ࿆] 本文专栏: 前端 文章目录 💯前言💯什么是indexOf()方法?参数解释返回值示例 💯indexOf() 方法的工作原理💯特殊案例:undefined 的处理示例代码图示解释 💯i…

HarmonyOS4+NEXT星河版入门与项目实战------Button组件

文章目录 1、控件图解2、案例实现1、代码实现2、代码解释3、运行效果4、总结1、控件图解 这里我们用一张完整的图来汇整 Button 的用法格式、属性和事件,如下所示: 按钮默认类型就是胶囊类型。 2、案例实现 这里我们实现一个根据放大和缩小按钮来改变图片大小的功能。 功…

WPF窗体基本知识-笔记-命名空间

窗体程序关闭方式 命名空间:可以理解命名空间的作用为引用下面的控件对象 给控件命名:一般都用x:Name,也可以用Name但是有的控件不支持 布局控件(容器)的类型 布局控件继承于Panel的控件,其中下面的border不是布局控件,panel是抽象类 在重叠的情况下,Zindex值越大的就在上面 Z…

【Qt】QComboBox设置默认显示为空

需求 使用QComboBox,遇到一个小需求是,想要设置未点击出下拉列表时,内容显示为空。并且不想在下拉列表中添加一个空条目。 实现 使用setPlaceholderText()接口。我们先来看下帮助文档: 这里说的是,placeholderText是…

音频信号采集前端电路分析

音频信号采集前端电路 一、实验要求 要求设计一个声音采集系统 信号幅度:0.1mVpp到1Vpp 信号频率:100Hz到16KHz 搭建一个带通滤波器,滤除高频和低频部分 ADC采用套件中的AD7920,转换率设定为96Ksps ;96*161536 …

[开源]1.2K star!中后台方向的低代码可视化平台,超赞!

大家好,我是JavaCodexPro! “时间就是金钱,效率就是生命”,快速搭建高质量中后台的低代码可视化搭建平台尤为重要! 今天JavaCodexPro给大家分享一款超赞的低代码可视化搭建平台 - Marsview ,旨在简化开发…

Leetcode 完全二叉树的节点个数

不讲武德的解法 java 实现 class Solution {public int countNodes(TreeNode root) {if(root null) return 0;return countNodes(root.left) countNodes(root.right) 1;} }根据完全二叉树和满二叉树的性质做 class Solution {public int countNodes(TreeNode root) {if (r…

基于CVE安全公告号,全面修复麒麟ARM系统OpenSSH漏洞

前言:负责的其中一个从0开始搭建的某生产项目上线前需要做青藤安全扫描,过了后才允许上线,该项目从操作系统、中间件、数据库、容器等全国产信创化,公司公告为CVE安全公告号,而修复漏洞的责任归我,需要根据…

【每日 C/C++ 问题】

一、什么是 C 中的初始化列表?它的作用是什么? 作用:c提供了初始化列表语法,用来初始化属性 语法:构造函数():属性1(值1),属性2(值…

【前端知识】Javascript前端框架Vue入门

前端框架VUE入门 概述基础语法介绍组件特性组件注册Props 属性声明事件组件 v-model(双向绑定)插槽Slots内容与出口 组件生命周期样式文件使用1. 直接在<style>标签中写CSS2. 引入外部CSS文件3. 使用CSS预处理器4. 在main.js中全局引入CSS文件5. 使用CSS Modules6. 使用P…

【代码pycharm】动手学深度学习v2-04 数据操作 + 数据预处理

数据操作 数据预处理 1.数据操作运行结果 2.数据预处理实现运行结果 第四课链接 1.数据操作 import torch # 张量的创建 x1 torch.arange(12) print(1.有12个元素的张量&#xff1a;\n,x1) print(2.张量的形状&#xff1a;\n,x1.shape) print(3.张量中元素的总数&#xff1…

《Python浪漫的烟花表白特效》

一、背景介绍 烟花象征着浪漫与激情&#xff0c;将它与表白结合在一起&#xff0c;会创造出别具一格的惊喜效果。使用Python的turtle模块&#xff0c;我们可以轻松绘制出动态的烟花特效&#xff0c;再配合文字表白&#xff0c;打造一段专属的浪漫体验。 接下来&#xff0c;让…

CSS中Flex布局应用实践总结

① 两端对齐 比如 要求ul下的li每行四个&#xff0c;中间间隔但是需要两段对齐&#xff0c;如下图所示&#xff1a; 这是除了基本的flex布局外&#xff0c;还需要用到:nth-of-type伪类来控制每行第一个与第四个的padding。 .hl_list{width: 100%;display: flex;align-items…

STM32与CS创世SD NAND(贴片SD卡)结合完成FATFS文件系统移植与测试是一个涉及硬件与软件综合应用的复杂过程

一、前言 在STM32项目开发中&#xff0c;经常会用到存储芯片存储数据。 比如&#xff1a;关机时保存机器运行过程中的状态数据&#xff0c;上电再从存储芯片里读取数据恢复&#xff1b;在存储芯片里也会存放很多资源文件。比如&#xff0c;开机音乐&#xff0c;界面上的菜单图…

Matlab实现海鸥优化算法优化随机森林算法模型 (SOA-RF)(附源码)

目录 1.内容介绍 2.部分代码 3.实验结果 4.内容获取 1内容介绍 海鸥优化算法&#xff08;Seagull Optimization Algorithm, SOA&#xff09;是一种基于海鸥群体行为的新型元启发式优化算法。SOA通过模拟海鸥在寻找食物时的飞行模式和集体行动来探索解空间&#xff0c;寻找最优…