基于随机森林回归预测葡萄酒质量

基于随机森林回归预测葡萄酒质量

  • 1.作者介绍
  • 2.随机森林算法与数据集介绍
    • 2.1定义
    • 2.2核心思想
    • 2.3主要步骤
    • 2.4数据集介绍
  • 3.算法实现
    • 3.1数据加载与探索
    • 3.2数据可视化
    • 3.3数据预处理(标准化、划分训练/测试集)
    • 3.4模型训练与优化(随机森林回归 + 超参数搜索)
    • 3.5模型评估(MSE、R²、混淆矩阵、准确率)
    • 3.6特征重要性分析
    • 3.7模型保存与加载
  • 4.结果分析
  • 5.完整代码

基于随机森林回归预测葡萄酒质量

1.作者介绍

朱亚彬, 男 ,西安工程大学电子信息学院 , 2024级研究生
研究方向:任务卸载与边缘计算
电子邮件:2292036787@qq.com

王晓睿,男,西安工程大学电子信息学院,2024级研究生,张宏伟人工智能课题组
研究方向:智能视觉检测与工业自动化技术
电子邮件:3234002295@qq.com

2.随机森林算法与数据集介绍

2.1定义

随机森林回归(Random Forest Regression)是一种基于集成学习的回归算法,通过组合多个决策树来提高预测的准确性和稳定性。

2.2核心思想

随机森林回归通过构建多棵决策树,并将它们的预测结果进行平均(或加权平均)来得到最终的回归结果。每棵决策树在训练时使用随机抽样的数据和特征,从而减少过拟合的风险。

2.3主要步骤

数据采样(Bootstrap Sampling):从训练集中随机抽取样本(有放回抽样),形成多个子集。每个子集用于训练一棵决策树。
特征随机选择:在每棵树的节点分裂时,随机选择一部分特征(而不是所有特征)来寻找最佳分割点。这增加了模型的多样性。
构建决策树:对每个子集数据,构建一棵决策树。树的生长通常不进行剪枝,直到达到最大深度或节点中的样本数小于某个阈值。
集成预测:对于回归问题,将所有树的预测结果取平均值,作为最终的预测值。
在这里插入图片描述
(1)随机选择一个样本子集作为该决策树的训练集。
(2)随机选择一部分特征(总特征数的平方根)作为该决策树的特征集。
(3)基于训练集和特征集构建决策树,直到达到预定的叶子节点数或无法分割为止。
(4)重复以上步骤,建立多颗决策树。
(5)对于一个新的样本,将它输入到每棵决策树中,得到多个预测结果。
(6)对多个预测结果进行平均,得到最终的预测结果。
算法公式基于决策树回归模型,每个预测树的预测函数为:
在这里插入图片描述
多棵决策树的预测函数可以表示为:
在这里插入图片描述
该部分参考链接:https://blog.csdn.net/m0_61399808/article/details/130733650

2.4数据集介绍

名称:Wine Quality Dataset(葡萄酒质量数据集):
数据集地址:https://archive.ics.uci.edu/dataset/186/wine+quality
数据集包含红葡萄酒和白葡萄酒两个子集(分别有1599和4898个样本)。每个样本有11个特征,包括固定酸度、挥发性酸度、柠檬酸、残糖、氯化物等。目标是一个回归任务,预测葡萄酒的质量评分(0到10分)。本次实验使用的是红葡萄酒数据集进行测试。

3.算法实现

基于红酒质量数据集,使用随机森林回归模型进行质量预测。主要步骤包括:数据加载与探索、数据可视化(散点图、热力图、直方图等)、数据预处理(标准化、划分训练/测试集)、模型训练与优化(随机森林回归 + 超参数搜索)、模型评估(MSE、R²、混淆矩阵、准确率)、特征重要性分析、模型保存与加载。

3.1数据加载与探索

(1)程序从指定路径加载红酒质量数据集:数据文件为 CSV 格式,分隔符为 ;使用 pandas.read_csv 读取数据,编码格式为 utf-8。
在这里插入图片描述
(2)数据探索。
在这里插入图片描述

3.2数据可视化

