ML-Decoder: Scalable and Versatile Classification Head

1、引言

论文链接:https://openaccess.thecvf.com/content/WACV2023/papers/Ridnik_ML-Decoder_Scalable_and_Versatile_Classification_Head_WACV_2023_paper.pdf

        因为 transformer 解码器分类头[1] 在少类别多标签分类数据集上表现得很好,但由于其查询复杂度为 O(n^2),n 为类别数量,故 transformer 解码器分类头对于多类别数据集是不可行的,且 transformer 解码器分类头只适用于多标签分类任务,故 Tal Ridnik 等引入了一种新的基于多头注意力机制的分类头——ML-Decoder[2]。ML-Decoder 可以用于单标签分类、多标签分类和多标签 ZSL(zero shot learning) 任务,它提供更好的精度-速度 trade-off,可以用于上万类别的数据集,可以作为各种分类头的 drop-in 替代品,结合词查询可以用于 ZSL。

2、方法

        ML-Decoder 流如图 1 右所示,相对于  transformer 解码器分类头,ML-Decoder 有一下改变。

图1  transformer-decoder vs. ML-Decoder

2.1  移除自注意力机制

        通过删除自注意力机制将 ML-Decoder 的查询复杂度由 O(n^2) 降至 O(n),并未影响表示能力。

2.2  组解码

        为了使查询数量与类别数量无关,使用固定的 k 组查询,而不是一个类别对应一个查询。在前馈神经网络后,通过组全连接层在将每个组查询扩展到 g=n/k 个输出的同时池化嵌入维度。如图 2 所示。

图2  组全连接方案(g=4)

2.3  固定查询        

        查询总是被输入到一个多头注意力层,该注意力层会先对查询应用一个可学习的投影计算。因此,将查询权重设置为可学习的是多余的——可学习的投影可以将任何固定值查询转换为可学习查询获得的任何值。

3、模块介绍

3.1  Cross-Attention

        Cross-Attention 的核心其实就是多头注意力机制,输入的 q 为固定查询,k 和 v 均为图像嵌入。Cross-Attention 和 Feed-Forward 模块构成所谓的 TransformerDecoder(Layer),python 代码如下所示:

class TransformerDecoder(nn.Module):def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1) -> None:super().__init__()self.dropout = nn.Dropout(dropout)self.norm0 = nn.LayerNorm(d_model)self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)# Implementation of Feedforward modelself.feed_forward = nn.Sequential(nn.LayerNorm(d_model),nn.Linear(d_model, dim_feedforward),nn.ReLU(),nn.Dropout(dropout),nn.Linear(dim_feedforward, d_model))self.norm1 = nn.LayerNorm(d_model)def forward(self, tgt: Tensor, memory: Tensor) -> Tensor:tgt = tgt + self.dropout(tgt)tgt = self.norm0(tgt)tgt0 = self.multihead_attn(tgt, memory, memory)[0]tgt = tgt + self.dropout(tgt0)tgt0 = self.feed_forward(tgt)tgt = tgt + self.dropout(tgt0)return self.norm1(tgt)

3.2  Group Fully Connected Pooling  

        Group Fully Connected Pooling的目的是将每个组查询扩展到 g=n/k 个输出的同时池化嵌入维度。即将每组查询结果与对应的可学习的 (hidde_dim, g) 维矩阵相乘,python 代码如下所示:

class GroupFC(object):def __init__(self, groups: int):self.groups = groupsdef __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor):"""计算每组类的 logits 值(未加偏置):param h: shape=(b, groups, hidden_dim):param duplicate_pooling: shape=(groups, hidden_dim, duplicate_factor), duplicate_factor 每组的类别数:param out_extrap: shape=(b, groups, duplicate_factor):return:"""for i in range(h.shape[1]):h_i = h[:, i, :]w_i = duplicate_pooling[i, :, :]out_extrap[:, i, :] = torch.matmul(h_i, w_i)

4、总结

        作者开源的 ML-Decoder 的 python 实现代码在:https://github.com/Alibaba-MIIL/ML_Decoder/blob/main/src_files/ml_decoder/ml_decoder.py

        论文[2] 在 paper with code 上的战绩如图 3 所示,表现还是不错的。

图3  来自论文[2] 的结果

        由于当参数 zsl != 0 时 wordvec_proj 的输入 query_embed = None,本人还未学习过 ZSL 领域,且使用该代码时报错(zsl = 0,当然应该是我的原因,但懒得排错了),于是参考作者的代码写了一个 MLDecoder 类(只考虑 zsl = 0),剩下的代码如下所示。

