【传知代码】基于标签相关性的多标签学习(论文复现)

在当今信息爆炸的时代,数据中包含的标签信息对于理解和分析复杂问题至关重要。在诸如文本分类、图像识别和推荐系统等应用中,如何有效地利用标签相关性提升多标签学习的效果成为了研究的热点之一。基于标签相关性的多标签学习方法,通过挖掘不同标签之间的潜在关联,旨在提高模型对多标签数据的准确性和泛化能力。

本文所涉及所有资源均在传知代码平台可获取

目录

概述

算法流程

核心代码

最后总结


概述

        帕金森病是一种使人虚弱的慢性神经系统疾病。传统中医(TCM)是一种诊断帕金森病的新方法,而用于诊断帕金森病的中医数据集是一个多标签数据集。考虑到帕金森病数据集中的症状(标签)之间总是存在相关性,可以通过利用标签相关性来促进多标签学习过程。目前的多标签分类方法主要尝试从标签对或标签链中挖掘相关性。该文章提出了一种简单且高效的多标签分类框架,称为潜在狄利克雷分布多标签(LDAML),该框架旨在通过使用类别标签的主题模型来学习全局相关性。简而言之,研究人员试图通过主题模型在标签集上获得抽象的“主题”,从而能够挖掘标签之间的全局相关性。大量实验清楚地验证了所提出的方法是一个通用且有效的框架,能够提高大多数多标签算法的性能。基于该框架,研究人员在中医帕金森病数据集上取得了令人满意的实验结果,这可以为该领域的发展提供参考和帮助。

        多标签学习(Multi-Label Learning)是一种机器学习方法,用于处理具有多个标签的数据样本。与传统的单标签学习不同,每个数据点在多标签学习中可以同时属于一个或多个类别,而不仅仅是一个确定的标签。其目标是经过算法训练后输出一个分类模型,即学习一组从特征空间到标记空间的实值函数映射。假设使用X=RdX=Rd表示一个d维的输入空间,Y={y1,y2,y3,...,yq}Y={y1​,y2​,y3​,...,yq​}表示可能输出的q个类别,多标签任务即在训练集合D={(x1,Y1),(x2,Y2),...,(xm,Ym)}D={(x1​,Y1​),(x2​,Y2​),...,(xm​,Ym​)}上学习一个X到Y的函数,该函数可以衡量x和y的相关性,对于未见过的实例x预测其对应的标签y。

今天介绍的论文是多标签学习经典算法——LDAML,论文链接:https://ieeexplore.ieee.org/abstract/document/8217717 ,如下图所示:

论文提出了一种通用且高效的多标签分类框架——Latent Dirichlet Allocation Multi-Label (LDAML)。该框架通过利用标签间的关联性进行多标签分类。该框架可以应用于大多数当前的多标签分类方法,使其性能得到提升。通过使用LDAML框架,可以显著提升简单方法(如Binary Relevance, BR)的性能,甚至超过某些最新的方法,同时保持较低的时间成本。提出的改进LDAML在某些特殊数据集(如帕金森数据集)上取得了最佳性能。特别是在帕金森数据集上,改进的LDAML框架实现了最优性能,达到了本文的最终目标。该方法能够在未来为医生提供指导和帮助。

算法流程

        与通过查找标签子集或标签链来利用相关性的传统方法不同,LDAML通过发现标签的抽象“主题”来利用相关性。假设为d维特征向量的输入空间,表示q类标号的输出空间。给定多标签训练集,其中为d维特征向量,为对应的标签集。我们可以将每个实例看作一个文档,每个标签看作文档中的一个单词。直观地说,一定有一些抽象的“主题”,期望特定的标签或多或少地出现在实例中,特别是在包含大量相关标签的多标签数据集中。LDAML算法的主要流程分为两步:(1)从训练集中挖掘标签主题;(2)计算主题的离散分布。