程序通过多个图表展示数据特征与质量之间的关系。
(1)特征与质量的散点图
在这里插入图片描述
(2)相关性热力图
在这里插入图片描述
(3)不同品质红酒的分布
在这里插入图片描述
(4)不同品质的酒精含量直方图
在这里插入图片描述

3.3数据预处理(标准化、划分训练/测试集)

(1)特征和标签划分:将数据集分为特征(X)和目标变量(y),其中 y 是红酒的质量评分。
在这里插入图片描述
(2)划分训练集和测试集:使用 train_test_split 将数据划分为训练集(80%)和测试集(20%)。
random_state=42 确保结果可重复,stratify=y 保持训练集和测试集中 y 的分布一致。
在这里插入图片描述
(3)数据标准化
使用 StandardScaler 对特征进行标准化(均值为 0,方差为 1)。在训练集上拟合并转换,在测试集上仅转换。
在这里插入图片描述

3.4模型训练与优化(随机森林回归 + 超参数搜索)

创建随机森林回归模型,设置 100 棵树,随机种子为 42。
使用 10 折交叉验证评估模型在训练集上的性能。
计算均方误差(MSE)和 R² 分数。neg_mean_squared_error 返回负值,故取反。
使用 GridSearchCV 搜索最佳超参数组合。
搜索的参数包括 max_features(特征选择方式)和 max_depth(树的最大深度)。
同样采用 10 折交叉验证。
重新训练模型:使用搜索到的最佳参数重新训练模型。
在这里插入图片描述

3.5模型评估(MSE、R²、混淆矩阵、准确率)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.6特征重要性分析

## 3.6特征重要性分析

3.7模型保存与加载

在这里插入图片描述

4.结果分析

(1)图表可视化
在这里插入图片描述
图中的横轴表示的是真实品质,纵坐标表示的是预测品质。对角线表示预测值与完全值完全一样,品质为6时准确率最高。
在这里插入图片描述
从混淆矩阵中可以看出,品质为6,7的时候预测的准确率最高。
(2)MSE,准确率等评估指标
在这里插入图片描述
MSE的值越小,模型的预测性能越好。R²的值范围在0到1之间。R²值越接近1,表示模型对目标变量的解释能力越强,模型的预测性能越好。结果表明模型在测试集上的性能超过了训练集。
在这里插入图片描述
图为特征重要性的排列,从图中可以看出,alochol的特征是最重要的。表明alochol的含量越高,葡萄酒的质量品质越高。

5.完整代码

