DETR整体模型结构解析

DETR流程

  1. Backbone用卷积神经网络抽特征。最后通过一层1*1卷积转化到d_model维度fm(B,d_model,HW)。

  2. position embedding建立跟fm维度相同的位置编码(B,d_model,HW)。

  3. Transformer Encoder,V为fm,K,Q为fm+position embedding。因为V代表的是图像特征。所以不添加位置编码

  4. Transformer Decoder。生成一个固定大小(query_num)的object query(B,q_num,d_model)比如100个预测框。Decoder输入tgt与object query形状相同。代码中为torch.zero()。第一层selfattention K,V为tgt+query,Q为tgt。第二层Q为上一层输出+query。V为encoder输出,K为encoder输出+position。这里V仍然代表图像特征所以不添加位置编码

  5. 用输出的100个object query框和ground truth框做一个匹配,然后在一一配对好的框中去计算目标检测的loss(分类loss与回归loss(L1+IOU))

  6. 二分图匹配与匈牙利算法

    DETR 预测了一组固定大小的 N = 100 个边界框

    将 ground-truth 也扩展成 N = 100 个检测框

    使用一个额外的特殊类标签 ϕ 来表示在未检测到任何对象,或者认为是背景类别。

    这样预测和真实都是两个100 个元素的集合了

    采用匈牙利算法进行二分图匹配,对预测集合和真实集合的元素进行一一对应,使得匹配损失最小。

  7. 推理过程不需要二分图匹配,只需要取最大得分框即可

代码详细参考:

transformer 在 CV 中的应用(二) DETR 目标检测网络 -

网络结构

参数说明:B:batchsize大小,C通道数,H,W:CNN输出特征图的高宽。d_model设定的特征维度大小如512。
Q,K,V:自注意力矩阵。l_q:Q矩阵的长度,l_kv:K,V矩阵的长度。KV矩阵的长度必须相同,Q矩阵长度可以跟KV矩阵长度不同
Q矩阵维度:(B,l_q,d_model)
K矩阵维度:(B,l_kv,d_model)
V矩阵维度:(B,l_kv,d_model)
object_query维度(B,q_num,d_model)

Backbone:

img→CNNbackbone→fm特征图(B,C,H,W) → fm特征图输入到transformer中时要再经过一层卷积将通道数转化成d_ model。C→d_model.

position embedding(B,d_model,H*W)。backbone通过CNN提取图像特征,然后通过特征图生成尺度对应的位置编码。

position embedding:

位置编码官方实现了两种,一种是固定位置编码,另一种是自学习位置编码,这里就介绍固定位置编码。

位置编码要考虑 x, y 两个方向,图像中任意一个点 (h, w) 有一个位置,这个位置编码长度为 256 ,前 128 维代表 h 的位置编码, 后 128 维代表 w 的位置编码,把这两个 128 维的向量拼接起来就得到一个 256 维的向量,它代表 (h, w) 的位置编码。位置编码的计算公式如下图所示

在这里插入图片描述

Transformer
DETRtransformer结构图
在这里插入图片描述

接受CNN提取的特征(B,d_model,HW),位置编码(B,d_model,HW),querys(B,query_num,d_model)

encoder:q,k添加位置编码。v代表图像本身特征,不添加位置编码。multi_head_attention跟FFN后都带了两个残差连接。

# post代表norm放在后面
def forward_post(self,src,src_mask: Optional[Tensor] = None,src_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None):q = k = self.with_pos_embed(src, pos)  #q,k增加positionsrc2 = self.self_attn(q, k, value=src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0]src = src + self.dropout1(src2)   # 残差src = self.norm1(src)# ffnsrc2 = self.linear2(self.dropout(self.activation(self.linear1(src))))# 残差src = src + self.dropout2(src2)src = self.norm2(src)return src

decoder:

设定一个object queries(num_query,d_model)

有两层multihead self attention

  • 第一层obquery添加到K,Q上作为position embedding

第二层的Q来于decoder,K,V来自于encoder输出。

