Pytorch深度学习实践(5)逻辑回归

逻辑回归

逻辑回归主要是解决分类问题

  • 回归任务:结果是一个连续的实数
  • 分类任务:结果是一个离散的值

分类任务不能直接使用回归去预测,比如在手写识别中(识别手写 0 − − 9 0 -- 9 09),因为各个类别之间没有大小之差。

因此,对于分类问题,我们最终的输出是个概率,即属于某个类别的概率是多少,然后从概率集合里找最大值,作为当前预测的结果

下载MNIST数据集

import torchvision
train_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=True, download=True)
test_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=False, dowload=True)
  • 通过train参数来指定训练集和测试集

逻辑回归

将之前的学习时长—考试分数转化为二分类任务,即学习时长—是否通过考试

x(hours)y(pass/fail)
10(fail)
20(fail)
31(pass)
4?

其中, P ( y ^ = 1 ) + P ( y ^ = 0 ) = 1 P(\hat y = 1) + P(\hat y = 0) = 1 P(y^=1)+P(y^=0)=1

当输出的概率在 0.5 0.5 0.5附近时,即模型不确定,因此通常会输出一个不确定的值

对于二分类任务,逻辑回归会先使用回归,生成一个得分值,即落在实数集区间内,然后再使用 s i g m o i d sigmoid sigmoid函数,将得分值映射到 [ 0 , 1 ] [0, 1] [0,1]区间内,得到预测概率

s i g m o i d sigmoid sigmoid函数
σ ( x ) = 1 1 + e − x \sigma (x) = \frac{1}{1+e^{-x}} σ(x)=1+ex1
在这里插入图片描述

S i g m o i d Sigmoid Sigmoid常用来做二分类任务,常具备三个特征:

  • 饱和函数
  • 单调递增
  • 有极限

当我们使用线性回归来得到逻辑回归的得分值时,逻辑回归模型的函数定义就如下所示:
y ^ = σ ( x ∗ ω + b ) \hat y = \sigma (x*\omega + b) y^=σ(xω+b)

损失函数

线性回归使用的损失函数是计算预测值和真实值之差

而对于逻辑回归,由于我们得到的是概率,是一个 0 − 1 0-1 01分布,因此需要修改损失函数
l o s s = − ( y l o g y ^ + ( 1 − y ) l o g ( 1 − y ^ ) ) loss = -(ylog\hat y + (1-y)log(1-\hat y)) loss=(ylogy^+(1y)log(1y^))
即我们比较的是分布之间的差异

交叉熵 c r o s s − e n t r o p y cross-entropy crossentropy

存在两个分布 P D 1 ( x ) P_{D1}(x) PD1(x) P D 2 ( x ) P_{D2}(x) PD2(x)

两个分布的差异程度使用公式: ∑ i = 1 n P D 1 ( x i ) l n P D 2 ( x i ) \sum_{i=1}^{n}P_{D1}(x_i)lnP_{D2}(x_i) i=1nPD1(xi)lnPD2(xi) 来衡量

上述公式越大时,两个分布的差异越小

模型的改变

模型构造的改变
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1)# 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可# 在这里不需要实例化def forward(self, x):y_pred = F.sigmoid(self.linear(x))return y_pred

需要先将输入写入到linear()线性模型中,再使用Sigmoid()函数

模型损失函数的改变

使用交叉熵函数BCELoss

criterion = torch.nn.BCELoss(size_average=False)

整体代码

import torch
import matplotlib.pyplot as plt
import numpy as np########## 数据集准备 ##########
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])########## 模型定义 ##########
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1)# 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可# 在这里不需要实例化def forward(self, x):y_pred = torch.sigmoid(self.linear(x))return y_predmodel = LogisticRegressionModel()########## 损失函数和优化器的设置 ##########
criterion = torch.nn.BCELoss(size_average=False) # BCELoss -- 交叉熵函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)########## 模型训练 ##########
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data)print(epoch, loss.item())optimizer.zero_grad()loss.backward()optimizer.step()########## 模型测试 ##########
x = np.linspace(0, 10, 200)
x_test = torch.Tensor(x).view((200, 1)) # view()相当于reshape
y_test = model(x_test)
y = y_test.data.numpy()  # 转化为np类型
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], 'r--')
plt.xlabel("Hours")
plt.ylabel("Probability of Pass")
plt.grid()
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/384339.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CentOS7下操作iptables防火墙和firewalld防火墙

CentOS7下操作iptables防火墙和firewalld防火墙 💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、…

【OpenCV C++20 学习笔记】调节图片对比度和亮度(像素变换)

调节图片对比度和亮度(像素变换) 原理像素变换亮度和对比度调整 代码实现更简便的方法结果展示 γ \gamma γ校正及其实操案例线性变换的缺点 γ \gamma γ校正低曝光图片矫正案例代码实现 原理 关于OpenCV的配置和基础用法,请参阅本专栏的其…

HAL STM32 SPI/ABZ/PWM方式读取MT6816磁编码器数据

HAL STM32 SPI/ABZ/PWM方式读取MT6816磁编码器数据 📚MT6816相关资料(来自商家的相关资料): 资料:https://pan.baidu.com/s/1CAbdLBRi2dmL4D7cFve1XA?pwd8888 提取码:8888📍驱动代码编写&…

FastAPI(七十九)实战开发《在线课程学习系统》接口开发-- 加入课程和退出课程