# 导入必要的库
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor
from sklearn.pipeline import make_pipeline
from sklearn.metrics import mean_squared_error, r2_score, confusion_matrix
from sklearn.model_selection import GridSearchCV
import joblib# 设置中文字体
plt.rcParams["font.sans-serif"] = "SimHei"
plt.rcParams["axes.unicode_minus"] = False# #加载数据
# dataset_url = 'wine+quality/winequality-red.csv'
# data = pd.read_csv(dataset_url)
# #查看数据的前五行
# print( data.head() )# 加载数据
dataset_url = 'wine+quality/winequality-red.csv'
data = pd.read_csv(dataset_url, sep=';', encoding="utf-8")# 查看数据的基本信息
print(data.head())
print(data.shape)
print(data.describe())# 绘制特征与质量的关系图
plt.figure(figsize=(12, 5))
plt.subplot(121)
sns.scatterplot(x='volatile acidity', y='quality', data=data)
plt.title('Volatile Acidity vs Quality')
plt.subplot(122)
sns.scatterplot(x='alcohol', y='quality', data=data)
plt.title('Alcohol vs Quality')
plt.savefig("results/quality.png", dpi=300)
plt.show()# 绘制相关性热力图
correlation = data.corr()
plt.figure(figsize=(12, 10))
sns.heatmap(correlation, annot=True)
plt.title("Correlation Matrix")
plt.savefig("results/Correlation_Matrix.png", dpi=400)
plt.show()# 绘制不同品质的红酒分布
sns.catplot(data=data, x="quality", kind='count')
plt.title('不同品质的红酒的分布')
plt.xlabel('品质')
plt.ylabel('数量')
plt.show()# 提取品质为 4、5、6、7 的数据
data_quality_4 = data.loc[data["quality"] == 4, :]
data_quality_5 = data.loc[data["quality"] == 5, :]
data_quality_6 = data.loc[data["quality"] == 6, :]
data_quality_7 = data.loc[data["quality"] == 7, :]# 绘制不同品质的酒精含量直方图
plt.figure(figsize=(9, 7), dpi=80)
plt.subplot(2, 2, 1)
plt.hist(data_quality_4["alcohol"], bins=50)
plt.xlabel("酒精含量")
plt.ylabel("数量")
plt.title("品质为4的酒精度数量直方图")plt.subplot(2, 2, 2)
plt.hist(data_quality_5["alcohol"], bins=50)
plt.xlabel("酒精含量")
plt.ylabel("数量")
plt.title("品质为5的酒精度数量直方图")plt.subplot(2, 2, 3)
plt.hist(data_quality_6["alcohol"], bins=50)
plt.xlabel("酒精含量")
plt.ylabel("数量")
plt.title("品质为6的酒精度数量直方图")plt.subplot(2, 2, 4)
plt.hist(data_quality_7["alcohol"], bins=50)
plt.xlabel("酒精含量")
plt.ylabel("数量")
plt.title("品质为7的酒精度数量直方图")plt.tight_layout()
plt.savefig("results/quality_4567_vs.png", dpi=300)
plt.show()# 特征和标签
X = data.drop('quality', axis=1)  # 特征,去除 'quality' 列
y = data['quality']  # 标签,仅保留 'quality' 列# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)# 数据标准化
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)# 创建随机森林回归模型
model = RandomForestRegressor(n_estimators=100, random_state=42)# 使用10折交叉验证评估模型在训练集上的性能
cv_mse = -cross_val_score(model, X_train_scaled, y_train, cv=10, scoring='neg_mean_squared_error').mean()
cv_r2 = cross_val_score(model, X_train_scaled, y_train, cv=10, scoring='r2').mean()
print(f'交叉验证的均方误差 (MSE): {cv_mse}')
print(f'交叉验证的R²分数: {cv_r2}')# 超参数优化
pipeline = make_pipeline(RandomForestRegressor(random_state=42))
hyperparameters = {'randomforestregressor__max_features': ['sqrt', 'log2'],'randomforestregressor__max_depth': [None, 5, 3, 1]
}
clf = GridSearchCV(pipeline, hyperparameters, cv=10)
clf.fit(X_train_scaled, y_train)
print(f'最佳超参数: {clf.best_params_}')# 使用最佳参数重新训练模型
best_model = clf.best_estimator_
best_model.fit(X_train_scaled, y_train)# 评估模型在测试集上的性能
y_pred = best_model.predict(X_test_scaled)
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f'测试集的均方误差 (MSE): {mse}')
print(f'测试集的R²分数: {r2}')# 绘制真实值与预测值的散点图
plt.scatter(y_test, y_pred, alpha=0.5)
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', lw=2)
plt.xlabel('True Quality')
plt.ylabel('Predicted Quality')
plt.title('Predicted vs True Quality Scores')
plt.savefig("results/Predicted_vs_True_Quality_Scores.png", dpi=400)
plt.show()# 计算混淆矩阵
cm = confusion_matrix(y_test, np.round(y_pred))# 绘制混淆矩阵
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, cmap='Blues', fmt='d', xticklabels=[4, 5, 6, 7], yticklabels=[4, 5, 6, 7])
plt.xlabel('预测品质')
plt.ylabel('实际品质')
plt.title('混淆矩阵')
plt.savefig("results/混淆矩阵.png", dpi=400)
plt.show()# 显示特征重要性
feature_names = X.columns
importances = best_model.named_steps['randomforestregressor'].feature_importances_
sorted_indices = np.argsort(importances)[::-1]print("特征重要性(从高到低):")
for i in range(len(feature_names)):print(f"{feature_names[sorted_indices[i]]:<20} | {importances[sorted_indices[i]]}")# 绘制特征重要性条形图
plt.figure(figsize=(10, 6))
sns.barplot(x=importances[sorted_indices], y=feature_names[sorted_indices])
plt.title('Feature Importances')
plt.xlabel('Importance')
plt.ylabel('Features')
plt.savefig("results/Feature Importances.png", dpi=400)
plt.show()# 保存模型
joblib.dump(best_model, 'rf_regressor.pkl')# 加载模型并进行预测
rf_regressor_loaded = joblib.load('rf_regressor.pkl')
y_pred_loaded = rf_regressor_loaded.predict(X_test_scaled)# 计算预测准确率
accuracy = np.mean(np.round(y_pred_loaded) == y_test)
print("预测准确率: {:.2f}%".format(accuracy * 100))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/38205.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【赵渝强老师】在Docker中运行达梦数据库

