第100+22步 ChatGPT学习:概率校准 Platt Scaling

基于Python 3.9版本演示

一、写在前面

最近看了一篇在Lancet子刊《eClinicalMedicine》上发表的机器学习分类的文章:《Development of a novel dementia risk prediction model in the general population: A large, longitudinal, population-based machine-learning study》。

学到一种叫做“概率校准”的骚操作,顺手利用GPT系统学习学习。

文章中用的技术是:保序回归(Isotonic regression)。

为了体现举一反三,顺便问了GPT还有哪些方法也可以实现概率校准。它给我列举了很多,那么就一个一个学习吧。

首先,是一个叫做 Platt Scaling 的方法。

二、Platt Scaling

Platt Scaling 是一种后处理方法,用于将机器学习分类器(尤其是二分类器)的输出分数(scores)转换为概率。它通过使用一个简单的逻辑回归模型将分类器的输出分数映射到 [0, 1] 的区间,从而提供概率估计。这个方法最初由 John Platt 在1999年提出,主要用于支持向量机(SVM)的概率估计。

(1)Platt Scaling 的基本步骤

1)训练分类器:首先,训练一个分类器(例如支持向量机、神经网络或决策树),并获得其输出分数(未归一化的分数或logits)。

2)拟合逻辑回归模型:将分类器的输出分数作为逻辑回归模型的输入。使用训练数据中的真实标签作为逻辑回归模型的输出。训练逻辑回归模型,找到最佳的参数(通常是通过最小化对数损失函数来实现)。

3)生成概率估计:使用训练好的逻辑回归模型,将新的分类器输出分数转换为概率估计。

(2)Platt Scaling 在分类中的作用

1)概率校准:Platt Scaling 可以将分类器的输出分数校准为概率,使得输出更具解释性。例如,一个输出概率为0.7的样本可以解释为该样本属于正类的概率为70%。

2)提升模型性能:通过将分数转换为概率,可以更好地进行决策边界的选择,从而提高分类器的性能。

3)适用于多种分类器:尽管Platt Scaling最初是为SVM设计的,但它可以广泛应用于各种分类器,如神经网络、决策树等。

4)融合不同模型的概率输出:在集成学习中,可以使用Platt Scaling将不同模型的输出分数转换为概率,然后结合这些概率进行加权平均,从而提升集成模型的性能。

三、Platt Scaling代码实现

下面,我编一个1比3的不太平衡的数据进行测试,对照组使用不进行校准的LR模型,实验组就是加入校准的LR模型,看看性能能够提高多少?

(1)不进行校准的SVM模型(默认参数)

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve# 加载数据
dataset = pd.read_csv('8PSMjianmo.csv')
X = dataset.iloc[:, 1:20].values
Y = dataset.iloc[:, 0].values# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.30, random_state=666)# 标准化数据
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)# 使用SVM分类器
classifier = SVC(kernel='linear', probability=True)
classifier.fit(X_train, y_train)# 预测结果
y_pred = classifier.predict(X_test)
y_testprba = classifier.decision_function(X_test)y_trainpred = classifier.predict(X_train)
y_trainprba = classifier.decision_function(X_train)# 混淆矩阵
cm_test = confusion_matrix(y_test, y_pred)
cm_train = confusion_matrix(y_train, y_trainpred)
print(cm_train)
print(cm_test)# 绘制测试集混淆矩阵
classes = list(set(y_test))
classes.sort()
plt.imshow(cm_test, cmap=plt.cm.Blues)
indices = range(len(cm_test))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_test)):for second_index in range(len(cm_test[first_index])):plt.text(first_index, second_index, cm_test[first_index][second_index])plt.show()# 绘制训练集混淆矩阵
classes = list(set(y_train))
classes.sort()
plt.imshow(cm_train, cmap=plt.cm.Blues)
indices = range(len(cm_train))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_train)):for second_index in range(len(cm_train[first_index])):plt.text(first_index, second_index, cm_train[first_index][second_index])plt.show()# 计算并打印性能参数
def calculate_metrics(cm, y_true, y_pred_prob):a = cm[0, 0]b = cm[0, 1]c = cm[1, 0]d = cm[1, 1]acc = (a + d) / (a + b + c + d)error_rate = 1 - accsen = d / (d + c)sep = a / (a + b)precision = d / (b + d)F1 = (2 * precision * sen) / (precision + sen)MCC = (d * a - b * c) / (np.sqrt((d + b) * (d + c) * (a + b) * (a + c)))auc_score = roc_auc_score(y_true, y_pred_prob)metrics = {"Accuracy": acc,"Error Rate": error_rate,"Sensitivity": sen,"Specificity": sep,"Precision": precision,"F1 Score": F1,"MCC": MCC,"AUC": auc_score}return metricsmetrics_test = calculate_metrics(cm_test, y_test, y_testprba)
metrics_train = calculate_metrics(cm_train, y_train, y_trainprba)print("Performance Metrics (Test):")
for key, value in metrics_test.items():print(f"{key}: {value:.4f}")print("\nPerformance Metrics (Train):")
for key, value in metrics_train.items():
print(f"{key}: {value:.4f}")

