【机器学习】机器学习的基本分类-监督学习-逻辑回归(Logistic Regression)

逻辑回归是一种分类算法,尽管名字中包含“回归”,但其主要用于解决二分类和多分类问题。它通过学习一个逻辑函数,预测输入属于某个类别的概率。


1. 逻辑回归的基本概念

目标

逻辑回归的目标是找到一个函数 h(x),输出一个概率值 P(y=1|x),表示输入样本 x 属于正类的概率。

逻辑函数(Sigmoid 函数)

逻辑回归使用 Sigmoid 函数将线性回归的结果映射到 (0, 1) 之间:

h(x) = \frac{1}{1 + e^{-z}}, \quad z = w^T x + b

其中:

  • z 是线性模型的结果。
  • h(x) 是预测为正类的概率。

【机器学习】机器学习的基本分类-监督学习-逻辑回归-Sigmoid 函数-CSDN博客


2. 逻辑回归的损失函数

为了优化模型参数 w 和 b,逻辑回归最小化 对数似然损失函数

L(w, b) = -\frac{1}{m} \sum_{i=1}^m \left[ y_i \log(h(x_i)) + (1 - y_i) \log(1 - h(x_i)) \right]

其中:

  • m 是样本数量。
  • y_i \in \{0, 1\}是第 i 个样本的真实标签。

逻辑回归通过梯度下降或其他优化算法最小化该损失函数。

【机器学习】机器学习的基本分类-监督学习-逻辑回归-对数似然损失函数(Log-Likelihood Loss Function)-CSDN博客


3. 逻辑回归的假设

  1. 数据集中的样本是独立的。
  2. 输入特征和目标变量之间是线性可分的(通过特征变换,可以扩展到非线性问题)。

4. Python 实现

4.1 数据生成

我们以二分类任务为例:

from sklearn.datasets import make_classification
import matplotlib.pyplot as plt# 生成二分类数据
# 参数说明:
# n_samples=100: 生成100个样本
# n_features=4: 每个样本有4个特征
# n_classes=2: 分为2个类别
# n_informative=2: 有2个特征是信息特征,对分类有帮助
# n_redundant=1: 有1个特征是冗余特征,对分类无直接帮助
# n_repeated=0: 没有重复的特征
# random_state=0: 设置随机种子,保证结果可重复
X, y = make_classification(n_samples=100, n_features=4, n_classes=2, n_informative=2, n_redundant=1, n_repeated=0,random_state=0)# 可视化生成的数据
# 这里只绘制了前两个特征,因为高维数据无法直接可视化
# c=y: 根据类别y上色
# cmap='viridis': 使用'viridis'颜色地图
# edgecolor='k': 设置点的边缘颜色为黑色
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='viridis', edgecolor='k')
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.title("Generated Data")
plt.show()

4.2 使用 Scikit-learn 实现
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report# 生成二分类数据
# 参数说明:n_samples=100表示生成100个样本,n_features=4表示数据有4个特征,n_classes=2表示二分类问题,
# n_informative=2表示其中2个特征是有信息的,n_redundant=1表示1个特征是冗余的,n_repeated=0表示没有重复的特征,
# random_state=0表示随机种子,保证结果可重复
X, y = make_classification(n_samples=100, n_features=4, n_classes=2, n_informative=2, n_redundant=1, n_repeated=0,random_state=0)# 数据集划分
# 将数据集划分为训练集和测试集,test_size=0.2表示测试集占20%,random_state=42保证划分结果可重复
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 逻辑回归模型
# 初始化逻辑回归模型
model = LogisticRegression()
# 使用训练集数据拟合模型
model.fit(X_train, y_train)# 模型预测
# 使用拟合好的模型对测试集进行预测
y_pred = model.predict(X_test)# 评估模型
# 输出模型的准确率
print("Accuracy:", accuracy_score(y_test, y_pred))
# 输出模型的混淆矩阵
print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
# 输出模型的分类报告,包括精确度、召回率、F1值等指标
print("Classification Report:\n", classification_report(y_test, y_pred))

输出结果

Accuracy: 0.9
Confusion Matrix:[[9 2][0 9]]
Classification Report:precision    recall  f1-score   support0       1.00      0.82      0.90        111       0.82      1.00      0.90         9accuracy                           0.90        20macro avg       0.91      0.91      0.90        20
weighted avg       0.92      0.90      0.90        20
4.3 自定义实现

使用 NumPy 手动实现逻辑回归的梯度下降:

import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score# 生成二分类数据
# 参数说明:n_samples=100表示生成100个样本,n_features=4表示数据有4个特征,n_classes=2表示二分类问题,
# n_informative=2表示其中2个特征是有信息的,n_redundant=1表示1个特征是冗余的,n_repeated=0表示没有重复的特征,
# random_state=0表示随机种子,保证结果可重复
X, y = make_classification(n_samples=100, n_features=4, n_classes=2, n_informative=2, n_redundant=1, n_repeated=0,random_state=0)# 数据集划分
# 将数据集划分为训练集和测试集,test_size=0.2表示测试集占20%,random_state=42保证划分结果可重复
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# Sigmoid 函数
# 用于将预测值映射到0到1之间,作为概率解释
def sigmoid(z):return 1 / (1 + np.exp(-z))# 损失函数
# 计算模型预测值与真实值之间的交叉熵损失
def compute_loss(y, y_pred):return -np.mean(y * np.log(y_pred) + (1 - y) * np.log(1 - y_pred))# 梯度下降实现逻辑回归
# 参数说明:X为特征数据,y为目标标签,learning_rate为学习率,n_iterations为迭代次数
# 返回值说明:weights为权重参数,bias为偏置参数
def logistic_regression(X, y, learning_rate=0.01, n_iterations=1000):m, n = X.shapeweights = np.zeros(n)bias = 0for _ in range(n_iterations):# 计算预测值z = np.dot(X, weights) + biasy_pred = sigmoid(z)# 计算梯度dw = (1 / m) * np.dot(X.T, (y_pred - y))db = (1 / m) * np.sum(y_pred - y)# 更新参数weights -= learning_rate * dwbias -= learning_rate * dbreturn weights, bias# 使用模型
# 训练逻辑回归模型,得到权重参数和偏置参数
weights, bias = logistic_regression(X_train, y_train)
# 对测试集进行预测
z = np.dot(X_test, weights) + bias
y_pred = (sigmoid(z) >= 0.5).astype(int)# 输出准确率
print("Accuracy:", accuracy_score(y_test, y_pred))

输出结果

Accuracy: 0.9

5. 多分类逻辑回归

逻辑回归也可以扩展到多分类问题,通过 一对多(One-vs-Rest)策略多项逻辑回归(Softmax Regression) 来处理多类别任务。

Softmax 函数

P(y=k|x) = \frac{e^{z_k}}{\sum_{j=1}^K e^{z_j}}

其中 z_k = w_k^T x + b_k,是类别 k 的得分。


6. 优缺点

优点
  1. 简单易用:实现简单,计算速度快。
  2. 结果解释性强:可以直接查看每个特征的权重。
  3. 适合小数据集:在小规模数据上效果良好。
缺点
  1. 线性可分性假设:当数据不可线性分割时,效果较差。
  2. 对异常值敏感:异常数据会极大影响模型。
  3. 扩展性有限:不能直接处理复杂的非线性关系。

7. 应用场景

  1. 二分类问题
    • 邮件分类(垃圾邮件/正常邮件)。
    • 医疗诊断(患病/健康)。
  2. 多分类问题
    • 文本情感分析(积极/消极/中立)。
    • 图片分类(猫/狗/鸟)。

8. 优化与改进

  1. 特征工程:通过多项式特征扩展、标准化等提高模型效果。
  2. 正则化:通过 L1(Lasso)或 L2(Ridge)正则化减少过拟合。
    • L1 可用于特征选择。
    • L2 提高模型的泛化能力。
  3. 核方法:结合核技巧(如支持向量机)处理非线性问题。

拓展内容

【机器学习】分类任务: 二分类与多分类-CSDN博客

【机器学习】机器学习的基本分类-监督学习-逻辑回归-Sigmoid 函数-CSDN博客

【机器学习】机器学习的基本分类-监督学习-逻辑回归-对数似然损失函数(Log-Likelihood Loss Function)-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/480800.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PyMOL操作手册

PyMOL 操作手册 The man will be silent, the woman will be tears. – itwangyang ​ 翻译整理:itwangyanng 2024 年 11月 29 日 目录 初识 PyMOL… 5 0.1 安装 PyMOL… 5 0.1.1 Windows 系统开源版 PyMOL 的安装… 5 0.1.2 教育版 PyMOL 的下载安装……

麒麟系统x86安装达梦数据库

一、安装准备前工作 操作系统:银河麒麟V10,CPU: x86_64 架构 下载地址,麒麟官网:https://www.kylinos.cn/ 数据库:dm8_20220915_x86_kylin10_64 下载地址,达梦数据库官网:https://…

Hot100 - 搜索二维矩阵II

Hot100 - 搜索二维矩阵II 最佳思路: 利用矩阵的特性,针对搜索操作可以从右上角或者左下角开始。通过判断当前位置的元素与目标值的关系,逐步缩小搜索范围,从而达到较高的效率。 从右上角开始:假设矩阵是升序排列的&a…

docker服务容器化

docker服务容器化 1 引言2 多个容器间网络联通2.1 单独创建关联2.2 创建时关联 3 服务搭建3.1 镜像清单3.2 容器创建 4 联合实战4.2 flink_sql之kafka到starrocks4.2 flink_sql之mysql到starrocks 5 文献借鉴 1 引言 ​ 利用docker可以很效率地搭建服务,本文在win1…

011变长子网掩码

变长子网掩码: 使用变长子网掩码(VLSM)优化地址分配 目标: 根据需求使用VLSM分配IP地址,减少浪费,并配置静态路由。 网络拓扑 创建一个包含三台路由器(R1、R2、R3)和五台PC&#x…

SpringBoot小知识(2):日志

