分类模型评估:混淆矩阵与ROC曲线

  • 1.混淆矩阵
  • 2.ROC曲线 & AUC指标

理解混淆矩阵和ROC曲线之前,先区分几个概念。对于分类问题,不论是多分类还是二分类,对于某个关注类来说,都可以看成是二分类问题,当前的这个关注类为正类,所有其他非关注类为负类。因为样本的真实值有正负两类,而模型的预测值也有正负两类,因此样本的真实值和模型的预测值之间产生了下面4种组合:

  • 真正例(True Positives/TP):在所有真实值为正类的样本中,模型预测值也为正类的样本数。
  • 假正例(False Positives/FP):在所有真实值为负类的样本中,模型预测值为正类的样本数。
  • 真负例(True Negatives/TN):所有真实值为负类的样本中,模型预测值也为负类的样本数。
  • 假负例(False Negatives/FN):所有真实值为正类的样本中,模型预测值为负类的样本数。

从上面几个定义可以知道:
1)样本总数 = TP+FP+TN+FN
2)所有真实值为正类的样本总数 = TP+FN
3)所有真实值为负类的样本总数 = TN+FP

1.混淆矩阵

使用sklearn自带的鸢尾花数据集,数据集里鸢尾花包含3个分类。

import numpy as np
from sklearn.datasets import load_iris
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix# 获取特征值与目标值
data = load_iris()
X, y = data['data'], data['target']# 自带的数据集分类准确率为1,为了后面更好的基于混淆矩阵验证相关指标的计算,为训练集添加均值0,标准差2的高斯噪声
np.random.seed(42)
noise = np.random.normal(0, 2, (len(X), len(X[0])))
X += noise# 特征值归一化到区间[-1,1]
scaler = MinMaxScaler(feature_range=(-1, 1))
X_scaled = scaler.fit_transform(X)# 划分训练集与测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)# 创建逻辑回归模型、训练并预测
model = LogisticRegression(multi_class='multinomial', max_iter=1000)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)# 获取模型混淆矩阵、分类报告、准确率
print(f"混淆矩阵:\n{confusion_matrix(y_test, y_pred)}")
print(f"分类报告:\n{classification_report(y_test, y_pred)}")
print(f"准确率:\n{accuracy_score(y_test, y_pred)}")

output:
在这里插入图片描述

混淆矩阵中,横向表示真实值,纵向表示预测值。比如第一个位置7,表示实际类别为0且预测类别为0的样本有7个。基于此混淆矩阵,可以衍生下面相关指标:
1)准确率(accuracy):准确率表示模型对一个样本类别预测正确的可能性,是相对整体来说的。计算方式为所有预测正确的样本(斜对角线之和)/ 样本总数,本例中accuracy=(7+4+8)/30=0.63…
2)精确率(precision):精确率是针对某个具体关注类来说的,精确率关注的是,对于所有预测值为该类的样本中,真实值也属于该类的样本所占的比例,计算公式为 T P T P + F P \frac{TP}{TP+FP} TP+FPTP。比如对于类别0,模型预测的该类样本数=(7+0+2)=9,而真实值为该类的样本数为7,那么类别0的precision=7/9=0.77…。精确率反映了“模型找的对不对”
3)召回率(recall):同样召回率也是针对某个具体关注类来说的,关注的是所有真实值为该类的样本中,模型能正确预测为该类的样本所占的比例,计算公式 T P T P + F N \frac{TP}{TP+FN} TP+FNTP。还是拿类别0来说,真实值为0的样本总数=(7+3+0)=10,模型能正确预测为该类的样本总数为7,所以类别0的召回率=7/10=0.7。召回率代表了“模型找的全不全”
4)F1-score:F1分数是精确率与召回率的调和平均数,计算公式为 2 ∗ p r e c i s i o n ∗ r e c a l l p r e c i s i o n + r e c a l l \frac{2*precision*recall}{precision+recall} precision+recall2precisionrecall。对于类别0的F1-score= 2 ∗ 0.78 ∗ 0.7 0.78 + 0.7 \frac{2*0.78*0.7}{0.78+0.7} 0.78+0.720.780.7=0.737…,F1分数用来表示模型在关注的类上识别正类的综合表现,最大值1表示分类效果最好完全正确,最小值0表示分类效果最差完全错误。

