机器学习7:pytorch的逻辑回归

一、说明

        逻辑回归模型是处理分类问题的最常见机器学习模型之一。二项式逻辑回归只是逻辑回归模型的一种类型。它指的是两个变量的分类,其中概率用于确定二元结果,因此“二项式”中的“bi”。结果为真或假 — 0 或 1。

        二项式逻辑回归的一个例子是预测人群中 COVID-19 的可能性。一个人要么感染了COVID-19,要么没有,必须建立一个阈值以尽可能准确地区分这些结果。

二、sigmoid函数

        这些预测不适合一条线,就像线性回归模型一样。相反,逻辑回归模型拟合到右侧所示的 sigmoid 函数。

        对于每个 x,生成的 y 值表示结果为 True 的概率。在 COVID-19 示例中,这表示医生对某人感染病毒的信心。在右图中,阴性结果为蓝色,阳性结果为红色。

图片来源:作者

三、过程

        要进行二项式逻辑回归,我们需要做各种事情:

  1. 创建训练数据集。
  2. 使用 PyTorch 创建我们的模型。
  3. 将我们的数据拟合到模型中。

        逻辑回归问题的第一步是创建训练数据集。首先,我们应该设置一个种子来确保我们的随机数据的可重复性。

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42)   # set a random seed

我们必须使用 PyTorch 的线性模型,因为我们正在处理一个输入 x 和一个输出 y。因此,我们的模型是线性的。为此,我们将使用 PyTorch 的函数:Linear

model = Linear(in_features=1, out_features=1) # use a linear model

接下来,我们必须生成蓝色 X 和红色 X 数据,确保将它们从行向量重塑为列向量。蓝色的在 0 到 7 之间,红色的在 7 到 10 之间。对于 y 值,蓝点表示 COVID-19 测试阴性,因此它们都将是

  1. 对于红点,它们代表 COVID-19 测试呈阳性,因此它们将为 1。下面是代码及其输出:
blue_x = (torch.rand(20) * 7).reshape(-1,1)   # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1)  # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x])   # matrix of x values
Y = torch.vstack([blue_y, red_y])   # matrix of y values

现在,我们的代码应如下所示:

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42)   # set a random seedmodel = Linear(in_features=1, out_features=1) # use a linear modelblue_x = (torch.rand(20) * 7).reshape(-1,1)   # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1)  # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x])   # matrix of x values
Y = torch.vstack([blue_y, red_y])   # matrix of y values

四、优化

        我们将使用梯度下降过程来优化 S 形函数的损失。损失是根据函数拟合数据的优度计算的,数据由 S 形曲线的斜率和截距控制。我们需要梯度下降来找到最佳斜率和截距。

        我们还将使用二进制交叉熵(BCE)作为我们的损失函数,或对数损失函数。对于一般的逻辑回归,不包含对数的损失函数将不起作用。

        为了实现BCE作为我们的损失函数,我们将它设置为我们的标准,并将随机梯度下降作为我们优化它的手段。由于这是我们将要优化的函数,我们需要传入模型参数和学习率。

epochs = 2000   # run 2000 iterations
criterion = nn.BCELoss()    # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descent

        现在,我们准备开始梯度下降以优化我们的损失。我们必须将梯度归零,通过将我们的数据插入 sigmoid 函数来找到 y-hat 值,计算损失,并找到损失函数的梯度。然后,我们必须迈出一步,确保存储我们的新斜率并为下一次迭代进行拦截。

optimizer.zero_grad()
Yhat = torch.sigmoid(model(X)) 
loss = criterion(Yhat,Y)
loss.backward()
optimizer.step() 

五、收尾

        为了找到最佳斜率和截距,我们本质上是在训练我们的模型。我们必须对多次迭代或纪元应用梯度下降。在此示例中,我们将使用 2,000 个纪元进行演示。

epochs = 2000   # run 2000 iterations
criterion = nn.BCELoss()    # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descentfor i in range(epochs):optimizer.zero_grad()Yhat = torch.sigmoid(model(X))loss = criterion(Yhat,Y)loss.backward()optimizer.step()print(f"epoch: {i+1}")print(f"loss: {loss: .5f}")print(f"slope: {model.weight.item(): .5f}")print(f"intercept: {model.bias.item(): .5f}")print()