结果输出:

记住这些个数字。

这个参数的SVM还没有LR好。

(2)进行校准的SVM模型(默认参数)

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix, roc_auc_score
from sklearn.calibration import CalibratedClassifierCV# 加载数据
dataset = pd.read_csv('8PSMjianmo.csv')
X = dataset.iloc[:, 1:20].values
Y = dataset.iloc[:, 0].values# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.30, random_state=666)# 标准化数据
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)# 使用SVM分类器并进行Platt Scaling
svm = SVC(kernel='linear', probability=True)
calibrated_svm = CalibratedClassifierCV(svm, method='sigmoid')
calibrated_svm.fit(X_train, y_train)# 预测结果
y_pred = calibrated_svm.predict(X_test)
y_testprba = calibrated_svm.predict_proba(X_test)[:, 1]y_trainpred = calibrated_svm.predict(X_train)
y_trainprba = calibrated_svm.predict_proba(X_train)[:, 1]# 混淆矩阵
cm_test = confusion_matrix(y_test, y_pred)
cm_train = confusion_matrix(y_train, y_trainpred)
print(cm_train)
print(cm_test)# 绘制测试集混淆矩阵
classes = list(set(y_test))
classes.sort()
plt.imshow(cm_test, cmap=plt.cm.Blues)
indices = range(len(cm_test))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_test)):for second_index in range(len(cm_test[first_index])):plt.text(first_index, second_index, cm_test[first_index][second_index])plt.show()# 绘制训练集混淆矩阵
classes = list(set(y_train))
classes.sort()
plt.imshow(cm_train, cmap=plt.cm.Blues)
indices = range(len(cm_train))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_train)):for second_index in range(len(cm_train[first_index])):plt.text(first_index, second_index, cm_train[first_index][second_index])plt.show()# 计算并打印性能参数
def calculate_metrics(cm, y_true, y_pred_prob):a = cm[0, 0]b = cm[0, 1]c = cm[1, 0]d = cm[1, 1]acc = (a + d) / (a + b + c + d)error_rate = 1 - accsen = d / (d + c)sep = a / (a + b)precision = d / (b + d)F1 = (2 * precision * sen) / (precision + sen)MCC = (d * a - b * c) / (np.sqrt((d + b) * (d + c) * (a + b) * (a + c)))auc_score = roc_auc_score(y_true, y_pred_prob)metrics = {"Accuracy": acc,"Error Rate": error_rate,"Sensitivity": sen,"Specificity": sep,"Precision": precision,"F1 Score": F1,"MCC": MCC,"AUC": auc_score}return metricsmetrics_test = calculate_metrics(cm_test, y_test, y_testprba)
metrics_train = calculate_metrics(cm_train, y_train, y_trainprba)print("Performance Metrics (Test):")
for key, value in metrics_test.items():print(f"{key}: {value:.4f}")print("\nPerformance Metrics (Train):")
for key, value in metrics_train.items():
print(f"{key}: {value:.4f}")

看看结果:

总体来看,起作用了。训练集和验证集的AUC均有点提升,但不多。

四、换个策略

