【机器学习】CatBoost 模型实践:回归与分类的全流程解析

一. 引言

本篇博客首发于掘金 https://juejin.cn/post/7441027173430018067。
PS:转载自己的文章也算原创吧。

在机器学习领域,CatBoost 是一款强大的梯度提升框架,特别适合处理带有类别特征的数据。本篇博客以脱敏后的保险数据集为例,展示如何利用 CatBoost 完成分类和回归任务,并以可视化的方式解析特征重要性与结果。

我们将完成以下任务:

  1. 回归任务:预测保险索赔金额。
  2. 分类任务:判断保险案件是否需要调查。
  3. 可视化分析:利用散点图与分割线展示结果。

二. CatBoost 模型简介

CatBoost 是由俄罗斯搜索巨头 Yandex 于 2017 年开源的机器学习库,其名称来源于 “Category” 和 “Boosting” 的组合,旨在高效处理类别特征的梯度提升算法。与其他模型(如 XGBoost 和 LightGBM)相比,CatBoost 具有以下优势:

  • 支持类别特征:无需对类别特征进行独热编码,直接处理类别数据,避免数据膨胀。
  • 对缺失值的鲁棒性:无需特殊预处理即可直接处理缺失值。
  • 防止过拟合:内置多种正则化手段,减少梯度偏差和预测偏移,提高模型的准确性和泛化能力。
  • 对称树结构:采用对称决策树(Oblivious Trees),在每个层级使用相同的特征和分割点,提升训练和预测效率。

三. 实战项目环境与数据准备

本项目使用了脱敏后的保险数据集,包含以下特征:

  • 类别特征:险种代码、出险原因、医疗责任类别等。
  • 数值特征:基本保额、索赔金额等。
  • 标签:是否需要调查(分类任务)。

所有数据均已脱敏,支持迁移至其他表格数据集。

因为不好分享,所以后续第七节补充了一个基于sklearn "California Housing"数据集的流程代码与说明。


四. 回归任务:预测保险索赔金额

数据预处理

在回归任务中,我们根据特征预测索赔金额。以下是数据清洗与预处理的关键步骤:

  1. 过滤无效数据:移除缺失或非法值的记录。
  2. 特征转换:将类别特征转为字符串类型。
  3. 分割数据集:按 80% 和 20% 的比例划分训练集与测试集。

4.1 模型训练与评估

我们使用 CatBoost 进行回归建模,模型参数包括:

  • 学习率:0.02
  • 深度:8
  • 迭代次数:10,000(支持提前停止)

以下是模型的关键代码:

from catboost import CatBoostRegressor# 初始化 CatBoost 回归模型
cat_regressor = CatBoostRegressor(iterations=10000,learning_rate=0.02,depth=8,eval_metric='RMSE',early_stopping_rounds=1500,random_seed=42
)# 训练模型
cat_regressor.fit(X_train, y_train,cat_features=categorical_features_indices,eval_set=(X_test, y_test),verbose=100
)

4.2 特征重要性分析

特征重要性是衡量特征对模型预测贡献程度的指标,可以帮助我们更好地理解模型。

# 获取特征重要性
feature_importances = cat_regressor.get_feature_importance()
feature_names = X_train.columns# 可视化特征重要性
import matplotlib.pyplot as plt
importance_df = pd.DataFrame({'Feature': feature_names,'Importance': feature_importances
}).sort_values(by='Importance', ascending=True)plt.figure(figsize=(10, 6))
plt.barh(importance_df['Feature'], importance_df['Importance'], color='salmon')
plt.xlabel('特征重要性')
plt.ylabel('特征名称')
plt.title('CatBoost 特征重要性分析')
plt.show()

结果展示
在这里插入图片描述


4.3 模型评估

我们可以均方误差 (MSE) 以及 平均绝对误差 (MAE) 来评估模型在测试集上的回归性能,同时展示模型的学习曲线:

# 获取训练和测试集的 RMSE
evals_result = cat_regressor.get_evals_result()
train_rmse = evals_result['learn']['RMSE']
test_rmse = evals_result['validation']['RMSE']# 绘制 RMSE 曲线
plt.figure(figsize=(10, 6))
plt.plot(train_rmse, label='训练集 RMSE')
plt.plot(test_rmse, label='测试集 RMSE')
plt.title('训练与测试集的 RMSE 学习曲线')
plt.xlabel('迭代次数')
plt.ylabel('RMSE')
plt.legend()
plt.show()

五. 分类任务:判别是否调查

5.1 数据标注与模型选择

分类任务以 是否调查 作为标签(1 表示需要调查,0 表示无需调查),特征包括所有数值和类别字段。

为了完成分类任务,我们选用 CatBoostClassifier。模型参数类似于回归模型,分类评估指标包括准确率、混淆矩阵和分类报告。


5.2 训练结果与模型评估

训练结果显示,分类准确率达 94.0%。以下是模型的分类报告:

分类报告 (训练集):precision    recall  f1-score   support0       0.96      0.98      0.97     130871       0.74      0.57      0.64      1354accuracy                           0.94     14441macro avg       0.85      0.77      0.80     14441
weighted avg       0.94      0.94      0.94     14441
5.3 代码示例
from catboost import CatBoostClassifier# 初始化分类器
cat_classifier = CatBoostClassifier(iterations=1000,learning_rate=0.02,depth=8,eval_metric='Accuracy',early_stopping_rounds=150,random_seed=42
)# 模型训练
cat_classifier.fit(X_train, y_train,cat_features=categorical_features_indices,eval_set=(X_test, y_test),verbose=100
)

六. 可视化分析

为更直观地理解模型,我们利用散点图和分割线对预测结果进行展示:

  • 散点图:展示实际金额与预测金额的分布。
  • 分割线:通过 KMeans 聚类划分四个金额档次。

以下代码生成散点图与分割线:

# 使用 KMeans 聚类生成分割线
from sklearn.cluster import KMeanskmeans = KMeans(n_clusters=4, random_state=42)
df['cluster'] = kmeans.fit_predict(df[['预测金额']])# 绘制散点图
plt.figure(figsize=(12, 8))
plt.scatter(df['预测金额'], df['是否调查'], c=df['cluster'], cmap='tab10')
plt.title("预测金额与是否调查的散点图")
plt.xlabel("预测金额")
plt.ylabel("是否调查")
plt.colorbar(label='Cluster')
plt.show()

散点图展示

在这里插入图片描述


七. 补充学习

7.1 基础数据集

California Housing 数据集包含加利福尼亚州 20,640 个街区的人口、住房和收入信息。目标是预测每个街区的房价中位数 MedHouseVal

数据特征

  1. MedInc:街区的收入中位数。
  2. HouseAge:街区住房的平均年龄。
  3. AveRooms:每个街区的平均房间数。
  4. AveBedrms:每个街区的平均卧室数。
  5. Population:街区的总人口。
  6. AveOccup:每户的平均人数。
  7. Latitude:街区的纬度。
  8. Longitude:街区的经度。

7.2 实践步骤

7.2.1 导入数据与预处理

我们使用 Scikit-learn 加载数据并进行预处理。

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import pandas as pd# 加载 California Housing 数据集
data = fetch_california_housing(as_frame=True)
df = data.frame# 特征和目标变量
X = df.drop(columns="MedHouseVal")
y = df["MedHouseVal"]# 数据划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)print(f"训练集大小: {X_train.shape}, 测试集大小: {X_test.shape}")

训练集大小: (16512, 8), 测试集大小: (4128, 8)


7.2.2 训练 CatBoost 回归模型

使用 CatBoost 对房价进行预测。