class MLDecoder(nn.Module):"""Args:groups: 查询/类别组数hidden_dim: Transformer 解码器特征维度in_dim: 输入 tensor 特征维度(CNN 编码器输出为通道数,Transformer 编码器输出为最后一个维度)"""def __init__(self, num_classes, groups, in_dim=2048, hidden_dim=768, mlp_dim=2048, nhead=8, dropout=0.1):super().__init__()self.proj = nn.Linear(in_dim, hidden_dim)# non-learnable queriesself.query_embed = nn.Embedding(groups, hidden_dim)self.query_embed.requires_grad_(False)self.num_classes = num_classesself.decoder = TransformerDecoder(d_model=hidden_dim, nhead=nhead, dim_feedforward=mlp_dim, dropout=dropout)# group fully-connectedself.duplicate_factor = math.ceil(num_classes / groups)  # 每组类别数量,math.ceil: 向上取整self.duplicate_pooling = torch.nn.Parameter(torch.zeros((groups, hidden_dim, self.duplicate_factor)))self.duplicate_pooling_bias = torch.nn.Parameter(torch.zeros(num_classes))torch.nn.init.xavier_normal_(self.duplicate_pooling)self.group_fc = GroupFC(groups)def forward(self, x):# 确保解码器输入 shape 为 [b, h * w, c]if len(x.shape) == 4:x = x.flatten(2).transpose(1, 2)x = F.relu(self.proj(x), True)  # (b, h * w, hidden_dim)# Cross-Attention + Feed-Forwardquery_embed = self.query_embed.weight  # (groups, hidden_dim)# tensor.expend: 增大一个维度至指定大小, 不增大的维度为-1,例如将 shape 由 (b, n, c)->(b, 2n, c), 参数 size=(-1, 2n,-1)tgt = query_embed[None].expand(x.shape[0], -1, -1)  # (b, groups, hidden_dim)h = self.decoder(tgt, x)  # (b, groups, hidden_dim)# Group Fully Connected Poolingout_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)self.group_fc(h, self.duplicate_pooling, out_extrap)h_out = out_extrap.flatten(1)[:, :self.num_classes]  # (b, num_classes)return h_out + self.duplicate_pooling_bias

参考文献

[1] Shilong Liu, Lei Zhang, Xiao Yang, Hang Su, and Jun Zhu. Query2label: A simple transformer way to multi-label classification. arXiv preprint arXiv:2107.10834, 2021.

[2] Tal Ridnik, Gilad Sharir, Avi Ben-Cohen, Emanuel Ben Baruch, and Asaf Noy. Ml-decoder: Scalable and versatile classification head. In IEEE/CVF Winter Conference on Applications of Computer Vision, WACV 2023, Waikoloa, HI, USA, January 2-7, 2023, pages 32–41. IEEE, 2023.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/292010.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

css3之动画animation

