pytorch(四、五)用pytorch实现线性回归和逻辑斯蒂回归(分类)

文章目录

  • 线性回归
    • 代码过程
    • 准备数据
    • 设计模型
    • 设计构造函数与优化器
    • 训练过程
    • 训练代码和结果
    • pytorch中的Linear层的底层原理(个人喜欢,不用看)
      • 普通矩阵乘法实现
      • Linear层实现
    • 回调机制
  • 逻辑斯蒂回归
    • 模型
    • 损失函数
    • 代码和结果

线性回归

代码过程

训练过程:

  1. 准备数据集
  2. 设计模型(用来计算 y ^ \hat y y^
  3. 构造损失函数和优化器(API)
  4. 训练周期(前馈、反馈、更新)

准备数据

这里的输入输出数据均表示为3×1的,也就是维度均为1

# 行表示实例数量,列表示维度feature
import torch
x_data=torch.Tensor([[1.0],[2.0],[3.0]])
y_data=torch.Tensor([[2.0],[4.0],[6.0]])

设计模型

模型继承Module类,并且必须要实现 init 和 forward 两个方法,其中self.linear=torch.nn.Linear(1,1)表示实例化Linear类,这个类是可调用的,其__call__函数调用了 forward 方法

class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel,self).__init__()# weight 和 bias 1 1 self.linear=torch.nn.Linear(1,1)def forward(self,x):# callabley_pred=self.linear(x)return y_pred# callable
model=LinearModel()

pytorch中的linear类是在某一个数据上应用线性转换,其公式表达为 y = x w T + b y=xw^T+b y=xwT+b

class torch.nn.Linear(in_features,out_features,bias=True) :其中in_features和out_features分别表示输入和输出的数据的维度(列的数量),bias表示偏置,默认是true,该类有两个参数

  • weight:可学习参数,值从均匀分布 U ( − k , k ) U(-\sqrt k,\sqrt k) U(k ,k )中获取,其中 k = 1 i n _ f e a t u r e s k=\frac{1}{in\_features} k=in_features1
  • bias:shape和输出的维度一样,也是从分布 U ( − k , k ) U(-\sqrt k,\sqrt k) U(k ,k )中初始化的
    在这里插入图片描述

设计构造函数与优化器

# 构造损失函数和优化器
criterion=torch.nn.MSELoss(size_average=False)# w和b--->parameters
opyimizer=torch.optim.SGD(model.parameters(),lr=0.01)

在这里插入图片描述
在这里插入图片描述

训练过程

# 训练过程
for epoch in range(100):y_pred=model(x_data)loss=criterion(y_pred,y_data)# loss标量,自动调用__str__()print(epoch,loss)optimizer.zero_grad()# backwardloss.backward()# updateoptimizer.step()

训练代码和结果

# 行表示实例数量,列表示维度feature
import torch
x_data=torch.Tensor([[1.0],[2.0],[3.0]])
y_data=torch.Tensor([[2.0],[4.0],[6.0]])class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel,self).__init__()# weight 和 bias 1 1 self.linear=torch.nn.Linear(1,1)def forward(self,x):# callabley_pred=self.linear(x)return y_pred# callable
model=LinearModel()# 构造损失函数和优化器
criterion=torch.nn.MSELoss(size_average=False)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)# 训练过程
for epoch in range(100):y_pred=model(x_data)loss=criterion(y_pred,y_data)# loss标量,自动调用__str__()print(epoch,loss)optimizer.zero_grad()# backwardloss.backward()# updateoptimizer.step()# 打印信息
print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())x_test=torch.Tensor([4.0])
y_test=model(x_test)
print('y_pred=',y_test.data)

在这里插入图片描述


pytorch中的Linear层的底层原理(个人喜欢,不用看)

我们在课本使用到的线性函数的基本公式表达为 y = x w T + b y=xw^T+b y=xwT+b,但是在Linear层中,当输入特征被Linear层接收是,它会接收后转置,然后乘以权重矩阵,得到的是输出特征的转置,换句话说可以在底层使用Linear,它实际上做的是 y T = w x T + b y^T=wx^T+b yT=wxT+b。可以使用下面的案例进行验证:

在这里插入图片描述

普通矩阵乘法实现

很明显,上面的图标表示一个 3×4 的矩阵乘以 4×1 的矩阵,得到一个 3×1 的输出矩阵,使用普通矩阵的乘法实现如下。