from catboost import CatBoostRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error# 初始化 CatBoost 回归模型
cat_regressor = CatBoostRegressor(iterations=1000,learning_rate=0.1,depth=6,eval_metric="RMSE",random_seed=42,verbose=100
)# 模型训练
cat_regressor.fit(X_train, y_train, eval_set=(X_test, y_test), verbose=100, early_stopping_rounds=50)# 模型预测
y_pred_train = cat_regressor.predict(X_train)
y_pred_test = cat_regressor.predict(X_test)# 模型评估
mse_train = mean_squared_error(y_train, y_pred_train)
mse_test = mean_squared_error(y_test, y_pred_test)
mae_test = mean_absolute_error(y_test, y_pred_test)print(f"训练集均方误差 (MSE): {mse_train}")
print(f"测试集均方误差 (MSE): {mse_test}")
print(f"测试集平均绝对误差 (MAE): {mae_test}")

输出如下:

0:	learn: 1.0934740	test: 1.0841841	best: 1.0841841 (0)	total: 1.24s	remaining: 20m 38s
100:	learn: 0.4867395	test: 0.5154868	best: 0.5154868 (100)	total: 1.54s	remaining: 13.7s
200:	learn: 0.4320149	test: 0.4798269	best: 0.4798269 (200)	total: 1.8s	remaining: 7.18s
300:	learn: 0.4020581	test: 0.4657293	best: 0.4657293 (300)	total: 2.07s	remaining: 4.8s
400:	learn: 0.3803801	test: 0.4582868	best: 0.4582868 (400)	total: 2.35s	remaining: 3.5s
500:	learn: 0.3633580	test: 0.4534430	best: 0.4534430 (500)	total: 2.61s	remaining: 2.6s
600:	learn: 0.3488402	test: 0.4491723	best: 0.4491723 (600)	total: 2.89s	remaining: 1.92s
700:	learn: 0.3358611	test: 0.4461323	best: 0.4461323 (700)	total: 3.17s	remaining: 1.35s
800:	learn: 0.3234759	test: 0.4431320	best: 0.4431320 (800)	total: 3.44s	remaining: 854ms
900:	learn: 0.3126821	test: 0.4403978	best: 0.4403978 (900)	total: 3.71s	remaining: 407ms
999:	learn: 0.3025414	test: 0.4386906	best: 0.4386902 (998)	total: 3.97s	remaining: 0usbestTest = 0.438690174
bestIteration = 998Shrink model to first 999 iterations.
训练集均方误差 (MSE): 0.09158491090576551
测试集均方误差 (MSE): 0.19244906768098075
测试集平均绝对误差 (MAE): 0.28701415230111493

7.2.3 可视化预测结果

展示预测值与实际值的对比,以及模型的特征重要性。

实际值与预测值对比
import matplotlib.pyplot as plt# 对比测试集的预测值和实际值
plt.figure(figsize=(10, 6))
plt.scatter(range(len(y_test)), y_test, color="blue", label="真实值", alpha=0.6)
plt.scatter(range(len(y_pred_test)), y_pred_test, color="red", label="预测值", alpha=0.6)
plt.title("真实房价与预测房价对比")
plt.xlabel("样本索引")
plt.ylabel("房价中位数")
plt.legend()
plt.show()

特征重要性分析
# 特征重要性可视化
feature_importances = cat_regressor.get_feature_importance()
feature_names = data.feature_namesplt.figure(figsize=(10, 6))
plt.barh(feature_names, feature_importances, color="skyblue")
plt.title("CatBoost 特征重要性")
plt.xlabel("重要性得分")
plt.ylabel("特征名称")
plt.show()

在这里插入图片描述


7.3 数据结果

  • 模型评估结果:
    • 训练集均方误差 (MSE): 0.09158491090576551
    • 测试集均方误差 (MSE): 0.19244906768098075
    • 测试集平均绝对误差 (MAE): 0.28701415230111493
  • 特征重要性解读:
    根据特征重要性分析,MedInc(收入中位数)对预测房价的影响最大,而经纬度特征(Latitude 和 Longitude)也提供了显著的信息。

八. 总结

