使用 PyTorch 实现逻辑回归并评估模型性能

1. 逻辑回归简介

逻辑回归是一种用于解决二分类问题的算法。它通过一个逻辑函数(Sigmoid 函数)将线性回归的输出映射到 [0, 1] 区间内,从而将问题转化为概率预测问题。如果预测概率大于 0.5,则将样本分类为正类;否则分类为负类。

2. 数据准备

为了演示逻辑回归的效果,我们构造了一个简单的二维数据集,包含两类样本。每类样本有 7 个数据点,特征维度为 2。

class1_points = np.array([[1.9, 1.2],[1.5, 2.1],[1.9, 0.5],[1.5, 0.9],[0.9, 1.2],[1.1, 1.7],[1.4, 1.1]])class2_points = np.array([[3.2, 3.2],[3.7, 2.9],[3.2, 2.6],[1.7, 3.3],[3.4, 2.6],[4.1, 2.3],[3.0, 2.9]])

我们将这两类数据点的特征合并,并为每个数据点分配标签(0 表示第一类,1 表示第二类)。

3. 模型构建

我们使用 PyTorch 框架来实现逻辑回归模型。模型结构非常简单,仅包含一个线性层和一个 Sigmoid 激活函数。

class LogisticRegression(nn.Module):def __init__(self):super(LogisticRegression, self).__init__()self.linear = nn.Linear(2, 1)  # 输入特征维度为 2,输出为 1def forward(self, x):return torch.sigmoid(self.linear(x))

4. 模型训练

我们使用二分类交叉熵损失函数(BCELoss)和随机梯度下降优化器(SGD)来训练模型。训练过程如下:

epochs = 5000
for epoch in range(epochs):model.train()optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')

训练过程中,我们每 100 个 epoch 打印一次损失值,以便观察模型的收敛情况。

5. 模型保存与加载

训练完成后,我们将模型的权重保存到文件中,方便后续加载和使用。

torch.save(model.state_dict(), 'model3.pth')
print("模型已保存")

加载模型时,我们创建一个新的模型实例,并使用 load_state_dict 方法加载保存的权重。

loaded_model = LogisticRegression()
loaded_model.load_state_dict(torch.load('model3.pth', map_location=torch.device('cpu')))
loaded_model.eval()

6. 模型预测与性能评估

加载模型后,我们使用模型对训练数据进行预测,并计算精确度、召回率和 F1 分数。

with torch.no_grad():predictions = loaded_model(X)predicted_labels = (predictions > 0.5).float()print("实际结果:", y.numpy().flatten())
print("预测结果:", predicted_labels.numpy().flatten())precision = precision_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
recall = recall_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
f1 = f1_score(y.numpy().flatten(), predicted_labels.numpy().flatten())print(f"精确度(Precision): {precision:.4f}")
print(f"召回率(Recall): {recall:.4f}")
print(f"F1 分数: {f1:.4f}")

7. 运行结果

运行上述代码后,我们得到了以下结果:

  • 实际结果:[0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]

  • 预测结果:[0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]

  • 精确度(Precision):1.0000

  • 召回率(Recall):1.0000

  • F1 分数:1.0000

从结果可以看出,模型在训练集上表现良好,精确度、召回率和 F1 分数均为 1.0000。

