ROC 曲线详解

前言

ROC 曲线是一种坐标图式的分析工具,是由二战中的电子和雷达工程师发明的,发明之初是用来侦测敌军飞机、船舰,后来被应用于医学、生物学、犯罪心理学。

如今,ROC 曲线已经被广泛应用于机器学习领域的模型评估,说到这里就不得不提到 Tom Fawcett 大佬,他一直在致力于推广 ROC 在机器学习领域的应用,他发布的论文《An introduction to ROC analysis》[1]更是被奉为 ROC 的经典之作(引用 2.2w 次),知名机器学习库 scikit-learn 中的 ROC 算法就是参考此论文实现,可见其影响力!

不知道大多数人是否和我一样,对于 ROC 曲线的理解只停留在调用 scikit-learn 库的函数,对于它的背后原理和公式所知甚少。

前几天我重读了《An introduction to ROC analysis》终于将 ROC 曲线彻底搞清楚了,独乐乐不如众乐乐!如果你也对 ROC 的算法及实现感兴趣,不妨花些时间看完全文,相信你一定会有所收获!

图片

一、什么是 ROC 曲线

下图中的蓝色曲线就是 ROC 曲线,它常被用来评价二值分类器的优劣,即评估模型预测的准确度。

二值分类器,就是字面意思它会将数据分成两个类别(正/负样本)。例如:预测银行用户是否会违约、内容分为违规和不违规,以及广告过滤、图片分类等场景。篇幅关系这里不做多分类 ROC 的讲解。

图片

坐标系中纵轴为 TPR(真阳率/命中率/召回率)最大值为 1,横轴为 FPR(假阳率/误判率)最大值为 1,虚线为基准线(最低标准),蓝色的曲线就是 ROC 曲线。其中 ROC 曲线距离基准线越远,则说明该模型的预测效果越好。(TPR: True positive rate; FPR: False positive rate)

  • ROC 曲线接近左上角:模型预测准确率很高

  • ROC 曲线略高于基准线:模型预测准确率一般

  • ROC 低于基准线:模型未达到最低标准,无法使用

二、背景知识

考虑一个二分类模型, 负样本(Negative) 为 0,正样本(Positive) 为 1。即:

  • 标签 y 的取值为 0 或 1。

  • 模型预测的标签为 \hat{y},取值也是 0 或 1。

因此,将 y\hat{y} 两两组合就会得到 4 种可能性,分别称为:

图片

2.1 公式

ROC 曲线的横坐标为 FPR(False Positive Rate),纵坐标为 TPR(True Positive Rate)。FPR 统计了所有负样本中 预测错误(FP) 的比例,TPR 统计了所有正样本中 预测正确(TP) 的比例,其计算公式如下,其中 # 表示统计个数,例如 #N 表示负样本的个数,#P 表示正样本的个数

