【机器学习】Lesson3 - 逻辑回归(LR)二分类

目录

背景

一、适用数据集

1. 数据集选择

1.1 领域

1.2 数据集维度

1.3 记录行(样本数量)

2. 本文数据集介绍

3. 数据集下载

注意

二、逻辑回归的基本原理

1. 目的

2. Sigmoid 函数

3. 类别划分

4. 召回率

三、代码

1. 导入所需包&数据集

2. 数据预处理&数据编码

2.1 数据预处理

2.2 数据预处理&编码(Titanic dataset版,可跳过)

3. 绘制热力图

4. 数据拆分&标准化

5. 逻辑回归分类算法

6. 绘制 PR / ROC 曲线

6.1 PR 曲线

6.2 ROC 曲线


背景

逻辑回归(Logistic Regression)是一种用于分类问题的统计模型,尽管名字里有“回归”,但它主要用于解决二元分类(binary classification)或多类分类问题。与线性回归不同,逻辑回归输出的是一个概率,通过设置阈值将概率转换为类别。

一、适用数据集

1. 数据集选择

逻辑回归适用于二分类问题,即输出为两个类别(如【是/否】或【正/负】)。若要练习逻辑回归分析,在选择数据集时可以参考以下 3 个方面。反过来,如果在为项目选择合适的算法模型时,数据集符合以下条件,则可以选用逻辑回归进行数据分析。

1.1 领域

逻辑回归广泛应用于各种领域,包括但不限于:

  • 医疗健康:预测某人是否患病(如患病/不患病,阳性/阴性)。
  • 金融:预测贷款申请是否会被批准,或是否会违约(如信用风险评估)。
  • 市场营销:根据用户行为预测是否会购买产品或订阅服务。
  • 社交媒体:通过用户行为预测某个帖子是否会获得点赞或分享。
  • 人力资源:预测员工是否会离职。
  • 教育:预测学生是否会通过考试或被录取。

1.2 数据集维度

  • 特征数(维度):逻辑回归适用于中低维度的数据集。一般情况下,输入特征数从 2 到 20 维是常见的范围。过多的特征可能会导致过拟合,尤其是当样本数量不足时。
  • 类别标签:逻辑回归特别适用于二分类任务,因此类别标签只有两个(如 0 和 1,或者 “A 类” 和 “B 类”)。

1.3 记录行(样本数量)

样本量大小:逻辑回归的一个优势是即使在较小的数据集上也能有效工作,通常数百到几千条记录的样本集即可获得良好的训练效果。典型样本数量可以在 100 到 10,000 行之间:

  • 小型数据集(100-1000 行):适合初学者理解逻辑回归的基本概念和快速训练模型。
  • 中型数据集(1000-10,000 行):适合做一些更复杂的练习,如交叉验证、正则化等。

2. 本文数据集介绍

Titanic 数据集 基于 1912 年泰坦尼克号沉船事故的真实乘客数据,泰坦尼克号是当时最大的远洋客轮之一,但在其首次航行中不幸撞上冰山沉没,超过1500人遇难。该数据集包含了泰坦尼克号部分乘客的个人信息以及他们是否在船难中幸存,分析目的是:预测哪些乘客能够在这场灾难中幸存下来

数据集中包含 12 列的维度信息,详细如下:

  • PassengerId: 乘客的唯一标识符。
  • Survived: 是否幸存,目标变量(0 = 未幸存,1 = 幸存)。
  • Pclass: 乘客舱位等级(1 = 一等舱,2 = 二等舱,3 = 三等舱)。
  • Name: 乘客姓名。
  • Sex: 性别(male = 男性,female = 女性)。
  • Age: 年龄(有部分缺失值)。
  • SibSp: 同行的兄弟姐妹或配偶人数。
  • Parch: 同行的父母或子女人数。
  • Ticket: 票号。
  • Fare: 票价。
  • Cabin: 船舱号(有部分缺失值)。
  • Embarked: 登船地点(C = Cherbourg, Q = Queenstown, S = Southampton)。

