第100+28步 ChatGPT学习:概率校准 Bayesian Calibration

基于Python 3.9版本演示

一、写在前面

最近看了一篇在Lancet子刊《eClinicalMedicine》上发表的机器学习分类的文章:《Development of a novel dementia risk prediction model in the general population: A large, longitudinal, population-based machine-learning study》。

学到一种叫做“概率校准”的骚操作,顺手利用GPT系统学习学习。

文章中用的技术是:保序回归(Isotonic regression)。

为了体现举一反三,顺便问了GPT还有哪些方法也可以实现概率校准。它给我列举了很多,那么就一个一个学习吧。

这一期,介绍一个叫做 Bayesian Calibration 的方法。

二、Bayesian Calibration

Bayesian Calibration是一种统计方法,用于调整模型的输出,通常是概率预测,以更好地与观察到的结果对齐。这种方法利用贝叶斯推断,通过结合新的数据更新先验的模型参数,从而确保预测结果具有良好的校准性。

(1)主要步骤

1)先验分布:定义需要校准的参数的先验分布。这个先验分布反映了在考虑当前数据之前对参数值的初始信念。

2)似然函数:定义一个似然函数,它描述了在给定参数值的情况下,观察到的数据的可能性。

3)后验分布:使用贝叶斯定理,将先验分布和似然函数结合起来,得到参数的后验分布。

4)预测和调整:使用后验分布来调整模型的预测结果,改善其与观察数据的校准性。

三、Bayesian Calibration代码实现

下面,我编一个1比3的不太平衡的数据进行测试,对照组使用不进行校准的SVM模型,实验组就是加入校准的SVM模型,看看性能能够提高多少?

(1)不进行校准的SVM模型(默认参数)

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve# 加载数据
dataset = pd.read_csv('8PSMjianmo.csv')
X = dataset.iloc[:, 1:20].values
Y = dataset.iloc[:, 0].values# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.30, random_state=666)# 标准化数据
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)# 使用SVM分类器
classifier = SVC(kernel='linear', probability=True)
classifier.fit(X_train, y_train)# 预测结果
y_pred = classifier.predict(X_test)
y_testprba = classifier.decision_function(X_test)y_trainpred = classifier.predict(X_train)
y_trainprba = classifier.decision_function(X_train)# 混淆矩阵
cm_test = confusion_matrix(y_test, y_pred)
cm_train = confusion_matrix(y_train, y_trainpred)
print(cm_train)
print(cm_test)# 绘制测试集混淆矩阵
classes = list(set(y_test))
classes.sort()
plt.imshow(cm_test, cmap=plt.cm.Blues)
indices = range(len(cm_test))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_test)):for second_index in range(len(cm_test[first_index])):plt.text(first_index, second_index, cm_test[first_index][second_index])plt.show()# 绘制训练集混淆矩阵
classes = list(set(y_train))
classes.sort()
plt.imshow(cm_train, cmap=plt.cm.Blues)
indices = range(len(cm_train))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_train)):for second_index in range(len(cm_train[first_index])):plt.text(first_index, second_index, cm_train[first_index][second_index])plt.show()# 计算并打印性能参数
def calculate_metrics(cm, y_true, y_pred_prob):a = cm[0, 0]b = cm[0, 1]c = cm[1, 0]d = cm[1, 1]acc = (a + d) / (a + b + c + d)error_rate = 1 - accsen = d / (d + c)sep = a / (a + b)precision = d / (b + d)F1 = (2 * precision * sen) / (precision + sen)MCC = (d * a - b * c) / (np.sqrt((d + b) * (d + c) * (a + b) * (a + c)))auc_score = roc_auc_score(y_true, y_pred_prob)metrics = {"Accuracy": acc,"Error Rate": error_rate,"Sensitivity": sen,"Specificity": sep,"Precision": precision,"F1 Score": F1,"MCC": MCC,"AUC": auc_score}return metricsmetrics_test = calculate_metrics(cm_test, y_test, y_testprba)
metrics_train = calculate_metrics(cm_train, y_train, y_trainprba)print("Performance Metrics (Test):")
for key, value in metrics_test.items():print(f"{key}: {value:.4f}")print("\nPerformance Metrics (Train):")
for key, value in metrics_train.items():
print(f"{key}: {value:.4f}")

结果输出:

记住这些个数字。

这个参数的SVM还没有LR好。

