《动手深度学习》线性回归简洁实现实例


🎈 作者:Linux猿

🎈 简介:CSDN博客专家🏆,华为云享专家🏆,Linux、C/C++、云计算、物联网、面试、刷题、算法尽管咨询我,关注我,有问题私聊!

🎈 欢迎小伙伴们点赞👍、收藏⭐、留言💬


本文是《动手深度学习》线性回归简洁实现实例的实现和分析,主要对代码进行详细讲解,有问题欢迎在评论区讨论交流。

一、代码实现

实现代码如下所示。

import torch
from torch.utils import data
# d2l包是李沐老师等人开发的动手深度学习配套的包,
# 里面封装了很多有关与数据集定义,数据预处理,优化损失函数的包
from d2l import torch as d2l
# nn 是神经网络 Neural Network 的缩写,提供了一系列的模块和类,实现创建、训练、保存、恢复神经网络
from torch import nn'''
1. 生成数据集,共 1000 条
true_w 和 true_b 是临时变量用于生成数据集
生成 X, y :满足关系 y = Xw + b + noise
'''
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)'''
2. 构造循环读取数据集的迭代器
'''
def load_array(data_arrays, batch_size, is_train=True):  #@save# 构造一个 PyTorch 数据迭代器,对 tensor 进行打包,包装成 dataset。dataset = data.TensorDataset(*data_arrays)# 根据数据集构造一个迭代器return data.DataLoader(dataset, batch_size, shuffle=is_train)# 小批量数据
batch_size = 10
# 设置了一个数据读取的迭代器,每次读取 batch_size(10) 条
data_iter = load_array((features, labels), batch_size)'''
3. 设置全连接层
'''
'''
# nn.Linear(in_features, out_features, bias=True)
# in_features : 输入向量的列数
# out_features : 输出向量的列数
# bias = True 是否包含偏置
执行线性变换:Yn*o = Xn*i Wi*o + b
其中:W 和 b 模型需要学习的参数
在本例中:n = 10,i = 2, o = 1
'''
net = nn.Sequential(nn.Linear(2, 1))
# 设置权重 w 和 偏置 b
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)'''
4. 定义损失函数
'''
# 均方误差,是预测值与真实值之差的平方和的平均值
loss = nn.MSELoss()
# lr 学习率 learning rate
trainer = torch.optim.SGD(net.parameters(), lr=0.03)'''
4. 训练数据
'''
# 超参数 设置批次
num_epochs = 3
for epoch in range(num_epochs): # 进行 num_epochs 个迭代周期for X, y in data_iter:l = loss(net(X) ,y) # 计算损失,net(X) 计算预测值 y1,loss(y1, y) 计算预测值和真实值之间的差距trainer.zero_grad() # 将所有模型参数的梯度置为 0l.backward() # 求梯度,不使用从零实现中 l.sum.backward 的原因是损失计算中使用了平均的 gardtrainer.step() # 优化参数 w 和 bl = loss(net(features), labels)print(f'epoch {epoch + 1}, loss {l:f}')w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)

二、实现解析

针对实例中重要的函数解析如下。

2.1 Linear 函数

nn.Linear(in_features, out_features, bias=True)

神经网络的线性层,也成为全连接层,进行 Y = XW + b 的线性变换。

参数:

in_features : 输入向量的列数

out_features : 输出向量的列数

bias = True 是否包含偏置

in_features 和 out_features 是 W 的行和列。

执行线性变换:Yn*o = Xn*i Wi*o + b

其中:W 和 b 模型需要学习的参数

在本例中:n = 10,i = 2, o = 1。

2.2 Sequential 函数

一个序列容器,用于搭建神经网络的模块,按照传入构造器的顺序添加到 nn.Sequential() 容器中。按照内部模块的顺序自动依次计算并输出结果。

2.3 MSELoss 函数

均方误差,是预测值与真实值之差的平方和的平均值,即:

2.4 TensorDataset 函数

用来对 tensor 进行打包,就好像 python 中的 zip 功能。该类通过每一个 tensor 的第一个维度进行索引。因此,该类中的 tensor 第一维度必须相等. 另外:TensorDataset 中的参数必须是 tensor。可以参考如下例子:

import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader# len = 12
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
# len = 12
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
# 将 tensor a 和 b 压缩在一起
train_ids = TensorDataset(a, b)
# 输出
for x, y in train_ids:print(x, y)

输出如下:

tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)

2.5 DataLoader 函数

DataLoader 是用来包装所使用的数据,每次抛出一批数据,下面来看一个例子。