第二层self attention K添加编码,Q增加object queries。V代表图像特征,不添加任何信息

KV要有相同维度,Q可以跟KV在长度维度上不同,d_model维度相同

softmax(QKt/(d^0.5))V→矩阵乘法Q*Kt:(l_q,d_model)@(d_model_l_kv)→(l_q,l_kv)

再乘以V(l_q,l_kv)@(l_kv,d_model)→(l_q,d_model)

def forward_post(self, tgt, memory,tgt_mask: Optional[Tensor] = None,memory_mask: Optional[Tensor] = None,tgt_key_padding_mask: Optional[Tensor] = None,memory_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None,query_pos: Optional[Tensor] = None):''':param tgt: query_pos rep 2tensor of shape (bs, c, h, w) ->tgt = torch.zeros_like(query_embed):param memory::param tgt_mask::param memory_mask::param tgt_key_padding_mask::param memory_key_padding_mask::param pos::param query_pos::return:'''q = k = self.with_pos_embed(tgt, query_pos)tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,key_padding_mask=tgt_key_padding_mask)[0]tgt = tgt + self.dropout1(tgt2)tgt = self.norm1(tgt)tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),key=self.with_pos_embed(memory, pos),value=memory, attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask)[0]tgt = tgt + self.dropout2(tgt2)tgt = self.norm2(tgt)tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))tgt = tgt + self.dropout3(tgt2)tgt = self.norm3(tgt)return tgt
  • DETR在计算attention的时候没有使用masked attention,因为将特征图展开成一维以后,所有像素都可能是互相关联的,因此没必要规定mask。

  • object queries的转换过程:object queries是预定义的目标查询的个数,代码中默认为100。它的意义是:根据Encoder编码的特征,Decoder将100个查询转化成100个目标,即最终预测这100个目标的类别和bbox位置。最终预测得到的shape应该为[N, 100, C],N为Batch Num,100个目标,C为预测的100个目标的类别数+1(背景类)以及bbox位置(4个值)

得到预测结果以后,将object predictions和ground truth box之间通过匈牙利算法进行二分匹配:假如有K个目标,那么100个object predictions中就会有K个能够匹配到这K个ground truth,其他的都会和“no object”匹配成功,使其在理论上每个object query都有唯一匹配的目标,不会存在重叠,所以DETR不需要nms进行后处理。

匹配

匈牙利匹配算法

匈牙利匹配算法,二分图匹配算法

scipy.optimize.linear_sum_assignment(cost_matrix, maximize=False)
#cost_matrix 二分图开销矩阵

https://blog.csdn.net/CV_Autobot/article/details/129096035

https://blog.csdn.net/lemonxiaoxiao/article/details/108672039

query与gt匹配

transformer通过query输出n_q数量的bbox与对应分类置信度

真实框[gt1,gt2,…gtn]

每个bbox与gt之间有一个距离度量。

距离度量由三部分组成:真实类别的置信度得分+边界框的L1loss+边界框的IOUloss

通过匈牙利算法找出距离最小的query_bbox为gt对应的prebbox

loss训练

整体流程

pred输出→100(num_queries)class,100(num_queries)boxes

gt(tagert)→100class,100boxes(包含背景类)

pred,gt→计算相互loss,得到二分图成本矩阵,然后计算匈牙利匹配算法→return匹配上的classes与boxes

匹配成功的框→计算真正的class损失(),box回归损失(GLOUloss)。。

预测框与真实框的差异来自于两方面:1.二分图匹配时带来的差异。2.预测框与真实框之间的差异。

  • 分类损失:交叉熵损失,针对所有predictions。没有匹配到的querybbox应该分类为背景

  • 回归损失:bbox loss采用了L1 loss和giou loss,针对匹配成功的querybbox

  • cardinality 损失,对应函数是 loss_cardinality ; cardinality 损失是计算预测有物体的个数的绝对损失,值是为了记录,不参与反向传播

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/335940.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Python】解决Python报错:TypeError: can only concatenate str (not “int“) to str