(2)进行校准的SVM模型(默认参数)

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix, roc_auc_score, brier_score_loss, precision_score, f1_score, matthews_corrcoef
from sklearn.calibration import calibration_curve
import statsmodels.api as sm# 加载数据
dataset = pd.read_csv('8PSMjianmo.csv')
X = dataset.iloc[:, 1:20].values
Y = dataset.iloc[:, 0].values# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.30, random_state=666)# 标准化数据
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)# 使用SVM分类器
classifier = SVC(kernel='rbf', C=0.1, probability=True)
classifier.fit(X_train, y_train)# 获取未校准的概率预测
y_train_probs = classifier.predict_proba(X_train)[:, 1]
y_test_probs = classifier.predict_proba(X_test)[:, 1]# 贝叶斯二项回归校准类
class BayesianBinomialCalibration:def __init__(self):self.model = Nonedef fit(self, probs, true_labels):# 使用逻辑回归模型进行校准self.model = sm.GLM(true_labels, sm.add_constant(probs), family=sm.families.Binomial()).fit()def predict(self, probs):# 使用校准模型预测校准后的概率calibrated_probs = self.model.predict(sm.add_constant(probs))return calibrated_probs# 训练贝叶斯二项回归校准模型
bbc = BayesianBinomialCalibration()
bbc.fit(y_train_probs, y_train)# 进行校准
calibrated_train_probs = bbc.predict(y_train_probs)
calibrated_test_probs = bbc.predict(y_test_probs)# 预测结果
y_train_pred = (calibrated_train_probs >= 0.5).astype(int)
y_test_pred = (calibrated_test_probs >= 0.5).astype(int)# 混淆矩阵
cm_test = confusion_matrix(y_test, y_test_pred)
cm_train = confusion_matrix(y_train, y_train_pred)
print(cm_train)
print(cm_test)# 绘制混淆矩阵函数
def plot_confusion_matrix(cm, classes, title='Confusion Matrix'):plt.imshow(cm, cmap=plt.cm.Blues)indices = range(len(cm))plt.xticks(indices, classes)plt.yticks(indices, classes)plt.colorbar()plt.xlabel('Predicted')plt.ylabel('Actual')for first_index in range(len(cm)):for second_index in range(len(cm[first_index])):plt.text(second_index, first_index, cm[first_index][second_index])plt.title(title)plt.show()# 绘制测试集混淆矩阵
plot_confusion_matrix(cm_test, list(set(y_test)), 'Confusion Matrix (Test)')# 绘制训练集混淆矩阵
plot_confusion_matrix(cm_train, list(set(y_train)), 'Confusion Matrix (Train)')# 计算并打印性能参数
def calculate_metrics(cm, y_true, y_pred, y_pred_prob):a = cm[0, 0]b = cm[0, 1]c = cm[1, 0]d = cm[1, 1]acc = (a + d) / (a + b + c + d)error_rate = 1 - accsen = d / (d + c) if (d + c) > 0 else 0sep = a / (a + b) if (a + b) > 0 else 0precision = precision_score(y_true, y_pred, zero_division=0)F1 = f1_score(y_true, y_pred, zero_division=0)MCC = matthews_corrcoef(y_true, y_pred)auc_score = roc_auc_score(y_true, y_pred_prob)brier_score = brier_score_loss(y_true, y_pred_prob)metrics = {"Accuracy": acc,"Error Rate": error_rate,"Sensitivity": sen,"Specificity": sep,"Precision": precision,"F1 Score": F1,"MCC": MCC,"AUC": auc_score,"Brier Score": brier_score}return metricsmetrics_test = calculate_metrics(cm_test, y_test, y_test_pred, calibrated_test_probs)
metrics_train = calculate_metrics(cm_train, y_train, y_train_pred, calibrated_train_probs)print("Performance Metrics (Test):")
for key, value in metrics_test.items():print(f"{key}: {value:.4f}")print("\nPerformance Metrics (Train):")
for key, value in metrics_train.items():print(f"{key}: {value:.4f}")

看看结果:

大同小异吧。

四、换个策略

参考那篇文章的策略:采用五折交叉验证来建立和评估模型,其中四折用于训练,一折用于评估,在训练集中,其中三折用于建立SVM模型,另一折采用Bayesian Calibration概率校正,在训练集内部采用交叉验证对超参数进行调参。

