随机森林算法(Random Forest)的二分类问题

二分类问题

  • 1. 数据导入
  • 2. RF模型构建
    • 2.1 调参:mtry和ntree
    • 2.2 运行模型
  • 3. 模型测试
  • 4.绘制混淆矩阵
  • 5.绘制ROC曲线
  • 6. 参考

1. 数据导入

library(dplyr) #数据处理使用
library(data.table) #数据读取使用
library(randomForest) #RF模型使用
library(caret) # 调参和计算模型评价参数使用
library(pROC) #绘图使用
library(ggplot2) #绘图使用
library(ggpubr) #绘图使用
library(ggprism) #绘图使用
library(skimr) #查看数据分布
library(caTools) #训练集和测试集划分
setwd("D:/BaiduNetdiskDownload")
# 读取数据
data <- fread("data.txt",data.table = F)  # 替换为你的数据文件名或路径
as.data.frame(data)
skim(data)#数据鸟瞰
colnames(data)
hist(data$feature7, breaks = 50) # 查看数据分布data <- data[1:1000, ] # 选择前1000个样本进行计算

在这里插入图片描述

  数据一共包含了35723个样本,214个特征,选择其中前1000个样本进行模型构建(数据太大,这样更快一些)。

在这里插入图片描述
  查看一下数据分布情况,是不是符合一定的规律,如正态性之类的。

2. RF模型构建

  数据集分割为训练集和测试集

# 分割数据为训练集和测试集
set.seed(123)  # 设置随机种子,保证结果可复现
split <- sample.split(data$type, SplitRatio = 0.8)  # 将数据按照指定比例分割
train_data <- subset(data, split == TRUE)  # 训练集
test_data <- subset(data, split == FALSE)  # 测试集# 定义训练集特征和目标变量
X_train <- train_data[, -1]
y_train <- as.factor(train_data[, 1]) #将第一列的标签转换为因子变量

2.1 调参:mtry和ntree

  mtry:随机选择特征数目

# 2.1 mtry的取值是平方根(对于分类问题)或总特征数的三分之一(对于回归问题)
# mtry: 表示每棵决策树在进行节点分裂时考虑的特征数量
# 创建训练控制对象
ctrl <- trainControl(method = "cv", number = 10) # 选择10折交叉验证。
# 定义参数网格
grid <- expand.grid(mtry = c(2: 6))  # 每棵树中用于分裂的特征数量,这里只是随便给的测试,主要为了介绍如何调参,并非最优选择。# 使用caret包进行调参
rf_model <- train(x = X_train, y = y_train,method = "rf",trControl = ctrl,tuneGrid = grid)# 输出最佳模型和参数
print(rf_model)

结果:

Random Forest 800 samples
213 predictors2 classes: 'malignant', 'normal' No pre-processing
Resampling: Cross-Validated (10 fold) 
Summary of sample sizes: 720, 720, 720, 721, 720, 719, ... 
Resampling results across tuning parameters:mtry  Accuracy   Kappa    2     0.9650768  0.90602323     0.9663268  0.90982194     0.9650768  0.90654335     0.9675768  0.91384086     0.9675768  0.9141520Accuracy was used to select the optimal model using the largest value.
The final value used for the model was mtry = 5.

选择mtry=5即可

  ntree:

# 调整Caret没有提供的参数
# 如果我们想调整的参数Caret没有提供,可以用下面的方式自己手动调参。
# 用刚刚调参的最佳mtry值固定mtry
grid <- expand.grid(mtry = c(5))  # 每棵树中用于分裂的特征数量# 定义模型列表,存储每一个模型评估结果
modellist <- list()# 调整的参数是决策树的数量
for (ntree in c(50, 70, 90)) {set.seed(123)fit <- train(x = X_train, y = y_train, method="rf", metric="Accuracy", tuneGrid=grid, trControl=ctrl, ntree=ntree)key <- toString(ntree)modellist[[key]] <- fitprint(ntree)
}# compare results
results <- resamples(modellist)
# 输出最佳模型和参数
summary(results)

结果:

Call:
summary.resamples(object = results)Models: 50, 70, 90 
Number of resamples: 10 Accuracy Min.  1st Qu.    Median      Mean 3rd Qu.      Max. NA's
50 0.9500 0.962500 0.9748418 0.9687492   0.975 0.9875000    0
70 0.9375 0.953125 0.9688233 0.9637647   0.975 0.9750000    0
90 0.9375 0.962500 0.9748418 0.9674838   0.975 0.9876543    0Kappa Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
50 0.8624248 0.9016393 0.9330971 0.9177543 0.9354318 0.9682035    0
70 0.8360656 0.8754931 0.9164091 0.9043374 0.9354318 0.9373532    0
90 0.8360656 0.8992677 0.9330971 0.9144723 0.9354318 0.9673519    0

