【动手学深度学习Pytorch】1. 线性回归代码

零实现

        导入所需要的包:

# %matplotlib inline
import random
import torch
from d2l import torch as d2l
import matplotlib.pyplot as plt
import matplotlib
import os

        构造人造数据集:假设w=[2, -3.4],b=4.2,存在随机噪音(均值为0,方差为0.001的正态分布噪声),函数拟合为y = w^{T}X + b + n。在构造数据集的过程中,首先X为正态分布(均值为0,方差为1,样本数/行数为num_examples,列数为len(w))

torch.normal(mean, std, *, generator=None, out=None):生成指定输出尺寸的正态分布随机数张量

torch.mv():矩阵和向量的乘积,此处X为矩阵,w为向量

def synthetic_data(w, b, num_examples):X = torch.normal(0, 1, (num_examples, len(w))) #均值为0方差为1的随机数,样本数,列数y = torch.mv(X, w) + b #y关于x的公式y += torch.normal(0, 0.001, y.shape) # 加入噪声项return X, y.reshape((-1,1)) #做成列向量返回
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

        查看数据集样本分布:

matplotlib.pyplot.scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, *, edgecolors=None, plotnonfinite=False, data=None, **kwargs):

        x,y:长度相同的数组,也就是我们即将绘制散点图的数据点,输入数据。

        s:点的大小,默认 20,也可以是个数组,数组每个参数为对应点的大小。

        c:点的颜色,默认蓝色 'b',也可以是个 RGB 或 RGBA 二维行数组。

        marker:点的样式,默认小圆圈 'o'。

        cmap:Colormap,默认 None,标量或者是一个 colormap 的名字,只有 c 是一个浮点数数组的时才使用。如果没有申明就是 image.cmap。

        norm:Normalize,默认 None,数据亮度在 0-1 之间,只有 c 是一个浮点数的数组的时才使用。

        vmin,vmax:亮度设置,在 norm 参数存在时会忽略。

        alpha:透明度设置,0-1 之间,默认 None,即不透明。

        linewidths:标记点的长度。

        edgecolors:颜色或颜色序列,默认为 'face',可选值有 'face', 'none', None。

        plotnonfinite:布尔值,设置是否使用非限定的 c ( inf, -inf 或 nan) 绘制点。

        **kwargs:其他参数。

detach():允许我们从计算图中分离出张量。当对一个张量调用detach()方法时,它会创建一个新的张量,这个新张量与原始张量共享数据,但它不再参与计算图的任何操作,对分离后的张量进行的任何操作都不会影响原始张量,也不会在计算图中留下任何痕迹。

plt.scatter(features[:,(1)].detach().numpy(),labels.detach().numpy(),1);
plt.show()

        遍历数据集,输出数据集内容:

len(): 返回对象(字符、列表、元组等)长度或项目个数(此处是张量的行数)

list(): 将元组转换为列表

range():创建一个整数列表

shuffle(): 随机打乱列表

def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples)) #生成样本索引random.shuffle(indices) #样本随机读取没有特定顺序# 进行batch划分for i in range(0, num_examples, batch_size): #从i开始到i+batchsizebatch_indices =  torch.tensor(indices[i:min(i + batch_size, num_examples)])# 截取切片:开始位置为i,结束位置为min函数的返回值# 返回值为i+batch_size和num_examples的值比较小的那个yield features[batch_indices], labels[batch_indices] #产生随机顺序的特征&标号batch_size = 10for X, y in data_iter(batch_size, features, labels):print(X, '\n', y)break

 

      定义参数、模型、损失函数以及优化算法:

torch.mutual():矩阵相乘

with torch.no_grad():所有计算得出的tensor的requires_grad都自动设置为False,不会进行自动求导

grad.zero_():将梯度置零(不然会发生累计的情况)

# 定义初始化模型参数
w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
# 定义模型
def linreg(X, w, b):return torch.matmul(X, w) + b
# 定义损失函数
def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape))**2/2
# 定义优化算法
def sgd(params, lr, batch_size):with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()

        定义训练过程:

# 训练过程
lr = 0.01
num_epochs = 10
net = linreg
loss = squared_lossfor epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y)l.sum().backward()sgd([w,b], lr, batch_size)with torch.no_grad():train_1= loss(net(features, w, b), labels)print(f'epoch{epoch + 1}, loss{float(train_1.mean()):f}')

简介实现

        导入所需要的包:

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
import matplotlib.pyplot as plt

        创建人造数据集:

data.TensorDataset():将数据进行封装