将所有代码片段放在一起,我们应该得到以下代码:

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42)   # set a random seedmodel = Linear(in_features=1, out_features=1) # use a linear modelblue_x = (torch.rand(20) * 7).reshape(-1,1)   # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1)  # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x])   # matrix of x values
Y = torch.vstack([blue_y, red_y])   # matrix of y valuesepochs = 2000   # run 2000 iterations
criterion = nn.BCELoss()    # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descentfor i in range(epochs):optimizer.zero_grad()Yhat = torch.sigmoid(model(X))loss = criterion(Yhat,Y)loss.backward()optimizer.step()print(f"epoch: {i+1}")print(f"loss: {loss: .5f}")print(f"slope: {model.weight.item(): .5f}")print(f"intercept: {model.bias.item(): .5f}")print()
两千个时期后的最终输出:epoch: 2000
loss:  0.53861
slope:  0.61276
intercept: -3.17314

两千个时期后的最终输出:

epoch: 2000
loss:  0.53861
slope:  0.61276
intercept: -3.17314 

六、可视化

        最后,我们可以将数据与 sigmoid 函数一起绘制,以获得以下可视化效果:

x = np.arange(0,10,.1)
y = model.weight.item()*x + model.bias.item()plt.plot(x, 1/(1 + np.exp(-y)), color="green")plt.xlim(0,10)
plt.scatter(blue_x, blue_y, color="blue")
plt.scatter(red_x, red_y, color="red")plt.show()

图片来源:作者

七、局限性

        二元分类的最大问题之一是需要阈值。在逻辑回归的情况下,此阈值应为 x 值,其中 y 为 50%。我们试图回答的问题是将阈值放在哪里?

        在 COVID-19 测试的情况下,原始示例说明了这种困境。如果我们将阈值设置为 x=5,我们可以清楚地看到应该是红色的蓝点和应该是蓝色的红点。

        悬垂的红点称为误报,即模型错误地预测正类的区域。悬垂的蓝点称为假阴性 - 模型错误地预测负类的区域。

 八、结论

        成功的二项式逻辑回归模型将减少假阴性的数量,因为这些假阴性通常会导致最大的危险。患有COVID-19但检测呈阴性对他人的健康和安全构成严重风险。

        通过对可用数据使用二项式逻辑回归,我们可以确定放置阈值的最佳位置,从而有助于减少不确定性并做出更明智的决策。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/150773.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Scala第十七章节

Scala第十七章节 scala总目录 文档资料下载 章节目标 了解集合的相关概念掌握Traversable集合的用法掌握随机学生序列案例 1. 集合 1.1 概述 但凡了解过编程的人都知道程序 算法 数据结构这句话, 它是由著名的瑞士计算机科学家尼古拉斯沃斯提出来的, 而他也是1984年图灵…

车险计算器微信小程序源码 带流量主功能

车险计算器微信小程序源码带流量主功能,可以精准的算出车险的书目,是一个非常实用的微信小程序源码。 简单的计算让你得知车险价值 另外也支持流量主,具体小编也就不多说了,大家自己搭建研究吧。 源码下载:https://d…

React xlsx(工具库) 处理表头合并