\text{FPR}=\frac{\#\text{FP}}{\#\text{N}} $$ $$\text{TPR}=\frac{\#\text{TP}}{\#\text{P}}

2.2 计算方法

下面举一个实际例子作为讲解,以下表 5 个样本为例,讲解如何计算 FPR 和 TPR

id真实标签  y预测标签 \hat{y}
111
210
300
411
501

正样本数 \#P=3,负样本数\#N=2

其中 y=0\hat{y}=1的样本有 1 个,即 \#FP=1,所以 FPR=1/2=0.5

其中 y=1\hat{y}=1 的样本有 2 个,即 \#TP=2,所以 FPR=2/3

FPR 和 TPR 的取值范围均是 0 到 1 之间。对于 FPR,我们希望其越小越好。而对于 TPR,我们希望其越大越好。

至此,我们已经介绍完如何计算 FPR 和 TPR 的值,下面将会讲解如何绘制 ROC 曲线。

三、绘制 ROC 曲线

讲到这里,可能有的同学会问:ROC 不是一条曲线吗?讲了这么多它到底应该怎么画呢?下面将分为两部分讲解如何绘制 ROC 曲线,直接打通你的“任督二脉”彻底拿下 ROC 曲线:

  • 第一部分:通过手绘的方式讲解原理

  • 第二部分:Python 代码实现,代码清爽易读

3.1 手绘 ROC 曲线

一般在二分类模型里(标签取值为 0 或 1),会默认设定一个阈值 (threshold)。当预测分数大于这个阈值时,输出 1,反之输出 0。我们可以通过调节这个阈值,改变模型预测的输出,进而画出 ROC 曲线。

以下面表格中的 20 个点为例,介绍如何人工画出 ROC 曲线,其中正样本和负样本都是 10 个,即 \#P = \#N = 10

id真实标签预测分数id真实标签预测分数
11.9111.4
21.8120.39
30.7131.38
41.6140.37
51.55150.36
61.54160.35
70.53171.34
80.52180.33
91.51191.30
100.505200.1

当设定阈值为 0.9 时,只有第一个点预测为 1,其余都为 0,故 \#FP=0\#TP=1,计算出 FPR=0/10=0TPR=1/10=0.1,画出点 (0,0.1)

当设定阈值为 0.8 时,只有前两个点预测为 1,其余都为 0,故 $\#FP=0、\#TP=2$,计算出 FPR=0/10=0 、TPR=2/10=0.2,画出点 (0,0.2)

当设定阈值为 0.7 时,只有前三个点预测为 1,其余都为 0,故 \#FP=1\#TP=2,计算出 FPR=1/10=0.1TPR=2/10=0.2,画出点 (0.1,0.2)。

以此类推,画出的 ROC 曲线如下:

图片

因此,在画 ROC 曲线前,需要将预测分数从大到小排序,然后将预测分数依次设定为阈值,分别计算 FPRTPR。而对于基准线,假设随机预测为正样本的概率为 x,即 \Pr(\hat{y}=1)=x 由于 FPR 计算的是负样本中,预测为正样本的概率,因此 FPR= x(同理,TPR= x)。所以,基准线为从点 (0, 0) 到 (1, 1) 的斜线

3.2 Python 代码

接下来,我们将结合代码讲解如何在 Python 中绘制 ROC 曲线。

下面的代码参考了《An Introduction to ROC Analysis》[2]中的算法 1(伪代码)。值得一提的是,知名机器学习库 scikit-learn 的 roc_curve 函数[3] 也参考了这个算法。

图片

下面我自己实现的 roc 函数可以理解为是简化版的 roc_curve,这里的代码逻辑更加简洁易懂,算法的时间复杂度 O ( n log ⁡ n ) O(n\log n) O(nlogn)。

完整的代码如下:

# import numpy as np
def roc(y_true, y_score, pos_label):"""y_true:真实标签y_score:模型预测分数pos_label:正样本标签,如“1”"""# 统计正样本和负样本的个数num_positive_examples = (y_true == pos_label).sum()num_negtive_examples = len(y_true) - num_positive_examplestp, fp = 0, 0tpr, fpr, thresholds = [], [], []score = max(y_score) + 1# 根据排序后的预测分数分别计算fpr和tprfor i in np.flip(np.argsort(y_score)):# 处理样本预测分数相同的情况if y_score[i] != score:fpr.append(fp / num_negtive_examples)tpr.append(tp / num_positive_examples)thresholds.append(score)score = y_score[i]if y_true[i] == pos_label:tp += 1else:fp += 1fpr.append(fp / num_negtive_examples)tpr.append(tp / num_positive_examples)thresholds.append(score)return fpr, tpr, thresholds

导入上面 3.1 表格中的数据,通过上面实现的 roc 方法,计算 ROC 曲线的坐标值。

import numpy as npy_true = np.array([1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0]
)
y_score = np.array([.9, .8, .7, .6, .55, .54, .53, .52, .51, .505,.4, .39, .38, .37, .36, .35, .34, .33, .3, .1
])fpr, tpr, thresholds = roc(y_true, y_score, pos_label=1)

最后,通过 Matplotlib 将计算出的 ROC 曲线坐标绘制成图。

import matplotlib.pyplot as pltplt.plot(fpr, tpr)
plt.axis("square")
plt.xlabel("False positive rate")
plt.ylabel("True positive rate")
plt.title("ROC curve")
plt.show()

图片

至此,ROC 的基础知识部分就全部讲完了,如果还想深入了解的同学可以继续往下看。

四、联邦学习中的 ROC 平均

图片

顾名思义,ROC 平均就是将多条 ROC 曲线“平均化”。那么,什么场景需要做 ROC 平均呢?例如:横向联邦学习中,由于样本都在用户本地,服务器可以采用 ROC 平均的方式,计算近似的全局 ROC 曲线

ROC 的平均有两种方法:垂直平均、阈值平均,下面将逐一进行讲解,并给出 Python 代码实现。

4.1 垂直平均

图片

垂直平均(Vertical averaging)的思想是,选取一些 FPR 的点,计算其平均的 TPR 值。下面是论文中的算法描述的伪代码,看不懂可直接略过看 Python 代码实现部分。

图片

下面是 Python 的代码实现:

# import numpy as np
def roc_vertical_avg(samples, FPR, TPR):"""samples:选取FPR点的个数FPR:包含所有FPR的列表TPR:包含所有TPR的列表"""nrocs = len(FPR)tpravg = []fpr = [i / samples for i in range(samples + 1)]for fpr_sample in fpr:tprsum = 0# 将所有计算的tpr累加for i in range(nrocs):tprsum += tpr_for_fpr(fpr_sample, FPR[i], TPR[i])# 计算平均的tprtpravg.append(tprsum / nrocs)return fpr, tpravg# 计算对应fpr的tpr
def tpr_for_fpr(fpr_sample, fpr, tpr):i = 0while i < len(fpr) - 1 and fpr[i + 1] <= fpr_sample:i += 1if fpr[i] == fpr_sample:return tpr[i]else:return interpolate(fpr[i], tpr[i], fpr[i + 1], tpr[i + 1], fpr_sample)# 插值
def interpolate(fprp1, tprp1, fprp2, tprp2, x):slope = (tprp2 - tprp1) / (fprp2 - fprp1)return tprp1 + slope * (x - fprp1)

4.2 阈值平均

图片

阈值平均(Threshold averaging)的思想是,选取一些阈值的点,计算其平均的 FPR 和 TPR。

图片

下面是 Python 的代码实现:

# import numpy as np
def roc_threshold_avg(samples, FPR, TPR, THRESHOLDS):"""samples:选取FPR点的个数FPR:包含所有FPR的列表TPR:包含所有TPR的列表THRESHOLDS:包含所有THRESHOLDS的列表"""nrocs = len(FPR)T = []fpravg = []tpravg = []for thresholds in THRESHOLDS:for t in thresholds:T.append(t)T.sort(reverse=True)for tidx in range(0, len(T), int(len(T) / samples)):fprsum = 0tprsum = 0# 将所有计算的fpr和tpr累加for i in range(nrocs):fprp, tprp = roc_point_at_threshold(FPR[i], TPR[i], THRESHOLDS[i], T[tidx])fprsum += fprptprsum += tprp# 计算平均的fpr和tprfpravg.append(fprsum / nrocs)tpravg.append(tprsum / nrocs)return fpravg, tpravg# 计算对应threshold的fpr和tpr
def roc_point_at_threshold(fpr, tpr, thresholds, thresh):i = 0while i < len(fpr) - 1 and thresholds[i] > thresh:i += 1return fpr[i], tpr[i]

五、最后

本文由浅入深地详细介绍了 ROC 曲线算法,包含算法原理、公式、计算、源码实现和讲解,希望能够帮助读者一口气搞懂 ROC。

虽然 ROC 是个不起眼的知识点,但能网上能彻底讲清楚 ROC 的文章并不多。所以我又花时间重温了一遍 Tom Fawcett 的经典论文《An introduction to ROC analysis》[4],并将论文的内容抽丝剥茧、配上通俗易懂的 Python 代码,最终写出了这篇文章。再次致敬🫡 Tom Fawcett,感谢他在机器学习领域的贡献!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/188923.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Leetcode—2471.逐层排序二叉树所需的最少操作数目【中等】(置换环解法!)

2023每日刷题&#xff08;二十七&#xff09; Leetcode—2471.逐层排序二叉树所需的最少操作数目 置换环解题思想 参考自网络 总交换次数 每一层最小交换次数之和 每一层元素个数 - 置换环数 实现代码 /*** Definition for a binary tree node.* struct TreeNode {* …

基于SSM的旅游管理系统的设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用JSP技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…

【第2章 Node.js基础】2.4 Node.js 全局对象...持续更新

什么是Node.js 全局对象 对于浏览器引擎来说&#xff0c;JavaScript 脚本中的 window 是全局对象&#xff0c;而Node.js程序中的全局对象是 global&#xff0c;所有全局变量(除global本身外)都是global 对象的属性。全局变量和全局对象是所有模块都可以调用的。Node.is 的全局…

axios请求的问题

本来不想记录&#xff0c;但是实在没有办法&#xff0c;因为总是会出现post请求&#xff0c;后台接收不到数据的情况,还是记录一下如何的解决的比较好。 但是我使用export const addPsiPurOrder data > request.post(/psi/psiPurOrder/add, data); 下面是封装的代码。后台接…

头歌答案--数据持久化(非数据库)

目录 ​编辑 数据持久化&#xff08;非数据库&#xff09; 第1关&#xff1a;数据持久化&#xff08;非数据库&#xff09; 任务描述 多线程、多进程爬虫 第1关&#xff1a;多线程、多进程爬虫 任务描述 Scrapy爬虫基础 任务描述 MySQL数据库编程 第1关&#xff1a;…

PowerPoint to HTML5 SDK Crack

Convert PowerPoint to HTML5 Retaining Animations, Transitions, Hyperlinks, Smartart, Triggers and other multimedia effects World’s first and industry best technology for building web/mobile based interactive presentations directly from PowerPoint – that …

2.0 熟悉CheatEngine修改器

Cheat Engine 一般简称为CE&#xff0c;它是一款功能强大的开源内存修改工具&#xff0c;其主要功能包括、内存扫描、十六进制编辑器、动态调试功能于一体&#xff0c;且该工具自身附带了脚本工具&#xff0c;可以用它很方便的生成自己的脚本窗体&#xff0c;CE工具可以帮助用户…

ARM 基础学习记录 / 异常与GIC介绍

GIC概念 念课本&#xff08;以下内容都是针对"通用中断控制器&#xff08;GIC&#xff09;"而言&#xff0c;直接摘录的&#xff0c;有的地方可能不符人类的理解方式&#xff09;&#xff1a; 通用中断控制器&#xff08;GIC&#xff09;架构提供了严格的规范&…

Python文件、文件夹操作汇总

目录 一、概览 二、文件操作 2.1 文件的打开、关闭 2.2 文件级操作 2.3 文件内容的操作 三、文件夹操作 四、常用技巧 五、常见使用场景 5.1 查找指定类型文件 5.2 查找指定名称的文件 5.3 查找指定名称的文件夹 5.4 指定路径查找包含指定内容的文件 一、概览 ​在…

Spring Boot(二)

1、运行维护 1.1、打包程序 SpringBoot程序是基于Maven创建的&#xff0c;在Maven中提供有打包的指令&#xff0c;叫做package。本操作可以在Idea环境下执行。 mvn package 打包后会产生一个与工程名类似的jar文件&#xff0c;其名称是由模块名版本号.jar组成的。 1.2、程序…

YOLO目标检测——猫狗目标检测数据集下载分享【含对应voc、coco和yolo三种格式标签】

实际项目应用&#xff1a;宠物识别、猫狗分类数据集说明&#xff1a;猫狗分类检测数据集&#xff0c;真实场景的高质量图片数据&#xff0c;数据场景丰富&#xff0c;含有猫和狗图片标签说明&#xff1a;使用lableimg标注软件标注&#xff0c;标注框质量高&#xff0c;含voc(xm…

[量化投资-学习笔记009]Python+TDengine从零开始搭建量化分析平台-KDJ

技术分析有点像烹饪&#xff0c;收盘价、最值、成交量等是食材&#xff1b;均值&#xff0c;移动平均&#xff0c;方差等是烹饪方法。随意组合一下就是一个技术指标。 KDJ又称随机指标&#xff08;随机这个名字起的很好&#xff09;。KDJ的计算依据是最高价、最低价和收盘价。…

微信小程序隐私政策不合规,应当由用户自主阅读后自行选择是否同意隐私政策协议,不得默认强制用户同意

小程序隐私政策不合规&#xff0c;默认自动同意《用户服务协议》及《隐私政策》&#xff0c;应当由用户自主阅读后自行选择是否同意隐私政策协议&#xff0c;不得默认强制用户同意&#xff0c;请整改后再重新提交。 把 登录代表同意《用户协议》和《隐私政策》 改为 同意《用…

git基础知识

1.git的必要配置 所有的配置文件&#xff0c;其实都保存在本地&#xff01; 查看所有配置 git config -l 即把 系统配置(system)和当前用户&#xff08;global&#xff09;配置都 列出来 以直接编辑配置文件&#xff0c;通过命令设置后会响应到这里。 注意&#xff1a; 如果…

DevOps简介

DevOps简介 1、DevOps的起源2、什么是DevOps3、DevOps的发展现状4、DevOps与虚拟化、容器 1、DevOps的起源 上个世纪40年代&#xff0c;世界上第一台计算机诞生。计算机离不开程序&#xff08;Program&#xff09;驱动&#xff0c;而负责编写程序的人&#xff0c;被称为程序员&…

【数据结构】:红黑树

1、红黑树的简介 红黑树&#xff08;Red Black Tree&#xff09; 是一种自平衡二叉查找树&#xff0c;是在计算机科学中用到的一种数据结构。 红黑树是在1972年由Rudolf Bayer发明的&#xff0c;当时被称为平衡二叉B树&#xff08;symmetric binary B-trees&#xff09;。后来…

【已验证-直接用】微信小程序wx.request请求服务器json数据并渲染到页面

微信小程序的数据总不能写死吧&#xff0c;肯定是要结合数据库来做数据更新&#xff0c;而小程序数据主要是json数据格式&#xff0c;所以我们可以利用php操作数据库&#xff0c;把数据以json格式数据输出即可。 现在给大家讲一下微信小程序的wx.request请求服务器获取数据的用…

【MySQL】列属性

文章目录 CHAR和VARCHAR插入单行 INSERT INTO插入多行插入分层行 LAST_INSERT_IN()创建表复制 CREAT TABLE AS更新单行 UPDATE...SET更新多行在UPDATES中使用子查询【需着重复习】删除行 DELETE恢复数据库到原始状态 CHAR和VARCHAR CHAR(50)&#xff1a;存储文本占5个字符&…

计算机网络基础知识-网络协议

一:计算机网络层次划分 1. 网络层次划分 2. OSI七层网络模型 1)物理层(Physical Layer):及硬件设备,物理层确保原始的数据可在各种物理媒体上传输,常见的设备名称如中继器(Repeater,也叫放大器)和集线器; 2)数据链路层(Data Link Layer):数据链路层在物理层提…

【入门Flink】- 09Flink水位线Watermark

在窗口的处理过程中&#xff0c;基于数据的时间戳&#xff0c;自定义一个“逻辑时钟”。这个时钟的时间不会自动流逝&#xff1b;它的时间进展&#xff0c;就是靠着新到数据的时间戳来推动的。 什么是水位线 用来衡量事件时间进展的标记&#xff0c;就被称作“水位线”&#x…