参考那篇文章的策略:采用五折交叉验证来建立和评估模型,其中四折用于训练,一折用于评估,在训练集中,其中三折用于建立SVM模型,另一折采用Platt Scaling概率校正,在训练集内部采用交叉验证对超参数进行调参。

代码:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split, KFold, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import confusion_matrix, roc_auc_score, make_scorer# 加载数据
dataset = pd.read_csv('8PSMjianmo.csv')
X = dataset.iloc[:, 1:20].values
Y = dataset.iloc[:, 0].values# 标准化数据
sc = StandardScaler()
X = sc.fit_transform(X)# 五折交叉验证
kf = KFold(n_splits=5, shuffle=True, random_state=666)# 超参数调优参数网格
param_grid = {'C': [0.1, 1, 10, 100],'kernel': ['linear', 'rbf']
}# 计算并打印性能参数
def calculate_metrics(cm, y_true, y_pred_prob):a = cm[0, 0]b = cm[0, 1]c = cm[1, 0]d = cm[1, 1]acc = (a + d) / (a + b + c + d)error_rate = 1 - accsen = d / (d + c)sep = a / (a + b)precision = d / (b + d)F1 = (2 * precision * sen) / (precision + sen)MCC = (d * a - b * c) / (np.sqrt((d + b) * (d + c) * (a + b) * (a + c)))auc_score = roc_auc_score(y_true, y_pred_prob)metrics = {"Accuracy": acc,"Error Rate": error_rate,"Sensitivity": sen,"Specificity": sep,"Precision": precision,"F1 Score": F1,"MCC": MCC,"AUC": auc_score}return metrics# 初始化结果列表
results_train = []
results_test = []# 初始化变量以跟踪最优模型
best_auc = 0
best_model = None
best_X_train = None
best_X_test = None
best_y_train = None
best_y_test = None# 交叉验证过程
for train_index, test_index in kf.split(X):X_train, X_test = X[train_index], X[test_index]y_train, y_test = Y[train_index], Y[test_index]# 内部交叉验证进行超参数调优和模型训练inner_kf = KFold(n_splits=4, shuffle=True, random_state=666)grid_search = GridSearchCV(SVC(probability=True), param_grid, cv=inner_kf, scoring='roc_auc')grid_search.fit(X_train, y_train)model = grid_search.best_estimator_# Platt Scaling 概率校准calibrated_svm = CalibratedClassifierCV(model, method='sigmoid', cv='prefit')calibrated_svm.fit(X_train, y_train)# 评估模型y_trainpred = calibrated_svm.predict(X_train)y_trainprba = calibrated_svm.predict_proba(X_train)[:, 1]cm_train = confusion_matrix(y_train, y_trainpred)metrics_train = calculate_metrics(cm_train, y_train, y_trainprba)results_train.append(metrics_train)y_pred = calibrated_svm.predict(X_test)y_testprba = calibrated_svm.predict_proba(X_test)[:, 1]cm_test = confusion_matrix(y_test, y_pred)metrics_test = calculate_metrics(cm_test, y_test, y_testprba)results_test.append(metrics_test)# 更新最优模型if metrics_test['AUC'] > best_auc:best_auc = metrics_test['AUC']best_model = calibrated_svmbest_X_train = X_trainbest_X_test = X_testbest_y_train = y_trainbest_y_test = y_testbest_params = grid_search.best_params_print("Performance Metrics (Train):")for key, value in metrics_train.items():print(f"{key}: {value:.4f}")print("\nPerformance Metrics (Test):")for key, value in metrics_test.items():print(f"{key}: {value:.4f}")print("\n" + "="*40 + "\n")# 使用最优模型评估性能
y_trainpred = best_model.predict(best_X_train)
y_trainprba = best_model.predict_proba(best_X_train)[:, 1]
cm_train = confusion_matrix(best_y_train, y_trainpred)
metrics_train = calculate_metrics(cm_train, best_y_train, y_trainprba)y_pred = best_model.predict(best_X_test)
y_testprba = best_model.predict_proba(best_X_test)[:, 1]
cm_test = confusion_matrix(best_y_test, y_pred)
metrics_test = calculate_metrics(cm_test, best_y_test, y_testprba)print("Performance Metrics of the Best Model (Train):")
for key, value in metrics_train.items():print(f"{key}: {value:.4f}")print("\nPerformance Metrics of the Best Model (Test):")
for key, value in metrics_test.items():print(f"{key}: {value:.4f}")# 打印最优模型的参数
print("\nBest Model Parameters:")
for key, value in best_params.items():print(f"{key}: {value}")

