如何用mmclassification训练多标签多分类数据

这里使用的源码版本是 mmclassification-0.25.0
训练数据标签文件格式如下,每行的空格前面是路径(图像文件所在的绝对路径),后面是标签名,因为特殊要求这里我的每张图像都记录了三个标签每个标签用“,”分开(具体看自己的需求),我的训练标签数量是17个。
在这里插入图片描述
训练参数配置文件,用ResNet作为特征提取主干,多标签分类要使用MultiLabelLinearClsHead作为分类头。数据集的格式使用CustomDataset,并修改该结构的定义文件,后面有详细内容。

# checkpoint saving
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(interval=100,hooks=[dict(type='TextLoggerHook'),# dict(type='TensorboardLoggerHook')])
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
optimizer = dict(lr=0.1, momentum=0.9, type='SGD', weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
runner = dict(max_epochs=100, type='EpochBasedRunner')
lr_config = dict(policy='step', step=[30,60,90,])model = dict(type='ImageClassifier',backbone=dict(type='ResNet',depth=18,num_stages=4,out_indices=(3, ),style='pytorch'), neck=dict(type='GlobalAveragePooling'),head=dict(type='MultiLabelLinearClsHead',num_classes=17,in_channels=512,))dataset_type = 'CustomDataset'          #'MultiLabelDataset'
img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [dict(type='LoadImageFromFile'),dict(type='RandomResizedCrop', size=224),dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),dict(type='Normalize', **img_norm_cfg),dict(type='ImageToTensor', keys=['img']),dict(type='ToTensor', keys=['gt_label']),dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [dict(type='LoadImageFromFile'),dict(type='Resize', size=(256, -1)),dict(type='CenterCrop', crop_size=224),dict(type='Normalize', **img_norm_cfg),dict(type='ImageToTensor', keys=['img']),dict(type='Collect', keys=['img'])
]data = dict(samples_per_gpu=32,workers_per_gpu=2,train=dict(type=dataset_type,data_prefix='rootpath/images',ann_file='rootpath/train.txt',pipeline=train_pipeline),val=dict(type=dataset_type,data_prefix='rootpath/images',ann_file='rootpath/val.txt',pipeline=test_pipeline),test=dict(type=dataset_type,data_prefix='rootpath/images',ann_file='rootpath/test.txt',pipeline=test_pipeline))evaluation = dict(interval=1, metric='accuracy')

其他需要修改的地方:
1、修改加载数据的格式,将./mmclassification-0.25.0/mmcls/datasets/custom.py的CustomDataset里面的load_annotations函数替换成下面的函数:

    ###修改成多标签分类数据加载方式###def load_annotations(self):"""Load image paths and gt_labels."""if self.ann_file is None:samples = self._find_samples()elif isinstance(self.ann_file, str):lines = mmcv.list_from_file(self.ann_file, file_client_args=self.file_client_args)samples = [x.strip().rsplit(' ', 1) for x in lines]else:raise TypeError('ann_file must be a str or None')data_infos = []for filename, gt_label in samples:info = {'img_prefix': self.data_prefix}info['img_info'] = {'filename': filename.strip()}temp_label = np.zeros(len(self.CLASSES))# if not self.multi_label:#     info['gt_label'] = np.array(gt_label, dtype=np.int64)# else:### multi-label classifyif len(gt_label) == 1:temp_label[np.array(gt_label, dtype=np.int64)] = 1info['gt_label'] = temp_labelelse:for label in gt_label.split(','):i = self.CLASSES.index(label)temp_label[np.array(i, dtype=np.int64)] = 1# for i in range(np.array(gt_label.split(','), dtype=np.int64).shape[0]):#     temp_label[np.array(gt_label.split(','), dtype=np.int64)[i]] = 1info['gt_label'] = temp_label# print(info)data_infos.append(info)return data_infos

记得在初始函数__init__里修改成自己要训练的类别:
在这里插入图片描述

2、修改评估数据的函数,将./mmclassification-0.25.0/mmcls/models/losses/accuracy.py里面的accuracy_torch函数替换成如下函数。我这里只是增加了一些度量函数,方便可视化多标签的指标情况,并没有更新其他地方,训练时还是会验证原来的指标,里面调用的Metric类可以参考这篇文章:https://blog.csdn.net/u013250861/article/details/122727704

def accuracy_torch(pred, target, topk=(1,), thrs=0.):if isinstance(thrs, Number):thrs = (thrs,)res_single = Trueelif isinstance(thrs, tuple):res_single = Falseelse:raise TypeError(f'thrs should be a number or tuple, but got {type(thrs)}.')res = []maxk = max(topk)num = pred.size(0)pred = pred.float()#### ysn修改,增加对多标签分类的度量函数 ###pred_ = (pred > 0.5).float()        # 将 pred 中大于0.5的元素替换为1,其余替换为0# print("pred shape:", pred.shape, "pred:", pred)# # print("pred_ shape:", pred_.shape, "pred_:", pred_)# # print("target shape", target.shape, "target:", target)from mmcls.utils import get_root_loggerlogger = get_root_logger()from sklearn.metrics import classification_reportclass_report = classification_report(target.numpy(), pred_.numpy(), target_names=[“这里可以写成你的训练类型列表,也可以不使用这个参数”])     #分类报告汇总了精确率、召回率和 F1 分数等指标logger.info("\nClassification Report:\n{}".format(class_report))myMetic = Metric(pred_.numpy(), target.numpy())ham = myMetic.hamming_distance()avgPrecision, _ = myMetic.avgPrecision()avgRecall, _, _  = myMetic.avgRecall()ranking_loss = myMetic.get_ranking_loss()accuracy_multiclass = myMetic.accuracy_multiclass()logger.info("\nHam:{}\tAvgPrecision:{}\tAvgRecall:{}\tRanking_loss:{}\tAccuracy_Multilabel:{}".format(ham, avgPrecision, avgRecall, ranking_loss, accuracy_multiclass))####原来的代码###pred_score, pred_label = pred.topk(maxk, dim=1)pred_label = pred_label.t()target = target.argmax(dim=1)     ### ysn修改,这里是多标签分类标签列表的格式,单标签分类去掉这一句 ###correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))for k in topk:res_thr = []for thr in thrs:# Only prediction values larger than thr are counted as correct_correct = correct & (pred_score.t() > thr)correct_k = _correct[:k].reshape(-1).float().sum(0, keepdim=True)res_thr.append((correct_k.mul_(100. / num)))if res_single:res.append(res_thr[0])else:res.append(res_thr)return res

3、修改推理部分,将./mmclassification-0.25.0/mmcls/apis/inference.py里面的inference_model函数修改如下,推理多标签时候可以指定输出所有得分阈值大于0.5的所有标签类型。

def inference_model(model, img):"""Inference image(s) with the classifier.Args:model (nn.Module): The loaded classifier.img (str/ndarray): The image filename or loaded image.Returns:result (dict): The classification results that contains`class_name`, `pred_label` and `pred_score`."""cfg = model.cfgdevice = next(model.parameters()).device  # model device# build the data pipelineif isinstance(img, str):if cfg.data.test.pipeline[0]['type'] != 'LoadImageFromFile':cfg.data.test.pipeline.insert(0, dict(type='LoadImageFromFile'))data = dict(img_info=dict(filename=img), img_prefix=None)else:if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile':cfg.data.test.pipeline.pop(0)data = dict(img=img)test_pipeline = Compose(cfg.data.test.pipeline)data = test_pipeline(data)data = collate([data], samples_per_gpu=1)if next(model.parameters()).is_cuda:# scatter to specified GPUdata = scatter(data, [device])[0]# forward the model# with torch.no_grad():#     scores = model(return_loss=False, **data)#     pred_score = np.max(scores, axis=1)[0]#     pred_label = np.argmax(scores, axis=1)[0]#     result = {'pred_label': pred_label, 'pred_score': float(pred_score)}# result['pred_class'] = model.CLASSES[result['pred_label']]# return result## ysn修改 ##with torch.no_grad():scores = model(return_loss=False, **data)# print(scores, type(scores), len(scores), len(model.CLASSES))result = {'pred_label':[], 'pred_score': [], 'pred_class':[]}for i in range(len(scores[0])):if scores[0][i]>0.5:result['pred_label'].append(int(i))result['pred_score'].append(float(scores[0][i]))result['pred_class'].append(model.CLASSES[int(i)])else:continuereturn result

通过以上修改,可以成功运行和评估我的多标签分类训练了。
由于我没有找到mmcls官方的训练多标签的训练教程,因此做了上述修改。如果有其他更方便有效的多标签多分类方法或者项目,欢迎在该文章下面留言,非常感谢。

参考文章
https://blog.csdn.net/litt1e/article/details/125316552
https://blog.csdn.net/u013250861/article/details/122727704

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/455074.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

力扣71~75题

题71(中等): python代码: class Solution:def simplifyPath(self, path: str) -> str:#首先根据/分割字符串,再使用栈来遍历存储p_listpath.split(/)p_stack[]for i in p_list:#如果为空则肯定是//或者///if i:con…

mac m1 安装openresty以及redis限流使用

一切源于一篇微信文章 早上我上着班,听着歌1.打算使用腾讯云服务器centos-7实验:安装ngx_devel_kitmac m1 os 12.7.6 安装openresty测试lua限流: 终于回到初心了! 早上我上着班,听着歌 突然微信推送了一篇文章《Nginx 实现动态封…

记录一次从nacos配置信息泄露到redis写计划任务接管主机

经典c段打点开局。使用dddd做快速的打点发现某系统存在nacos权限绕过 有点怀疑是蜜罐,毕竟nacos这实在是有点经典 nacos利用 老规矩见面先上nacos利用工具打一波看看什么情况 弱口令nacos以及未授权访问,看这记录估计被光顾挺多次了啊 手动利用Nacos-…

MySQL - Navicat自动备份MySQL数据

对于从事IT开发的工程师,数据备份我想大家并不陌生,这件工程太重要了!对于比较重要的数据,我们希望能定期备份,每天备份1次或多次,或者是每周备份1次或多次。 如果大家在平时使用Navicat操作数据库&#x…

深入解析Python数据容器

Python数据容器 1,数据容器介绍2,数据容器的分类3,数据容器:list(列表)3.1,列表的定义3.2,列表的下标索引3.3,列表的常用操作3.3.1,查找指定元素下标3.3.2&am…

【OpenAI】第三节(上下文)什么是上下文?全面解读GPT中的上下文概念与实际案例

文章目录 一、GPT上下文的定义1.1 上下文的组成 二、GPT上下文的重要性2.1 提高生成文本的相关性2.2 增强对话的连贯性2.3 支持多轮对话 三、使用上下文改善编程对话3.1 使用上下文的概念3.2 使用上下文改善对话的作用3.3 使用上下文改善对话的方法3.4 案例分析 四、利用历史记…

安装Openeuler出现的问题

1.正常安装中,不显示已有的网络,ens33 尝试:手敲ens33配置,包括使用uuidgen ens33 配置还是不行 可能解决办法1:更换安装的版本。譬如说安装cenos 7 64位 启动虚拟机,更换版本之后的安装界面,…

Excel常用操作培训

以下是Excel的基本操作,内部培训专用。喜欢就点赞收藏哦! 目录 1 Excel基本操作 1.1 常用快捷键 1.1.1快捷键操作工作簿、工作表 1.1.2快捷键操作 1.1.3单元格操作 1.1.4输入操作 2.1 常见功能描述 2.1.1 窗口功能栏 2.1.2 剪切板 2.1.3 字体…

计算机网络——传输层服务

传输层会给段加上目标ip和目标端口号 应用层去识别报文的开始和结束

海南聚广众达电子商务咨询有限公司靠谱吗怎么样?

在当今这个数字化浪潮席卷全球的时代,抖音电商以其独特的魅力成为了众多商家争相入驻的新蓝海。而在这片浩瀚的电商海洋中,如何找到一家既专业又可靠的合作伙伴,成为了众多商家心中的一大难题。今天,我们就来深入剖析一下海南聚广…

组件可控个性化生成新方法MagicTailor:生成过程中可以自由地定制ID

今天的文章来自公众号粉丝投稿,文章提出了一种组件可控的个性化生成方法MagicTailor,旨在个性化生成过程中可以自由地定制ID的特定组件。 相关链接 论文阅读:https://arxiv.org/pdf/2410.13370 项目主页:https://correr-zhou.gi…

拼多多详情API接口的获取与应用

一、拼多多详情API接口概述 1. API接口定义与功能 拼多多开放平台为开发者提供了丰富的API接口,其中商品详情API接口尤为重要。该接口允许开发者通过编程方式获取商品的详细信息,包括商品标题、价格、描述、图片、规格参数、库存等。这些信息对于电商数…

无人机之自主飞行关键技术篇

无人机自主飞行指的是无人机利用先进的算法和传感器,实现自我导航、路径规划、环境感知和自动避障等能力。这种飞行模式大大提升了无人机的智能化水平和操作的自动化程度。 一、传感器技术 传感器是无人机实现自主飞行和数据采集的关键组件,主要包括&a…

Unity3D学习FPS游戏(1)获取素材、快速了解三维模型素材(骨骼、网格、动画、Avatar、材质贴图)

前言:最近重拾Unity,准备做个3D的FPS小游戏,这里以官方FPS案例素材作为切入。 导入素材和素材理解 安装Unity新建项目新建文件夹和Scene如何去理解三维模型素材找到模型素材素材预制体结构骨骼和网格材质(Material)、…

No.18 笔记 | XXE(XML 外部实体注入)漏洞原理、分类、利用及防御整理

一、XXE 漏洞概述 (一)定义 XXE(XML 外部实体注入)漏洞源于 XML 解析器对外部实体的不当处理,攻击者借此注入恶意 XML 实体,可实现敏感文件读取、远程命令执行和内网渗透等危险操作。 (二&am…

一、Python基础语法(有C语言基础速成版)

在python中,变量是没有类型的,变量存储的数据是有类型的 可以把变量当做一个存放物品的盒子 一、字面量 字面量:在代码中,被写下来的 固定的值 python中常见的值的类型 二、注释 # 我是单行注释,一般要加个空格&a…

java设计模式——装饰者模式

定义: 装饰者模式是一种结构型设计模式,它允许动态地给对象添加新的功能,而不会改变其原有的结构。与继承不同,装饰者模式通过组合而不是继承来扩展对象的功能,这样可以有效地避免类爆炸问题(多个子类的冗余…

动手学深度学习9.7. 序列到序列学习(seq2seq)-笔记练习(PyTorch)

本节课程地址:62 序列到序列学习(seq2seq)【动手学深度学习v2】_哔哩哔哩_bilibili 本节教材地址:9.7. 序列到序列学习(seq2seq) — 动手学深度学习 2.0.0 documentation 本节开源代码:...>…

pdf编辑软件有哪些?方便好用的pdf编辑软件分享

PDF文件因其跨平台、格式固定的特性,成为了工作、学习和生活中不可或缺的一部分。然而,随着需求的不断增加,仅仅阅读PDF文件已难以满足我们的需求,编辑、转换PDF文件成为了新的焦点,下面给大家分享几款方便好用的PDF编…

《Linux从小白到高手》综合应用篇:深入理解Linux常用关键内核参数及其调优

1. 题记 有关Linux关键内核参数的调整,我前面的调优文章其实就有涉及到,只是比较零散,本篇集中深入介绍Linux常用关键内核参数及其调优,Linux调优80%以上都涉及到内核的这些参数的调整。 2. 文件系统相关参数 fs.file-max 参数…