🧑 博主简介:阿里巴巴嵌入式技术专家,深耕嵌入式人工智能领域,具备多年的嵌入式硬件产品研发管理经验。 📒 博客介绍:分享嵌入式开发领域的相关知识、经验、思考和感悟,欢迎关注。提供嵌入式方向…

反转!Greenplum 还在,快去 Fork 源码

↑ 关注“少安事务所”公众号,欢迎⭐收藏,不错过精彩内容~ 今早被一条消息刷爆群聊,看到知名开源数仓 Greenplum 的源码仓“删库跑路”了。 要知道 GP 新东家 Broadcom 前几日才刚刚免费开放了 VMware Workstation PRO 17 和 VMware Fusion P…

Selenium 自动化测试工具(1) (Selenium 工作原理,常用API的使用)

文章目录 什么是自动化测试什么是测试工具:Selenium 工作原理(重要)Selenium API定位元素CSS 选择器xpath 定位元素 通过Java代码实现自动化1. 定位元素2. 关闭浏览器3. 获取元素文本4. 鼠标点击与键盘输入5. 清空内容6.打印信息 什么是自动化测试 关于自动化&…

使用C语言实现学生信息管理系统

前言 在我们实现学生信息管理系统的过程中,我们几乎会使用到C语言最常用最重要的知识,对于刚学习完C语言的同学来说是一次很好的巩固机会,其中还牵扯到数据结果中链表的插入和删除内容。 实现学生信息管理系统 文件的创建与使用 对于要实现…

设计模式13——桥接模式