代码:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split, GridSearchCV, KFold
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix, roc_auc_score, brier_score_loss, precision_score, f1_score, matthews_corrcoef
from sklearn.calibration import calibration_curve
import statsmodels.api as sm# 加载数据
dataset = pd.read_csv('8PSMjianmo.csv')
X = dataset.iloc[:, 1:20].values
Y = dataset.iloc[:, 0].values# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.30, random_state=666)# 标准化数据
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)# 定义五折交叉验证
kf = KFold(n_splits=5, shuffle=True, random_state=666)
calibrated_probs = []
true_labels = []# 贝叶斯二项回归校准类
class BayesianBinomialCalibration:def __init__(self):self.model = Nonedef fit(self, probs, true_labels):# 使用逻辑回归模型进行校准self.model = sm.GLM(true_labels, sm.add_constant(probs), family=sm.families.Binomial()).fit()def predict(self, probs):# 使用校准模型预测校准后的概率calibrated_probs = self.model.predict(sm.add_constant(probs))return calibrated_probsbest_params = None  # 用于存储最优参数for train_index, val_index in kf.split(X_train):X_train_fold, X_val_fold = X_train[train_index], X_train[val_index]y_train_fold, y_val_fold = y_train[train_index], y_train[val_index]# 内部三折交叉验证用于超参数调优inner_kf = KFold(n_splits=3, shuffle=True, random_state=666)param_grid = {'C': [0.01, 0.1, 1, 10, 100], 'kernel': ['rbf']}svm = SVC(probability=True)clf = GridSearchCV(svm, param_grid, cv=inner_kf, scoring='roc_auc')clf.fit(X_train_fold, y_train_fold)best_params = clf.best_params_# 使用最佳参数训练SVMclassifier = SVC(kernel=best_params['kernel'], C=best_params['C'], probability=True)classifier.fit(X_train_fold, y_train_fold)# 获取未校准的概率预测y_val_fold_probs = classifier.predict_proba(X_val_fold)[:, 1]# 贝叶斯二项回归校准bbc = BayesianBinomialCalibration()bbc.fit(y_val_fold_probs, y_val_fold)calibrated_val_fold_probs = bbc.predict(y_val_fold_probs)calibrated_probs.extend(calibrated_val_fold_probs)true_labels.extend(y_val_fold)# 用于测试集的SVM模型训练和校准
classifier_final = SVC(kernel=best_params['kernel'], C=best_params['C'], probability=True)
classifier_final.fit(X_train, y_train)
y_test_probs = classifier_final.predict_proba(X_test)[:, 1]# 贝叶斯二项回归校准
bbc_final = BayesianBinomialCalibration()
bbc_final.fit(y_test_probs, y_test)
calibrated_test_probs = bbc_final.predict(y_test_probs)# 预测结果
y_train_pred = (np.array(calibrated_probs) >= 0.5).astype(int)
y_test_pred = (calibrated_test_probs >= 0.5).astype(int)# 混淆矩阵
cm_test = confusion_matrix(y_test, y_test_pred)
cm_train = confusion_matrix(true_labels, y_train_pred)
print("Training Confusion Matrix:\n", cm_train)
print("Testing Confusion Matrix:\n", cm_test)# 绘制混淆矩阵函数
def plot_confusion_matrix(cm, classes, title='Confusion Matrix'):plt.imshow(cm, cmap=plt.cm.Blues)indices = range(len(cm))plt.xticks(indices, classes)plt.yticks(indices, classes)plt.colorbar()plt.xlabel('Predicted')plt.ylabel('Actual')for first_index in range(len(cm)):for second_index in range(len(cm[first_index])):plt.text(second_index, first_index, cm[first_index][second_index])plt.title(title)plt.show()# 绘制测试集混淆矩阵
plot_confusion_matrix(cm_test, list(set(y_test)), 'Confusion Matrix (Test)')# 绘制训练集混淆矩阵
plot_confusion_matrix(cm_train, list(set(true_labels)), 'Confusion Matrix (Train)')# 计算并打印性能参数
def calculate_metrics(cm, y_true, y_pred, y_pred_prob):a = cm[0, 0]b = cm[0, 1]c = cm[1, 0]d = cm[1, 1]acc = (a + d) / (a + b + c + d)error_rate = 1 - accsen = d / (d + c) if (d + c) > 0 else 0sep = a / (a + b) if (a + b) > 0 else 0precision = precision_score(y_true, y_pred, zero_division=0)F1 = f1_score(y_true, y_pred, zero_division=0)MCC = matthews_corrcoef(y_true, y_pred)auc_score = roc_auc_score(y_true, y_pred_prob)brier_score = brier_score_loss(y_true, y_pred_prob)metrics = {"Accuracy": acc,"Error Rate": error_rate,"Sensitivity": sen,"Specificity": sep,"Precision": precision,"F1 Score": F1,"MCC": MCC,"AUC": auc_score,"Brier Score": brier_score}return metricsmetrics_test = calculate_metrics(cm_test, y_test, y_test_pred, calibrated_test_probs)
metrics_train = calculate_metrics(cm_train, true_labels, y_train_pred, np.array(calibrated_probs))print("Performance Metrics (Test):")
for key, value in metrics_test.items():print(f"{key}: {value:.4f}")print("\nPerformance Metrics (Train):")
for key, value in metrics_train.items():print(f"{key}: {value:.4f}")# 绘制校准曲线
def plot_calibration_curve(y_true, probs, title='Calibration Curve'):fraction_of_positives, mean_predicted_value = calibration_curve(y_true, probs, n_bins=10)plt.plot(mean_predicted_value, fraction_of_positives, "s-", label="Bayesian Binomial Calibration")plt.plot([0, 1], [0, 1], "k--")plt.xlabel('Mean predicted value')plt.ylabel('Fraction of positives')plt.title(title)plt.legend()plt.show()# 绘制校准曲线
plot_calibration_curve(y_test, calibrated_test_probs, title='Calibration Curve (Test)')
plot_calibration_curve(true_labels, np.array(calibrated_probs), title='Calibration Curve (Train)')