从训练集中挖掘标签主题: 首先,我们将LDA引入到训练集d中,每个实例xi表示文档,每个标签表示第i个实例中的第j个标签。然后利用LDA模型生成过程计算实例-主题 θ 的概率分布矩阵,其中 表示第i个实例注入第j主题的概率。
主题的离散分布: 计算实例-主题分布矩阵后,得到每个实例属于每个主题的概率值。为了确定实例确切属于哪个主题,我们需要用离散值0/1来代替概率值。在这里我们使用的离散化方法如下所示:

在这里我们的训练集数据与测试集数据分布相似,因此我们可以假设测试数据集的主题概率分布与训练数据集相同。首先我们对训练集提取出具有标记相关性的k个主题(利用算法1),然后我们使用多标签分类模型MTMT​对训练集的特征-主题进行拟合,然后利用训练好的MT模型对未知标记集合的测试集特征数据生成含有标记相关性的k个主题(这里需要注意的是,MTMT​可以随便选取一个有效的多标签分类模型,文章的重点是利用标签相关性来提高各种多标签学习模型的效率)。

文章在四份数据集上用多种多标签学习分类模型分别加上LDAML算法与其原始模型的分类效果进行对比,实验结果如图所示:

以上实验结果表明,LDAML能够在性能和时间成本之间取得良好的平衡。目前的大多数方法都可以应用于LDAML。我们可以采用目前最先进的方法作为LDAML在原始基础上取得突破的基本方法(base model)。另一方面,唯一额外的时间代价是计算主题概率分布矩阵的小词空间。因此,LDAML的时间成本接近于其基础方法的时间成本。通过采用BR或CC等较弱的方法作为基本方法,可以在较低的时间成本下提高接近实际状态的性能。这些结果表明,LDAML是一个通用的框架,可以为具有标签相关性的多标签问题提供鲁棒且更优的解决方案。 

核心代码

由于改论文代码目前尚未开源,因此在本文中我将给出由本人根据论文算法流程一比一复制的复现代码,代码源文件我将放在附件中,其核心逻辑如下:

#########################伪代码###########################
# 导入必要的库
Import libraries# 定义函数
Function discretize(theta):# 初始化二进制矩阵 YTInitialize YT as a zero matrix with the same shape as thetaFor each row i in theta:Find the maximum value in row iFor each column j in row i:If the difference between the max value and theta[i][j] is less than 1/K:Set YT[i][j] to 1Else:Set YT[i][j] to 0Return YTFunction convert_to_one_hot(data):# 获取唯一值和类别数Find unique values in dataInitialize one_hot_encoded as a zero matrixFor each value in data:Find the index of the value in unique valuesSet the corresponding position in one_hot_encoded to 1Return one_hot_encodedFunction lda(labels, n):# 进行潜在狄利克雷分配(LDA)Initialize LDA model with n componentsFit and transform labels using LDA modelDiscretize the transformed dataReturn the discretized dataFunction metric_cal(test, pred):# 计算并打印评估指标Calculate accuracy, precision, recall, F1 score, and AUCPrint the calculated metrics# 主程序
If __name__ == "__main__":# 加载数据Load data from Excel file# 定义标签列和特征Define label_cols and featuresConvert features and labels to NumPy arrays# 设置主题数Set n to 6# 对标签进行LDACall lda function to get Y_T# 将特征与离散化的标签组合Concatenate features and Y_T to get XYT# 划分训练集和测试集Split XYT and labels into X_train, X_test, y_train, y_test# 初始化多标签分类器Initialize MT_classifier as RankSVM# 从训练集和测试集中提取主题Extract yt_train and yt_test from X_train and X_testRemove last n columns from X_train and X_test# 训练多标签分类器Fit MT_classifier using X_train and yt_train# 预测测试集的主题Predict yt_proba and yt_pred using MT_classifier on X_testConvert yt_pred to integer# 使用预测的主题扩展训练集和测试集Concatenate X_train with yt_train to get X_train_augConcatenate X_test with yt_pred to get X_test_aug# 初始化并训练二进制相关性分类器Initialize base_classifier as MLPClassifierInitialize clf as BinaryRelevance with base_classifierFit clf using X_train_aug and y_train# 预测测试集的标签Predict y_pred and y_score using clf on X_test_aug# 计算评估指标Calculate hamming loss, ranking loss, coverage error, and average precisionPrint calculated metrics# 对每个标签计算并打印评估指标For each label i:Extract test and pred for label iCall metric_cal function to calculate and print metricsPrint separatorPrint final separator

