【李沐深度学习笔记】线性回归的从零开始实现

课程地址和说明

线性回归的从零开始实现p3
本系列文章是我学习李沐老师深度学习系列课程的学习笔记,可能会对李沐老师上课没讲到的进行补充。

线性回归的从零开始实现

不使用任何深度学习框架提供的计算功能,只使用PyTorch提供的Tensor来实现线性回归
我们将从零开始实现整个方法,包括数据流水线、模型、损失函数和小批量随机梯度下降优化器。
首先我们要安装李沐老师提供的一个d2l包,安装方法见视频评论区,导入必要的库

%matplotlib inline
import random
import torch
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" # jupyter notebook运行d2l内核挂掉解决方法
from d2l import torch as d2l
# 生成训练集
def synthetic_data(w, b, num_example):# 生成y= Xw+b+噪声X = torch.normal(0, 1, (num_example, len(w)),dtype=torch.float32) # 生成均值为0方差为1的随机数,行数是n个样本数,列数是len(w)print(X)y = torch.matmul(X, w)+b # y=Xw+by += torch.normal(0, 0.01, y.shape) # 生成均值为0,方差为1的随机噪音,形状与y相同return X, y.reshape((-1,1)) # 把X和y做成一个列向量返回,-1是说不指定行数,1是指定为一列,即列向量true_w = torch.tensor([2,-3,4],dtype=torch.float32)
true_b = 4.2
# 生成特征和标签(即训练集)
features, labels = synthetic_data(true_w, true_b, 1000) # 样本数为1000
# 打印训练样本
print("特征[0]:",features[0],"\n标签[0]:",labels[0])

运行结果:
特征[0]: tensor([ 1.3343, -1.3974, -1.3717])
标签[0]: tensor([5.5832])

  • 画一下生成的训练集