import torch
from torch.utils import data# len = 12
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
# len = 12
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
# 将 tensor a 和 b 压缩在一起
train_ids = data.TensorDataset(a, b)
# 输出
#for x, y in train_ids:
#    print(x, y)BATCH_SIZE = 4
loader = data.DataLoader(dataset=train_ids,batch_size=BATCH_SIZE, # 每次取 BATCH_SIZE=4 个数据shuffle=False, # 不打乱顺序,便于查看num_workers=0)for x, y in loader:print(x, y)break

输出如下:

tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9],[1, 2, 3]]) tensor([44, 55, 66, 44])

 如上所示,输出第一个 BATCH_SIZE=4。

2.6 zero_grad 函数

trainer.zero_grad() 是用来清空模型参数梯度的函数,它将模型参数的梯度缓存设置为 0。在进行反向传播时,梯度会累加,如果不清空梯度,会影响后续的梯度计算。

2.7 backward 函数

对计算图进行梯度计算,求解计算图中所有节点的梯度。

2.8 step 函数

根据 backward 函数计算出的梯度进行参数更新。

参考链接:

线性回归的实现学习_data.tensordataset_带刺的厚崽的博客-CSDN博客

nn.Sequential()_一颗磐石的博客-CSDN博客

【Pytorch基础】torch.nn.MSELoss损失函数_一穷二白到年薪百万的博客-CSDN博客

pytorch之trainer.zero_grad()_FibonacciCode的博客-CSDN博客

清空模型参数梯度的函数 - 知乎

pytorch中backward()函数详解_backward函数_Camlin_Z的博客-CSDN博客

理解Pytorch的loss.backward()和optimizer.step() - 知乎


🎈 感觉有帮助记得「一键三连支持下哦!有问题可在评论区留言💬,感谢大家的一路支持!🤞猿哥将持续输出「优质文章回馈大家!🤞🌹🌹🌹🌹🌹🌹🤞


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/175192.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

antv/g6使用教程及图配置

介绍 G6 是一款由蚂蚁金服 AntV 团队开发的 JavaScript 图形引擎,用于构建各种交互式可视化图形,包括但不限于图表、网络拓扑图、关系图、流程图等。无论是数据分析、决策支持,还是信息可视化,G6 都是一个强大的工具。 以下是 G…

蓝牙 - BLE SPP实现举例 (Bluecode Protocol Stack)

这里以一个无线扫描枪设备为例,这个设备会通过蓝牙通讯协议连接一个底座,使用的是BLE SPP进行通讯。 扫描枪用来扫条码,解析出条码信息后,将数据通过无线传输给底座,底座再通过USB将数据传送给电脑。 底座是Central d…

一篇博客理解Recyclerview的使用

从Android 5.0开始,谷歌公司推出了RecylerView控件,当看到RecylerView这个新控件的时候,大部分人会首先发出一个疑问,recylerview是什么?为什么会有recylerview也就是说recylerview的优点是什么?recylerview怎么用&…

C#,数值计算——分类与推理Svmpolykernel的计算方法与源程序

