从0开始深度学习(7)——线性回归的简洁实现

在从0开始深度学习(5)——线性回归的逐步实现中,我们手动编写了数据构造模块、损失函数模块、优化器等,但是在现代深度学习框架下,这些已经包装好了
本章展示如果利用深度学习框架简洁的实现线性回归

0 导入头文件

import random
import torch
import matplotlib.pyplot as plt
from torch.utils import data
import numpy as np
from torch import nn#nn是神经网络的缩写

1 生成数据集

和之前的数据一样

def synthetic_data(w, b, num_examples):  #@save"""生成y=Xw+b+噪声"""X = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(X, w) + by += torch.normal(0, 0.01, y.shape)return X, y.reshape((-1, 1))true_w = torch.tensor([2, -3.4])# 真实的W,是个二维张量
true_b = 4.2# 真实的b
features, labels = synthetic_data(true_w, true_b, 1000)# 生成1000个点# 绘制散点图
plt.scatter(features[:, 0].numpy(), labels.numpy(), 1.0)
plt.xlabel('Feature')
plt.ylabel('Label')
plt.title('Scatter Plot of Generated Data')
plt.show()

2 读取数据

直接使用torch中的TensorDatasetDataLoader

  • TensorDataset 是 PyTorch 中的一个类,它将数据和对应的标签组合成一个数据集对象。
  • DataLoader 是 PyTorch 提供的一个迭代器,可以用来批量加载数据,并且能够处理多线程数据读取、数据打乱等任务。
# 读取数据
def load_data(data_array,batch_size):dataset=data.TensorDataset(*data_array)return data.DataLoader(dataset,batch_size,shuffle=True)batch_size=10
data_iter=load_data((features,labels),batch_size)

3 定义模型

直接使用torch自带的神经网络中的全连接层,全连接层和线性回归模型都使用线性变换来生成输出, 所以可以用全连接层来实现线性回归

net = nn.Sequential(nn.Linear(2, 1))
# 第一个参数是输出的特征形状,第二个是输出的特征形状
# 因为我们的w是个二维向量,所以这里的形状是2

4 初始化参数

我们的函数是 y = w x + b y=wx+b y=wx+b,所以有一个权重 w w w和偏置项 b b b

#初始化权重,通常情况下,权重可以从一个正态分布中初始化,这样可以确保权重的初始值既不是太大也不是太小,有助于模型的收敛。
net[0].weight.data.normal_(0,0.01)# 从均值为 0、标准差为 0.01 的正态分布中初始化权重。
#初始化偏置项,偏置通常初始化为 0
net[0].bias.data.fill_(0)

5 定义损失函数和优化器

之前是手写的,这里我们可以直接使用torch自带的

# 定义损失函数
loss=nn.MSELoss()
#定义优化算法
trainer=torch.optim.SGD(net.parameters(),lr=0.01)
#第一个参数是指,返回所有需要更新的参数,第二个是学习率

6 训练模型

注意: 每次都要初始化梯度为0,避免梯度累积,每次反向传播之前将梯度清零,可以确保每次更新都是基于当前批次的数据

total_epochs=3
for epoch in range(total_epochs):for X,y in data_iter:# X是特征数据,y是标签l=loss(net(X),y)# 前向传播,生成预测,并计算损失trainer.zero_grad()# 初始化梯度l.backward()# 反向传播计算梯度trainer.step()# 调用优化器更新参数l=loss(net(features),labels)print(f'epoch {epoch + 1}, loss {l:f}')

7 评估模型

最后和我们的真实权重 w w w和偏置项 b b b做差,观察差距

w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/444024.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu 22.04.4 LTS更换下载源

方法1:使用图形界面更换下载源 1. 打开软件和更新应用 2. 在Ubuntu 软件标签中,点击“下载自”旁边的下拉菜单,选择“其他” 3. 点击“选择最佳服务器”来自动选择最快的服务器 4. 选择服务器 5. 确定并关闭窗口,系统会提示您重新…

ElasticSearch备考 -- Multi match

一、题目 索引task有3个字段a、b、c,写一个查询去匹配这三个字段为mom,其中b的字段评分比a、c字段大一倍,将他们的分数相加作为最后的总分数 二、思考 通过题目要求对多个字段进行匹配查询,可以考虑multi match、bool query操作。…

【C++第十八章】Map和Set

Map和Set map和set的介绍 容器分为两种,序列式容器和关联式容器,序列式容器因为底层是线性序列的数据结构,存储的是元素本身,而关联式容器中不单是为了存储数据,还要进行查找,所以存储的是键值对&#xff…

网络编程(17)——asio多线程模型IOThreadPool

十七、day17 之前我们介绍了IOServicePool的方式,一个IOServicePool开启n个线程和n个iocontext,每个线程内独立运行iocontext, 各个iocontext监听各自绑定的socket是否就绪,如果就绪就在各自线程里触发回调函数。为避免线程安全问题&#xf…

腾讯云SDK点播播放数据

点播播放质量监控提供点播播放全链路的数据统计、质量监控及可视化分析服务。支持实时数据上报、数据聚合、多维筛选和精细化定向分析,可帮助企业实时掌控大盘运营状况、了解用户习惯和行为特征,有效指导运营决策、驱动业务增长。 注意事项 点播播放质…