选择ntree为90即可

2.2 运行模型

# 使用最佳参数训练最终模型
final_model <- randomForest(x = X_train, y = y_train,mtry = 5,ntree = 90)
# 输出最终模型
print(final_model)

结果:

Call:randomForest(x = X_train, y = y_train, ntree = 90, mtry = 5) Type of random forest: classificationNumber of trees: 90
No. of variables tried at each split: 5OOB estimate of  error rate: 2.38%
Confusion matrix:malignant normal class.error
malignant       586      5 0.008460237
normal           14    195 0.066985646

3. 模型测试

# 在测试集上进行预测
X_test <- test_data[, -1]
y_test <- as.factor(test_data[, 1])
test_predictions <- predict(final_model, newdata = test_data)# 计算模型指标
confusion_matrix <- confusionMatrix(test_predictions, y_test)
accuracy <- confusion_matrix$overall["Accuracy"]
precision <- confusion_matrix$byClass["Pos Pred Value"]
recall <- confusion_matrix$byClass["Sensitivity"]
f1_score <- confusion_matrix$byClass["F1"]# 输出模型指标
print(confusion_matrix)
print(paste("Accuracy:", accuracy))
print(paste("Precision:", precision))
print(paste("Recall:", recall)) # sensitivity
print(paste("F1 Score:", f1_score))

结果:

> print(confusion_matrix)
Confusion Matrix and StatisticsReference
Prediction  malignant normalmalignant       146      3normal            2     49Accuracy : 0.975           95% CI : (0.9426, 0.9918)No Information Rate : 0.74            P-Value [Acc > NIR] : <2e-16          Kappa : 0.9346          Mcnemar's Test P-Value : 1               Sensitivity : 0.9865          Specificity : 0.9423          Pos Pred Value : 0.9799          Neg Pred Value : 0.9608          Prevalence : 0.7400          Detection Rate : 0.7300          Detection Prevalence : 0.7450          Balanced Accuracy : 0.9644          'Positive' Class : malignant       > print(paste("Accuracy:", accuracy))
[1] "Accuracy: 0.975"
> print(paste("Precision:", precision))
[1] "Precision: 0.979865771812081"
> print(paste("Recall:", recall)) # sensitivity
[1] "Recall: 0.986486486486487"
> print(paste("F1 Score:", f1_score))
[1] "F1 Score: 0.983164983164983"

4.绘制混淆矩阵

# 绘制混淆矩阵热图
# 将混淆矩阵转换为数据框
confusion_matrix_df <- as.data.frame.matrix(confusion_matrix$table)
colnames(confusion_matrix_df) <- c("cluster1","cluster2")
rownames(confusion_matrix_df) <- c("cluster1","cluster2")
draw_data <- round(confusion_matrix_df / rowSums(confusion_matrix_df),2)
draw_data$real <- rownames(draw_data)
draw_data <- melt(draw_data)ggplot(draw_data, aes(real,variable, fill = value)) +geom_tile() +geom_text(aes(label = scales::percent(value))) +scale_fill_gradient(low = "#F0F0F0", high = "#3575b5") +labs(x = "True", y = "Guess", title = "Confusion matrix") +theme_prism(border = T)+theme(panel.border = element_blank(),axis.ticks.y = element_blank(),axis.ticks.x = element_blank(),legend.position="none")

在这里插入图片描述

5.绘制ROC曲线

# 绘制ROC曲线需要将预测结果以概率的形式输出
test_predictions <- predict(final_model, newdata = test_data,type = "prob")# 计算ROC曲线的参数
roc_obj <- roc(response = y_test, predictor = test_predictions[, 2])
roc_auc <- auc(roc_obj)# 将ROC对象转换为数据框
roc_data <- data.frame(1 - roc_obj$specificities, roc_obj$sensitivities)# 绘制ROC曲线
ggplot(roc_data, aes(x = 1 - roc_obj$specificities, y = roc_obj$sensitivities)) +geom_line(color = "#0073C2FF", size = 1.5) +geom_segment(aes(x = 0, y = 0, xend = 1, yend = 1), linetype = "dashed", color = "gray") +geom_text(aes(x = 0.8, y = 0.2, label = paste("AUC =", round(roc_auc, 2))), size = 4, color = "black") +coord_cartesian(xlim = c(0, 1), ylim = c(0, 1)) +theme_pubr() +labs(x = "1 - Specificity", y = "Sensitivity") +ggtitle("ROC Curve") +theme(plot.title = element_text(size = 14, face = "bold"))+theme_prism(border = T)