动画animation 一.优点二.定义和使用三.动画序列和解释四.常见属性及解释五.简写(名字和时间不能省略)(持续时间在何时开始的时间前)(简写中无animation-play-state)六.例子1.大数据热点图2.奔跑的熊大(一个…

设计模式6--抽象工厂模式

定义 案例一 案例二 优缺点

代码随想录-二叉树(路径)

目录 257. 二叉树的所有路径 题目描述: 输入输出描述: 思路和想法: 404. 左叶子之和 题目描述: 输入输出描述: 思路和想法: 513.找树左下角的值 题目描述: 输入输出描述:…

Android裁剪图片为波浪形或者曲线形的ImageView

如果需要做一个自定义的波浪效果的进度条,裁剪图片,对ImageView的图片进行裁剪,比如下面2张图,如何实现? 先看下面的效果,看到其实只需要对第一张高亮的图片进行处理即可,灰色状态的作为背景图。…

基于SSM的戒烟网站(有报告)。Javaee项目。ssm项目。

演示视频: 基于SSM的戒烟网站(有报告)。Javaee项目。ssm项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结构,通过Spring SpringMv…

腾讯云优惠券领取方法大公开,省钱不再是难事

腾讯云—腾讯倾力打造的云计算品牌,以卓越科技能力助力各行各业数字化转型,为全球客户提供领先的云计算、大数据、人工智能服务,以及定制化行业解决方案和提供可靠上云服务,助力企业和开发者稳定上云! 然而&#xff0…

数据结构进阶篇 之 【二叉树顺序存储(堆)】的整体实现讲解(赋完整实现代码)

做人要谦虚,多听听别人的意见,然后记录下来,看看谁对你有意见 一、二叉树的顺序(堆)结构及实现 1.二叉树的顺序结构 2.堆的概念及结构 3.堆的实现 3.1 向下调整算法 AdJustDown 3.2 向上调整算法 AdJustUP 3.3 …

C语言例1-8:设 char x,y; ,scanf(“x=%c,y=%c“,x,y); 后使 x 为 ‘X‘, y为 ‘Y‘,则键盘上的正确输入是

代码如下&#xff1a; #include<stdio.h> int main(void) {char x,y;scanf("x%c,y%c",&x,&y);printf("x%c,y%c\n",x,y);return 0; } 键盘输入选项A: xXyY 结果如下&#xff1a; 键盘输入选项B: xX,yY 结果如下&#xff1a; 键盘输入选项…

通过Jmeter准备压测数据-mysql示例

1、新建线程组 总共30万条数据 2、创建jdbc链接 创建jdbc连接配置 配置mysql连接 需要在jmeter安装的路径\apache-jmeter-5.6.3\lib\ext 目录下添加mysql 驱动 3、创建jdbc请求 jdbc链接名称需要与上一步中的保持一致&#xff0c;同时添加insert语句 例如 INSERT INTO test…

关系型数据库mysql(7)sql高级语句①

目录 一.MySQL常用查询 1.按关键字&#xff08;字段&#xff09;进行升降排序 按分数排序 &#xff08;默认为升序&#xff09; 按分数升序显示 按分数降序显示 根据条件进行排序&#xff08;加上where&#xff09; 根据多个字段进行排序 ​编辑 2.用或&#xff08;or&…

Python Flask框架 -- flask-migrate迁移ORM模型

# 之前使用的这个db.create_all()很有局限性&#xff0c;它不能把在class里修改的东西同步上数据库&#xff0c;所以不用了 # with app.app_context(): # 请求应用上下文 # db.create_all() # 把所有的表同步到数据库中去 例如&#xff0c;在User类中增加一个email字段&…

C语言例1-3:设 int a; ,语句 for(a=0;a==0;a++); 和语句 for(a=0;a=0;a++); 执行的循环次数分别是

答案&#xff1a;1,0 代码如下&#xff1a; #include<stdio.h> int main(void) {int a;for(a0;a0;a){printf("1\n");} return 0; } 结果如下&#xff1a; 代码如下&#xff1a; #include<stdio.h> int main(void) {int a;for(a0;a0;a){printf("…

【前端】layui学习笔记

参考视频&#xff1a;LayUI 1.介绍 官网&#xff1a;http://layui.apixx.net/index.html 国人16年开发的框架,拿来即用,门槛低 … 2. LayUi的安装及使用 Layui 是一套开源的 Web UI 组件库&#xff0c;采用自身轻量级模块化规范&#xff0c;遵循原生态的 HTML/CSS/JavaScript…

深入解析大语言模型显存占用:训练与推理

深入解析大语言模型显存占用&#xff1a;训练与推理 文章脉络 估算模型保存大小 估算模型在训练时占用显存的大小 全量参数训练 PEFT训练 估算模型在推理时占用显存的大小 总结 对于NLP领域的从业者和研究人员来说&#xff0c;有没有遇到过这样一个场景&#xff0c;你的…

C语言例1-11:语句 while(!a); 中的表达式 !a 可以替换为

A. a!1 B. a!0 C. a0 D. a1 答案&#xff1a;C while()成真才执行&#xff0c;所以!a1 &#xff0c;也就是 a0 原代码如下&#xff1a; #include<stdio.h> int main(void) {int a0;while(!a){a;printf("a\n");} return 0; } 结果如…

平台介绍-搭建赛事运营平台(8)

平台介绍-搭建赛事运营平台&#xff08;5&#xff09;提到了字典是分级的&#xff0c;本篇具体介绍实现。 平台级别的代码是存储在核心库中&#xff0c;品牌级别的代码是存储在品牌库中&#xff08;注意代码类是一样的&#xff09;。这部分底层功能封装为jar包&#xff0c;然后…

算法打卡day21(开始回溯)

今日任务&#xff1a; 1&#xff09;77.组合 77.组合 题目链接&#xff1a;77. 组合 - 力扣&#xff08;LeetCode&#xff09; 文章讲解&#xff1a;代码随想录 (programmercarl.com) 视频讲解&#xff1a;带你学透回溯算法-组合问题&#xff08;对应力扣题目&#xff1a;77…

Stable Diffusion之核心基础知识和网络结构解析

Stable Diffusion核心基础知识和网络结构解析 一. Stable Diffusion核心基础知识1.1 Stable Diffusion模型工作流程1. 文生图(txt2img)2. 图生图3. 图像优化模块 1.2 Stable Diffusion模型核心基础原理1. 扩散模型的基本原理2. 前向扩散过程详解3. 反向扩散过程详解4. 引入Late…

axios+springboot上传图片到本地(vue)

结果&#xff1a; 前端文件&#xff1a; <template> <div> <input type"file" id"file" ref"file" v-on:change"handleFileUpload()"/> <button click"submitFile">上传</button> </div&g…

3D汽车模型线上三维互动展示提供视觉盛宴

VR全景虚拟看车软件正在引领汽车展览行业迈向一个全新的时代&#xff0c;它不仅颠覆了传统展览的局限&#xff0c;还为参展者提供了前所未有的高效、便捷和互动体验。借助于尖端的vr虚拟现实技术、逼真的web3d开发、先进的云计算能力以及强大的大数据处理&#xff0c;这一在线展…