在主文件main.py中我复现了LDAML算法的整个流程,并实现了从输入数据到输出评价指标的全过程,在这里默认采用的多标签学习分类起MTMT​和MM是RankSVM和二元回归+深度学习。 

# 定义LIFTClassifier类,继承自BaseEstimator和ClassifierMixin
class LIFTClassifier(BaseEstimator, ClassifierMixin):# 初始化函数,接受一个基本分类器作为参数def __init__(self, base_classifier=DecisionTreeClassifier()):设置base_classifier为传入的参数初始化classifiers字典# 训练模型函数def fit(self, X, y):获取标签数量遍历每个标签对每个标签训练一个分类器将训练好的分类器存入classifiers字典返回self# 预测函数def predict(self, X):获取标签数量初始化预测结果矩阵遍历每个标签使用对应的分类器进行预测将预测结果存入预测结果矩阵返回预测结果矩阵# 预测概率函数def predict_proba(self, X):获取标签数量初始化概率预测结果矩阵遍历每个标签使用对应的分类器进行概率预测将预测概率结果存入概率预测结果矩阵返回概率预测结果矩阵# 定义MLkNN类
class MLkNN:# 初始化函数,接受一个k值作为参数def __init__(self, k=3):设置k值初始化k近邻模型# 训练模型函数def fit(self, X, y):保存训练数据X和y使用X训练k近邻模型# 预测函数def predict(self, X):获取样本数量初始化预测结果矩阵遍历每个样本获取样本的k+1个最近邻排除样本自身计算邻居标签的和根据标签和判断最终预测结果返回预测结果矩阵# 预测概率函数def predict_proba(self, X):获取样本数量初始化概率预测结果矩阵遍历每个样本获取样本的k+1个最近邻排除样本自身计算每个标签的概率返回概率预测结果矩阵# 定义RankSVM类,继承自BaseEstimator和ClassifierMixin
class RankSVM(BaseEstimator, ClassifierMixin):# 初始化函数,接受参数C, kernel, gammadef __init__(self, C=1.0, kernel='rbf', gamma='scale'):设置C, kernel, gamma值初始化模型列表初始化多标签二值化器# 训练模型函数def fit(self, X, y):使用多标签二值化器转换y获取标签数量遍历每个标签将当前标签转换为二值格式使用SVM训练二值化后的标签将训练好的SVM模型加入模型列表# 预测函数def predict(self, X):初始化预测结果矩阵遍历每个SVM模型使用模型进行预测将预测结果存入预测结果矩阵返回预测结果矩阵# 预测概率函数def predict_proba(self, X):初始化概率预测结果矩阵遍历每个SVM模型使用模型进行概率预测将预测概率结果存入概率预测结果矩阵返回概率预测结果矩阵# 定义MultiLabelDecisionTree类
class MultiLabelDecisionTree:# 初始化函数,接受参数max_depth, random_statedef __init__(self, max_depth=None, random_state=None):设置max_depth, random_state值初始化标签幂集转换器初始化决策树分类器# 训练模型函数def fit(self, X, y):使用标签幂集转换器转换y使用转换后的y训练决策树分类器# 预测概率函数def predict_proba(self, X):使用决策树分类器进行概率预测将预测概率结果转换为原始标签格式返回概率预测结果# 预测函数def predict(self, X):使用决策树分类器进行预测将预测结果转换为原始标签格式返回预测结果# 定义MLP神经网络类,继承自nn.Module
class MLP(nn.Module):# 初始化函数,接受输入大小、隐藏层大小和输出大小作为参数def __init__(self, input_size, hidden_size, output_size):调用父类的初始化函数初始化全连接层1初始化ReLU激活函数初始化全连接层2初始化Sigmoid激活函数# 前向传播函数def forward(self, x):通过全连接层1通过ReLU激活函数通过全连接层2通过Sigmoid激活函数返回输出# 定义BPMLL类,继承自BaseEstimator和ClassifierMixin
class BPMLL(BaseEstimator, ClassifierMixin):# 初始化函数,接受参数input_size, hidden_size, output_size, epochs, lrdef __init__(self, input_size, hidden_size, output_size, epochs=10, lr=0.0001):设置输入大小、隐藏层大小、输出大小、训练轮数、学习率初始化MLP模型初始化优化器初始化损失函数# 训练模型函数def fit(self, X_train, X_val, y_train, y_val):将训练数据和验证数据转换为张量创建训练数据集和数据加载器遍历每个训练轮次设置模型为训练模式遍历训练数据加载器清零梯度前向传播计算损失反向传播更新参数设置模型为评估模式计算验证损失并打印# 预测概率函数def predict_proba(self, X):设置模型为评估模式禁用梯度计算进行前向传播返回预测概率结果# 预测函数def predict(self, X, threshold=0.5):获取预测概率结果根据阈值判断最终预测结果返回预测结果# 定义RandomKLabelsetsClassifier类,继承自BaseEstimator和ClassifierMixin
class RandomKLabelsetsClassifier(BaseEstimator, ClassifierMixin):# 初始化函数,接受参数base_classifier, labelset_size, model_countdef __init__(self, base_classifier=None, labelset_size=3, model_count=10):设置基本分类器、标签集大小、模型数量初始化RakelD模型# 训练模型函数def fit(self, X, y):使用RakelD模型训练数据返回self# 预测函数def predict(self, X):使用RakelD模型进行预测返回预测结果# 预测概率函数def predict_proba(self, X):使用RakelD模型进行概率预测返回概率预测结果