在这里插入图片描述

6. 参考

  • 机器学习之分类器性能指标之ROC曲线、AUC值

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/169211.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

河南工业大学人工智能与大数据学院学子在第三届“火焰杯”软件测试开发选拔赛中 取得佳绩

近日&#xff0c;第三届“火焰杯”软件测试开发选拔赛落下帷幕&#xff0c;我校人工智能与大数据学院选派的多名参赛选手在王雪涛老师的指导下&#xff0c;经过激烈的角逐&#xff0c;取得优异成绩。其中&#xff0c;何鸿彬&#xff0c;贾文聪获得决赛二等奖&#xff0c;王静宇…

【前段基础入门之】=>CSS3新增渐变颜色属性

导语&#xff1a; CSS3 新增了&#xff0c;渐变色 的解决方案&#xff0c;这使得我们可以绘制出更加生动的炫酷的的配色效果 线性渐变 多个颜色之间的渐变&#xff0c; 默认从上到下渐变 background-image: linear-gradient(red,yellow,green); /*默认从上到下渐变*/默认从上…

常用Web安全扫描工具汇整

漏洞扫描是一种安全检测行为&#xff0c;更是一类重要的网络安全技术&#xff0c;它能够有效提高网络的安全性&#xff0c;而且漏洞扫描属于主动的防范措施&#xff0c;可以很好地避免黑客攻击行为&#xff0c;做到防患于未然。 1、AWVS Acunetix Web Vulnerability Scanner&a…

javaEE -5(8000字详解多线程)

一&#xff1a;JUC(java.util.concurrent) 的常见类 1.1 ReentrantLock 可重入互斥锁. 和 synchronized 定位类似, 都是用来实现互斥效果, 保证线程安全&#xff0c;ReentrantLock 也是可重入锁. “Reentrant” 这个单词的原意就是 “可重入” ReentrantLock 的用法&#xf…

Java实现连接SQL Server解决方案及代码

下面展示了连接SQL Server数据库的整个流程&#xff1a; 加载数据库驱动建立数据库连接执行SQL语句处理结果关闭连接 在连接之前&#xff0c;前提是确保数据库成功的下载&#xff0c;创建&#xff0c;配置好账号密码。 运行成功的代码&#xff1a; import java.sql.*;publi…

【1024一起敲代码!】C#mysql/Sqlserver图书借阅管理系统期末设计源代码

本系统自带7800字报告,有两个版本&#xff0c;Mysql版本、Sqlserver版本&#xff0c; 介绍 架构设计主要体现在代码层的架构和窗体层调用的架构。 在窗体中&#xff0c;由Program.cs为入口&#xff0c;启动后直接调用LoginForm.cs进入登陆界面&#xff0c;登陆成功后进入主菜…

docker-compose安装ES7.14和Kibana7.14(有账号密码)

一、docker-compose安装ES7.14.0和kibana7.14.0 1、下载镜像 1.1、ES镜像 docker pull elasticsearch:7.14.0 1.2、kibana镜像 docker pull kibana:7.14.0 2、docker-compose安装ES和kibana 2.1、创建配置文件目录和文件 #创建目录 mkdir -p /home/es-kibana/config mkdir…

【JAVA学习笔记】40 - 抽象类、模版设计模式(抽象类的使用)

项目代码 https://github.com/yinhai1114/Java_Learning_Code/tree/main/IDEA_Chapter10/src/com/yinhai/abstract_ 一、抽象类的引入 很多时候在创建类的时候有一个父类&#xff0c;比如animal类&#xff0c;他的子类会有各种方法&#xff0c;为了复用需要进行方法的重写&…

UML(Unified Modeling Language)统一建模语言,及工具介绍、使用

1. UML介绍&#xff1a; UML&#xff08;Unified Modeling Language&#xff09;统一建模语言。是一种图形化语言。 在UML 2.5 中共包含14种图形&#xff1a;类图、用例图、活动图、对象图、时序图、交互概述图、包图、配置文件图、部署图、组件图、组合结构图、状态机图、通…