日志是开发项目中非常重要的一个环节,它是程序员在检查程序运行的手段之一。 1.日志的基础操作 1.1 日志的作用 编程期调试代码运营期记录信息: * 记录日常运营重要信息(峰值流量、平均响应时长……) * 记录应用报错信息(错误堆栈) * 记录运维过程数据(…

大数据新视界 -- 大数据大厂之 Hive 数据安全:权限管理体系的深度解读(上)(15/ 30)

💖💖💖亲爱的朋友们,热烈欢迎你们来到 青云交的博客!能与你们在此邂逅,我满心欢喜,深感无比荣幸。在这个瞬息万变的时代,我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

智能探针技术:实现可视、可知、可诊的主动网络运维策略

网络维护的重要性 网络运维是确保网络系统稳定、高效、安全运行的关键活动。在当今这个高度依赖信息技术的时代,网络运维的重要性不仅体现在技术层面,更关乎到企业运营的方方面面。网络运维具有保障网络的稳定性、提升网络运维性能、降低企业运营成本等…

RT-DETR融合Inner-IoU及相关改进思路

RT-DETR使用教程: RT-DETR使用教程 RT-DETR改进汇总贴:RT-DETR更新汇总贴 《Inner-IoU: More Effective Intersection over Union Loss with Auxiliary Bounding Box》 一、 模块介绍 论文链接:https://arxiv.org/abs/2311.02877 代码链接&a…

在Springboot项目中实现将文件上传至阿里云 OSS

oss介绍 阿里云对象存储服务(OSS)是一种高效、安全和成本低廉的数据存储服务,可以用来存储和管理海量的数据文件。本文将教你如何使用 Java 将文件上传到阿里云 OSS,并实现访问文件。 1. 准备工作 1.1 开通 OSS 服务 登录阿里云…

Java项目中加缓存

Java项目中加缓存 1.更新频率低;但读写频率高的数据很适合加缓存; 2.可以加缓存的地方很多:浏览器的缓存;CDN的缓存;服务器的缓存; 本地内存;分布式远端缓存; 加缓存的时候不要…

Vuex —— Day1

vuex概述 vuex是vue的状态管理工具,可以帮我们管理vue通用的数据(多组件共享的数据) vuex的应用场景: 某个状态在很多个组件中都会使用(eg.个人信息)多个组件共同维护一份数据(eg.购物车&…

【前端】Next.js 服务器端渲染(SSR)与客户端渲染(CSR)的最佳实践

关于Next.js 服务器端渲染(SSR)与客户端渲染(CSR)的实践内容方面,我们按下面几点进行阐述。 1. 原理 服务器端渲染 (SSR): 在服务器上生成完整的HTML页面,然后发送给客户端。这使得用户在首次访问时能够…

基于FPGA的FM调制(载波频率、频偏、峰值、DAC输出)-带仿真文件-上板验证正确

基于FPGA的FM调制-带仿真文件-上板验证正确 前言一、FM调制储备知识载波频率频偏峰值个人理解 二、代码分析1.模块分析2.波形分析 总结 前言 FM、AM等调制是学习FPGA信号处理一个比较好的小项目,通过学习FM调制过程熟悉信号处理的一个简单流程,进而熟悉…

Scala学习记录,统计成绩

统计成绩练习 1.计算每个同学的总分和平均分 2.统计每个科目的平均分 3.列出总分前三名和单科前三名,并保存结果到文件中 解题思路如下: 1.读入txt文件,按行读入 2.处理数据 (1)计算每个同学的总分平均分 import s…

路由策略与路由控制实验

AR1、AR2、AR3在互联接口、Loopback0接口上激活OSPF。AR3、AR4属于IS-IS Area 49.0001,这两者都是Level-1路由器,AR3、AR4的系统ID采用0000.0000.000x格式,其中x为设备编号 AR1上存在三个业务网段A、B、C(分别用Loopback1、2、3接…

第J7周:对于RenseNeXt-50算法的思考

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 文章目录 一、前言1、导入包2、分组卷积模块3、残差单元4、堆叠残差单元5、搭建ResNeXt-50网络 二、问题思考 电脑环境: 语言环境:Pyth…

某充电桩业务服务内存监控和程序行为分析

原作者:展贝 原文地址:https://mp.weixin.qq.com/s/nnYCcVtwowvmj7Zn9XLIUg 在当今数据驱动的环境中,理解内存指标和程序行为对于确保应用程序的性能和可靠性至关重要。在依赖实时数据处理和高可用性的行业中尤其如此。通过利用可观测工具&am…

基于SpringBoot共享汽车管理系统【附源码】

基于SpringBoot共享汽车管理系统 效果如下: 系统注册页面 系统登陆页面 系统管理员主页面 用户信息管理页面 汽车投放管理页面 使用订单页面 汽车归还管理页面 研究背景 随着计算机技术和计算机网络的逐渐普及,互联网成为人们查找信息的重要场所。二十…

计算机网络基础(2):网络安全/ 网络通信介质

1. 网络安全威胁 网络安全:目的就是要让网络入侵者进不了网络系统,及时强行攻入网络,也拿不走信息,改不了数据,看不懂信息。 事发后能审查追踪到破坏者,让破坏者跑不掉。 网络威胁来自多方面&#xff1a…