调用LDAML算法的方法放在main.py文件中,首先我们需要将文件路径修改成自己所要使用的数据集路径。这里我使用的文件路径为’./测试数据.xlsx’,供大家一键运行熟悉项目。然后大家需要将自己的标签列名称提取变量label_cols中,用于对数据集划分特征集合与标签集合。 

构建想要的多标签学习分类算法,这里我给大家复现了多种经典的多标签分类器,如LIFT、MlkNN和RankSVM等,并帮大家配置好了参数,大家可以将想要使用的算法对应行的注释删掉即可(MTMT​和MM都是一样)。

设置好这些外在参数后,我们就可以运行代码,主文件将自动调用第三方库和multi_label_learn.py文件中的函数来进行训练和测试。下面是我选取的几种测试指标,分别会输出模型对整体的多标签分类性能指标(Hamming loss、Ranking loss、Coverage error和Average precision)和对单一标签的分类指标(Accuracy、Precision、Recall、F1 Score和AUC)。 

下面是在测试数据集上模型的表现:

最后总结

        多标签学习作为处理现实世界复杂数据的重要方法,其有效性在很大程度上依赖于如何处理标签之间的相关性。本文探讨了基于标签相关性的多标签学习的关键技术和方法。我们首先介绍了不同的标签相关性建模方法,包括基于图结构的方法、注意力机制和迁移学习等。

通过实例和案例分析,我们展示了这些方法如何提高模型对多标签数据的分类精度和泛化能力。特别是在面对标签稀疏性和噪声数据时,这些方法显示出了明显的优势和适应能力。未来的研究方向可能包括更加复杂的标签关联建模、跨领域的标签迁移学习以及与深度学习技术的进一步集成,以应对日益复杂和多样化的数据挑战。

基于标签相关性的多标签学习不仅在学术研究中具有深远意义,也在实际应用中展现了巨大潜力。我们希望本文能够为读者提供一个全面的视角,激发更多关于多标签学习和标签关联性研究的探索与创新。