数据记录行共有 1309 行,属于中型数据集。

3. 数据集下载

数据集下载地址:https://www.kaggle.com/datasets/heptapod/titanic

也可以在文首 绑定资源 中下载获取,原版下载出来有多列是无意义的【zero】列,作者这里直接在表里删了,上传资源为处理后的。自行在 Kaggle 下载的朋友记得处理【zero】列。

注意

Kaggle 上 Titanic 数据集有俩都点赞比较高,不要选择看起来更完整的 Titanic dataset!

博主先使用了 Titanic dataset,因为预测结果全都是 100%,于是开始 debug。。。

直到发现这个数据集的获救与否是由性别决定的(dead。。)

重复:千万要使用对的 Titanic 数据集啊!

二、逻辑回归的基本原理

1. 目的

逻辑回归的目的是预测输入数据属于某一类别的概率。它最常用于二元分类问题,例如:预测某个乘客在 Titanic 数据集中是否生还。

2. Sigmoid 函数

Sigmoid 函数的公式:

\sigma(z) = \frac{1}{1 + e^{-z}}

逻辑回归模型的核心是 Sigmoid 函数,它将线性回归的输出值(任何实数)压缩到 [0, 1] 的范围内,这样我们可以将输出解释为概率。

其中,z = w^T X + b ,即将特征向量 X 经过线性组合(权重 w 和偏置 b)后输入到 Sigmoid 函数中。

通过这种方式,我们将线性回归输出的值(通常是实数)转换为一个概率值:

P(y=1|X)

3. 类别划分

逻辑回归通常使用 0.5 作为默认的阈值,将输出的概率值转换为二元分类的类别:

\hat{y} = \begin{cases} 1 & \text{if } P(y=1|X) > 0.5 \\ 0 & \text{if } P(y=1|X) \leq 0.5 \end{cases}

这个阈值可以根据具体的需求调整。

4. 召回率

召回率(Recall)是分类模型评估指标之一,尤其在 二元分类 问题中常用。召回率表示在所有实际为正例(Positive)的样本中,模型正确识别出的正例样本的比例。

公式:

\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}

三、代码

完整版代码和初步处理的数据集见文章绑定资源

1. 导入所需包&数据集

#数据分析与可视化包
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as pltplt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号from sklearn.model_selection import learning_curve, train_test_split
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
from sklearn.linear_model import LogisticRegression #逻辑回归
from sklearn.metrics import roc_curve, auc, precision_recall_curve
from sklearn.metrics import log_loss
from sklearn.pipeline import make_pipelinefrom sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
#accuracy_score,准确率
#f1_score,精确率和召回率的调和平均数
#precision_score,精确率
#recall_score,召回率from sklearn.metrics import confusion_matrix, classification_report
#confusion_matrix混淆矩阵
#classification_report任务性能分析Titanic_data=pd.read_csv("D:/project/Jupyter/csdn/AI_ML/datasets/Lesson3_Titanic.csv")Titanic_dataTitanic_data.info()

2. 数据预处理&数据编码

2.1 数据预处理

# 将 float 类型转换为 int 类型
Titanic_data['Age'] = Titanic_data['Age'].astype('int64')
Titanic_data['Fare'] = Titanic_data['Fare'].astype('int64')# 用众数填充 Embarked(登船地点)列缺失值
Titanic_data['Embarked'].fillna(Titanic_data['Embarked'].mode()[0], inplace=True)# 将 float 类型转换为 int 类型
Titanic_data['Embarked'] = Titanic_data['Embarked'].astype('int64')

2.2 数据预处理&编码(Titanic dataset版,可跳过)

本部分是 Titanic dataset 的预处理,相对完整。参考一下就行,在 Titanic 数据集用不到。