Docker是一个客户端服务器&#xff08;Client-Server&#xff09;架构。Docker客户端和Docker守护进程交流&#xff0c;而Docker的守护进程是运作Docker的核心&#xff0c;起着非常重要的作用&#xff08;如构建、运行和分发Docker容器等&#xff09;。达梦官方提供了DM 8在Doc…

【C语言】深入理解指针(二):从数组到二维数组的指针魔法

前言 在C语言中&#xff0c;指针一直是一个神秘而强大的存在。它不仅可以帮助我们高效地操作内存&#xff0c;还能让代码更加灵活和高效。今天&#xff0c;我们就来深入探讨指针的多种用法&#xff0c;从数组到二维数组&#xff0c;一步步揭开指针的神秘面纱。 一、数组名的指…

【MySQL】事务

目录 基本概念事务操作自动提交事务开启事务提交事务回滚事务代码示例 事务的特性 ACID事务的隔离级别读未提交 read uncommitted读已提交 read committed可重复读 repeatable read序列化&#xff08;串行&#xff09; serializable操作示例 基本概念 在 MySQL 中的事务&#…

flutter doctor提示cmdline-tools component is missing错误的解决

flutter doctor检测环境后报错如下: STEP1: 配置command-lines &#x1f4cc; 打开Androidstudio &#xff0c;找到sdkmanager &#x1f447; 安装command-line tools 如果找不到&#xff0c;记得打开右下角的「Show Package Details} 再次运行flutter doctor 即可正常 如…

iptables和netfilter内部报文处理

一、Iptables和netfilter 1.iptables基础 netfilter强大功能以及灵活性是通过iptables界面来实现。此命令行工具和它的前身ipchains语法相似&#xff1b;不过iptables使用netfilter子系统来增进网络连接、检验和处理方面的能力&#xff1b;ipchains使用错综复杂的规则集合来过…

[项目]基于FreeRTOS的STM32四轴飞行器: 十一.MPU6050配置与读取

基于FreeRTOS的STM32四轴飞行器: 十一.MPU6050 一.芯片介绍二.配置I2C三.编写驱动四.读取任务的测试 一.芯片介绍 芯片应该放置在PCB中间&#xff0c;X Y轴原点&#xff0c;敏感度131表示范围越小越灵敏。理想状态放置在地面上X&#xff0c;Y&#xff0c;Z轴为0&#xff0c;即…

JVM垃圾回收笔记01

文章目录 前言1. 如何判断对象可以回收1.1 引用计数法1.2 可达性分析算法查看根对象哪些对象可以作为 GC Root ?对象可以被回收&#xff0c;就代表一定会被回收吗&#xff1f; 1.3 引用类型1.强引用&#xff08;StrongReference&#xff09;2.软引用&#xff08;SoftReference…

解决Popwindow宽高的问题。

问题 在使用Popwindow进行自定义的过程中&#xff0c;需要设置popwindow的宽高。但是宽高很多时候容易出问题。比如下面的例子。 布局文件如下 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas.andr…

Bell-1量子计算机分析:开启量子计算2.0时代的创新引擎

Bell-1量子计算机:开启量子计算2.0时代的创新引擎 一、引言 1.1 研究背景 在当今科技飞速发展的时代,量子计算作为前沿领域,正深刻地改变着科技格局,引领新一轮科技革命与产业变革。自 20 世纪 80 年代量子计算概念被提出以来,历经多年的理论探索与技术攻坚,已取得了众…

什么?中断禁用失效了?

什么&#xff1f;中断禁用失效了&#xff1f; 1. 前言 道友们&#xff0c;在嵌入式的开发中我们不管是RTOS或NO-RTOS的开发&#xff0c;都无法避免“多线程”的应用场景&#xff0c;高优先级的任务或中断打断低优先级的任务或中断&#xff0c;此时为了要保证共享数据的安全性…