通过本项目,我们完成了基于 CatBoost 的回归与分类建模,并展示了预测结果的可视化。CatBoost 的强大功能和易用性使其在处理类别特征和缺失值的数据中表现优异。

希望本篇博客能为大家带来启发,助力实际项目的落地实现。如果对您有所帮助,也欢迎点赞与分享😊。

源码已上传到:https://github.com/YYForReal/ML-DL-RL-Learning/blob/main/ML-Learning/Catboost/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/483513.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用三维模型的顶点法向量计算法线贴图

法线贴图的核心概念是在不增加额外多边形数目的情况下,通过模拟细节来改善光照效果。具体流程包括: 法线的计算与存储:通过法线映射将三维法线向量转化为法线贴图的 RGB 值。渲染中的使用:在片段着色器中使用法线贴图来替代原有的…

Hadoop分布式文件系统(二)

目录 1. 引言1. Hadoop文件操作命令2. 部分常用的Hadoop FS Shell命令2.1 ls列出文件2.2 mkdir创建目录2.3 put上传文件2.4 cat查看文件2.5 get复制文件2.6 rm删除文件 3. Hadoop系统管理命令4. HDFS Java API 示例参考 1. 引言 大多数HDFS Shell命令的行为和对应的Unix Shell命…

ESP32-S3模组上跑通ES8388(13)

接前一篇文章:ESP32-S3模组上跑通ES8388(12) 二、利用ESP-ADF操作ES8388 2. 详细解析 上一回解析了es8388_init函数中的第6段代码,本回继续往下解析。为了便于理解和回顾,再次贴出es8388_init函数源码,在…

LearnOpenGL学习(光照 -- 颜色,基础光照,材质,光照贴图)

光照 glm::vec3 lightColor(0.0f, 1.0f, 0.0f); glm::vec3 toyColor(1.0f, 0.5f, 0.31f); glm::vec3 result lightColor * toyColor; // (0.0f, 0.5f, 0.0f); 说明:当我们把光源的颜色与物体的颜色值相乘,所得到的就是这个物体所反射的颜色。 创建…

Linux条件变量线程池详解

一、条件变量 【互斥量】解决了线程间同步的问题,避免了多线程对同一块临界资源访问产生的冲突,但同一时刻对临界资源的访问,不论是生产者还是消费者,都需要竞争互斥锁,由此也带来了竞争的问题。即生产者和消费者、消费…

Figma入门-自动布局

Figma入门-自动布局 前言 在之前的工作中,大家的原型图都是使用 Axure 制作的,印象中 Figma 一直是个专业设计软件。 最近,很多产品朋友告诉我,很多原型图都开始用Figma制作了,并且很多组件都是内置的,对…

威联通-001 手机相册备份

文章目录 前言1.Qfile Pro2.Qsync Pro总结 前言 威联通有两种数据备份手段:1.Qfile Pro和2.Qsync Pro,实践使用中存在一些区别,针对不同备份环境选择是不同。 1.Qfile Pro 用来备份制定目录内容的。 2.Qsync Pro 主要用来查看和操作文…

大R玩家流失预测在休闲社交游戏中的应用

摘要 预测玩家何时会离开游戏为延长玩家生命周期和增加收入贡献创造了独特的机会。玩家可以被激励留下来,战略性地与公司组合中的其他游戏交叉链接,或者作为最后的手段,通过游戏内广告传递给其他公司。本文重点预测休闲社交游戏中高价值玩家…

基于Java Springboot宠物咖微信小程序

一、作品包含 源码数据库全套环境和工具资源部署教程 二、项目技术 前端技术:Html、Css、Js、Vue、Element-ui 数据库:MySQL 后端技术:Java、Spring Boot、MyBatis 三、运行环境 开发工具:IDEA/eclipse 微信开发者工具 数…

ultralytics-YOLOv11的目标检测解析