详细复现过程的项目源码、数据和预训练好的模型可从该文章下方附件获取。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/391093.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JAVA项目基于SpringBoot的外卖点餐管理系统

目录 一、前言 二、技术介绍 三、系统实现 四、论文参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 随着生活节…

Flink笔记整理(六)

Flink笔记整理(六) 完整系列在公众号:是十三不是四十三,欢迎关注~ 文章目录 Flink笔记整理(六)八、状态管理8.1 Flink中的状态概述状态的分类 8.2 按键分区状态(Keyed State)值状态…

Windows 添加自定义服务实现开机(用户登录之前)自动运行 Python 脚本

实现效果 使用 Python 编写的一个脚本, 希望在 Windows 系统启动时, 用户登录之前就自动运行. 准备工作 首先确保 Python 脚本可以手动正常运行, 演示起见, 编写下面的一个简单的脚本用于在 C 盘根目录中生成一个包含脚本运行时间戳的文本文件. Python 脚本存放在 C:\Python…

python爬虫实践

两个python程序的小实验(附带源码) 题目1 爬取http://www.gaosan.com/gaokao/196075.html 中国大学排名,并输出。提示:使用requests库获取页面的基本操作获取该页面,运用BeautifulSoup解析该页面绑定对象soup&#x…

搭建jenkins一键部署java项目

一、搭建jenkins 链接: https://pan.baidu.com/s/1jzx15PiyI8EhLd_vg7q8bw 提取码: ydhl 复制这段内容后打开百度网盘手机App,操作更方便哦 直接使用docker导入镜像,运行就好 docker run -di --name jenkins -p 8080:8080 -v /home/jenkins_home:/var/je…

【人工智能】NLP入门指南:自然语言处理基础全解析

文章目录 前言一、NLPNLP(自然语言处理)NLU(自然语言理解)NLG(自然语言生成) 二、分词1.什么是分词2.常见的分词工具3.jieba分词 三、词向量1.什么是词向量2.文本张量表示方法3.常见的词向量模型3.1 ont-ho…

15.4 zookeeper java client之Curator使用(❤❤❤❤❤)

Curator使用 1. 为什么使用Curator对比Zookeeper原生2. 集成Curator2.1 依赖引入curator-frameworkcurator-recipes2.2 `yml`配置连接信息2.3 CuratorConfig配置类2.4 Curator实现Zookeeper分布式锁业务2.4.1 业务:可重入锁和不可重入锁可重入锁和不可重入锁InterProcessMutex …

scratch魔法门 2024年6月scratch四级 中国电子学会图形化编程 少儿编程等级考试四级真题和答案解析

目录 scratch魔法门 一、题目要求 1、准备工作 2、功能实现 二、案例分析 1、角色分析 2、背景分析 3、前期准备 三、解题思路 1、思路分析 2、详细过程 四、程序编写 五、考点分析 六、推荐资料 1、入门基础 2、蓝桥杯比赛 3、考级资料 4、视频课程 5、pyt…

基于JAVA的美甲店员工管理系统,源码、部署+讲解

摘 要 随着社会科技的飞速发展和进步,网络技术的应用已经深入到生活的方方面面。在这样的背景下,企事业单位的绩效考评体系也受到了极大的影响和冲击。传统的绩效考评方式已经无法满足现代社会的高效需求,因此,研发一款智能化、高…

42 PCB布线叠层与阻抗介绍43 PCB布线过孔添加与设置44 差分对添加与设置45 布线间距规则与介绍

42 PCB布线叠层与阻抗介绍&&43 PCB布线过孔添加与设置&44 差分对添加与设置&&45 布线间距规则与介绍 第一部分 42 PCB布线叠层与阻抗介绍1 板子是怎么来的。2 四层板为例,做叠层和阻抗计算。 第二部分 43 PCB布线过孔添加与设置介绍PCBEdotor中…

c#中的BitConverter的java实现