单表达式倒计时工具:datetime的极度优雅(Kimi)

一个简单表达式&#xff0c;也可以优雅自成工具。 笔记模板由python脚本于2025-03-22 20:25:49创建&#xff0c;本篇笔记适合任意喜欢学习的coder翻阅。 【学习的细节是欢悦的历程】 博客的核心价值&#xff1a;在于输出思考与经验&#xff0c;而不仅仅是知识的简单复述。 Pyth…

[笔记.AI]多头自注意力机制(Multi-Head Attention)

多头自注意力是深度学习领域&#xff0c;特别是自然语言处理&#xff08;NLP&#xff09;和Transformer模型中的关键概念。其发展源于对序列数据中复杂依赖关系的建模需求&#xff0c;特别是在Transformer架构的背景下。 举例 比喻-读长篇文章 用一个简单的比喻来理解“多头注…

SOFABoot-02-模块化隔离方案

sofaboot 前言 大家好&#xff0c;我是老马。 sofastack 其实出来很久了&#xff0c;第一次应该是在 2022 年左右开始关注&#xff0c;但是一直没有深入研究。 最近想学习一下 SOFA 对于生态的设计和思考。 sofaboot 系列 SOFABoot-00-sofaboot 概览 SOFABoot-01-蚂蚁金…

【实用部署教程】olmOCR智能PDF文本提取系统:从安装到可视化界面实现

文章目录 引言系统要求1. 环境准备&#xff1a;安装Miniconda激活环境 2. 配置pip源加速下载3. 配置学术加速&#xff08;访问国外资源&#xff09;4. 安装系统依赖5. 安装OLMOCR6. 运行OLMOCR处理PDF文档7. 理解OLMOCR输出结果9. 可视化UI界面9.1 安装界面依赖9.2 创建界面应用…

asp.net core mvc模块化开发

razor类库 新建PluginController using Microsoft.AspNetCore.Mvc;namespace RazorClassLibrary1.Controllers {public class PluginController : Controller{public IActionResult Index(){return View();}} }Views下Plugin下新建Index.cshtml {ViewBag.Title "插件页…

边缘计算革命:重构软件架构的范式与未来

摘要 边缘计算通过将算力下沉至网络边缘&#xff0c;正在颠覆传统中心化软件架构的设计逻辑。本文系统分析了边缘计算对软件架构的范式革新&#xff0c;包括分布式分层架构、实时资源调度、安全防护体系等技术变革&#xff0c;并结合工业物联网、智慧医疗等场景案例&#xff0c…

单链表:数据结构的灵动之链

本文主要讲解链表的概念和结构以及实现单链表 目录 一、链表的概念及结构 二、单链表的实现 1.1链表的实现&#xff1a; 1.2单链表的实现&#xff1a; 单链表尾插&#xff1a; 单链表的头插&#xff1a; 单链表的尾删&#xff1a; 单链表头删&#xff1a; 单链表查找&#…

链表题型-链表操作-JS

一定要注意链表现在的头节点是空节点还是有值的节点。 一、移除链表中的元素 有两种方式&#xff0c;直接使用原来的链表进行删除操作&#xff1b;设置一个虚拟头节点进行删除操作。 直接使用原来的链表进行删除操作时&#xff0c;需要考虑是不是头节点&#xff0c;因为移除…

读《浪潮之巅》:探寻科技产业的兴衰密码

引言&#xff1a;邂逅《浪潮之巅》 在信息技术飞速发展的今天&#xff0c;科技公司如繁星般闪烁&#xff0c;又似流星般划过。而我与《浪潮之巅》的相遇&#xff0c;就像在浩渺的科技海洋中&#xff0c;发现了一座指引方向的灯塔。初次听闻这本书&#xff0c;是在一次技术交流会…

【和春笋一起学C++】文本文件I/O

在windows系统中读取键盘的输入和在屏幕上显示输出统称为&#xff1a;控制台输入/输出。把读取文本文件和把字符输出到文本文件中统称为&#xff1a;文本文件I/O。 目录 1. 输出文本文件 2. 读取文本文件 1. 输出文本文件 把字符输出到文本文件中和输出到控制台很相似&#x…