data.DataLoader():将数据分批次处理

iter():获取列表的迭代器

next():获取下一个值

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b,1000)
def load_array(data_arrays, batch_size, is_train=True):dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle=is_train)batch_size = 10
data_iter = load_array((features, labels), batch_size)next(iter(data_iter))

初始化模型、模型参数、loss: 

nn.Sequential():实现模型层结构的简单排序

torch.optim.SGD():定义优化算法

torch.optim.SGD().step():进行模型的更新

# 使用框架的预定义好的层
from torch import nn
net = nn.Sequential(nn.Linear(2,1))
# 初始化模型参数
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)
# 计算均方误差使用的是MSELoss类
loss = nn.MSELoss()
trainer = torch.optim.SGD(net.parameters(),lr=0.01)

        定义训练过程:

num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X), y)trainer.zero_grad()l.backward()trainer.step()l = loss(net(features), labels)print(f'epoch{epoch + 1}, loss{1:f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/473355.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

论文笔记(五十六)VIPose: Real-time Visual-Inertial 6D Object Pose Tracking

VIPose: Real-time Visual-Inertial 6D Object Pose Tracking 文章概括摘要I. INTRODACTIONII. 相关工作III. APPROACHA. 姿态跟踪工作流程B. VIPose网络 文章概括 引用: inproceedings{ge2021vipose,title{Vipose: Real-time visual-inertial 6d object pose tra…

web——upload-labs——第三关——后缀黑名单绕过

上传一个正常的一句话木马,判断一下验证类型 响应后返回提示不允许上传.asp,.aspx,.php,.jsp后缀文件! 且查看网页源代码中并没有前端验证机制,所以可以判断这道题是后端验证 使用burp 提示无法上传.php结尾的文件,但我们的一句…

LeetCode题解:18.四数之和【Python题解超详细】,三数之和 vs. 四数之和

题目描述 给你一个由 n 个整数组成的数组 nums ,和一个目标值 target 。请你找出并返回满足下述全部条件且不重复的四元组 [nums[a], nums[b], nums[c], nums[d]] (若两个四元组元素一一对应,则认为两个四元组重复): …

如何利用SAP低代码平台快速构建企业级应用?

SAP作为全球领先的企业管理软件解决方案提供商,一直致力于为企业提供全面且高效的业务管理工具。随着技术的快速发展,传统的开发方式已经无法满足企业在快速变化的市场环境下的需求。低代码开发平台应运而生,它通过简化应用程序的创建过程&am…

Redis基础篇

文章目录 1.Redis的引入2.单机和分布式3.读写分离4.缓存服务器5.微服务 1.Redis的引入 我们的这个redis就是对于这个内存数据进行存储的,和我们的这个变量的这个性质是一样的,但是我们的这个redis主要是应用于这个分布式的这个系统上面的,如…

C++11(四)---可变参数模板

文章目录 可变参数模板 可变参数模板 参数包代表多个类型和参数 // Args是一个模板参数包&#xff0c;args是一个函数形参参数包 // 声明一个参数包Args...args&#xff0c;这个参数包中可以包含0到任意个模板参数。 template <class ...Args> void ShowList(Args... arg…

基于Springboot+Vue的中国蛇类识别系统 (含源码数据库)

1.开发环境 开发系统:Windows10/11 架构模式:MVC/前后端分离 JDK版本: Java JDK1.8 开发工具:IDEA 数据库版本: mysql5.7或8.0 数据库可视化工具: navicat 服务器: SpringBoot自带 apache tomcat 主要技术: Java,Springboot,mybatis,mysql,vue 2.视频演示地址 3.功能 这个系…

大数据新视界 -- 大数据大厂之 Impala 性能飞跃:分区修剪优化的应用案例(下)(22 / 30)

&#x1f496;&#x1f496;&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎你们来到 青云交的博客&#xff01;能与你们在此邂逅&#xff0c;我满心欢喜&#xff0c;深感无比荣幸。在这个瞬息万变的时代&#xff0c;我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

ES6标准-Promise对象

目录 Promise对象的含义 Promise对象的特点 Promise对象的缺点 Promise对象的基本用法 Promise对象的简单例子 Promise新建后就会立即执行 Promise对象回调函数的参数 Promise参数不会中断运行 Promise对象的then方法 Promise对象的catch()方法 Promise状态为resolv…

【目标检测】【Ultralytics-YOLO系列】Windows11下YOLOV5人脸目标检测

【目标检测】【Ultralytics-YOLO系列】Windows11下YOLOV5人脸目标检测 文章目录 【目标检测】【Ultralytics-YOLO系列】Windows11下YOLOV5人脸目标检测前言YOLOV5模型运行环境搭建YOLOV5模型运行数据集准备YOLOV5运行模型训练模型验证模型推理 总结 前言 Ultralytics YOLO 是一…

使用Axios函数库进行网络请求的使用指南

目录 前言1. 什么是Axios2. Axios的引入方式2.1 通过CDN直接引入2.2 在模块化项目中引入 3. 使用Axios发送请求3.1 GET请求3.2 POST请求 4. Axios请求方式别名5. 使用Axios创建实例5.1 创建Axios实例5.2 使用实例发送请求 6. 使用async/await简化异步请求6.1 获取所有文章数据6…

windows工具 -- 使用rustdesk和云服务器自建远程桌面服务, 手机, PC, Mac, Linux远程桌面 (简洁明了)

目的 向日葵最先放弃了, todesk某些功能需要收费, 不想用了想要 自己搭建远程桌面 自己使用希望可以电脑 控制手机分辨率高一些 原理理解 ubuntu云服务器配置 够买好自己的云服务器, 安装 Ubuntu操作系统 点击下载 hbbr 和 hbbs 两个 deb文件: https://github.com/rustdesk/…

MySQL-关联查询和子查询

目录 一、笛卡尔积 二、表连接 1、内部连接 1.1 等值连接 1.2 非等值连接 2、外部链接 2.1 左外连接-LEFT JOIN 2.2 右外连接-RIGHT JOIN 2.3 全关联-FULL JOIN/UNION 三、子查询 1、嵌套子查询 2、相关子查询 3、insert和select语句添加数据 4、update和select语…

AWTK-WIDGET-WEB-VIEW 实现笔记 (1) - 难点

webview 提供了一个跨平台的 webview 库&#xff0c;其接口简单&#xff0c;提供的例子也直观易懂。但是把它集成到 AWTK 里&#xff0c;还是遇到一些难题&#xff0c;这里记录一下&#xff0c;供有需要的朋友参考。 1. 作为 AWTK 控件 webview 提供的例子都是独立的程序&…

类与对象;

目录 一、认识类&#xff1b; 1、类的引入&#xff1b; 2、类的定义&#xff1b; 类的两种定义方式&#xff1a; 3、类的访问限定符及封装&#xff1b; 4、类的作用域&#xff1b; 5、类的实例化&#xff1b; 6、类对象模型&#xff1b; 计算类对象的大小&#xff1b; …

Ubuntu22.04LTS 部署前后端分离项目

一、安装mysql8.0 1. 安装mysql8.0 # 更新安装包管理工具 sudo apt-get update # 安装 mysql数据库&#xff0c;过程中的选项选择 y sudo apt-get install mysql-server # 启动mysql命令如下 &#xff08;停止mysql的命令为&#xff1a;sudo service mysql stop&#xff0…

使用 Ant Design Vue 自定渲染函数customRender实现单元格合并功能rowSpan

使用 Ant Design Vue 自定渲染函数customRender实现单元格合并功能rowSpan 背景 在使用Ant Design Vue 开发数据表格时&#xff0c;我们常常会遇到需要合并单元格的需求。 比如&#xff0c;某些字段的值可能会在多行中重复出现&#xff0c;而我们希望将这些重复的单元格合并为…

27.<Spring博客系统③(实现用户退出登录接口+发布博客+删除/编辑博客)>

PS&#xff1a;关于打印日志 1.建议在关键节点打印日志。 ①请求入口。 ②结果响应 2.在可能发生错误的节点打印日志 3.日志不是越多越好。因为打日志也会消耗性能。 日志也可以配置去除重复日志。 一、用户退出功能 判断用户退出。我们只需要在前端将token删掉就可以了。 由于…

[前端面试]javascript

js数据类型 简单数据类型 null undefined string number boolean bigint 任意精度的大整数 symbol 创建唯一且不变的值&#xff0c;常用来表示对象属性的唯一标识 复杂数据类型 object&#xff0c;数组&#xff0c;函数,正则,日期等 区别 存储区别 简单数据类型因为其大小固定…

uniapp自动注册机制:easycom

传统 Vue 项目中&#xff0c;我们需要注册、导入组件之后才能使用组件。 uniapp 框架提供了一种组件自动注册机制&#xff0c;只要你在 components 文件夹下新建的组件满足 /components/组件名/组件名.vue 的命名规范&#xff0c;就能直接使用。 注意&#xff1a;组件的文件夹…