import torchin_features=torch.tensor([1,2,3,4],dtype=torch.float32)
weight_matrix=torch.tensor([[1,2,3,4],[2,3,4,5],[3,4,5,6]
],dtype=torch.float32)weight_matrix.matmul(in_features)# 矩阵乘法

实现截图:
在这里插入图片描述

Linear层实现

# 这里还是使用上面使用过的数据
import torch
in_features=torch.tensor([1,2,3,4],dtype=torch.float32)
weight_matrix=torch.tensor([[1,2,3,4],[2,3,4,5],[3,4,5,6]
],dtype=torch.float32)print(weight_matrix.matmul(in_features))# 矩阵乘法fc = torch.nn.Linear(in_features=4, out_features=3, bias=False)
# 这里是随机一个权重矩阵
print('fc.weight',fc.weight)
fc(in_features)

输出结果:
在这里插入图片描述

print('fc.weight',fc.weight)# 使用上面的权重矩阵进行计算
fc.weight = torch.nn.Parameter(weight_matrix)
print('fc.weight',fc.weight)
fc(in_features)

结果截图:
在这里插入图片描述

可以看到上面截图与下面的截图的区别,一开始随机一个权重的时候,进行运算,使用到前面提及到的权重矩阵后,Linear层进行运算之后,得到与使用普通矩阵乘法一样的结果,相同的结果说明,Linear底层的实现与上面的矩阵乘法的逻辑是一致的

以上的论证可以说明,Linear的底层实现其实是 y T = w x T + b y^T=wx^T+b yT=wxT+b,而不是 y = x w T + b y=xw^T+b y=xwT+b,可能会有人好奇,为什么书本上都是写的后者而不是写前者,其实本质上二者都一样,前者的转置就是后者。

回调机制

在pytorch学习(一)线性模型中,第一个代码中,我们没有通过pytorch实现线性模型的时候,我们会显式调用forward函数,计算前馈的值,我们是这样写的y_pred_val=forward(x_val),但是在使用pytorch之后,我们是这样写的y_pred=model(x_data),直接实例化一个对象,然后通过对象直接计算预测值(前馈值),但是并没有使用到forward函数。这是因为pytorch模块类中实现了python中一个特殊的函数,也就是回调函数

如果一个类实现了回调方法,那么只要对象实例被调用,这个特殊的方法也会被调用。我们不直接调用forward()方法,而是调用对象实例。在对象实例被调用之后,在底层调用了__ call __方法,然后调用了forward()方法。这适用于所有的PyTorch神经网络模块。

以上仅代表小白个人学习观点,如有错误欢迎批评指正。

参考



逻辑斯蒂回归

逻辑斯蒂回归解决的事分类问题,分类输出的是类别的概率

模型

在线性模型中,通过 y = w x + b y=wx+b y=wx+b输出的是一个实数值,但是在分类问题中,输出的是类别的概率,所以需要一个函数,把实数值映射到[0,1]之间,表示概率,这个函数为 sigmoid 函数 y = 1 1 + e − x ∈ [ 0 , 1 ] y=\frac{1}{1+e^{-x}}∈[0,1] y=1+ex1[0,1],sigmoid函数属于饱和函数。

以上,逻辑斯蒂回归模型的公式为: y ^ = σ ( x ∗ w + b ) \hat{y}=\sigma(x*w+b) y^=σ(xw+b)

在这里插入图片描述

损失函数

在线性回归中,计算损失一般是使用均方误差(预测值与真实值的差值的平方和的累加),在回归问题中,均方误差表示数轴上两个值之间的距离,但是分类问题中,输出的结果表示的是概率(分布),使用距离是没有意义的,所以分类问题中的损失函数并不是均方误差。

在逻辑斯蒂回归中,使用的是BCE
在这里插入图片描述