前端导出excel表格 引入xlsx插件,不然应该是运行不起来的 npm i xlsx xlsx中文文档 插件2 exceljs npm i exceljs exceljs中文文档 导出 例子 import { ExportExcel } from ./exportExcel/index;const columns[{title: id,dataIndex: item1,},{title: 序号,dataInd…

开环模块化多电平换流器仿真(MMC)N=6(Simulink仿真)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

RabbitMQ之Fanout(扇形) Exchange解读

目录 基本介绍 适用场景 springboot代码演示 演示架构 工程概述 RabbitConfig配置类:创建队列及交换机并进行绑定 MessageService业务类:发送消息及接收消息 主启动类RabbitMq01Application:实现ApplicationRunner接口 基本介绍 Fa…

12.1 使用键盘鼠标监控钩子

本节将介绍如何使用Windows API中的SetWindowsHookEx和RegisterHotKey函数来实现键盘鼠标的监控。这些函数可以用来设置全局钩子,通过对特定热键挂钩实现监控的效果,两者的区别在于SetWindowsHookEx函数可以对所有线程进行监控,包括其他进程中…

Spring Cloud Gateway2之路由详解

Spring Cloud Gateway路由 文章目录 1. 前言2. Gateway路由的基本概念3. 三种路由1. 静态路由2. 动态路由1. 利用外部存储2. API动态路由 3. 服务发现路由(自动路由)3.1. 配置方式3.2 自动路由(服务发现)原理核心源码GatewayDiscoveryClientAutoConfigur…

机器学习算法基础--KNN分类算法

文章目录 1.KNN算法原理介绍2.KNN分类决策原则3.KNN度量距离介绍3.1.闵可夫斯基距离3.2.曼哈顿距离3.3.欧式距离 4.KNN分类算法实现5.KNN分类算法效果6.参考文章与致谢 1.KNN算法原理介绍 KNN(K-Nearest Neighbor)工作原理: 在一个存在标签的…

Springcloud支付模块

客户端消费者80 order(订单模块) 微服务提供者8001 payment(支付模块) 订单模块可以调动支付模块 步骤: 1、建moudle 2、改写pom 3、写yml 4、主启类 5、业务类 1、建父工程(创建spri…

架构师选择题--计算机网络

架构师选择题--计算机网络 22年考题21年考题20年考题19年真题2017考题 22年考题 d http:80 https:httpssl :443 b b pop3是邮件接收协议:110 SMTP是邮件发送协议:25 http:80 A 网络隔离:防火墙(逻辑)&…

C++(反向迭代器)

前言: 上一章我们介绍了适配器,也提了一下迭代器适配器,今天我们就从反向迭代器把迭代器适配器给解释一下。 既然 都叫迭代器容器了 就说名只要接口合适他可以封装实现各种容器需求包括vector list 。 目录 1.反向迭代器设计 1.1反向迭代…

Spring面试题学习: 单例Bean是单例模式吗?

单例Bean是单例模式吗 学习背景答案扩展知识单例模式Spring BeanJava Bean单例Bean 个人评价我的回答 学习背景 想换工作. 学习记录, 算是一个输出. 答案 通常来说, 单例模式是指在一个JVM中, 一个类只能构造出一个对象. 有很多方法来实现单例模式, 比如饿汉模式. 但是我们通…

【JVM】 类加载机制、类加载器、双亲委派模型详解

文章目录 前言一、类加载机制二、类加载器三、双亲委派模型总结 前言 📕各位读者好, 我是小陈, 这是我的个人主页 📗小陈还在持续努力学习编程, 努力通过博客输出所学知识 📘如果本篇对你有帮助, 烦请点赞关注支持一波, 感激不尽 &#x1f4d…

配置文件生成器-秒杀SSM的xml整合

配置文件生成器-秒杀SSM的xml整合 思路&#xff1a; 通过简单的配置&#xff0c;直接生成对应配置文件。 maven坐标 <dependencies><!-- 配置文件生成 --><dependency><groupId>org.freemarker</groupId><artifactId>freemarker<…

Java中栈实现怎么选?Stack、Deque、ArrayDeque、LinkedList(含常用Api积累)

目录 Java中的Stack类 不用Stack有以下两点原因 1、从性能上来说应该使用Deque代替Stack。 2、Stack从Vector继承是个历史遗留问题&#xff0c;JDK官方已建议优先使用Deque的实现类来代替Stack。 该用ArrayDeque还是LinkedList&#xff1f; ArrayDeque与LinkList区别&#xff1…

mysql MVCC(多版本并发控制)理解

最近看MVCC相关资料&#xff0c;这边做一个记录总结&#xff0c;方便后续理解。 目录 一、MVCC相关概念 二、MVCC实现原理 1.隐藏字段 2.undo log 3.Read View 4.MVCC的整体处理流程 5. RC&#xff0c;RR级级别下的innoDB快照读有什么不同 6.总结 一、MVCC相关概念 1…

Mysql——创建数据库,对表的创建及字段定义、数据录入、字段增加及删除、重命名表。

一.创建数据库 create database db_classics default charsetutf8mb4;//创建数据库 use db_classics;//使用该数据库二.对表的创建及字段定义 create table if not exists t_hero ( id int primary key auto_increment, Name varchar(100) not null unique, Nickname varchar(1…

机器视觉工程师,公司设置奖金,真的为了奖励你吗?其实和你没关系

​据说某家大厂&#xff0c;超额罚款&#xff0c;有奖有罚很正常&#xff0c;但是我觉得你罚款代理商员工就不一样了&#xff0c;把代理商当成你的员工&#xff0c;我就觉得这些大厂的脑回路有问题。 有人从来没听说过项目奖金&#xff0c;更没有奖金。那么为什么设置奖金呢&a…

深入浅出,SpringBoot整合Quartz实现定时任务与Redis健康检测(二)

前言 在上一篇深入浅出&#xff0c;SpringBoot整合Quartz实现定时任务与Redis健康检测&#xff08;一&#xff09;_往事如烟隔多年的博客-CSDN博客 文章中对SpringBoot整合Quartz做了初步的介绍以及提供了一个基本的使用例子&#xff0c;因为实际各自的需求任务不尽相同因此并…

Neo4j深度学习

Neo4j的简介 Neo4j是用Java实现的开源NoSQL图数据库。从2003年开始开发&#xff0c;2007年正式发布第一版&#xff0c;其源码托管于GitHtb。Neo4j作为图数据库中的代表产品&#xff0c;已经在众多的行业项目中进行了应用&#xff0c;如&#xff1a;网络管理、软件分析、组织和…