Day07 Stream流递归Map集合Collections可变参数

Stream 也叫Stream流&#xff0c;是Jdk8开始新增的一套API (java.util.stream.*)&#xff0c;可以用于操作集合或者数组的数据。 Stream流大量的结合了Lambda的语法风格来编程&#xff0c;提供了一种更加强大&#xff0c;更加简单的方式操作 public class Demo1 {public stati…

真空室的内表面加工

真空室和部件的内表面是在高真空和超高真空下实现工作压力的重要因素。必须在该条件下进行加工&#xff0c;以最小化有效表面&#xff0c;并产生具有最小解吸率的表面。 真空室和部件的表面往往是在焊接和机械加工后经过精细玻璃珠喷砂的。具有限定直径的高压玻璃珠被吹到表面…

跟着NatureMetabolism学作图:R语言ggplot2转录组差异表达火山图

论文 Independent phenotypic plasticity axes define distinct obesity sub-types https://www.nature.com/articles/s42255-022-00629-2#Sec15 s42255-022-00629-2.pdf 论文中没有公开代码&#xff0c;但是所有作图数据都公开了&#xff0c;我们可以试着用论文中提供的数据…

一百九十二、Flume——Flume数据流监控工具Ganglia单机版安装

一、目的 在安装好Flume之后&#xff0c;需要用一个工具可以对Flume数据传输进行实时监控&#xff0c;这就是Ganglia 二、Ganglia介绍 Ganglia 由 gmond、gmetad 和 gweb 三部分组成。 &#xff08;一&#xff09;第一部分——gmond gmond&#xff08;Ganglia Monitoring Da…

群晖synology DSM 7.2设置钉钉Webhooks通知

现在越来越多的小伙伴都有了自己的Nas系统&#xff0c;为了更加方便的接收Nas的消息&#xff0c;这篇文章带着大家一起配置一个钉钉&#xff08;机器人&#xff09;即时消息通知 首先登录钉钉的开放平台&#xff1a;开发者后台统一登录 - 钉钉统一身份认证 1.创建一个机器人&…

MYSQL(索引+SQL优化)

索引: 索引是帮助MYSQL高效获取数据的排好序的数据结构 1)假设现在进行查询数据&#xff0c;select * from user where userID89 2)没有索引是一行一行从MYSQL进行查询的&#xff0c;还有就是数据的记录都是存储在MYSQL磁盘上面的&#xff0c;比如说插入数据的时候是向磁盘上面…

十三水中各种牌型判断LUA版

近期回归程序行业&#xff0c;由于业务需求需要做十三水游戏&#xff0c;什么是十三水就不在多讲&#xff0c;下面是判断十三水牌型的方法&#xff08;带大小王&#xff09; GetSSSPaiType {}; local this GetSSSPaiType; local huaseTable {}; local numTable {}; functi…

逐字稿 | 视频理解论文串讲(下)【论文精读】

1 为什么研究者这么想把这个双流网络替换掉&#xff0c;想用3D 卷积神经网络来做&#xff1f; 大家好&#xff0c;上次我们讲完了上半部分&#xff0c;就是 2D 网络和一些双流网络以及。它们的。变体。今天我们就来讲一下下半部分&#xff0c;就是 3D 网络和 video Transformer…

Java 控制台 进度条

Java 控制台 进度条 progress-bar简介效果图使用介绍 progress-bar 简介 gitee链接: https://gitee.com/sincere-jxx/progress-bar main分支 进度条颜色可变&#xff0c;绿色&#xff08;默认&#xff09;&#xff0c;红色&#xff0c;黄色&#xff0c;蓝色等 长度50&#x…

Nacos全面知识 ----微服务 SpringCloud

快速入门 分级存储模型 修改集群配置 Nacos设置负载均衡策略 集群优先 权重优先 Nacos热更新配置 Nacos添加配置信息 微服务配置拉取 热更新:推荐使用第二种方法进行热部署 ConfigurationProperties(prefix "pattern") 是 Spring Boot 中用于自动配置属性的注解。它…

SparkSQL的Shuffle分区设定及异常数据处理API(去重、缺失值处理)

一、Spark SQL的Shuffle分区数目设定 在允许spark程序时&#xff0c;查看WEB UI监控页面发现&#xff0c;某个Stage中有200个Task任务&#xff0c;也就是说RDD有200分区Partion。 产生原因&#xff1a; 在Spark SQL中&#xff0c;当Job中产生Shuffle时&#xff0c;默认的分区数…