分类报告直接提供了每个分类下的精确率、召回率、F1分数等指标。最下面两行的macro avgweighted avg分别表示对每个指标的算术平均和加权平均,最后一列的support表示对应的样本数量。

2.ROC曲线 & AUC指标

ROC:Receiver Operating Characteristic。
AUC:Area Under the [ROC] Curve,ROC曲线下的面积。

ROC曲线的绘制中需要用到两个指标:

  • 真正率(True Positive Rate/TPR):在所有真实类别为正类的样本中,模型正确识别为正类的样本所占的比例,也就是把正类样本识别成正类样本的概率。反映了模型识别正类的能力,可以看成是模型在识别正类样本时的收获能力,计算公式 T P T P + F N \frac{TP}{TP+FN} TP+FNTP
  • 假正率(False Positive Rate/FPR):在所有真实类别为负类的样本中,模型错误识别为正类的样本所占的比例,即把负类样本识别为正类样本的概率。反映了模型识别为正类样本时的错误程度,可以理解成模型在识别正类样本时付出的代价,计算公式 F P T N + F P \frac{FP}{TN+FP} TN+FPFP

大多数分类模型都是通过计算出每个样本属于正类的概率,和属于正类的概率阈值进行比较来对样本进行分类的。正类的概率>=阈值,判定为正类,反之判定为负类。

ROC曲线是由不同概率阈值下真正率(y轴)和假正率(x轴)对应的一系列点所构成的曲线,x轴从左到右判定为正类的概率阈值从1到0逐渐递减。ROC曲线用来描述二分类模型预测效果,对于多分类问题,是将关注类视为正类,其他类视为负类。

ROC曲线的具体绘制过程可以理解为:

  1. 对于测试集中的每个样本,利用分类器预测其为正类的概率值。
  2. 将这些概率值按照从大到小的顺序排列,作为阈值。
  3. 对于每个阈值,分别计算真正率和假正率,对应坐标轴上的一个点。
  4. 连接这些点。

从真正率和假正率的计算,可以看出曲线越往右,判定为正类的概率阈值越低,那么就有更多的样本被归类到正类当中,因为分母是不变的,分子(TP/FP)随着正类样本增多都会逐渐增大,因此ROC的曲线走势应该是一个从(0, 0)到(1, 1)逐渐上升的曲线。

同时,因为x轴代表了在识别正类时付出的代价,y轴代表了在识别正类时的收获,因此当x值越小,y值越大,即曲线越靠近左上角(0, 1),说明模型的分类效果越好。

而AUC,是ROC曲线下的面积,它衡量的是模型在所有概率阈值下识别正类时“收获”与“代价”的比重,因此AUC值越大越好,值域范围[0, 1]。
AUC=0.5:模型不具有分类效果,相当于盲猜。
AUC<0.5:分类效果最差,不如盲猜。
AUC>0.5:有一定的分类效果,值越接近1分类效果越好。