Python 工具库每日推荐 【Pandas】

文章目录 引言Python数据处理库的重要性今日推荐:Pandas工具库主要功能:使用场景:安装与配置快速上手示例代码代码解释实际应用案例案例:销售数据分析案例分析高级特性数据合并和连接时间序列处理数据透视表扩展阅读与资源优缺点分析优点:缺点:总结【 已更新完 TypeScrip…

基于 CSS Grid 的简易拖拉拽 Vue3 组件,从代码到NPM发布(1)- 拖拉拽交互

基于特定的应用场景,需要在页面中以网格的方式,实现目标组件在网格中可以进行拖拉拽、修改大小等交互。本章开始分享如何一步步从代码设计,最后到如何在 NPM 上发布。 请大家动动小手,给我一个免费的 Star 吧~ 大家如果发现了 Bug…

探索未来:mosquitto-python,AI领域的新宠

文章目录 探索未来:mosquitto-python,AI领域的新宠背景:为何选择mosquitto-python?库简介:mosquitto-python是什么?安装指南:如何安装mosquitto-python?函数用法:5个简单…

代码随想录算法训练营第四十六天 | 647. 回文子串,516.最长回文子序列

四十六天打卡,今天用动态规划解决回文问题,回文问题需要用二维dp解决 647.回文子串 题目链接 解题思路 没做出来,布尔类型的dp[i][j]:表示区间范围[i,j] (注意是左闭右闭)的子串是否是回文子串&#xff0…

深入理解Transformer的笔记记录(精简版本)---- Transformer

自注意力机制开启大规模预训练时代 1 从机器翻译模型举例 1.1把编码器和解码器联合起来看待的话,则整个流程就是(如下图从左至右所示): 1.首先,从编码器输入的句子会先经过一个自注意力层(即self-attention),它会帮助编码器在对每个单词编码时关注输入句子中的的其他单…

【JavaEE】——回显服务器的实现

阿华代码,不是逆风,就是我疯 你们的点赞收藏是我前进最大的动力!! 希望本文内容能够帮助到你!! 目录 一:引入 1:基本概念 二:UDP socket API使用 1:socke…

2-118 基于matlab的六面体建模和掉落仿真

基于matlab的六面体建模和掉落仿真,将对象建模为刚体来模拟将立方体扔到地面上。同时考虑地面摩擦力、刚度和阻尼所施加的力,在三个维度上跟踪平移运动和旋转运动。程序已调通,可直接运行。 下载源程序请点链接:2-118 基于matla…

基于SpringBoot“花开富贵”花园管理系统【附源码】

效果如下: 系统注册页面 系统首页界面 植物信息详细页面 后台登录界面 管理员主界面 植物分类管理界面 植物信息管理界面 园艺记录管理界面 研究背景 随着城市化进程的加快和人们生活质量的提升,越来越多的人开始追求与自然和谐共生的生活方式&#xf…

使用激光跟踪仪提升码垛机器人精度

标题1.背景 码垛机器人是一种用于工业自动化的机器人,专门设计用来将物品按照一定的顺序和结构堆叠起来,通常用于仓库、物流中心和生产线上,它们可以自动执行重复的、高强度的搬运和堆垛任务。 图1 码垛机器人 传统调整码垛机器人的方法&a…

通信工程学习:什么是DIP数据集成点

DIP:数据集成点 DIP数据集成点(Data Integration Point),简称DIP,是物联网技术(IoT)和机器到机器(M2M)通信中的一个重要组成部分。DIP在数据集成和传输过程中扮演着关键角…

【笔记】6.2 玻璃的成型

玻璃熔体的成型方法,有压制法(例如,制作水杯、烟灰缸等)、压延法(例如,制作压花玻璃等)、浇铸法(例如,制作光学玻璃、熔铸耐火材料、铸石等) 、吹制法(例如,制作瓶罐等空心玻璃)、拉制法(例如,制作窗用玻璃、玻璃管、玻璃纤维等)、离心法(例如,制作玻璃棉等)、喷吹法(例如,制作…

Ansible 工具从入门到使用

1. Ansible概述 Ansible是一个基于Python开发的配置管理和应用部署工具,现在也在自动化管理领域大放异彩。它融合了众多老牌运维工具的优点,Pubbet和Saltstack能实现的功能,Ansible基本上都可以实现。 Ansible能批量配置、部署、管理上千台主…

各类排序详解

前言 本篇博客将为大家介绍各类排序算法,大家知道,在我们生活中,排序其实是一件很重要的事,我们在网上购物,需要根据不同的需求进行排序,异或是我们在高考完报志愿时,需要看看院校的排名&#…

qt QGraphicsItem详解

一、概述 QGraphicsItem是Qt框架中图形视图框架(Graphics View Framework)的一个核心组件,它是用于表示2D图形元素的基类。 它支持的功能包括: 设置和获取图形项的位置和尺寸。控制图形项的外观,如颜色、笔刷、边框…

京东web 京东e卡绑定 第二部分分析

声明 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关! 有相关问题请第一时间头像私信联系我删…