第100+15步 ChatGPT学习:R实现Ababoost分类

基于R 4.2.2版本演示

一、写在前面

有不少大佬问做机器学习分类能不能用R语言,不想学Python咯。

答曰:可!用GPT或者Kimi转一下就得了呗。

加上最近也没啥内容写了,就帮各位搬运一下吧。

二、R代码实现Ababoost分类

(1)导入数据

我习惯用RStudio自带的导入功能:

(2)建立Ababoost模型(默认参数)

# Load necessary libraries
library(caret)
library(pROC)
library(ggplot2)# Assume 'data' is your dataframe containing the data
# Set seed to ensure reproducibility
set.seed(123)# Split data into training and validation sets (80% training, 20% validation)
trainIndex <- createDataPartition(data$X, p = 0.8, list = FALSE)
trainData <- data[trainIndex, ]
validData <- data[-trainIndex, ]# Convert the target variable to a factor for classification
trainData$X <- as.factor(trainData$X)
validData$X <- as.factor(validData$X)# Define control method for training with cross-validation
trainControl <- trainControl(method = "cv", number = 10)# Fit Random Forest model on the training set
model <- train(X ~ ., data = trainData, method = "ada", trControl = trainControl)# Print the best parameters found by the model
best_params <- model$bestTune
cat("The best parameters found are:\n")
print(best_params)# Predict on the training and validation sets
trainPredict <- predict(model, trainData, type = "prob")[,2]
validPredict <- predict(model, validData, type = "prob")[,2]# Calculate ROC curves and AUC values
trainRoc <- roc(response = trainData$X, predictor = trainPredict)
validRoc <- roc(response = validData$X, predictor = validPredict)# Plot ROC curves with AUC values
ggplot(data = data.frame(fpr = trainRoc$specificities, tpr = trainRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +geom_line(color = "blue") +geom_area(alpha = 0.2, fill = "blue") +geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +ggtitle("Training ROC Curve") +xlab("False Positive Rate") +ylab("True Positive Rate") +annotate("text", x = 0.5, y = 0.1, label = paste("Training AUC =", round(auc(trainRoc), 2)), hjust = 0.5, color = "blue")ggplot(data = data.frame(fpr = validRoc$specificities, tpr = validRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +geom_line(color = "red") +geom_area(alpha = 0.2, fill = "red") +geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +ggtitle("Validation ROC Curve") +xlab("False Positive Rate") +ylab("True Positive Rate") +annotate("text", x = 0.5, y = 0.2, label = paste("Validation AUC =", round(auc(validRoc), 2)), hjust = 0.5, color = "red")# Calculate confusion matrices based on 0.5 cutoff for probability
confMatTrain <- table(trainData$X, trainPredict >= 0.5)
confMatValid <- table(validData$X, validPredict >= 0.5)# Function to plot confusion matrix using ggplot2
plot_confusion_matrix <- function(conf_mat, dataset_name) {conf_mat_df <- as.data.frame(as.table(conf_mat))colnames(conf_mat_df) <- c("Actual", "Predicted", "Freq")p <- ggplot(data = conf_mat_df, aes(x = Predicted, y = Actual, fill = Freq)) +geom_tile(color = "white") +geom_text(aes(label = Freq), vjust = 1.5, color = "black", size = 5) +scale_fill_gradient(low = "white", high = "steelblue") +labs(title = paste("Confusion Matrix -", dataset_name, "Set"), x = "Predicted Class", y = "Actual Class") +theme_minimal() +theme(axis.text.x = element_text(angle = 45, hjust = 1), plot.title = element_text(hjust = 0.5))print(p)
}# Now call the function to plot and display the confusion matrices
plot_confusion_matrix(confMatTrain, "Training")
plot_confusion_matrix(confMatValid, "Validation")# Extract values for calculations
a_train <- confMatTrain[1, 1]
b_train <- confMatTrain[1, 2]
c_train <- confMatTrain[2, 1]
d_train <- confMatTrain[2, 2]a_valid <- confMatValid[1, 1]
b_valid <- confMatValid[1, 2]
c_valid <- confMatValid[2, 1]
d_valid <- confMatValid[2, 2]# Training Set Metrics
acc_train <- (a_train + d_train) / sum(confMatTrain)
error_rate_train <- 1 - acc_train
sen_train <- d_train / (d_train + c_train)
sep_train <- a_train / (a_train + b_train)
precision_train <- d_train / (b_train + d_train)
F1_train <- (2 * precision_train * sen_train) / (precision_train + sen_train)
MCC_train <- (d_train * a_train - b_train * c_train) / sqrt((d_train + b_train) * (d_train + c_train) * (a_train + b_train) * (a_train + c_train))
auc_train <- roc(response = trainData$X, predictor = trainPredict)$auc# Validation Set Metrics
acc_valid <- (a_valid + d_valid) / sum(confMatValid)
error_rate_valid <- 1 - acc_valid
sen_valid <- d_valid / (d_valid + c_valid)
sep_valid <- a_valid / (a_valid + b_valid)
precision_valid <- d_valid / (b_valid + d_valid)
F1_valid <- (2 * precision_valid * sen_valid) / (precision_valid + sen_valid)
MCC_valid <- (d_valid * a_valid - b_valid * c_valid) / sqrt((d_valid + b_valid) * (d_valid + c_valid) * (a_valid + b_valid) * (a_valid + c_valid))
auc_valid <- roc(response = validData$X, predictor = validPredict)$auc# Print Metrics
cat("Training Metrics\n")
cat("Accuracy:", acc_train, "\n")
cat("Error Rate:", error_rate_train, "\n")
cat("Sensitivity:", sen_train, "\n")
cat("Specificity:", sep_train, "\n")
cat("Precision:", precision_train, "\n")
cat("F1 Score:", F1_train, "\n")
cat("MCC:", MCC_train, "\n")
cat("AUC:", auc_train, "\n\n")cat("Validation Metrics\n")
cat("Accuracy:", acc_valid, "\n")
cat("Error Rate:", error_rate_valid, "\n")
cat("Sensitivity:", sen_valid, "\n")
cat("Specificity:", sep_valid, "\n")
cat("Precision:", precision_valid, "\n")
cat("F1 Score:", F1_valid, "\n")
cat("MCC:", MCC_valid, "\n")
cat("AUC:", auc_valid, "\n")

在R语言中,使用 caret 包训练Ababoost模型时,最关键的可调参数不多,下面是一些可以调整的关键参数:

①Iter: 这是最重要的参数之一,代表弱学习器的数量,即AdaBoost算法中的迭代次数。较大的nIter值通常可以提高模型的复杂度和拟合能力,但也可能导致过拟合。

②maxdepth: 这是决策树的最大深度。AdaBoost通常使用决策树作为其弱学习器。通过调整maxdepth可以控制单个决策树的复杂度,从而影响整个集成模型的复杂度。

③nu: 这个参数是学习率(也称为收缩参数或步长)。它用于更新每次迭代中模型权重。较小的nu值可以使模型学习得更加谨慎,通常可以减少过拟合的风险,但可能需要更多的迭代次数来收敛。

结果输出(默认参数):

在默认参数中,caret包已经默默帮我们吧上面三个参数进行测试和寻优。

从AUC来看,Ababoost随便一跑,就跑出个不错的结果。不过有些过拟合了,验证集的性能稍微差些。

三、Ababoost手动调参方法(3个值)

设置iter值取值50、100、200、400、600;maxdepth取值1、2、5、7和9;nu取值0.01、0.1、0.5:

# Load necessary libraries
library(caret)
library(pROC)
library(ggplot2)# Assume 'data' is your dataframe containing the data
# Set seed to ensure reproducibility
set.seed(123)# Split data into training and validation sets (80% training, 20% validation)
trainIndex <- createDataPartition(data$X, p = 0.8, list = FALSE)
trainData <- data[trainIndex, ]
validData <- data[-trainIndex, ]# Convert the target variable to a factor for classification
trainData$X <- as.factor(trainData$X)
validData$X <- as.factor(validData$X)# Define control method for training with cross-validation
trainControl <- trainControl(method = "cv", number = 10)# Define the tuning grid with correct parameter names
tuneGrid <- expand.grid(iter = c(50, 100, 200, 400, 600),maxdepth = c(1, 2, 5, 7, 9),nu = c(0.01, 0.1, 0.5))# Train the model using the ada method and the corrected tuning grid
model <- train(X ~ ., data = trainData, method = "ada", trControl = trainControl, tuneGrid = tuneGrid)# Print the best parameters found by the model
best_params <- model$bestTune
cat("The best parameters found are:\n")
print(best_params)# Predict on the training and validation sets
trainPredict <- predict(model, trainData, type = "prob")[,2]
validPredict <- predict(model, validData, type = "prob")[,2]# Calculate ROC curves and AUC values
trainRoc <- roc(response = trainData$X, predictor = trainPredict)
validRoc <- roc(response = validData$X, predictor = validPredict)# Plot ROC curves with AUC values
ggplot(data = data.frame(fpr = trainRoc$specificities, tpr = trainRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +geom_line(color = "blue") +geom_area(alpha = 0.2, fill = "blue") +geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +ggtitle("Training ROC Curve") +xlab("False Positive Rate") +ylab("True Positive Rate") +annotate("text", x = 0.5, y = 0.1, label = paste("Training AUC =", round(auc(trainRoc), 2)), hjust = 0.5, color = "blue")ggplot(data = data.frame(fpr = validRoc$specificities, tpr = validRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +geom_line(color = "red") +geom_area(alpha = 0.2, fill = "red") +geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +ggtitle("Validation ROC Curve") +xlab("False Positive Rate") +ylab("True Positive Rate") +annotate("text", x = 0.5, y = 0.2, label = paste("Validation AUC =", round(auc(validRoc), 2)), hjust = 0.5, color = "red")# Calculate confusion matrices based on 0.5 cutoff for probability
confMatTrain <- table(trainData$X, trainPredict >= 0.5)
confMatValid <- table(validData$X, validPredict >= 0.5)# Function to plot confusion matrix using ggplot2
plot_confusion_matrix <- function(conf_mat, dataset_name) {conf_mat_df <- as.data.frame(as.table(conf_mat))colnames(conf_mat_df) <- c("Actual", "Predicted", "Freq")p <- ggplot(data = conf_mat_df, aes(x = Predicted, y = Actual, fill = Freq)) +geom_tile(color = "white") +geom_text(aes(label = Freq), vjust = 1.5, color = "black", size = 5) +scale_fill_gradient(low = "white", high = "steelblue") +labs(title = paste("Confusion Matrix -", dataset_name, "Set"), x = "Predicted Class", y = "Actual Class") +theme_minimal() +theme(axis.text.x = element_text(angle = 45, hjust = 1), plot.title = element_text(hjust = 0.5))print(p)
}# Now call the function to plot and display the confusion matrices
plot_confusion_matrix(confMatTrain, "Training")
plot_confusion_matrix(confMatValid, "Validation")# Extract values for calculations
a_train <- confMatTrain[1, 1]
b_train <- confMatTrain[1, 2]
c_train <- confMatTrain[2, 1]
d_train <- confMatTrain[2, 2]a_valid <- confMatValid[1, 1]
b_valid <- confMatValid[1, 2]
c_valid <- confMatValid[2, 1]
d_valid <- confMatValid[2, 2]# Training Set Metrics
acc_train <- (a_train + d_train) / sum(confMatTrain)
error_rate_train <- 1 - acc_train
sen_train <- d_train / (d_train + c_train)
sep_train <- a_train / (a_train + b_train)
precision_train <- d_train / (b_train + d_train)
F1_train <- (2 * precision_train * sen_train) / (precision_train + sen_train)
MCC_train <- (d_train * a_train - b_train * c_train) / sqrt((d_train + b_train) * (d_train + c_train) * (a_train + b_train) * (a_train + c_train))
auc_train <- roc(response = trainData$X, predictor = trainPredict)$auc# Validation Set Metrics
acc_valid <- (a_valid + d_valid) / sum(confMatValid)
error_rate_valid <- 1 - acc_valid
sen_valid <- d_valid / (d_valid + c_valid)
sep_valid <- a_valid / (a_valid + b_valid)
precision_valid <- d_valid / (b_valid + d_valid)
F1_valid <- (2 * precision_valid * sen_valid) / (precision_valid + sen_valid)
MCC_valid <- (d_valid * a_valid - b_valid * c_valid) / sqrt((d_valid + b_valid) * (d_valid + c_valid) * (a_valid + b_valid) * (a_valid + c_valid))
auc_valid <- roc(response = validData$X, predictor = validPredict)$auc# Print Metrics
cat("Training Metrics\n")
cat("Accuracy:", acc_train, "\n")
cat("Error Rate:", error_rate_train, "\n")
cat("Sensitivity:", sen_train, "\n")
cat("Specificity:", sep_train, "\n")
cat("Precision:", precision_train, "\n")
cat("F1 Score:", F1_train, "\n")
cat("MCC:", MCC_train, "\n")
cat("AUC:", auc_train, "\n\n")cat("Validation Metrics\n")
cat("Accuracy:", acc_valid, "\n")
cat("Error Rate:", error_rate_valid, "\n")
cat("Sensitivity:", sen_valid, "\n")
cat("Specificity:", sep_valid, "\n")
cat("Precision:", precision_valid, "\n")
cat("F1 Score:", F1_valid, "\n")
cat("MCC:", MCC_valid, "\n")
cat("AUC:", auc_valid, "\n")

结果输出:

以上是找到的相对最优参数组合,看看具体性能:

还不让入默认的性能好呢。

看看GPT给的参数的取值建议,祝各位调得开心:

iter (迭代次数): 这个参数通常设置在10到1000之间。较小的数据集可能需要较少的迭代,而较大或较复杂的数据集可能需要更多的迭代。通常开始可以尝试50, 100, 200等值,然后根据模型的性能来调整。

maxdepth (树的最大深度): 这个参数一般设置在1到10之间。深度为1意味着使用决策树桩(仅一个决策点),这有助于防止过拟合,是AdaBoost中常用的设置。但对于更复杂的数据模式,可能需要更深的树。可以尝试的值包括1, 2, 3, 5等。

nu (学习率): 学习率的典型取值范围是0.01到1。较小的学习率(如0.01, 0.1)可以使模型学习得更稳健,但收敛速度可能较慢,需要更多的迭代次数。较高的学习率可以加快学习速度,但可能导致模型在训练过程中不稳定。

四、最后

数据嘛:

链接:https://pan.baidu.com/s/1rEf6JZyzA1ia5exoq5OF7g?pwd=x8xm

提取码:x8xm

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/375660.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mybatis动态传入参数 pgsql 日期 Interval ,day,minute

mybatis动态传入参数 pgsql 日期 Interval 在navicat中&#xff0c;标准写法 SELECT * FROM test WHERE time > (NOW() - INTERVAL 5 day)在mybatis中&#xff0c;错误写法 SELECT * FROM test WHERE time > (NOW() - INTERVAL#{numbers,jdbcTypeINTEGER} day)报错内…

根据脚手架archetype快速构建spring boot/cloud项目

1、找到archetype&#xff0c;并从私仓下载添加archetype到本地 点击IDEA的file&#xff0c;选择new project 选择maven项目&#xff0c;勾选create from archetype 填写archetype信息&#xff0c;&#xff08;repository填写私仓地址&#xff09; 2、选择自定义的脚手架arche…

C++进阶:继承和多态

文章目录 ❤️继承&#x1fa77;继承与友元&#x1f9e1;继承和静态成员&#x1f49b;菱形继承及菱形虚拟继承&#x1f49a;继承和组合 ❤️多态&#x1fa77;什么是多态&#xff1f;&#x1f9e1;多态的定义以及实现&#x1f49b;虚函数&#x1f49a;虚函数的重写&#x1f499…

海外媒体发稿-全媒体百科

全球知名媒体机构 在全球范围内&#xff0c;有许多知名的新闻机构负责报道世界各地的新闻事件。以下是一些国外常见的媒体机构&#xff1a; AP&#xff08;美联社&#xff09;合众国际社&#xff08;UPI&#xff09;AFP(法新社)EFE&#xff08;埃菲通讯社&#xff09;Europa …

iPhone删除所有照片的高效三部曲

苹果手机用久了&#xff0c;系统缓存包括自己使用手机留下的内存肯定会越来越多。其中&#xff0c;相册中的照片数量可能会急剧增加&#xff0c;占据大量的存储空间。当用户们想要对相册进行彻底清理&#xff0c;实现iPhone删除所有照片时&#xff0c;不妨跟随以下详细的三部曲…

【Redis】哨兵(sentinel)

文章目录 一、哨兵是什么&#xff1f;二、 哨兵sentinel文件参数三、 模仿主机redis宕机四、哨兵运行流程和选举原理SDOWN主观下线ODOWN客观下线 五、 使用建议 以下是本篇文章正文内容 一、哨兵是什么&#xff1f; 哨兵巡查监控后台master主机是否故障&#xff0c;如果故障了…

Elasticsearch 更新指定字段

Elasticsearch 更新指定字段 准备条件查询数据更新指定字段更新子级字段 准备条件 以下查询操作都基于索引crm_clue来操作&#xff0c;索引已经建过了&#xff0c;本文主要讲Elasticsearch更新指定字段语句&#xff0c;下面开始写更新语句执行更新啦&#xff01; 查询数据 查…

Linux(Ubuntu)/Windows-C++云备份实现

目录 项目介绍&#xff1a;概要设计&#xff1a;技术调研 详细设计&#xff1a;目录监控模块样例编写&#xff08;c17的filesystem中的文件遍历功能&#xff09; 数据管理模块样例编写&#xff08;unordered_map格式存储&#xff09; 文件压缩与解压缩模块bundle数据压缩库使用…

【C++】开源:paho-mqtt-cpp库配置与使用

&#x1f60f;★,:.☆(&#xffe3;▽&#xffe3;)/$:.★ &#x1f60f; 这篇文章主要介绍paho-mqtt-cpp库配置与使用。 无专精则不能成&#xff0c;无涉猎则不能通。——梁启超 欢迎来到我的博客&#xff0c;一起学习&#xff0c;共同进步。 喜欢的朋友可以关注一下&#xff…

【qt】客户端连接到服务器

获取到IP地址和端口号. 通过connectToHost() 来进行连接. 对于客户端来讲,只需要socket即可. 客户端连接服务端只需要使用套接字(Socket)来进行通信。客户端通过创建一个套接字来连接服务端&#xff0c;然后可以通过套接字发送和接收数据。套接字提供了一种简单而灵活的方式来…

pycharm 占满磁盘

磁盘里没装什么大文件&#xff0c;发现磁盘被占的越来越满&#xff0c;使用工具查看到底是哪个文件如此之大。 发现罪魁祸首是pycharm&#xff01;&#xff01;&#xff01; 根据工具的提示找到对应的路径文件&#xff1a;E:\pycharm\PyCharmCE2022.3\python_packages 发现pa…

海外短剧开源系统UNIAPP源码(支持多语言/海外支付/快捷登录)

给大家推荐一款非常好的海外短剧系统源码&#xff0c;这款源码案例应该最多的。 他们的开源地址&#xff1a;https://gitee.com/qiao-yonggang/haiwaiduanju 这套系统的优势&#xff1a; 支持世界的所有语言&#xff0c;可后台自定义设置支持paypal stripe pix等多种海外支付…

Redis的配置优化、数据类型、消息队列

文章目录 一、Redis的配置优化redis主要配置项CONFIG 动态修改配置慢查询持久化RDB模式AOF模式 Redis多实例Redis命令相关 二、Redis数据类型字符串string列表list集合 set有序集合sorted set哈希hash 三、消息队列生产者消费者模式发布者订阅者模式 一、Redis的配置优化 redi…

easyExcel 不规则模板导入数据

文章目录 前言一、需求和效果二、难点和思路三、全部代码踩坑 前言 之前分享的 EasyExcel 批量导入并校验数据&#xff0c;仅支持规则excel&#xff0c;即首行表头&#xff0c;下面对应数据&#xff0c;无合并单元格情况。 本篇主要解决问题&#xff1a; 模板excel 表头不在首…

基于PCIe总线架构的2路1GSPS AD、4路1GSPS DA信号处理平台(100%国产化)

板卡概述 PCIE723-165是基于PCIE总线架构的2通道1GSPS采样率14位分辨率、4通道1GSPS采样率16位分辨率信号处理平台&#xff0c;该板卡采用国产16nm FPGA作为实时处理器&#xff0c;支持2路高速采集以及4路高速数据回放&#xff0c;板载2组DDR4 SDRAM大容量数据缓存&#xff0c;…

ElasticSearch 深度分页详解

原文链接&#xff1a;https://zhuanlan.zhihu.com/p/667036768 1 前言 ElasticSearch 是一个实时的分布式搜索与分析引擎&#xff0c;常用于大量非结构化数据的存储和快速检索场景&#xff0c;具有很强的扩展性。纵使其有诸多优点&#xff0c;在搜索领域远超关系型数据库&…

IDEA 中的调试方式(以 java 为例)

文章目录 IDEA 中的调试方式(以 java 为例)1. 基本介绍2. 断点调试的快捷键2.1 设置断点并启动调试2.3 快捷键 IDEA 中的调试方式(以 java 为例) 在开发中查找错误的时候&#xff0c;我们可以用断点调试&#xff0c;一步一步的看源码执行的过程&#xff0c;从而发现错误所在。 …

Linux --- 高级IO

一、基本概念 五种IO模型 阻塞IO&#xff1a;在内核将数据准备好之前&#xff0c;系统调用会一直等待&#xff0c;所有的套接字&#xff0c;默认都是阻塞方式 非阻塞IO&#xff1a;如果内核还未将数据准备好&#xff0c;系统调用仍然会直接返回&#xff0c;并且返回EWOULDBLOCK…

java—Spring框架

Spring 简介 Spring框架由Rod Johnson开发&#xff0c;2004年发布了Spring框架的第一版。Spring是一个从实际开发中抽取出来的框架&#xff0c;因此它完成了大量开发中的通用步骤&#xff0c;留给开发者的仅仅是与特定应用相关的部分&#xff0c;从而大大提高了企业应用的开发…

MacOS 开发 — Packages 程序 macOS新版本 演示选项卡无法显示

MacOS 开发 — Packages 程序 macOS新版本 演示选项卡无法显示 问题描述 &#xff1a; 之前写过 Packages 的使用以及如何打包macOS程序。最近更新了新的macOS系统&#xff0c;发现Packages的演示选项卡无法显示&#xff0c;我尝试从新安转了Packages 也是没作用&#xff0c;…