代码和结果

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as pltx_data=torch.Tensor([[1.0],[2.0],[3.0]])
y_data=torch.Tensor([[0],[0],[1]])class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel,self).__init__()self.linear=torch.nn.Linear(1,1)def forward(self,x):y_pred=F.sigmoid(self.linear(x))return y_predmodel=LogisticRegressionModel()# 损失函数
criterion=torch.nn.BCELoss(size_average=False)
# 优化器
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)# 训练
for epoch in range(1000):y_pred=model(x_data)loss=criterion(y_pred,y_data)print(epoch,loss.item())optimizer.zero_grad()loss.backward()optimizer.step()# linspace与range函数类似,用于生成均匀分布的数值序列
# np.linspace(start=0,stop=10,num=200)
x=np.linspace(0,10,200)
# 数据集,test,生成200*1的矩阵
x_t=torch.Tensor(x).view((200,1))
y_t=model(x_t)
y=y_t.data.numpy()
plt.plot(x,y)
plt.plot([0,10],[0.5,0.5],c='r')
plt.xlabel("Hours")
plt.ylabel("Probability of Pass")
# 显示网格线
plt.grid()
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/271288.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一键清除JavaScript代码中的注释:使用正则表达式实现

这个正则表达式可以有效地匹配 JavaScript 代码中的各种注释&#xff0c;并且跳过了以 http: 或 https: 开头的链接。 /\/\*[\s\S]*?\*\/|\/\/[^\n]*|<!--[\s\S]*?-->|(?<!http:|https:)\/\/[^\n]*/gvscode 实战&#xff0c;ctrlF 调出查找替换工具&#xff0c;点…

(五)关系数据库标准语言SQL

注&#xff1a;课堂讲义使用的数据库 5.1利用SQL语言建立数据库 5.1.1 create Database 5.1.2 create schema...authorization... 创建数据库和创建模式的区别&#xff1a; 数据库是架构的集合&#xff0c;架构是表的集合。但在MySQL中&#xff0c;他们使用的方式是相同的。 …

设计模式学习笔记(二):工厂方法模式

一、定义 工厂方法模式&#xff08;Factory Method Pattern&#xff09;是一种创建型设计模式&#xff0c;它提供了一种在不指定具体类的情况下创建对象的方法。工厂方法模式定义了一个用于创建对象的接口&#xff0c;让子类决定实例化哪一个类。工厂方法使一个类的实例化延迟…

maven项目结构管理统一项目配置操作

一、maven分模块开发 Maven 分模块开发 1.先创建父工程&#xff0c;pom.xml文件中&#xff0c;打包方式为pom 2.然后里面有许多子工程 3.我要对父工程的maven对所有子工程进行操作 二、解读maven的结构 1.模块1 <groupId>org.TS</groupId><artifactId>TruthS…

Dubbo基础入门二

8、Dubbo协议 服务调用 8.1 服务端 启动过程深入分析 我们查看一下服务启动的过程 ProtocolFilterWrapper.export 好我们进入DubboProtocol.export 创建服务 分析我们的Handler 我们接着返回刚才位置 下面的super方法里面会创建服务&#xff0c;ChannelHandlers.wrap会对hand…

19-Java中介者模式 ( Mediator Pattern )

Java中介者模式 摘要实现范例 中介者模式&#xff08;Mediator Pattern&#xff09;提供了一个中介类&#xff0c;该类通常处理不同类之间的通信&#xff0c;并支持松耦合&#xff0c;使代码易于维护中介者模式是用来降低多个对象和类之间的通信复杂性中介者模式属于行为型模式…

每周一算法:A*(A Star)算法

八数码难题 题目描述 在 3 3 3\times 3 33 的棋盘上&#xff0c;摆有八个棋子&#xff0c;每个棋子上标有 1 1 1 至 8 8 8 的某一数字。棋盘中留有一个空格&#xff0c;空格用 0 0 0 来表示。空格周围的棋子可以移到空格中。要求解的问题是&#xff1a;给出一种初始布局…

【RT-DETR有效改进】全新的SOATA轻量化下采样操作ADown(轻量又涨点,附手撕结构图)

一、本文介绍 本文给大家带来的改进机制是利用2024/02/21号最新发布的YOLOv9其中提出的ADown模块来改进我们的Conv模块,其中YOLOv9针对于这个模块并没有介绍,只是在其项目文件中用到了,我将其整理出来用于我们的RT-DETR的项目,经过实验我发现该卷积模块(作为下采样模块)…

nicegui学习使用

https://www.douyin.com/shipin/7283814177230178363 python轻量级高自由度web框架 - NiceGUI (6) - 知乎 python做界面&#xff0c;为什么我会强烈推荐nicegui 秒杀官方实现&#xff0c;python界面库&#xff0c;去掉90%事件代码的nicegui python web GUI框架-NiceGUI 教程…

嵌入式学习第二十五天!(网络的概念、UDP编程)