# 填补缺失值
# Age(年龄)列用中位数填充,Embarked(登船地点)列用众数填充
# 按乘客的票舱等级(Pclass)进行分组,分别计算每个等级的均值填补 Fare(票价)列
Titanic_data['Age'].fillna(Titanic_data['Age'].median(), inplace=True)
Titanic_data['Embarked'].fillna(Titanic_data['Embarked'].mode()[0], inplace=True)
Titanic_data['Fare'] = Titanic_data.groupby('Pclass')['Fare'].apply(lambda x: x.fillna(x.median()))# 将 非int或float 类型列进行编码变为 int 类型
# 将性别和登船港口编码为数值
Titanic_data['Sex'] = Titanic_data['Sex'].map({'male': 0, 'female': 1})
Titanic_data['Embarked'] = Titanic_data['Embarked'].map({'C': 0, 'Q': 1, 'S': 2})# 丢弃无用的列
Titanic_data.drop(['Name', 'Ticket', 'Cabin'], axis=1, inplace=True)# 对 Age 和 Fare 列进行四舍五入,并转换为 int64 型
Titanic_data['Age'] = Titanic_data['Age'].round().astype('int64')
Titanic_data['Fare'] = Titanic_data['Fare'].round().astype('int64')# 查看处理后的数据
Titanic_data.head()Titanic_data.info()

由于 'Name'、 'Ticket' 列,即 “姓名”、“票号” 列都是唯一的,在判断乘客是否会遇难时无法起到帮助分析作用,即可删去。

由于 'Cabin' 列,即 “船舱号” 并无性质表现,且缺失过多,分析是可删去,以免大量的填补数据干扰模型预测结果。

3. 绘制热力图

显示各个特征之间的相关性,并显示不同特征之间的相关关系。观察数据各维度之间的相关性,同时具有验证数据集可用性的功效(如作者通过热力图发现 Titanic dataset 用不了)。

plt.figure(figsize=(10, 8))
sns.heatmap(Titanic_data.corr(), annot=True, cmap='coolwarm')
plt.title('Correlation Heatmap')
plt.show()

4. 数据拆分&标准化

# 特征选择
X = Titanic_data.drop('Survived', axis=1)
y = Titanic_data['Survived']# 拆分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 标准化特征(可选,但推荐逻辑回归中进行特征缩放)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

5. 逻辑回归分类算法

# 使用 scikit-learn 的逻辑回归模型
model = LogisticRegression(max_iter=10000,penalty='l2',#C=0.1,#solver='liblinear',class_weight="balanced")
model.fit(X_train, y_train)# 预测
y_pred_test = model.predict(X_test)
print(classification_report(y_test,y_pred_test))y_pred_train = model.predict(X_train)
print(classification_report(y_train,y_pred_train))confusion_matrix(y_train,y_pred_train)#生成混淆矩阵#画热力图
sns.heatmap(confusion_matrix(y_test,y_pred_test),annot=True,fmt="d")accuracy_score(y_pred_test,y_test)#准确率
f1_score(y_pred_test,y_test)#精确率和召回率的调和平均数# 从负类,计算模型精确率和召回率
precision_score(y_pred_test,y_test,pos_label=0)#精确率
recall_score(y_pred_test,y_test,pos_label=0)#召回率print(model.coef_)print(model.intercept_)# 显示模型系数
coefficients = pd.DataFrame(model.coef_.flatten(), index=X.columns, columns=['Coefficient'])
print(coefficients)pd.DataFrame(model.coef_,columns=X.columns).T.plot(kind="barh")#显示系数大小model.predict_proba(X_test)model.predict(X_test)# 从正类,计算精确率和召回率
precision = precision_score(y_test, y_pred_test,pos_label=1)
recall = recall_score(y_test, y_pred_test,pos_label=1)print(f"Precision: {precision}")
print(f"Recall: {recall}")

6. 绘制 PR / ROC 曲线

6.1 PR 曲线

显示模型在不同阈值下的精确率和召回率的权衡关系。

probs=model.predict_proba(X_test)[:,0]
precision,recall,thresholds=precision_recall_curve(y_test,probs,pos_label=0)
plt.plot(recall,precision)
plt.title("Precision-Recall Curve")
plt.xlabel("recall")
plt.ylabel("precision")pd.DataFrame([precision,recall,thresholds])

6.2 ROC 曲线

probs=model.predict_proba(X_test)[:,1]
fpr,tpr,thresholds=roc_curve(y_test,probs,pos_label=1)
roc_auc = auc(fpr, tpr)
plt.plot(fpr,tpr, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], linestyle='--')
plt.title("roc_curve")
plt.xlabel("fpr")
plt.ylabel("tpr")
plt.legend(loc='lower right')
plt.show()