8. 完整代码

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score"""使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数"""
# 提取特征和标签
class1_points = np.array([[1.9, 1.2],[1.5, 2.1],[1.9, 0.5],[1.5, 0.9],[0.9, 1.2],[1.1, 1.7],[1.4, 1.1]])class2_points = np.array([[3.2, 3.2],[3.7, 2.9],[3.2, 2.6],[1.7, 3.3],[3.4, 2.6],[4.1, 2.3],[3.0, 2.9]])# 提取两类特征,输入特征维度为2
x1_data = np.concatenate((class1_points[:, 0], class2_points[:, 0]), axis=0)
x2_data = np.concatenate((class1_points[:, 1], class2_points[:, 1]), axis=0)
label = np.concatenate((np.zeros(len(class1_points)), np.ones(len(class2_points))), axis=0)# 将数据转换为 PyTorch 张量
X = torch.tensor(np.column_stack((x1_data, x2_data)), dtype=torch.float32)
y = torch.tensor(label, dtype=torch.float32).view(-1, 1)# 定义逻辑回归模型
class LogisticRegression(nn.Module):def __init__(self):super(LogisticRegression, self).__init__()self.linear = nn.Linear(2, 1)  # 输入特征维度为 2,输出为 1def forward(self, x):return torch.sigmoid(self.linear(x))# 初始化模型、损失函数和优化器
model = LogisticRegression()
criterion = nn.BCELoss()  # 二分类交叉熵损失
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
epochs = 5000
for epoch in range(epochs):model.train()optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')# 保存模型
torch.save(model.state_dict(), 'model3.pth')
print("模型已保存")# 加载模型
loaded_model = LogisticRegression()
loaded_model.load_state_dict(torch.load('model3.pth',map_location=torch.device('cpu'),weights_only=True))
loaded_model.eval()# 进行预测
with torch.no_grad():predictions = loaded_model(X)predicted_labels = (predictions > 0.5).float()# 展示预测结果和实际结果
print("实际结果:", y.numpy().flatten())
print("预测结果:", predicted_labels.numpy().flatten())# 计算精确度、召回率和 F1 分数
precision = precision_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
recall = recall_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
f1 = f1_score(y.numpy().flatten(), predicted_labels.numpy().flatten())print(f"精确度(Precision): {precision:.4f}")
print(f"召回率(Recall): {recall:.4f}")
print(f"F1 分数: {f1:.4f}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/11687.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用PyQt5绘制带有刻度的温度计控件

前言:进入学习Python开发上位机界面的第二阶段,学习如何开发自定义控件,从常用的控件入手学习,本期主要学习如何使用PyQt5绘制带有刻度的温度计控件。 1. 先找到一篇参考文章 参考文章:Qt编写自定义控件5-柱状温度计…

问deepseek,如何看待ai降低学习成本而导致软件开发岗位需求降低,和工资下降。 软件从业人员何去何从?

它给我的回答是这样的: 思考逻辑 嗯,用户问的是AI如何降低学习成本,进而导致软件开发岗位需求减少和工资下降,以及软件从业人员该怎么办。这个问题挺复杂的,我得先理清楚各个部分。首先,AI确实在改变很多行…

Error: Expected a mutable image

你的函数用了不支持的图片格式比如我的人脸检测,本来要RGB565我却用JPEG所以报错

海思ISP开发说明

1、概述 ISP(Image Signal Processor)图像信号处理器是专门用于处理图像信号的硬件或处理单元,广泛应用于图像传感器(如 CMOS 或 CCD 传感器)与显示设备之间的信号转换过程中。ISP通过一系列数字图像处理算法完成对数字…

2.攻防世界PHP2及知识点

进入题目页面如下 意思是你能访问这个网站吗? ctrlu、F12查看源码,什么都没有发现 用kali中的dirsearch扫描根目录 命令如下,根据题目提示以及需要查看源码,扫描以php、phps、html为后缀的文件 dirsearch -u http://61.147.17…

线性数据结构:单向链表

放弃眼高手低,你真正投入学习,会因为找到一个新方法产生成就感,学习不仅是片面的记单词、学高数......只要是提升自己的过程,探索到了未知,就是学习。 目录 一.链表的理解 二.链表的分类(重点理解&#xf…

【AI】探索自然语言处理(NLP):从基础到前沿技术及代码实践

Hi ! 云边有个稻草人-CSDN博客 必须有为成功付出代价的决心,然后想办法付出这个代价。 目录 引言 1. 什么是自然语言处理(NLP)? 2. NLP的基础技术 2.1 词袋模型(Bag-of-Words,BoW&#xff…

书生大模型实战营7

文章目录 L1——基础岛提示词工程实践什么是Prompt(提示词)什么是提示工程提示设计框架CRISPECO-STAR LangGPT结构化提示词LangGPT结构编写技巧构建全局思维链保持上下文语义一致性有机结合其他 Prompt 技巧 常用的提示词模块 浦语提示词工程实践(LangGPT版)自动化生成LangGPT提…

一个开源 GenBI AI 本地代理(确保本地数据安全),使数据驱动型团队能够与其数据进行互动,生成文本到 SQL、图表、电子表格、报告和 BI

一、GenBI AI 代理介绍(文末提供下载) github地址:https://github.com/Canner/WrenAI 本文信息图片均来源于github作者主页 在 Wren AI,我们的使命是通过生成式商业智能 (GenBI) 使组织能够无缝访问数据&…

41. 缺失的第一个正数

参考题解:https://leetcode.cn/problems/first-missing-positive/solutions/7703/tong-pai-xu-python-dai-ma-by-liweiwei1419 难点在于时间复杂度控制在O(n),空间复杂度为常数级。 哈希表时间复杂度符合,但是空间复杂度为O(n) 排序空间复杂…

深入核心:一步步手撕Tomcat搭建自己的Web服务器

介绍: servlet:处理 http 请求 tomcat:服务器 Servlet servlet 接口: 定义 Servlet 声明周期初始化:init服务:service销毁:destory 继承链: Tomcat Tomcat 和 servlet 原理&#x…

final-关键字

一、final修饰的类不能被继承 当final修饰一个类时,表明这个类不能被其他类继承。例如,在 Java 中,String类就是被final修饰的,这保证了String类的不可变性和安全性,防止其他类通过继承来改变String类的行为。 final…

51单片机 01 LED

一、点亮一个LED 在STC-ISP中单片机型号选择 STC89C52RC/LE52RC;如果没有找到hex文件(在objects文件夹下),在keil中options for target-output- 勾选 create hex file。 如果要修改编程 :重新编译-下载/编程-单片机重…

知识库建设与知识管理实践对企业发展的助推作用探索

内容概要 在当今瞬息万变的商业环境中,知识库建设与知识管理实践日益成为企业发展的重要驱动力。知识库作为组织内信息和知识的集成,起着信息存储、整理和共享的关键作用。通过有效的知识库建设,企业不仅能够提升员工获取信息的便利性&#…

【Pytorch和Keras】使用transformer库进行图像分类

目录 一、环境准备二、基于Pytorch的预训练模型1、准备数据集2、加载预训练模型3、 使用pytorch进行模型构建 三、基于keras的预训练模型四、模型测试五、参考 现在大多数的模型都会上传到huggface平台进行统一的管理,transformer库能关联到huggface中对应的模型&am…

如何使用 DeepSeek 和 Dexscreener 构建免费的 AI 加密交易机器人?

我使用DeepSeek AI和Dexscreener API构建的一个简单的 AI 加密交易机器人实现了这一目标。在本文中,我将逐步指导您如何构建像我一样的机器人。 DeepSeek 最近发布了R1,这是一种先进的 AI 模型。您可以将其视为 ChatGPT 的免费开源版本,但增加…

ArkTS渲染控制

文章目录 if/else:条件渲染ArkUI通过自定义组件的build()函数和@Builder装饰器中的声明式UI描述语句构建相应的UI。在声明式描述语句中开发者除了使用系统组件外,还可以使用渲染控制语句来辅助UI的构建,这些渲染控制语句包括控制组件是否显示的条件渲染语句,基于数组数据快…

potplayer字幕

看视频学习,实时字幕可以快速过滤水字数阶段,提高效率,但是容易错过一些信息。下面就是解决这一问题。 工具ptoplayer 一.生成字幕 打开学习视频,右键点击视频画面,点选字幕。勾选显示字幕。点选创建有声字幕&#…

deepseek的两种本地使用方式

总结来说 ollama是命令行 GPT4ALL桌面程序。 然后ollamaAnythingLLM可以达到桌面或web的两种接入方式。 一. ollama和deepseek-r1-1.5b和AnythingLLM 本文介绍一个桌面版的deepseek的本地部署过程,其中ollama可以部署在远程。 1. https://www.cnblogs.com/janeysj/p…

海外问卷调查渠道查,如何影响企业的运营

我们注意到,随着信息资源和传播的变化,海外问卷调查渠道查已发生了深刻的变化。几年前,市场调研是业内专家们的事,即使是第二手资料也需要专业人士来完成;但如今的因特网和许许多多的信息数据库,使每个人都…