网络&#xff1a; 可以用来&#xff1a;数据传输、数据共享 1. 网络协议模型&#xff1a; 1. OSI协议模型&#xff1a; 应用层实际收发的数据表示层发送的数据是否加密会话层是否建立会话连接传输层数据传输的方式&#xff08;数据包&#xff0c;流式&#xff09;网络层数据的…

吴恩达深度学习笔记:神经网络的编程基础2.1-2.3

目录 第一门课&#xff1a;神经网络和深度学习 (Neural Networks and Deep Learning)第二周&#xff1a;神经网络的编程基础 (Basics of Neural Network programming)2.1 二分类(Binary Classification)2.2 逻辑回归(Logistic Regression) 第一门课&#xff1a;神经网络和深度学…

基于YOLOv5的无人机视角水稻杂草识别检测

&#x1f4a1;&#x1f4a1;&#x1f4a1;本文主要内容:详细介绍了无人机视角水稻杂草识别检测整个过程&#xff0c;从数据集到训练模型到结果可视化分析。 博主简介 AI小怪兽&#xff0c;YOLO骨灰级玩家&#xff0c;1&#xff09;YOLOv5、v7、v8优化创新&#xff0c;轻松涨点…

Github 2024-03-08 Java开源项目日报 Top10

根据Github Trendings的统计,今日(2024-03-08统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Java项目9C++项目1非开发语言项目1《Hello 算法》:动画图解、一键运行的数据结构与算法教程 创建周期:476 天协议类型:OtherStar数量:63556…

网络原理初识(2)

目录 一、协议分层 1、分层的作用 2、OSI七层模型 3、TCP / IP五层&#xff08;或四层&#xff09;模型 4、网络设备所在分层 5、网络分层对应 二、封装和分用 发送过程&#xff08;封装&#xff09; 1、应用层(应用程序) QQ 2、输入层 3、网络层 4、数据链路层 5、物理…

金融行业专题|基金超融合架构转型与场景探索合集(2023版)

更新内容 更新 SmartX 超融合在基金行业的覆盖范围、部署规模与应用场景。更新信创云资源池、关键业务系统性能优化等场景实践。更多超融合金融核心生产业务场景实践&#xff0c;欢迎下载阅读电子书《金融核心生产业务场景探索文章合集》。 随着数字化经济的蓬勃发展&#xf…

记一次Flink任务无限期INITIALIZING排查过程

1.前言 环境&#xff1a;Flink-1.16.1&#xff0c;部署模式&#xff1a;Flink On YARN&#xff0c;现象&#xff1a;Flink程序能正常提交到 YARN&#xff0c;Job状态是 RUNNING&#xff0c;而 Task状态一直处于 INITIALIZING&#xff0c;如下图&#xff1a; 通过界面可以看到…

ue4.27 发现 getRandomReachedLocation 返回 false

把这个玩意儿删掉&#xff0c;重启工程&#xff0c;即可 如果还不行 保证运动物体在 volum 内部&#xff0c;也就是绿色范围内确保 project setting 里面的 navigation system 中 auto create navigation data 是打开的(看到过博客说关掉&#xff0c;不知道为啥) 如果还不行&…

智奇科技工业 Linux 屏更新开机logo

智奇科技工业 Linux 屏更新开机logo 简介制作logo.img文件1、转换格式得到logo.bmp2、使用Linux命令生成img文件 制作rootfs.img文件替换rootfs.img中的logo 生成update.img固件附件 简介 智奇科技的 Linux 屏刷开机logo必须刷img镜像文件&#xff0c;比较复杂。 制作logo.i…

深入浅出(二)MVVM

MVVM 1. 简介2. 示例 1. 简介 2. 示例 示例下载地址&#xff1a;https://download.csdn.net/download/qq_43572400/88925141 创建C# WPF应用(.NET Framework)工程&#xff0c;WpfApp1 添加程序集 GalaSoft.MvvmLight 创建ViewModel文件夹&#xff0c;并创建MainWindowV…

通过对话式人工智能打破语言障碍

「AI突破语言障碍」智能人工智能如何让全球交流无障碍 在当今互联的世界中&#xff0c;跨越语言界限进行交流的能力比以往任何时候都更加重要。 对话式人工智能&#xff08;包括聊天机器人和语音助手等技术&#xff09;在打破这些语言障碍方面发挥着关键作用。 在这篇博文中&am…