源码见:"fastapi_study_road-learning_system_online_courses: fastapi框架实战之--在线课程学习系统" 加入课程 我们先看下加入课程 1.是否登录 2.课程是否存在 3.是否已经存在 4.添加 首先实现逻辑 def get_student_course(db: Session, course: int…

如何开启或者关闭 Windows 安全登录?

什么是安全登录 什么是 Windows 安全登录呢?安全登录是 Windows 附加的一个组件,它可以在用户需要登录的之前先将登录界面隐藏,只有当用户按下 CtrlAltDelete 之后才出现登录屏幕,这样可以防止那些模拟登录界面的程序获取密码信息…

【9.PIE-Engine案例——加载Terra星全球500m植被指数16天合成产品(MOD13A1 V61)数据集】

加载Terra星全球500m植被指数16天合成产品(MOD13A1 V61)数据集 原始路径 欢迎大家登录航天宏图官网查看本案例原始来源 最终结果 具体代码 /*** File : MOD13A1* Time : 2020/7/21* Author : piesat* Version : 1.0* Contact : 400-890-0662* License : …

Interesting bug caused by getattr

题意:由 getattr 引起的有趣的 bug 问题背景: I try to train 8 CNN models with the same structures simultaneously. After training a model on a batch, I need to synchronize the weights of the feature extraction layers in other 7 models. …

WARNING: Ignoring invalid distribution -ip警告信息如何去掉?

查看已安装依赖列表的时候,出现了很多警告信息,如何去掉呢? 解决办法 打开这个路径:d:\software\python\python39\lib\site-packages 这种波浪线开头的,我们将它删除掉,就可以了。

Ubuntu设置网络

进入网络配置文件夹 cd /etc/netplan 使用 vim 打开下的配置文件 打开后的配置 配置说明: network:# 网络配置部分ethernets:# 配置名为ens33的以太网接口ens33:addresses:# 为ens33接口分配IP地址192.168.220.30,子网掩码为24位- 192.168.220.30/24n…

VS2019报错:找不到导入的项目,请确认import声明

解决办法 找到项目的.vcxproj文件 用记事本打开后使用ctrlF搜索import 发现import Project后面的.props文件路径不对,将路径改为相对路径 保存后重新加载项目,即可生成成功

AI发展下的伦理挑战:构建未来科技的道德框架

一、引言 随着人工智能(AI)技术的飞速发展,我们正处在一个前所未有的科技变革时代。AI不仅在医疗、教育、金融、交通等领域展现出巨大的应用潜力,也在日常生活中扮演着越来越重要的角色。然而,这一技术的迅猛进步也带来…

git实践汇总【配置+日常使用+问题解决】

**最初配置步骤:** git config --global user.name "yournemae" git config --global user.email "yourmail" git config -l ssh-keygen -t rsa -C “xxx.xxxx.EXTcccc.com” git config --global ssh.variant ssh $ git clone git仓库路径 git…

云盘高速检测的秘密:密封圈外观检测全解析!

密封圈是一种用于填塞、隔离或密封两个相互连接部件之间空隙的圆形密封装置。密封圈通常由橡胶、塑料、金属等材料制成,具有弹性并能在压力作用下填充间隙,防止液体、气体或固体物质泄漏。 密封圈可根据具体应用选择不同材料,如橡胶密封圈适…

「Unity3D」场景中的距离单位Unit与相关设置PixelsToUnits、PixelsPerUnit

GameObject在场景的位置Position,并没有明确是什么具体单位——如:Transform的x、y、z,或RectTransform的PosX、PosY、PosZ。而RectTransform在面板上显示的Width和Height,也没有具体单位,其实并不是像素。 事实上&am…

谷粒商城实战笔记-59-商品服务-API-品牌管理-使用逆向工程的前后端代码

文章目录 一, 使用逆向工程生成的代码二,生成品牌管理菜单三,几个小问题 在本次的技术实践中,我们利用逆向工程的方法成功地为后台管理系统增加了品牌管理功能。这种开发方式不仅能快速地构建起功能模块,还能在一定程度…

qemu上运行android-x86 (基于ubuntu)

安装qemu x86_64 下载qemu源代码 进入根目录,执行./configure --target-listx86_64-softmmu make -j4 sudo make install 期间可能遇到python相关问题,比如版本不对、库找不到,版本不对可自行安装或是修改环境变量,库找不到可以检…

如何检查代理IP地址是否被占用

使用代理IP时,有时候会发现IP仍然不可用,可能是因为已经被其他用户或者网络占用了。为了检测代理IP是否被占用,我们可以采用一些方法进行验证测试,以保证代理IP的有效性和稳定性。 1.ARP缓存方法 ARP缓存法是一种简单有效的检测代…

unity 实现图片的放大与缩小(根据鼠标位置拉伸放缩)

1创建UnityHelper.cs using UnityEngine.Events; using UnityEngine.EventSystems;public class UnityHelper {/// <summary>/// 简化向EventTrigger组件添加事件的操作。/// </summary>/// <param name"_eventTrigger">要添加事件监听的UI元素上…

将Android Library项目发布到JitPack仓库

将项目代码导入Github 1.将本地项目目录初始化为 Git 仓库。 默认情况下&#xff0c;初始分支称为 main; 如果使用 Git 2.28.0 或更高版本&#xff0c;则可以使用 -b 设置默认分支的名称。 git init -b main 如果使用 Git 2.27.1 或更低版本&#xff0c;则可以使用 git symbo…

ffmpeg更改视频的帧率

note 视频帧率调整 帧率(fps-frame per second) 例如&#xff1a;原来帧率为30&#xff0c;调整后为1 现象&#xff1a;原来是每秒有30张图像&#xff0c;调整后每秒1张图像&#xff0c;看着图像很慢 实现&#xff1a;在每秒的时间区间里&#xff0c;取一张图像…