1. Python的调用 from ultralytics import YOLO import os def detect_predict():model YOLO(../weights/yolo11n.pt)print(model)results model(../ultralytics/assets/bus.jpg)if not os.path.exists(results[0].save_dir):os.makedirs(results[0].save_dir)for result in…

蓝桥杯准备训练(lesson1,c++方向)

前言 报名参加了蓝桥杯(c)方向的宝子们,今天我将与大家一起努力参赛,后序会与大家分享我的学习情况,我将从最基础的内容开始学习,带大家打好基础,在每节课后都会有练习题,刚开始的练…

vscode 如何支持点击跳转函数,以C++为例,Python等其它编程语言同理,Visual Studio Code。

VScode(Visual Studio Code)按住Ctrl鼠标左键,没法跳转到对应的函数怎么办。 如下图所示 1、点击有四个小方块的图标 2、输入C(如果你的编程语言是C,其它的就输其它的) 3、找到C Extension(其它编程语言&#xff0…

【包教包会】CocosCreator3.x——重写Sprite,圆角、3D翻转、纹理循环、可合批调色板、不影响子节点的位移旋转缩放透明度

一、效果演示 重写Sprite组件,做了以下优化: 1、新增自变换,在不影响子节点的前提下位移、旋转、缩放、改变透明度 新增可合批调色板,支持色相、明暗调节 新增圆角矩形、3D透视旋转、纹理循环 所有功能均支持合批、原生平台&…

Java八股文(11-29start)

p1 缓存预热也要预热到布隆过滤器.过滤不存在的数据 布隆过滤器需要存储 添加数据的时候进行预热.布隆过滤器里面是位图结构,通过多个hash函数获得下标.改为1. 查询 id进行查询获得对应下标是否为1.可能会出现误判. 判断id是否存在. 穿透就是查询一个不存在的id.一直查询数…

【Gitlab】gitrunner并发配置

并发介绍 涉及到并发控制的一共有4个参数: concurrent , limit ,request_concurrency,parallel 全局的配置: [rootiZ2vc6igbukkxw6rbl64ljZ config]# vi config.toml concurrent 4 #这是一个总的全局控制,它限制了所有pipline,所有runner执行器…

智能运维在配电所设备监控中的应用与洞察

在配电所的设备监控中,智能运维正发挥着越来越重要的作用。通过对配电所内各关键设备的实时监测和数据分析,智能运维系统不仅提高了运维效率,还为我们提供了更深入的设备运行洞察。 一、设备监控概况 配电所内设有多个监测点,包括…

Lumos学习王佩丰Excel第十九讲:Indirect函数

一、认识indirect单元格引用 1、了解Indirect函数的意义及语法 Indirect:引用函数,间接引用。 函数语法:INDIRECT(ref_text,[a1]) 其中,ref_text是一个表示单元格地址或名称的字符串,a1是一个可选的逻辑值参数&…

QT6学习第八天 QFrame 类

QT6学习第八天 QFrame 类族QLabel 标签部件按钮部件QLineEdit 行编辑器部件QAbstractSpinBoxQAbstractSlider 今天来学一学 QFrame 类。 QFrame 类族 QFrame 类是带有边框的部件的基类。它的子类包括常用的标签部件 QLabel、以及 QLCDNumber、QSplitter、QStackedWidget、QToo…

Nginx学习-安装以及基本的使用

一、背景 Nginx是一个很强大的高性能Web和反向代理服务,也是一种轻量级的Web服务器,可以作为独立的服务器部署网站,应用非常广泛,特别是现在前后端分离的情况下。而在开发过程中,我们常常需要在window系统下使用Nginx…

【AI系统】Ascend C 语法扩展

Ascend C 语法扩展 Ascend C 的本质构成其实是标准 C加上一组扩展的语法和 API。本文首先对 Ascend C 的基础语法扩展进行简要介绍,随后讨论 Ascend C 的两种 API——基础 API 和高阶 API。 接下来针对 Ascend C 的几种关键编程对象——数据存储、任务间通信与同步…