输出:

效果甚至还变差了呢。

五、最后

各位可以去试一试在其他数据或者在其他机器学习分类模型中使用的效果。

数据不分享啦。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/448040.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt-窗口对话框相关操作(50)

目录 描述 创建 使用 点击弹出对话框 内存泄漏问题 自定义对话框 纯代码 界面操作 模态和非模态对话框 描述 对话框是 GUI 程序中不可或缺的组成部分。⼀些不适合在主窗⼝实现的功能组件可以设置在对话框中。对话框通常是⼀个顶层窗⼝,出现在程序最上层&am…

2024年腾讯外包面试题(微创公司)

笔试&#xff1a; 1、判断异步执行顺序 console.log(1);setTimeout(()>{Promise.resolve().then(()>{console.log(2);})console.log(3);},0);new Promise ((resolve)>{for(let i0; i<1000;i ){if(i1000){resolve();}}console.log(4);}).then(()>{console.log(5…

高可用之限流 08-leaky bucket漏桶算法

限流系列 开源组件 rate-limit: 限流 高可用之限流-01-入门介绍 高可用之限流-02-如何设计限流框架 高可用之限流-03-Semaphore 信号量做限流 高可用之限流-04-fixed window 固定窗口 高可用之限流-05-slide window 滑动窗口 高可用之限流-06-slide window 滑动窗口 sen…

SCALABLEANDEFFECTIVE IMPLICIT GRAPH NEURALNETWORKS ON LARGEGRAPHS

ICLR24 推荐指数&#xff1a; #paper/⭐⭐ 领域&#xff1a; 大图&#xff0c;图扩展 大概的工作&#xff1a;提出了针对子图的虚拟节点&#xff0c;让所有点都与其相连 相关工作&#xff1a; 传统GNN与Inplicit gnn 传统GNN的传播函数&#xff1a; Z ( l 1 ) ϕ ( W ( …

Karmada核心概念

以下内容为翻译&#xff0c;原文地址 Karmada 是什么&#xff1f; | karmada 一、Karmada核心概念 一&#xff09;什么是Karmada 1、Karmada&#xff1a;开放&#xff0c;多云&#xff0c;多集群Kubernetes业务流程 Karmada (Kubernetes Armada)是一个Kubernetes管理系统&…

【OpenCV】(六)—— 阈值处理

阈值处理&#xff08;Thresholding&#xff09;用于将灰度图像转换为二值图像。通过设定一个或多个阈值&#xff0c;可以将图像中的像素分为不同的类别&#xff0c;通常用于分割前景和背景、简化图像、去除噪声等任务。OpenCV 提供了多种阈值处理方法&#xff0c;下面介绍基本阈…

让AI像人一样思考和使用工具,reAct机制详解

reAct机制详解 reAct是什么reAct的关键要素reAct的思维过程reAct的代码实现查看效果引入依赖&#xff0c;定义模型定义相关工具集合工具创建代理启动测试完整代码 思考 reAct是什么 reAct的核心思想是将**推理&#xff08;Reasoning&#xff09;和行动&#xff08;Acting&…

探索人工智能:深度解析未来科技的核心驱动力

目录 &#x1f354; 人工智能的应用方向 &#x1f354; 人工智能的发展历史 &#x1f354; 人工智能、机器学习、深度学习关系 &#x1f354; 为什么学习机器学习&#xff1f; &#x1f354; 小节 学习目标 &#x1f340; 了解人工智能的应用方向 &#x1f340; 了解人工智…

【千库网-注册安全分析报告】

前言 由于网站注册入口容易被黑客攻击&#xff0c;存在如下安全问题&#xff1a; 暴力破解密码&#xff0c;造成用户信息泄露短信盗刷的安全问题&#xff0c;影响业务及导致用户投诉带来经济损失&#xff0c;尤其是后付费客户&#xff0c;风险巨大&#xff0c;造成亏损无底洞…