下面还是以鸢尾花的数据集为例,通过一个demo对ROC和AUC进行计算和绘制。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve, auc# 获取特征值与目标值
data = load_iris()
X, y = data['data'], data['target']# 仅使用两个类别:0 & 1
X = X[y != 2]
y = y[y != 2]# 训练集添加噪声
np.random.seed(42)
noise = np.random.normal(0, 2, (len(X), len(X[0])))
X += noise# 归一化
scaler = MinMaxScaler(feature_range=(-1, 1))
X_scaled = scaler.fit_transform(X)# 划分数据集、创建逻辑回归模型、训练
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
model = LogisticRegression(multi_class='multinomial', max_iter=1000)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)# 获取每个样本预测为正类样本的概率,[:, 0]是负类样本的概率
y_pred_prob = model.predict_proba(X_test)[:, 1]# 计算FPR、TPR和AUC值
fpr, tpr, thresholds = roc_curve(y_test, y_pred_prob)
roc_auc = auc(fpr, tpr)# 绘制ROC曲线
plt.figure()
plt.plot(fpr, tpr, color='green', lw=1, label=f'ROC Curve (AUC={roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='red', lw=1, linestyle='--')
plt.xlim([0.0, 1])
plt.ylim([0.0, 1.05])
plt.xlabel('FPR')
plt.ylabel('TPR')
plt.title('ROC Curve')
plt.legend(loc="lower right")
plt.show()

output:
在这里插入图片描述
绿线代表ROC曲线,红线相当于盲猜,绿线在红线上方距离红线越远模型分类效果越好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/290188.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java项目:78 springboot学生宿舍管理系统的设计与开发

作者主页&#xff1a;源码空间codegym 简介&#xff1a;Java领域优质创作者、Java项目、学习资料、技术互助 文中获取源码 项目介绍 系统的角色&#xff1a;管理员、宿管、学生 管理员管理宿管员&#xff0c;管理学生&#xff0c;修改密码&#xff0c;维护个人信息。 宿管员…

RegSeg 学习笔记(待完善)

论文阅读 解决的问题 引用别的论文的内容 可以用 controlf 寻找想要的内容 PPM 空间金字塔池化改进 SPP / SPPF / SimSPPF / ASPP / RFB / SPPCSPC / SPPFCSPC / SPPELAN &#xfffc; ASPP STDC&#xff1a;short-term dense concatenate module 和 DDRNet SE-ResNeXt …

Java代码混淆技术在应用程序保护中的关键作用与应用

摘要 本文探讨了代码混淆在保护Java代码安全性和知识产权方面的重要意义。通过混淆技术&#xff0c;可以有效防止代码被反编译、逆向工程或恶意篡改&#xff0c;提高代码的安全性。常见的Java代码混淆工具如IPAGuard、Allatori、DashO、Zelix KlassMaster和yGuard等&#xff0…

2. Java基本语法

文章目录 2. Java基本语法2.1 关键字保留字2.1.1 关键字2.1.2 保留字2.1.3 标识符2.1.4 Java中的名称命名规范 2.2 变量2.2.1 分类2.2.2 整型变量2.2.3 浮点型2.2.4 字符型 char2.2.5 Unicode编码2.2.6 UTF-82.2.7 boolean类型 2.3 基本数据类型转换2.3.1 自动类型转换2.2.2 强…

JVM篇详细分析

JVM总体图 程序计数器&#xff1a; 线程私有的&#xff0c;每个线程一份&#xff0c;内部保存字节码的行号&#xff0c;用于记录正在执行字节码指令的地址。&#xff08;可通过javap -v XX.class命令查看&#xff09; java堆&#xff1a; 线程共享的区域&#xff0c;用来保存对…

Java安全篇-Fastjson漏洞

前言知识&#xff1a; 一、json 概念&#xff1a; json全称是JavaScript object notation。即JavaScript对象标记法&#xff0c;使用键值对进行信息的存储。 格式&#xff1a; {"name":"wenda","age":21,} 作用&#xff1a; JSON 可以作为…

Git,GitHub,Gitee,GitLab 四者有什么区别?

目录 1. Git 2. GitHub 3. Gitee 4. GitLab 5. 总结概括 1. Git Git 是一个版本管理工具&#xff0c;常应用于本地代码的管理&#xff0c;下载完毕之后&#xff0c;我们可以使用此工具对本地的资料&#xff0c;代码进行版本管理。 下载链接&#xff1a; Git - Downlo…

见证实力 | 走进飞凌嵌入式研发实验室

研发实验室是一家高新技术企业技术实力与创新动能的核心&#xff0c;一个设备齐全、流程规范、标准严格的实验室&#xff0c;能够确保产品功能的先进性、运行的稳定性和质量的可靠性&#xff0c;使产品在激烈的市场竞争中脱颖而出。 十八年来&#xff0c;飞凌嵌入式已成功帮助…

Haproxy2.8.1+Lua5.1.4部署,haproxy.cfg配置文件详解和演示

目录 一.快速安装lua和haproxy 二.配置haproxy的配置文件 三.配置haproxy的全局日志 四.测试负载均衡、监控和日志效果 五.server常用可选项 1.check 2.weight 3.backup 4.disabled 5.redirect prefix和redir 6.maxconn 六.调度算法 1.静态 2.动态 一.快速安装lu…

如何解决 IntelliJ IDEA 中属性文件的编码问题

在使用 IntelliJ IDEA 进行开发过程中&#xff0c;我们经常会遇到属性文件&#xff08;.properties 文件&#xff09;的编码问题。如果属性文件的编码设置不正确&#xff0c;就会导致中文等特殊字符显示乱码。这是因为IntelliJ IDEA中默认的配置文件的编码格式是ISO-8859-1。 …

骗子查询系统源码

源码简介 小权云黑管理系统 V1.0 功能如下&#xff1a; 1.添加骗子&#xff0c;查询骗子 2.可添加团队后台方便审核用 3.在线反馈留言系统 4.前台提交骗子&#xff0c;后台需要审核才能过 5.后台使用光年UI界面 6.新增导航列表&#xff0c;可给网站添加导航友链 7.可添加云黑类…

Linux repo基本用法: 搭建自己的repo仓库[服务端]

概述 Repo的使用离不开Git, Git 和 Repo 都是版本控制工具&#xff0c;但它们在使用场景和功能上有明显区别… Git 定义&#xff1a;Git 是一个分布式的版本控制系统&#xff0c;由 Linus Torvalds 为 Linux 内核开发而设计&#xff0c;现已成为世界上最流行的版本控制软件之…

C语言--编译和链接

1.翻译环境 计算机能够执行二进制指令&#xff0c;我们的电脑不会直接执行C语言代码&#xff0c;编译器把代码转换成二进制的指令&#xff1b; 我们在VS上面写下printf("hello world");这行代码的时候&#xff0c;经过翻译环境&#xff0c;生成可执行的exe文件&…

【超图 SuperMap3D】【基础API使用示例】51、超图SuperMap3D - 绘制圆|椭圆形面标注并将视角定位过去

前言 引擎下载地址&#xff1a;[添加链接描述](http://support.supermap.com.cn/DownloadCenter/DownloadPage.aspx?id2524) 绘制圆形或者椭圆形效果 核心代码 entity viewer.entities.add({// 圆中心点position: { x: -1405746.5243351874, y: 4988274.8462937465, z: 370…

SpringBoot Redis 之Lettuce 驱动

一、前言 一直以为SpringBoot中 spring-boot-starter-data-redis使用的是Jredis连接池&#xff0c;直到昨天在部署报价系统生产环境时&#xff0c;因为端口配置错误造成无法连接&#xff0c;发现报错信息如下&#xff1a; 一了解才知道在SpringBoot2.X以后默认是使用Lettuce作…

物联网实战--入门篇之(一)物联网概述

目录 一、前言 二、知识梳理 三、项目体验 四、项目分解 一、前言 近几年很多学校开设了物联网专业&#xff0c;但是确却地讲&#xff0c;物联网属于一个领域&#xff0c;包含了很多的专业或者说技能树&#xff0c;例如计算机、电子设计、传感器、单片机、网…

公链角逐中突围,Solana 何以成为 Web3 世界的流量焦点?

在众多区块链公链中&#xff0c;Solana 凭借其创纪录的处理速度和极低的交易费用&#xff0c;成为了众多开发者和投资者的宠儿。就像网络上流行的那句话所说&#xff1a;“Why slow, when you can Solana?”&#xff0c;Solana 正以它的速度和强大的生态系统&#xff0c;重新定…

某某消消乐增加步数漏洞分析

一、漏洞简介 1&#xff09; 漏洞所属游戏名及基本介绍&#xff1a;某某消消乐&#xff0c;三消游戏&#xff0c;类似爱消除。 2&#xff09; 漏洞对应游戏版本及平台&#xff1a;某某消消乐Android 1.22.22。 3&#xff09; 漏洞功能&#xff1a;增加游戏步数。 4&#xf…

如何快速搭建一个ELK环境?

前言 ELK是Elasticsearch、Logstash和Kibana三个开源软件的统称&#xff0c;通常配合使用&#xff0c;并且都先后归于Elastic.co企业名下&#xff0c;故被简称为ELK协议栈。 Elasticsearch是一个实时的分布式搜索和分析引擎&#xff0c;它可以用于全文搜索、结构化搜索以及分…

政安晨:专栏目录【TensorFlow与Keras实战演绎机器学习】

政安晨的个人主页&#xff1a;政安晨 欢迎 &#x1f44d;点赞✍评论⭐收藏 收录专栏: TensorFlow与Keras实战演绎机器学习 希望政安晨的博客能够对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff01; 本篇是作者政安晨的专栏《TensorFlow与Keras…