提升挺大的,不过可能是因为进行了超参数的搜索吧。

最优的SVM参数是:

Best Model Parameters:

C: 0.1

kernel: rbf

大家有空可以去使用这个参数试一试,不进行校准的SVM的性能如何?

不过也无所谓啦,结果不错就行。

五、最后

各位可以去试一试在其他数据或者在其他机器学习分类模型中使用的效果。

数据不分享啦。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/408717.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unet改进3:在不同位置添加NAMAttention注意力机制

本文内容:在不同位置添加NAMAttention注意力机制 目录 论文简介 1.步骤一 2.步骤二 3.步骤三 4.步骤四 论文简介 识别不太显著的特征是模型压缩的关键。然而,它在革命性的注意机制中尚未得到研究。在这项工作中,我们提出了一种新的基于归一化的注意力模块(NAM),它抑制…

2024/8/25 Nacos本机配置

目录 一、nacos下载 二、修改配置文件 2.1、配置数据库 2.2、配置鉴定密钥 2.3、配置启动脚本 三、nacos启动 3.1、启动运行脚本 3.2、新增配置 3.3、服务列表 记录一下本机nacos2.2.3配置全过程 一、nacos下载 下载地址:https://github.com/alibaba/nacos/r…

【通俗易懂】分库分表时分表字段应该如何选择?

目录 一、按买家ID分还是卖家ID分? 二、卖家查询怎么办? 三、订单查询怎么办? 在分库分表的过程中,我们需要有一个字段用来进行分表,比如按照用户分表、按照时间分表、按照地区分表。这里面的用户、时间、地区就是所…

MYSQL————数据库的约束

1.约束类型 1.not null:指示某列不能存储null值 2.unique:保证某列的每行必须有唯一值 3.default:规定没有给列赋值时的默认值 4.primary key:not null和unique的结合。确保某列(或两个或多个列的结合)有唯…

游戏开发设计模式之命令模式