最近在做c#项目的java迁移,发现部分C#方法java中没有对应实现如图: 且java中的数字类型都是有符号的所以转无符号的时候需要进行手动对符号位& 0xFFFF进行处理,目前只整理了项目中使用到的方法,后续有用到其他方法在进行追加如…

linux学习记录(一)--------目录及文件操作

文章目录 前言Linux目录及文件操作1.Linux目录结构2.常用的Linux命令3.vi编辑器的简单使用4.vi的两个模式 前言 小白学习linux记录有错误随时指出~ Linux目录及文件操作 Linux采用Shell命令->操作文件 1.Linux目录结构 根目录:/ 用户目录:~或者/ho…

H5+JS 4096小游戏

主要实现 1.使用WASD或方向按钮控制游戏 2.最高值4096,玩到4096视为胜利 3.随机生成2、4、8方块 4.移动方块 5.合并方块 JS代码干了什么 初始化游戏界面:创建游戏板和控制按钮。 定义游戏相关变量:如棋盘大小、棋盘状态、得分等。 初始化棋…

软件测试生命周期、BUG描述与处理策略

软件测试的生命周期 需求分析:需求是否完整、是否正确 测试计划:确定由谁测试、测试的起止时间、设计哪些模块 测试设计、测试开发:写测试用例(手工、自动化测试用例)、编写测试工具 执行测试用例 测试评估&…

面向未来的S2B2C电商供应链系统发展趋势与创新探索

S2B2C电商供应链系统的发展趋势及创新方向。首先分析当前市场环境和消费者需求的变化,如个性化消费、即时配送、绿色环保等趋势对供应链系统的影响。随后,预测并讨论未来供应链系统可能的技术革新,如物联网(IoT)在物流…

【系统架构设计师】二十四、安全架构设计理论与实践①

目录 一、安全架构概述 1.1 信息安全面临的威胁 1.1.1 安全威胁分类 1.1.2 常见的安全威胁 1.2 安全架构的定义和范围 二、安全模型 2.1 状态机模型 2.2 Bell-LaPadula模型 2.3 Biba模型 2.4 Clark-Wilson模型 2.5 Chinese Wall 模型 往期推荐 一、安全架构概述 1…

基于LoRA和AdaLoRA微调Qwen1.5-0.5B-Chat

本文只开放基于LoRA和AdaLoRA微调代码,具体技术可以自行学习。 Qwen1.5-0.5B-Chat权重路径:https://huggingface.co/Qwen/Qwen1.5-0.5B 数据集路径:https://github.com/DB-lost/self-llm/blob/master/dataset/huanhuan.json 1. 知识点 LoRA, AdaLoRA技术 具体技术可以去看…

数据结构第十讲:二叉树OJ题

数据结构第十讲:二叉树OJ题 1.单值二叉树2.相同的树3.对称二叉树4. 另一棵树的子树5.二叉树的前序遍历6.二叉树的中序遍历7.二叉树的后序遍历8.二叉树的构建及其遍历9.二叉树选择题9.1二叉树性质19.2二叉树性质29.3二叉树性质三9.4选择题 1.单值二叉树 链接: OJ题链…

『python爬虫』beautifulsoup库获取文本的方法.get_text()、.text 和 .string区别(保姆级图文)

目录 区别.string(不推荐用).text(get_text的简化版少敲代码的时候用).get_text(推荐用,功能强大,为什么不爱呢?) 示例代码总结 欢迎关注 『python爬虫』 专栏,持续更新中 欢迎关注 『python爬虫』 专栏,持续更新中 区别 省流直接看get_text 推荐用这个…

【Git】如何优雅地使用Git中的tag来管理项目版本

目录 tagtag 和 branch区别操作命令打tag,当前分支标记tag提交到远程服务器删除本地tag删除远程tag切换到特定的tag查看所有tag查看标签详细信息 好书推荐 tag Git中的tag(标签)用于给项目在特定时间点(某个版本发布)…