写文章的初心主要是用来帮助自己快速的回忆这个模式该怎么用,主要是下面的UML图可以起到大作用,在你学习过一遍以后可能会遗忘,忘记了不要紧,只要看一眼UML图就能想起来了。同时也请大家多多指教。 桥接模式(Bridge&a…

存储器和CPU的连接与TCP的流量控制

存储器与CPU的连接 存储容量的拓展 (1)位拓展:增加存储字长 (2)字拓展 增加存储器字的数量 例题:设CPU有16根地址线,8根数据线,并用MREQ作为访问存储控制信号(低电平有效),WR作为…

建议大家少用点儿网站测速工具

春节休息期间明月有接了几个服务器代运维的业务,期间就发现不少新手站长们还在用 17ce、站长工具等等这些网站测速工具来评判站点访问速度的,感觉很有必要给大家聊聊这个事儿,因为这毕竟也是一个涉及服务器安全的一个重要环节了。 其实&#…

C++ list类

目录 0.前言 1.list介绍 1.1优势 1.2劣势 1.3容器属性 2.list使用 2.1构造函数 2.1.1默认构造函数 2.1.2填充构造函数 2.1.3范围构造函数 2.1.4拷贝构造函数 2.1.5初始化列表构造函数 2.2迭代器 2.2.1 begin() 2.2.2 end() 2.2.3 cbegin() 2.2.4 cend() 2.2.…

100个 Unity小游戏系列四 -Unity 抽奖游戏专题二 水果机游戏

一、演示效果 二、知识点 2.1 布局 private void CreateItems(){for (int i 0; i < rewardDatas.Length; i){var reward_data rewardDatas[i];GameObject fruitOjb;if (i < itemRoot.childCount){fruitOjb itemRoot.GetChild(i).gameObject;}else{fruitOjb Instant…

MATLAB分类与判别模型算法: 快速近邻法(FastNN)分类程序【含Matlab源码 MX_005期】

算法思路介绍&#xff1a; 1. 数据准备阶段&#xff1a; 生成一个合成数据集 X&#xff0c;其中包含三个簇&#xff0c;每个簇分布在不同的区域。 定义聚类层数 L 和每个层次的子集数量 l。 2. 聚类阶段&#xff1a; 使用K均值聚类算法将初始数据集 X 分成 l 个簇。…

mac m1安装homebrew管理工具(brew命令)完整流程

背景 因为mac上的brew很久没用了&#xff0c;版本非常旧&#xff0c;随着mac os的更新&#xff0c;本机的homebrew大部分的功能都无法使用&#xff0c;幸好过去通过brew安装的工具比较少&#xff0c;于是决定重新安装一遍brew。 卸载旧版brew 法一&#xff1a;通过使用线上…

【PB案例学习笔记】-13 徒手做个电子时钟

写在前面 这是PB案例学习笔记系列文章的第11篇&#xff0c;该系列文章适合具有一定PB基础的读者。 通过一个个由浅入深的编程实战案例学习&#xff0c;提高编程技巧&#xff0c;以保证小伙伴们能应付公司的各种开发需求。 文章中设计到的源码&#xff0c;小凡都上传到了gite…

渗透测试工具Cobalt strike-2.CS基础使用

三、结合metasploit,反弹shell 在kali中开启使用命令开启metasploit msfconsole ┌──(root㉿oldboy)-[~] └─# msfconsole --- msf6 > use exploit/multi/handler [*] Using configured payload generic/shell_reverse_tcp --- msf6 exploit(multi/handler) > show …

【5.基础知识和程序编译及调试】

一、GCC概述&#xff1a;是GUN推出的多平台编译器&#xff0c;可将C/C源程序编译成可执行文件。编译流程分为以下四个步骤&#xff1a; 1、预处理 2、编译 3、汇编 4、链接 注&#xff1a;编译器根据程序的扩展名来分辨编写源程序所用的语言。根据不同的后缀名对他们进行相…

058.最后一个单词的长度

题意 给你一个字符串 s&#xff0c;由若干单词组成&#xff0c;单词前后用一些空格字符隔开。返回字符串中 最后一个 单词的长度。 单词 是指仅由字母组成、不包含任何空格字符的最大子字符串。 难度 简单 示例 1&#xff1a; 输入&#xff1a;s "Hello World" 输…

excel表格里怎样不删除0,又不显示0呢?

在单元格里不显示0&#xff0c;大体上有这么几种方法&#xff1a; 1.设置单元格自定义格式 选中数据区域&#xff0c;鼠标右键&#xff0c;点一下设置单元格格式&#xff0c;选中数字&#xff0c;自定义&#xff0c;在右侧的类型栏&#xff0c;设置格式&#xff1a; [0]&quo…

android11禁止进入屏保和自动休眠

应某些客户要求&#xff0c;关闭了开机进入屏保&#xff0c;一段时间会休眠的问题。以下diff可供参考&#xff1a; diff --git a/overlay/frameworks/base/packages/SettingsProvider/res/values/defaults.xml b/overlay/frameworks/base/packages/SettingsProvider/res/value…

《微服务王国的守护者:Spring Cloud Dubbo的奇幻冒险》

5. 经典问题与解决方案 5.3 服务追踪与链路监控 在微服务架构的广袤宇宙中&#xff0c;服务间的调用关系错综复杂&#xff0c;如同一张庞大的星系网络。当一个请求穿越这个星系&#xff0c;经过多个服务节点时&#xff0c;如何追踪它的路径&#xff0c;如何监控整个链路的健康…

ssm校园疫情防控管理系统-计算机毕业设计源码30796

目 录 摘要 1 绪论 1.1目的及意义 1.2开发现状 1.3ssm框架介绍 1.3论文结构与章节安排 2 校园疫情防控管理系统系统分析 2.1 可行性分析 2.2 系统流程分析 2.2.1 数据流程 3.3.2 业务流程 2.3 系统功能分析 2.3.1 功能性分析 2.3.2 非功能性分析 2.4 系统用例分…

手摸手教你uniapp原生插件开发

行有余力,心无恐惧 这篇技术文章写了得有两三个礼拜,虽然最近各种事情,工作上的生活上的,但是感觉还是有很多时间被浪费.还记得几年前曾经有一段时间7点多起床运动,然后工作学习,看书提升认知.现在我都要佩服那会儿的自己.如果想回到那种状态,我觉得需要有三个重要的条件. 其…