【Sklearn】基于随机森林算法的数据分类预测(Excel可直接替换数据)
- 1.模型原理
- 1.1 模型原理
- 1.2 数学模型
- 2.模型参数
- 3.文件结构
- 4.Excel数据
- 5.下载地址
- 6.完整代码
- 7.运行结果
1.模型原理
随机森林(Random Forest)是一种集成学习方法,通过组合多个决策树来构建强大的分类或回归模型。随机森林的模型原理和数学模型如下:
1.1 模型原理
随机森林是一种集成学习方法,它结合了多个决策树来改善预测的准确性和鲁棒性。每个决策树都是独立地训练,并且它们的预测结果综合起来形成最终的预测。随机森林的主要思想是构建一个“森林”,其中每棵树都是一个分类器,而每个分类器都在随机的数据子集上进行训练。在预测时,通过投票或平均来综合所有分类器的结果。
随机森林的主要步骤:
-
随机抽样(Bootstrap抽样): 从原始训练数据中随机抽取多个样本,允许同一个样本在一个抽样中出现多次,形成一个新的训练集。
-
随机特征选择: 对每个决策树的训练过程中,在节点分裂时,只考虑部分特征,而不是全部特征。这样有助于增加树之间的多样性,减少过拟合。
-
独立训练: 对于每个样本和每个决策树,使用随机抽样的训练数据和随机选择的特征进行训练,得到多棵独立的决策树。
-
预测聚合: 在预测时,将每棵树的预测结果进行投票(分类问题)或平均(回归问题),以决定最终的分类或预测值。
1.2 数学模型
随机森林的数学模型是由多个决策树组成的集合,每个决策树都是一个独立的分类器或回归器。随机森林的预测是通过对每个决策树的预测结果进行综合得到的。可以用以下形式表示:
F ( x ) = 1 T ∑ t = 1 T f t ( x ) F(x) = \frac{1}{T} \sum_{t=1}^{T} f_t(x) F(x)=T1t=1∑Tft(x)
其中, F ( x ) F(x) F(x)表示随机森林的预测结果, T T T表示决策树的数量, f t ( x ) f_t(x) ft(x)表示第 t t t棵决策树的预测结果。
在训练每棵决策树时,随机森林通过随机抽样和随机特征选择增加了每棵树之间的多样性,从而减少了过拟合的风险。在预测时,通过对多个决策树的预测结果进行综合,提高了模型的准确性和稳定性。
总之,随机森林通过构建多个独立的决策树,并对它们的预测结果进行综合,从而创建了一个强大的集成模型,适用于分类和回归任务。
2.模型参数
RandomForestClassifier
是scikit-learn
中随机森林分类器的类,它具有多个参数可以调整。以下是你提到的参数以及它们的说明:
-
n_estimators: 随机森林中决策树的数量。默认为100。
-
criterion: 衡量分割质量的标准。可以是"gini"(基尼系数)或"entropy"(信息熵)。默认是"gini"。
-
max_depth: 决策树的最大深度。默认为None,表示不限制深度。
-
min_samples_split: 节点分裂所需的最小样本数。默认为2。
-
min_samples_leaf: 叶节点所需的最小样本数。默认为1。
-
min_weight_fraction_leaf: 叶节点所需的最小权重分数总和。默认为0。
-
max_features: 寻找最佳分割时要考虑的特征数量。可以是整数、浮点数、字符串或None。默认是"auto",意味着"sqrt(n_features)"。
-
max_leaf_nodes: 最大叶节点数。默认为None。
-
min_impurity_decrease: 分割需要达到的最小不纯度减少量。默认为0。
-
bootstrap: 是否对数据进行有放回抽样。默认为True。
-
oob_score: 是否计算袋外(oob)准确率。默认为False。
-
n_jobs: 并行处理的作业数。默认为None,表示使用1个作业。
-
random_state: 随机数生成器的种子,用于重现随机结果。
-
class_weight: 类别权重,用于处理不平衡数据集。
-
verbose: 控制训练过程中的输出信息。默认为0,不显示输出。
这些参数可以根据你的数据集和问题进行调整,以获得最佳的模型性能。
3.文件结构
iris.xlsx % 可替换数据集
Main.py % 主函数
4.Excel数据
5.下载地址
- 资源下载地址
6.完整代码
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as snsdef random_forest_classification(data_path, test_size=0.2, random_state=42):# 加载数据data = pd.read_excel(data_path)# 分割特征和标签X = data.iloc[:, :-1] # 所有列除了最后一列y = data.iloc[:, -1] # 最后一列# 划分训练集和测试集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)# 创建随机森林分类器# 1. ** n_estimators: ** 随机森林中决策树的数量。默认为100。# 2. ** criterion: ** 衡量分割质量的标准。可以是"gini"(基尼系数)或"entropy"(信息熵)。默认是"gini"。# 3. ** max_depth: ** 决策树的最大深度。默认为None,表示不限制深度。# 4. ** min_samples_split: ** 节点分裂所需的最小样本数。默认为2。# 5. ** min_samples_leaf: ** 叶节点所需的最小样本数。默认为1。# 6. ** min_weight_fraction_leaf: ** 叶节点所需的最小权重分数总和。默认为0。# 7. ** max_features: ** 寻找最佳分割时要考虑的特征数量。可以是整数、浮点数、字符串或None。默认是"auto",意味着"sqrt(n_features)"。# 8. ** max_leaf_nodes: ** 最大叶节点数。默认为None。# 9. ** min_impurity_decrease: ** 分割需要达到的最小不纯度减少量。默认为0。# 10. ** bootstrap: ** 是否对数据进行有放回抽样。默认为True。# 11. ** oob_score: ** 是否计算袋外(oob)准确率。默认为False。# 12. ** n_jobs: ** 并行处理的作业数。默认为None,表示使用1个作业。# 13. ** random_state: ** 随机数生成器的种子,用于重现随机结果。# 14. ** class_weight: ** 类别权重,用于处理不平衡数据集。# 15. ** verbose: ** 控制训练过程中的输出信息。默认为0,不显示输出。model = RandomForestClassifier(n_estimators=100, max_depth=10, random_state=random_state)# 在训练集上训练模型model.fit(X_train, y_train)# 在测试集上进行预测y_pred = model.predict(X_test)# 计算准确率accuracy = accuracy_score(y_test, y_pred)return confusion_matrix(y_test, y_pred), y_test.values, y_pred, accuracyif __name__ == "__main__":# 使用函数进行分类任务data_path = "iris.xlsx"confusion_mat, true_labels, predicted_labels, accuracy = random_forest_classification(data_path)print("真实值:", true_labels)print("预测值:", predicted_labels)print("准确率:{:.2%}".format(accuracy))# 绘制混淆矩阵plt.figure(figsize=(8, 6))sns.heatmap(confusion_mat, annot=True, fmt="d", cmap="Blues")plt.title("Confusion Matrix")plt.xlabel("Predicted Labels")plt.ylabel("True Labels")plt.show()# 用圆圈表示真实值,用叉叉表示预测值# 绘制真实值与预测值的对比结果plt.figure(figsize=(10, 6))plt.plot(true_labels, 'o', label="True Labels")plt.plot(predicted_labels, 'x', label="Predicted Labels")plt.title("True Labels vs Predicted Labels")plt.xlabel("Sample Index")plt.ylabel("Label")plt.legend()plt.show()
7.运行结果