iPad备份软件哪个好?好用的苹果备份软件推荐

苹果手机在将数据备份到电脑时&#xff0c;需要通过第三方的管理软件&#xff0c;才可以将手机连接到电脑进行备份。苹果手机备份软件有很多&#xff0c;常用的有&#xff1a;爱思助手、iMazing、iTuns等。那么这三款常用的备份软件究竟哪款更好呢&#xff1f;下面就给大家盘点…

uniapp学习(004-2 组件 Part.2生命周期)

零基础入门uniapp Vue3组合式API版本到咸虾米壁纸项目实战&#xff0c;开发打包微信小程序、抖音小程序、H5、安卓APP客户端等 总时长 23:40:00 共116P 此文章包含第31p-第p35的内容 文章目录 组件生命周期我们主要使用的三种生命周期setup(创建组件时执行)不可以操作dom节点…

Kimi AI助手重大更新:语音通话功能闪亮登场!

Kimi人工智能助手近日发布了一项令人瞩目的重大更新&#xff0c;其中最引人注目的是新增的语音通话功能。这一创新不仅拓展了用户与AI互动的方式&#xff0c;还为学习和工作场景提供了突破性的解决方案。 Ai 智能办公利器 - Ai-321.com 人工智能 - Ai工具集 - 全球热门人工智能…

使用 python 下载 bilibili 视频

本文想要达成的目标为&#xff1a;运行 python 代码之后&#xff0c;在终端输入视频链接&#xff0c;可自动下载高清 1080P 视频并保存到相应文件夹。 具体可分为两大步&#xff1a;首先&#xff0c;使用浏览器开发者工具 F12 获取请求链接相关信息&#xff08;根据 api 接口下…

性能测试持续继承 CICD

目录 一、如何实现性能测试持续继承操作 下载ant 验证ant是否安装成功 二、jmeterant结合 1、我们需要把jmeter中extres 中的ant-jmeter-1.1.1.jar 复制到ant的安装目录中的lib目录中 2、把jmeter中extres中的build.xml 复制到ant的安装目录中的bin目录 3、编辑build.x…

uniapp 设置 tabbar 的 midButton 按钮

效果展示&#xff1a; 中间的国际化没生效&#xff08;忽略就行&#xff09; 示例代码&#xff1a; 然后在 App.vue 中进行监听&#xff1a; <script>export default {onLaunch(e) {// #ifdef APPuni.onTabBarMidButtonTap(()>{console.log("中间按钮点击回调…

Nacos安装指南

1.Windows安装 开发阶段采用单机安装即可。 1.1.下载安装包 在Nacos的GitHub页面,提供有下载链接,可以下载编译好的Nacos服务端或者源代码: GitHub主页:https://github.com/alibaba/nacos GitHub的Release下载页:https://github.com/alibaba/nacos/releases 如图: …

C1. Adjust The Presentation (Easy Version) 双指针

C1. Adjust The Presentation (Easy Version) 妈呀, 最难读懂的一道题(英语不好) 原题 思路 这道题读懂之后就是双指针. 不难想到只要之前出现过, 就一定可以展示出来, 唯一需要注意的时不能在a里有多余的科幻片 代码 #include <bits/stdc.h> #define int long long…

Python爬虫之正则表达式于xpath的使用教学及案例

正则表达式 常用的匹配模式 \d # 匹配任意一个数字 \D # 匹配任意一个非数字 \w # 匹配任意一个单词字符&#xff08;数字、字母、下划线&#xff09; \W # 匹配任意一个非单词字符 . # 匹配任意一个字符&#xff08;除了换行符&#xff09; [a-z] # 匹配任意一个小写字母 […

牛客:Holding Two,Inverse Pair,Counting Triangles

Holding Two 题目描述 登录—专业IT笔试面试备考平台_牛客网 ​​运行代码 #include<bits/stdc.h> using namespace std; const int N3e45; string s1,s2; int main(){int n,m;cin>>n>>m;for(int i0;i<m;i){if(i&1){s10;s21;} else{s11;s20;} }fo…

jvm内存溢出问题排查Java服务自动停止问题排查

Java服务自动停止&#xff0c;Java服务内存溢出问题解决记录。 过程描述 服务器上的一个项目突然服务不了了&#xff0c;登录服务器一看&#xff0c;服务被停了&#xff0c;第一反应大概率就是内存溢出导致的&#xff0c;结果查看日志没有任何报错&#xff0c;就很奇怪&#…