目录 命令模式的定义和工作原理 应用场景 实现方式 优点 缺点 结论 命令模式在游戏开发中的具体实现案例是什么? 如何在Unity3D中有效地实现和管理命令模式以提高游戏性能? 命令模式与其他设计模式(如观察者模式、状态模式&#xff…

2024前端面试题-js篇

1.js有哪些数据类型 基础数据类型:string,number,boolean,null,undefined,bigInt,symbol 引用数据类型:Object 2.js检测数据类型的方式 typeof:其中数组、对象、null都会被判断为object&…

多动症的孩子有哪些症状表现?

在星启帆自闭症儿童康复机构,我们不仅关注自闭症儿童的成长与康复,也深刻认识到多动症对儿童日常生活、学习和社交的深远影响。多动症,全称注意缺陷多动障碍,是一种常见于儿童时期的神经发育性疾病,其症状表现多种多样…

Python酷库之旅-第三方库Pandas(097)

目录 一、用法精讲 416、pandas.DataFrame.memory_usage方法 416-1、语法 416-2、参数 416-3、功能 416-4、返回值 416-5、说明 416-6、用法 416-6-1、数据准备 416-6-2、代码示例 416-6-3、结果输出 417、pandas.DataFrame.empty属性 417-1、语法 417-2、参数 …

FTP协议-匿名用户登录 从0到1

前言 日常大家可能接触web漏洞比较多而对其他端口及协议不那么了解,其实其他协议漏洞在渗透中也同样重要只是平时可能接触得不多。本文将介绍FTP协议、FTP匿名用户登录及其具体流程分析和自动化利用demo。 FTP简介 FTP是File Transfer Protocol(文件传…

google浏览器chrome用户数据(拓展程序,书签等)丢失问题

一、问题背景 我出现这个情况的问题背景是:因为C盘块满了想清理一部分空间(具体看这:windows -- C盘清理_c盘softwaredistribution-CSDN博客),于是找到了更改AppDatta这个方法,但因为,当时做迁移…

opencv-python图像增强十五:高级滤镜实现

文章目录 前言二、鲜食滤镜三、巧克力滤镜三,冷艳滤镜: 前言 在之前两个滤镜文章中介绍了六种简单的滤镜实现,它们大多都是由一个单独函数实现的接下来介绍五种结合了之前图像增强文章提的的算法的复合滤镜。本案例中的算法来自于文章一&…

qt-PLC可视化编辑器

qt-PLC可视化编辑器 一、演示效果二、核心代码三、下载链接 一、演示效果 二、核心代码 #include "diagramitem.h" #include "arrow.h"#include <QDebug> #include <QGraphicsScene> #include <QGraphicsSceneContextMenuEvent> #includ…

[000-01-018].第3节:Linux环境下ElasticSearch环境搭建

我的后端学习笔记大纲 我的ElasticSearch学习大纲 1.Linux系统搭建ES环境&#xff1a; 1.1.单机版&#xff1a; a.安装ES-7.8版本 1.下载ES: 2.上传与解压&#xff1a;将下载的tar包上传到服务器software目录下&#xff0c;然后解压缩&#xff1a;tar -zxvf elasticsearch-7…

人工智能算法工程师(中级)课程21-深度学习中各种优化器算法的应用与实践、代码详解

大家好&#xff0c;我是微学AI&#xff0c;今天给大家介绍一下人工智能算法工程师(中级)课程21-深度学习中各种优化器算法的应用与实践、代码详解。本文将介绍PyTorch框架下的几种优化器&#xff0c;展示如何使用PyTorch中的优化器&#xff0c;我们将使用MNIST数据集和一个简单…

增材制造(3D打印):为何备受制造业瞩目?

在科技浪潮的推动下&#xff0c;增材制造——即3D打印技术&#xff0c;正逐步成为制造业领域的璀璨新星&#xff0c;吸引了航空航天、汽车、家电、电子等众多行业的目光。那么&#xff0c;是什么让3D打印技术如此引人注目并广泛应用于制造领域&#xff1f;其背后的核心优势又是…

Unity-可分组折叠的Editor

Unity-可分组折叠的Editor &#x1f957;功能介绍&#x1f36d;用法 &#x1f957;功能介绍 在序列化的字段上标记特性:[FoldoutGroup(“xxx”)]&#xff0c;inspector上就会被分组折叠显示。 &#xff08;没有被指定的字段自动放到Default组中&#xff09; 传送门&#x1f30…

【Python】1.基础语法(1)

文章目录 1.变量的语法1.1定义变量1.1.1硬性规则(务必遵守)1.1.2软性规则&#xff08;建议遵守&#xff09; 1.2使用变量 2.变量的类型2.1整型2.2浮点型2.3 字符串类型2.4布尔类型2.5其他类型2.6 动态类型特性 3.注释3.1 注释行3.2 文档字符串 3.3 如何批量注释3.4注释的规范 4…

深信服上半年亏损5.92亿,营收同比降低2.3亿

吉祥知识星球http://mp.weixin.qq.com/s?__bizMzkwNjY1Mzc0Nw&mid2247485367&idx1&sn837891059c360ad60db7e9ac980a3321&chksmc0e47eebf793f7fdb8fcd7eed8ce29160cf79ba303b59858ba3a6660c6dac536774afb2a6330#rd 《网安面试指南》http://mp.weixin.qq.com/s?…

如何使用ssm实现网络安全宣传网站设计

TOC ssm177网络安全宣传网站设计jsp 绪论 1.1 研究背景 当前社会各行业领域竞争压力非常大&#xff0c;随着当前时代的信息化&#xff0c;科学化发展&#xff0c;让社会各行业领域都争相使用新的信息技术&#xff0c;对行业内的各种相关数据进行科学化&#xff0c;规范化管…

【SQL】指定日期的产品价格

目录 题目 分析 代码 题目 产品数据表: Products ------------------------ | Column Name | Type | ------------------------ | product_id | int | | new_price | int | | change_date | date | ------------------------ (product_id, chang…