代码功能
加载数据:从 UCI Adult Census 数据集中读取样本,进行清洗和编码。
特征处理:对分类特征进行标签编码,对数值特征进行标准化。
模型训练:使用 TabNet 模型对数据进行分类训练,采用早停机制提高效率。
性能评估:计算模型在测试集上的准确率、精确率、召回率和 F1 分数。
解释性分析:输出每个特征的重要性评分,帮助理解模型决策依据。
代码
# 安装必要的库
# pip install pytorch-tabnet scikit-learn pandasimport numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from pytorch_tabnet.tab_model import TabNetClassifier# 加载UCI Adult数据集
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = ["age", "workclass", "fnlwgt", "education", "education-num", "marital-status", "occupation", "relationship", "race", "sex", "capital-gain", "capital-loss", "hours-per-week", "native-country", "income"
]
data = pd.read_csv(url, header=None, names=columns, na_values=" ?", skipinitialspace=True)# 处理缺失值
data = data.dropna()# 标签编码
label_encoder = LabelEncoder()
data["income"] = label_encoder.fit_transform(data["income"])# 将分类特征编码为数字
categorical_features = data.select_dtypes(include=["object"]).columns
for col in categorical_features:data[col] = label_encoder.fit_transform(data[col])# 分离特征和标签
X = data.drop("income", axis=1).values
y = data["income"].values# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)# 标准化数值特征
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 定义TabNet模型
clf = TabNetClassifier()# 训练模型
clf.fit(X_train, y_train,eval_set=[(X_test, y_test)],eval_name=["test"],eval_metric=["accuracy"],max_epochs=50,patience=10,batch_size=1024,virtual_batch_size=128
)# 模型预测
y_pred = clf.predict(X_test)# 计算主流评估指标
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")# 特征重要性
feature_importances = clf.feature_importances_
for name, importance in zip(columns[:-1], feature_importances):print(f"Feature: {name}, Importance: {importance:.4f}")