1 文本格式 using System; namespace Legalsoft.Truffer { public class Svmpolykernel : Svmgenkernel { public int n { get; set; } public double a { get; set; } public double b { get; set; } public double d { get; set; …

故障诊断模型 | Maltab实现LSTM长短期记忆神经网络故障诊断

文章目录 效果一览文章概述模型描述源码设计参考资料效果一览 文章概述 故障诊断模型 | Maltab实现LSTM长短期记忆神经网络故障诊断 模型描述 长短记忆神经网络——通常称作LSTM,是一种特殊的RNN,能够学习长的依赖关系。 他们由Hochreiter&Schmidhuber引入,并被许多人进行了…

美妆造型教培服务预约小程序的作用是什么

美业市场规模很高,细分类目更是比较广,而美妆造型就是其中的一类,从业者也比较多,除了学校科目外,美妆造型教培机构也有生意。 对机构来说主要目的是拓客引流-转化及赋能,而想要完善路径却是不太容易&…

机器人的触发条件有什么区别,如何巧妙的使用

简介​ 维格机器人触发条件,分为3个,分别是: 有新表单提交时、有记录满足条件时、有新的记录创建时 。 看似3个,其实是能够满足我们非常多的使用场景。 本篇将先介绍3个条件的触发条件,然后再列举一些复杂的触发条件如何用现有的触发条件来满足 注意: 维格机器人所有的…

剖析C语言中的自定义类型(结构体、枚举常量、联合)兼内存对齐与位段

目录 前言 一、结构体 1. 基本定义与使用 2. 内存对齐 3. 自定义对齐数 4. 函数传参 二、位段 三、枚举 四、联合(共同体) 总结​​​​​​​ 前言 本篇博客将介绍C语言中的结构体(struct)、枚举(enum&…

【Redis】高并发分布式结构服务器

文章目录 服务端高并发分布式结构名词基本概念评价指标1.单机架构缺点 2.应用数据分离架构应用服务集群架构读写分离/主从分离架构引入缓存-冷热分离架构分库分表(垂直分库)业务拆分⸺微服务 总结 服务端高并发分布式结构 名词基本概念 应⽤&#xff0…

【错误解决方案】ModuleNotFoundError: No module named ‘ngboost‘

1. 错误提示 在python程序,尝试导入一个名为ngboost的模块,但Python提示找不到这个模块。 错误提示:ModuleNotFoundError: No module named ‘ngboost‘ 2. 解决方案 出现上述问题,可能是因为你还没有安装这个模块,…

了解Docker的文件系统网络模式的基本原理

Docker文件系统 Linux基础 一个Linux系统运行需要两个文件系统: bootfs rbootfs bootfs(boot file system) bootfs 即引导文件系统,Linux内核启动时使用的文件系统。对于同样的内核版本的不同Lunx发行版本,其boot…

百度富文本上传图片后样式崩塌

🔥博客主页: 破浪前进 🔖系列专栏: Vue、React、PHP ❤️感谢大家点赞👍收藏⭐评论✍️ 问题描述:上传图片后,图片会变得很大,当点击的时候更是会顶开整个的容器的高跟宽 原因&#…

C++之类型转换

目录 一、C语言中的类型转换 二、C的强制类型转换 1、 static_cast 2、reinterpret_cast 3、 const_cast 4、dynamic_cast 一、C语言中的类型转换 在C语言中,如果赋值运算符左右两侧类型不同,或者形参与实参类型不匹配,或者返回值类型…

idea的设置

1.设置搜索encoding,所有编码都给换为utf-8 安装插件 eval-reset插件 https://www.yuque.com/huanlema-pjnah/okuh3c/lvaoxt#m1pdA 设置活动模板,idea有两种方式集成tomcat,一种是右上角config配置本地tomcat,一种是插件,如果使用插件集成,则在maven,pom.xml里面加上tomcat…

openGauss学习笔记-110 openGauss 数据库管理-管理用户及权限-Schema

文章目录 openGauss学习笔记-110 openGauss 数据库管理-管理用户及权限-Schema110.1 创建、修改和删除Schema110.2 搜索路径 openGauss学习笔记-110 openGauss 数据库管理-管理用户及权限-Schema Schema又称作模式。通过管理Schema,允许多个用户使用同一数据库而不…

XML教学视频(黑马程序员精讲 XML 知识!)笔记

第一章XML概述 1.1认识XML XML数据格式: 不是html但又和html有点相似 XML数据格式最主要的功能就是数据传输(一个服务器到另一个服务器,一个网站到另一个网站)配置文件、储存数据当做小型数据可使用、规范数据格式让数据具有结…

多线程---synchronized特性+原理

文章目录 synchronized特性synchronized原理锁升级/锁膨胀锁消除锁粗化 synchronized特性 互斥 当某个线程执行到某个对象的synchronized中时,其他线程如果也执行到同一个对象的synchronized就会阻塞等待。 进入synchronized修饰的代码块相当于加锁 退出synchronize…

基于Qt 文本读写(QFile/QTextStream/QDataStream)实现

​ 在很多时候我们需要读写文本文件进行读写,比如写个 Mp3 音乐播放器需要读 Mp3 歌词里的文本,比如修改了一个 txt 文件后保存,就需要对这个文件进行读写操作。本章介绍简单的文本文件读写,内容精简,让大家了解文本读写的基本操作。 ## QFile 读写文本 QFile 类提供了读…

一个注解,实现数据脱敏-plus版

shigen坚持日更的博客写手,擅长Java、python、vue、shell等编程语言和各种应用程序、脚本的开发。坚持记录和分享从业两年以来的技术积累和思考,不断沉淀和成长。 当看到这个文章名的时候,是不是很熟悉,是的shigen之前发表了一个这…

【PC】特殊空投-2023年10月

亲爱的玩家朋友们,大家好! 10月特殊空投活动来袭。本月我们也准备了超多活动等着大家来体验。快来完成任务获得丰富的奖励吧!签到活动,每周一次的PUBG空投节,还有可以领取PGC2023免费投票劵的活动等着大家!…