TabNet 模型示例

代码功能

加载数据:从 UCI Adult Census 数据集中读取样本,进行清洗和编码。
特征处理:对分类特征进行标签编码,对数值特征进行标准化。
模型训练:使用 TabNet 模型对数据进行分类训练,采用早停机制提高效率。
性能评估:计算模型在测试集上的准确率、精确率、召回率和 F1 分数。
解释性分析:输出每个特征的重要性评分,帮助理解模型决策依据。
在这里插入图片描述

代码

# 安装必要的库
# pip install pytorch-tabnet scikit-learn pandasimport numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from pytorch_tabnet.tab_model import TabNetClassifier# 加载UCI Adult数据集
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = ["age", "workclass", "fnlwgt", "education", "education-num", "marital-status", "occupation", "relationship", "race", "sex", "capital-gain", "capital-loss", "hours-per-week", "native-country", "income"
]
data = pd.read_csv(url, header=None, names=columns, na_values=" ?", skipinitialspace=True)# 处理缺失值
data = data.dropna()# 标签编码
label_encoder = LabelEncoder()
data["income"] = label_encoder.fit_transform(data["income"])# 将分类特征编码为数字
categorical_features = data.select_dtypes(include=["object"]).columns
for col in categorical_features:data[col] = label_encoder.fit_transform(data[col])# 分离特征和标签
X = data.drop("income", axis=1).values
y = data["income"].values# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)# 标准化数值特征
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 定义TabNet模型
clf = TabNetClassifier()# 训练模型
clf.fit(X_train, y_train,eval_set=[(X_test, y_test)],eval_name=["test"],eval_metric=["accuracy"],max_epochs=50,patience=10,batch_size=1024,virtual_batch_size=128
)# 模型预测
y_pred = clf.predict(X_test)# 计算主流评估指标
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")# 特征重要性
feature_importances = clf.feature_importances_
for name, importance in zip(columns[:-1], feature_importances):print(f"Feature: {name}, Importance: {importance:.4f}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/476492.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一次封装,解放双手:Requests如何实现0入侵请求与响应的智能加解密

引言 之前写了 Requests 自动重试的文章,突然想到,之前还用到过 Requests 自动加解密请求的逻辑,分享一下。之前在做逆向的时候,发现一般医院的小程序请求会这么玩,请求数据可能加密也可能不加密,但是返回…

锂电池学习笔记(一) 初识锂电池

前言 锂电池近几年一直都是很热门的产品,充放电管理更是学问蛮多,工作生活中难免会碰到,所以说学习锂电池是工程师的必备知识储备,今天学习锂电池的基本知识,分类,优缺点,循序渐进 学习参考 【…

《Vue零基础入门教程》第四课: 应用实例

往期内容 《Vue零基础入门教程》第一课:Vue简介 《Vue零基础入门教程》第二课:搭建开发环境 《Vue零基础入门教程》第三课:起步案例 参考官方文档 https://cn.vuejs.org/api/application#create-app 示例 const {createApp} Vue// 通…

介绍一下strncmp(c基础)

strncmp是strcmp的进阶版 链接介绍一下strcmp(c基础)-CSDN博客 作用 比较两个字符串的前n位 格式 #include <string.h> strncmp (arr1,arr2,n); 工作原理&#xff1a;strcmp函数按照ACII&#xff08;字符编码顺序&#xff09;比较两个字符串。它从两个字符串的第一…

Lucene(2):Springboot整合全文检索引擎TermInSetQuery应用实例附源码

前言 本章代码已分享至Gitee: https://gitee.com/lengcz/springbootlucene01 接上文。Lucene(1):Springboot整合全文检索引擎Lucene常规入门附源码 如何在指定范围内查询。从lucene 7 开始&#xff0c;filter 被弃用&#xff0c;导致无法进行调节过滤。 TermInSetQuery 指定…

【电路笔记 TMS320F28335DSP】时钟+看门狗+相关寄存器(功能模块使能、时钟频率配置、看门狗配置)

时钟源和主时钟&#xff08;SYSCLKOUT&#xff09; 外部晶振&#xff1a;通常使用外部晶振&#xff08;如 20 MHz&#xff09;作为主要时钟源。内部振荡器&#xff1a;还可以选择内部振荡器&#xff08;INTOSC1 和 INTOSC2&#xff09;&#xff0c;适合无需高精度外部时钟的应…

java 并发编程 (1)java中如何实现并发编程

目录 1. 继承 Thread 类 2. 实现 Runnable 接口 3. 使用 FutureTask 4. 使用 Executor 框架 5. 具体案例 1. 继承 Thread 类 概述&#xff1a;通过继承 Thread 类并重写其 run() 方法来创建一个新的线程。 步骤&#xff1a; 创建一个继承 Thread 类的子类。重…

巧用观测云可用性监测(云拨测)

前言 做为系统运维或者开发&#xff0c;很多时候我们需要能够实时感知我们所运维的系统和服务的情况&#xff0c;比如以下的场景&#xff1a; 系统上线前测试&#xff1a;包括功能完整性检查&#xff0c;确保页面元素&#xff08;如图像、视频、脚本等&#xff09;都能够正常…

python oa服务器巡检报告脚本的重构和修改(适应数盾OTP)有空再去改

Two-Step Vertification required&#xff1a; Please enter the mobile app OTPverification code: 01.因为巡检的服务器要双因子认证登录&#xff0c;也就是登录堡垒机时还要输入验证码。这对我的巡检查服务器的工作带来了不便。它的机制是每一次登录&#xff0c;算一次会话…

Unreal从入门到精通之如何绘制用于VR的3DUI交互的手柄射线

文章目录 前言实现方式MenuLaser实现步骤1.Laser和Cursor2.移植函数3.启动逻辑4.检测射线和UI的碰撞5.激活手柄射线6.更新手柄射线位置7.隐藏手柄射线8.添加手柄的Trigger监听完整节点如下:效果图前言 之前我写过一篇文章《Unreal5从入门到精通之如何在VR中使用3DUI》,其中讲…

Win11 22H2/23H2系统11月可选更新KB5046732发布!

系统之家11月22日报道&#xff0c;微软针对Win11 22H2/23H2版本推送了2024年11月最新可选更新补丁KB5046732&#xff0c;更新后&#xff0c;系统版本号升至22621.4541和22631.4541。本次更新后系统托盘能够显示缩短的日期和时间&#xff0c;文件资源管理器窗口很小时搜索框被切…

【数据结构】【线性表】【练习】反转链表

申明 该题源自力扣题库19&#xff0c;文章内容&#xff08;代码&#xff0c;图表等&#xff09;均原创&#xff0c;侵删&#xff01; 题目 给你单链表的头指针head以及两个整数left和right&#xff0c;其中left<right&#xff0c;请你反转从位置left到right的链表节点&…

【赵渝强老师】MySQL的慢查询日志

MySQL的慢查询日志可以把超过参数long_query_time时间的所有SQL语句记录进来&#xff0c;帮助DBA人员优化所有有问题的SQL语句。通过mysqldumpslow工具可以查看慢查询日志。 视频讲解如下&#xff1a; MySQL的慢查询日志 【赵渝强老师】MySQL的慢查询日志 下面通过具体的演示…

基于docker进行任意项目灵活发布

引言 不管是java还是python程序等&#xff0c;使用docker发布的优势有以下几点&#xff1a; 易于维护。直接docker命令进行管理&#xff0c;如docker stop、docker start等&#xff0c;快速方便无需各种进程查询关闭。环境隔离。项目代码任何依赖或设置都可以基本独立&#x…

Android 分区相关介绍

目录 一、MTK平台 1、MTK平台分区表配置 2、MTK平台刷机配置表 3、MTK平台分区表配置不生效 4、Super分区的研究 1&#xff09;Super partition layout 2&#xff09;Block device table 二、高通平台 三、展锐平台 四、相关案例 1、Super分区不够导致编译报错 经验…

数据库类型介绍

1. 关系型数据库&#xff08;Relational Database, RDBMS&#xff09;&#xff1a; • 定义&#xff1a;基于关系模型&#xff08;即表格&#xff09;存储数据&#xff0c;数据之间通过外键等关系相互关联。 • 特点&#xff1a;支持复杂的SQL查询&#xff0c;数据一致性和完整…

当产业经济插上“数字羽翼”,魔珐有言AIGC“3D视频创作大赛”成功举办

随着AI技术的飞速发展&#xff0c;3D数字人技术已成为驱动各行各业转型升级的重要力量。在这一背景下&#xff0c;2024山东3D数字人视频创作大赛应运而生&#xff0c;并在一番激烈的角逐后圆满落幕&#xff0c;为科技与创意的交融写下浓墨重彩的一笔。 11月20日&#xff0c;一…

经济增长初步

1.人均产出 人均产出&#xff0c;通常指的是一个国家、地区或组织在一定时期内&#xff0c;每个劳动人口平均创造的生产总值。它是衡量一个地区或国家经济效率和劳动生产率的重要指标。具体来说&#xff0c;人均产出可以通过以下公式计算&#xff1a; 人均产出总产出/劳动人口…

图像增强夜视仪行业全面而深入的分析

图像增强夜视设备&#xff08;I2ND 或 INVD&#xff09;是一种增强监视、安全和军事应用的微光可见度的技术。 它允许用户在非常弱的光线甚至完全黑暗的条件下看到东西。 一、市场研究 1. 市场规模与增长趋势 据QYResearch调研团队最新报告&#xff0c;预计2029年全球图像增强…

002 MATLAB语言基础

01 变量命名规则 变量名只能由字母、数字和下划线组成&#xff0c;且必须以字母开头&#xff1b; 变量名区分字母的大小写&#xff1b; 变量名不能超过最大长度限制&#xff1b; 关键字不能作为变量名&#xff0c;如for、end和if等&#xff1b; 注意&#xff1a;存变量命名时…