# 画一下生成的训练集
d2l.set_figsize()
# 将第二个特征(下标为1)和最终的真实值(标签)画在二维图像中
d2l.plt.scatter(features[:,1].detach().numpy(), # 将特征的tensor从计算图中detach出来才能转到numpy对象labels.detach().numpy(),1);

运行结果:

  • 定义一个data_iter函数,该函数接受批量大小、特征矩阵和标签向量作为输入,生成大小为batch_size的小批量
# 数据迭代器
# 这个函数的作用是先打乱样本数据的下标,然后每隔1个batch_size取一次下标
# 并将当前样本的下标以及其后面的batch_size-1个样本的下标存入一个张量之中
# 返回这个存储下标的张量中的各下标对应的特征值和标签
def data_iter(batch_size, features, labels):num_examples = len(features) #样本数量即是特征矩阵的行数,特征矩阵的每一列对应着不同的特征indices = list(range(num_examples)) # 生成每一个样本对应的下标(从0开始)# 将样本的下标打乱,以做到让这些样本是随机读取的,没有特定的顺序# 这样就能以随机的顺序访问每一个样本random.shuffle(indices)# 从0开始到样本数量(左闭右开),每一次隔batch_size大小访问for i in range(0, num_examples, batch_size):# batch_indices每隔batch_size取一次下表存入张量中# 使用min(i+batch_size,num_examples)是因为有可能下标会超出范围# 所以要取每隔batch_size后的下标与张量向量的最大下标中最小的一个作为左闭右开区间的右侧值batch_indices = torch.tensor(indices[i:min(i+batch_size,num_examples)])yield features[batch_indices], labels[batch_indices] # yield返回单个返回值并记住返回的位置,下一次迭代从这个位置开始(方便迭代)# 假定batch_size 为10 
batch_size = 10for X,y in data_iter(batch_size, features, labels):print(X,'\n',y)break # 加个break先迭代一次看看,应该会显示10个样本的特征和其对应的标签

运行结果:
tensor([[-0.1669, 0.6399, 1.8548],
[-0.9414, -1.1511, -0.2822],
[ 0.3021, -0.5745, -0.8592],
[-0.0539, -0.8632, -1.8610],
[ 1.4092, 2.3873, -1.7796],
[-0.0033, -0.2928, -1.1324],
[ 0.4779, 1.7110, -0.0561],
[-0.3564, -0.5307, -0.2051],
[ 0.9909, -0.0604, -0.2683],
[ 1.6115, 0.9627, 0.8231]])
tensor([[ 9.3663],
[ 4.6398],
[ 3.0753],
[-0.7490],
[-7.2540],
[ 0.5359],
[-0.1931],
[ 4.2597],
[ 5.2664],
[ 7.8300]])

  • 做完上述步骤,数据就搞定了,接下来定义初始化模型参数
# 定义初始化模型参数
# 需要计算梯度requires_grad=True
w = torch.normal(0, 0.01, size=(3,1), requires_grad=True) # w初始化为一个均值为0,方差为0.01的列数为3(3个特征,视频里那个两个特征有问题)的正态分布的向量
# 偏差直接置为0,标量,所以第一个参数是1
b = torch.zeros(1, requires_grad=True)
  • 定义模型 ,损失函数
# 定义模型
def linreg(X, w, b):return torch.matmul(X,w)+b
# 定义损失函数
def squared_loss(y_hat, y):# 返回均方损失return (y_hat - y.reshape(y_hat.shape))**2/2
  • 定义优化算法
# 定义优化算法
# 小批量随机梯度下降
# lr是学习率, param是前一个点的向量(随机初始化)
def sgd(params, lr, batch_size):# with语法的好处在于,它确保资源的正确释放,即使在发生异常的情况下也能够被处理# 更新的时候不要有梯度的计算with torch.no_grad():for param in params:# 按批量梯度下降公式写(要求均值)param -= lr * param.grad / batch_size# 不要累积梯度,将梯度设置为0param.grad.zero_()
  • 训练过程
# 训练过程
# 指定超参数
lr = 0.03 # 学习率
num_epochs = 3 # num_epochs指的是把整个数据扫描几遍(目前设置为3)
net = linreg # 网络:线性回归
loss = squared_loss # 损失函数:均方损失# 每次扫epoch次数据
for epoch in range(num_epochs):# 遍历数据for X, y in data_iter(batch_size, features, labels):# 预测值是net(X,w,b),对应的是线性回归,真实值是y,l是求得其损失l = loss(net(X,w,b),y) # 对均方损失的向量求和,求和后计算梯度(对应均方损失函数的公式)l.sum().backward()# 得到梯度后,使用小批量随机梯度下降法对梯度进行更新,传入[w,b]参数sgd([w,b], lr, batch_size)# 用优化好的参数,对所有的样本求损失,看作验证集,故不需要梯度来更新参数with torch.no_grad():train_l = loss(net(features, w, b), labels)print(f'epoch{epoch + 1}, loss{float(train_l.mean()):f}')

运行结果:
epoch1, loss0.045511
epoch2, loss0.000137
epoch3, loss0.000050
可以看到,迭代次数多了,损失函数越小

  • 比较真实参数和通过训练学到的参数来评估训练的成功程度
print ( f'w的估计误差(真实值与预测值误差):{true_w-w.reshape(true_w.shape)} ' )
print ( f'b的估计误差(真实值与预测值误差):{true_b-b}')

运行结果:

w的估计误差(真实值与预测值误差):tensor([-0.0003, -0.0005, 0.0003], grad_fn=<SubBackward0>)
b的估计误差(真实值与预测值误差):tensor([0.0004], grad_fn=<RsubBackward1>)
这个误差很小了

  • 注意
    学习率不能太大太小,如果学习率特别小,跑很多次才能将损失减小,学习率太大会震荡,会出现无穷大的值
    学习率太小的情况

    学习率太大的情况

  • 直接打印w和b就得到了你预测模型的参数,然后对比一下真实值就知道你和真实的线性回归相差多少了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/142745.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

合肥对新通过(CMMI)五级、四级、三级认证的软件企业,对新通过信息技术服务标准(ITSS)认证的软件企业,给予最高50万奖励

合肥市加快软件产业发展 推进软件名城创建若干政策实施细则 为贯彻落实《合肥市人民政府办公室关于印发合肥市加快软件产业发展推进软件名城创建若干政策的通知》&#xff08;合政办〔2023〕9号&#xff09;文件精神&#xff0c;规范政策资金管理&#xff0c;制定本实施细则。…

linux内网渗透

一、信息收集 主机发现&#xff1a; nmap -sP 192.168.16.0/24 端口探测 masscan -p 1-65535 192.168.16.168 --rate1000 开放端口如下 nmap端口详细信息获取 nmap -sC -p 8888,3306,888,21,80 -A 192.168.16.168 -oA ddd4-port目录扫描 gobuster dir…

同创永益CNBR平台——云原生时代下的系统稳定器

随着各行业数字化的快速发展&#xff0c;企业的业务运作、经营管理越来越依赖于云原生系统的可靠运行。信息系统服务的连续性, 业务数据的完整性、正确性、有效性会直接关系到企业的生产、经营与决策活动。一旦因自然灾害、设备故障或人为因素等引起信息数据丢失和云原生业务处…

TextSniper for Mac: 革新您的文本识别体验

你是否曾经需要从图片或扫描文档中提取文本&#xff0c;却苦于没有合适的工具&#xff1f;那么&#xff0c;TextSniper for Mac将是你的完美解决方案。这款文本识别工具将彻底改变你处理图像和扫描文件的方式&#xff0c;让你更快速、更高效地完成任务。 TextSniper for Mac 是…

Apache Hive安装部署详细图文教程

目录 一、Apache Hive 元数据 1.1 Hive Metadata 1.2 Hive Metastore 二、Metastore 三种配置方式 ​2.1 内嵌模式 ​2.2 本地模式 ​2.3 远程模式 ​三、Hive 部署实战 3.1 安装前准备 3.2 Hadoop 与 Hive 整合 3.3 远程模式安装 3.3.1 安装 MySQL 3.3.2 …

基于Matlab求解2023华为杯研究生数学建模竞赛E题——出血性脑卒中临床智能诊疗建模实现步骤(附上源码+数据)

文章目录&#xff0c;源码见文末下载 背景介绍准备工作&#xff1a;处理数据第一题&#xff1a;血肿扩张风险相关因素探索建模a&#xff09;问题b&#xff09;问题 第二题&#xff1a; 血肿周围水肿的发生及进展建模&#xff0c;并探索治疗干预和水肿进展的关联关系a&#xff0…

【yolox训练过程中遇到的问题集合】

这里写目录标题 深度学习遇到的一系列bugVScode无法激活conda1.vscode加载web 视图报错2.CUDA out of memory3.voc2007数据集中的txt文件4.object has no attribute ‘cache‘5.KeyError:model6.No module named loguru7.Python AttributeError: module ‘distutils‘ has no a…

TensorFlow入门(五、指定GPU运算)

一般情况下,下载的TensorFlow版本如果是GPU版本,在运行过程中TensorFlow能自动检测。如果检测到GPU,TensorFlow会默认利用找到的第一个GPU来执行操作。如果机器上有超过一个可用的GPU,除第一个之外的其他GPU默认是不参与计算的。如果想让TensorFlow使用这些GPU执行操作,需要将运…

R语言学习笔记

R语言学习笔记 一.准备环境二.认识控制台三.R包四.数据结构1.向量Vector1.1创建向量1.2访问向量中的数据1.3向量的循环补齐 2.矩阵matrix2.1创建矩阵2.2访问矩阵中的数据 3数组Array3.1创建数组3.2访问数组中的数据 4.数据框Dataframe4.1创建数据框4.2访问数据框中的数据 5因子…

大数据Flink(八十六):DML:Group 聚合和Over 聚合

文章目录 DML:Group 聚合和Over 聚合 一、DML:Group 聚合

【Git】Deepin提示git remote: HTTP Basic: Access denied 错误解决办法

git remote: HTTP Basic: Access denied 错误解决办法 1.提交代码的时候提示2. 原因3.解决方案 1.提交代码的时候提示 git remote: HTTP Basic: Access denied 错误解决办法 2. 原因 本地git配置的用户名、密码与gitlabs上注册的用户名、密码不一致。 3.解决方案 如果账号…

数据结构 - 泛型

目录 前言 1. 什么是泛型? 2. 为什么需要泛型? 引入泛型之前 引入泛型之后 3.泛型类 4.泛型的界限 1.上下界 2.通配符 前言 今天给大家介绍一下泛型的使用 1. 什么是泛型? 一般的类和方法&#xff0c;只能使用具体的类型: 要么是基本类型&#xff0c;要么是自定义…

Java高级应用——异常处理

文章目录 异常处理概念Java异常体系Error 和 Exception编译时异常和运行时异常Java异常处理的方式 异常处理 概念 异常处理是在程序执行过程中遇到错误或异常情况时的一种机制&#xff0c;它允许程序在错误发生时进行适当的处理&#xff0c;而不会导致程序崩溃或产生不可预测…

【办公自动化】用Python将PDF文件转存为图片(文末送书)

&#x1f935;‍♂️ 个人主页&#xff1a;艾派森的个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持&#xff0c;我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f4…

【嵌入式】使用嵌入式芯片唯一ID进行程序加密实现

目录 一 背景说明 二 原理介绍 三 设计实现 四 参考资料 一 背景说明 项目程序需要进行加密处理。 考虑利用嵌入式芯片的唯一UID&#xff0c;结合Flash读写来实现。 加密后的程序&#xff0c;可以使得从芯片Flash中读取出来的文件&#xff08;一般为HEX格式&#xff09;不能…

rsync+inotify实时同步数据

一、相关简介 1、rsync&#xff08;remote synchronize&#xff09; rsync是 Liunx/Unix 下的一个远程数据同步工具&#xff0c;它可通过 LAN/WAN 快速同步多台主机间的文件和目录。   Linux 之间同步文件一般有两种方式&#xff0c;分别是 rsync 与 scp &#xff0c;scp 相…

前端性能优化汇总

1.减少HTTP请求次数和请求的大小 &#xff08;三大类&#xff09; 文件的合并和压缩&#xff1a;&#xff08;1&#xff09;&#xff08;6&#xff09; 延迟加载&#xff1a;&#xff08;3&#xff09;&#xff08;4&#xff09; 用新的文件格式代替传统文件格式&#xff1a;&a…

基于LQR算法的一阶倒立摆控制

1. 一阶倒立摆建模 2. 数学模型 倒立摆的受力分析网上有很多&#xff0c;这里就不再叙述。直接放线性化后的方程&#xff1a; F (Mm)x″-mLφ″ (ImL)φ″ mLx″ mgLφ&#xff08;F为外力&#xff0c;x为物块位移&#xff0c;M&#xff0c;m为物块和摆杆的质量&#xff0c;…

抽象轻松java

嗨嗨嗨&#xff01; 没想到吧&#xff0c;出现了抽象轻松第4种语言系列&#xff08;我也没想到&#xff09; 简单的java程序&#xff0c;看完就懂的简单逻辑——购物车系统 购物车&#xff0c;首先要有商品吧&#xff0c;现实中的商品有什么属性&#xff1f; 名字&#xff0…

CSS详细基础(一)选择器基础

本帖开始&#xff0c;我们来介绍CSS——即&#xff0c;层叠样式表~ 层叠样式表是一种用来表现HTML&#xff08;标准通用标记语言的一个应用&#xff09;或XML&#xff08;标准通用标记语言的一个子集&#xff09;等文件样式的计算机语言。简单的说&#xff0c;层叠就是对一个元…