AUC 在 0.8 到 0.9 之间,表明模型具有良好的性能,可以有效地区分正例和负例。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/460287.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

kubernetes——part2-3 使用RKE构建企业生产级Kubernetes集群

使用RKE构建企业生产级Kubernetes集群 一、RKE工具介绍 RKE是一款经过CNCF认证的开源Kubernetes发行版,可以在Docker容器内运行。 它通过删除大部分主机依赖项,并为部署、升级和回滚提供一个稳定的路径,从而解决了Kubernetes最常见的安装复杂…

重学SpringBoot3-Spring WebFlux之HttpHandler和HttpServer

更多SpringBoot3内容请关注我的专栏:《SpringBoot3》 期待您的点赞👍收藏⭐评论✍ 重学SpringBoot3-Spring WebFlux之HttpHandler和HttpServer 1. 什么是响应式编程?2. Project Reactor 概述3. HttpHandler概述3.1 HttpHandler是什么3.2 Http…

3D Gaussian Splatting代码详解(三):模型构建,实现3D 高斯椭球体的克隆和分裂

3 模型构建 3.4 根据梯度对3D gaussian 进行增加或删减 (1) 对3D高斯分布进行密集化和修剪的操作 def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):"""对3D高斯分布进行密集化和修剪的操作:param max_g…

无人机协同控制技术详解!

一、算法概述 无人机协同控制技术算法是指通过综合运用通信、控制、优化等多学科知识,实现对多个无人机的协同控制和任务规划。这些算法通常基于各种数学模型和控制理论,如线性代数、微分方程、最优控制等,旨在确保无人机能够相互协作&#…

【热门主题】000013 C++游戏开发全攻略

前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 💕 目录 【热…

QT中的item views与Item widgets控件的用法总结

0、前言 在一般进行数据表格展示的时候,大多时候要用到表格、列表或者树形结构。 Qt中常见的控件显示有两大类: Item View (list View、Tree View、Table View、Column View和Undo View)Item widget(List Widget、Tree Widget和…

ssm+vue645基于web的电影购票系统设计与实现

博主介绍:专注于Java(springboot ssm 等开发框架) vue .net php phython node.js uniapp 微信小程序 等诸多技术领域和毕业项目实战、企业信息化系统建设,从业十五余年开发设计教学工作 ☆☆☆ 精彩专栏推荐订阅☆☆☆☆☆不…

Spark RDD

概念 RDD是一种抽象,是Spark对于分布式数据集的抽象,它用于囊括所有内存中和磁盘中的分布式数据实体 RDD 与 数组对比 对比项数组RDD概念类型数据结构实体数据模型抽象数据跨度单机进程内跨进程、跨计算节点数据构成数组元素数据分片(Partitions)数据…

java-数据结构

一.链表 单向链表 /单向链表 public class SinglyLinkedList implements Iterable<Integer> {//头节点private Node head null;//头指针//节点类private static class Node{int value;//值Node next;//下一个节点的指针public Node(int value, Node next) {this.val…

pycharm与anaconda下的pyside6的安装记录

一、打开anaconda虚拟环境的命令行窗口&#xff0c;pip install&#xff0c;加入清华源&#xff1a; pip install PySide6 -i https://pypi.tuna.tsinghua.edu.cn/simple 二、打开pycharm&#xff0c;在文件--设置--工具--外部工具中配置一下三项&#xff1a; 1、 QtDesigner…

成本累计曲线:项目预算的秘密武器

在项目管理的过程中&#xff0c;成本控制是影响项目成败的关键因素之一&#xff0c;而其中“成本累计曲线”就像是一位财务导航员&#xff0c;为项目的成本控制和进度监控提供了极大的帮助。那么&#xff0c;什么是成本累计曲线&#xff1f;它包含哪些步骤&#xff1f;如何应用…

[0152].第3节:IDEA中工程与模块

我的后端学习大纲 IDEA大纲 1、Project和Module的概念&#xff1a; 2、Module操作&#xff1a; 2.1.创建Module: 2.2.删除Module&#xff1a; 2.3.导入Module&#xff1a; 1.导入外来模块的代码&#xff1a; 查看Project Structure&#xff0c;选择import module&#xff1a…

【Linux网络】UdpSocket

目录 套接字 socket编程 源IP地址和目的IP地址 端口号 网络字节序 socket常用API socket结构 UDP UDP协议&#xff08;用户数据报协议&#xff09; 创建套接字 绑定 通信 udp_echo_server:简单的回显服务器和客户端代码 dict_server:简单的英译汉的网络字典 chat_…

双11猫咪好物盛典开启,线上抢购不停 购物清单新鲜出炉

双十一购物狂欢节终于到了&#xff01;铲屎官们想好要给猫咪添置什么好东西了吗&#xff1f;还不知道怎么选的看过来啦~这里整理了一份双十一购物清单&#xff0c;快看看有没有你需要的吧&#xff01; 双十一养猫必购1&#xff1a;CEWEY自动猫砂盆 CEWEY自动猫砂盆真的是我最爱…

magic-api简单使用二:自定义返回结果

背景 在上一章 中我们学习了搭建项目和导入文件&#xff0c; 这二天稍微有点时间&#xff0c;研究下这个magic-api的写法。 后续如果需要维护或者更改&#xff0c;也能在项目中尽快上手。 今天我们主要学习自定义返回结果&#xff0c;当然也可以使用官方的。不需要任何更改。…

二百七十、Kettle——ClickHouse中增量导入清洗数据错误表

一、目的 比如原始数据100条&#xff0c;清洗后&#xff0c;90条正确数据在DWD层清洗表&#xff0c;10条错误数据在DWD层清洗数据错误表&#xff0c;所以清洗数据错误表任务一定要放在清洗表任务之后。 更关键的是&#xff0c;Hive中原本的SQL语句&#xff0c;放在ClickHouse…

【Nas】X-Doc:jellyfin“该客户端与媒体不兼容,服务器未发送兼容的媒体格式”问题解决方案

【Nas】X-Doc&#xff1a;jellyfin“该客户端与媒体不兼容&#xff0c;服务器未发送兼容的媒体格式”问题解决方案 当使用Jellyfin播放视频时出现“该客户端与媒体不兼容&#xff0c;服务器未发送兼容的媒体格式”&#xff0c;这是与硬件解码和ffmpeg设置有关系&#xff0c;具体…

linux应急响应-1

声明&#xff1a;部分内容来源于网络&#xff0c;只是新手练习 靶场环境来自于知攻善防实验室 概述&#xff1a; 一、整体过程 初始环境设置 将Linux centOS 7配置为图形化界面&#xff0c;通过yum groupinstall “X Window System” -y和yum groupinstall “GNOME Desktop”&a…

视频去水印软件推荐:6款去水印工具值得一试

在视频创作和分享的过程中&#xff0c;水印往往会成为影响美观和平台推流。幸运的是&#xff0c;市面上有许多视频去水印软件能够帮助我们轻松解决这一问题。本文将为大家推荐几款实用的视频去水印软件&#xff0c;并详细介绍它们的功能和去除水印的方法。 1.影忆 功能介绍&…

MaxKB: 一款基于大语言模型的知识库问答系统

嗨, 大家好, 我是徐小夕. 之前一直在社区分享零代码&低代码的技术实践&#xff0c;最近也在研究多模态文档引擎相关的产品, 在社区发现一款非常有意思的知识库问答系统——MaxKB, 它支持通过大语言模型和RAG技术来为知识库赋能,今天就和大家分享一下这款项目. PS: 它提供了…