DETR论文详解

文章目录

  • 前言
  • 一、DETR理论
  • 二、模型架构
    • 1. CNN
    • 2. Transformer
    • 3. FFN
  • 三、损失函数
  • 四、代码实现
  • 总结

前言

 DETR是Facebook团队在2020年提出的一篇论文,名字叫做《End-to-End Object Detection with Transformers》端到端的基于Transformers的目标检测,DETR是Detection Transformers的缩写。DETR摆脱了传统目标检测模型中复杂的组件,例如NMS、先验框等,是一种基于Transformer的简单的端到端的目标检测架构,并且取的了与Faster R-CNN差不多的成绩。接下来对DETR这篇论文进行介绍。

一、DETR理论

 DETR将目标检测看做一个直接的预测集合问题,集合的大小是固定的,集合中的每个元素可以当作一个检测框,可以理解为DETR根据图片的输入去预测集合中元素的属性(分类,坐标)。
 DETR的推理概如下图所示:输入是一张图片,图片首先经过CNN网络进行下采样,然后将CNN输出的特征图拉直成向量进入Transformer的编码器中,Transformer的解码器的输入是object queries,(object queries是可学习的参数,可以理解为object queries学习预测框的先验知识)然后与Transformer的编码器输出做交叉注意力并行得到最终的检测框。在DETR论文中这个集合的大小取100,也就是对于每张图片,都会一口气预测出100个框,对于预测框中的分类预测为’no object’的则不显示,然后我们可以设置一个阈值,把集合中置信度低于阈值的预测框去掉,从而得到最终的输出。
 DETR根据Transformer的全局建模能力对图像进行全局的上下文推理,因此对于大物体的检测效果很好,但是对于小物体的检测效果不如Faster R-CNN。但是由于其架构的简单性深受大家喜欢。
在这里插入图片描述

二、模型架构

DETR的模型结构主要由三部分组成,分别是CNN,Transformer和FFN。如下图所示:
在这里插入图片描述

1. CNN

 DETR使用CNN骨干网络(例如ResNet50)将输入的照片由[3,H,W]下采样成[2048,H/32,W/32],然后在使用一个卷积核用来减少通道数,由[2048,H/32,W/32]变为[d,H/32,W/32]。然后将其进行展平与位置编码表进行相加送入到Transformer编码器中。

2. Transformer

 在论文中,DETR使用了6层的Transformer。
编码器 DETR中的编码器架构与经典的Transformer编码器相同,由多头自注意力层和FFN组成。下图为encoder部分的自注意力可视化,可以看到encoder主要负责预测物体的主体部分。
在这里插入图片描述

解码器 DETR中的解码器架构也与经典的Transformer解码器相同,稍微不同的是,经典的Transformer模型是自回归方式,而DETR在每个解码器层并行解码N个对象,因此N个输入嵌入必须不同,输入嵌入是可以学习的位置编码,称之为’object query’,在经过解码器的解码后并行计算出N个输出。下图为解码器的可视化输出,可以看到解码器主要用于区分物体的边缘或者轮廓。
在这里插入图片描述

3. FFN

 当得到解码器的N个输出后,使用两个线性层和ReLU激活函数将每个输出分别映射到预测框的输出类别和坐标位置。由于集合中的边界框比实际照片中的物体数量多,所以集合中剩余的那些边界框预测的标签则为’no object’,表示预测为背景。

三、损失函数

 DETR在进行训练时,例如我们设置集合的大小为100,那么模型最终会输出100个框,Ground True可能只有2个,那么我们如何算Loss呢?

1.论文里首先是使用匈牙利算法进行最佳二分图匹配。 如何理解最佳二分图匹配呢?

例如:有100个预测框,最终只有2个预测框与真实的Ground True相匹配。最佳二分图匹配就是那么选哪两个预测框与2个Ground True进行匹配得到的最终Loss最小。

那么我们如何衡量一个预测框与Groud True之间的损失呢(匹配损失)?公式如下:
在这里插入图片描述
通过这个公式我们可以获得预测框分别于Ground True进行匹配的Loss,最终使用匈牙利算法选定哪两个预测框与Ground True相对应,而其他的98个预测框可以看作与’no object‘相匹配。

2. 接着我们得出100个预测框如何与Ground True进行匹配,使用如下公式算的最终的Loss:
在这里插入图片描述
在实际中,为了保持正负样本的均衡,当预测框的种类为’no object’,其对数概率项的权重被缩小了10倍。同时为了使Bounding box loss ( L b o x L_{box} Lbox)对于不同大小的预测框的惩罚项尽可能公平, L b o x L_{box} Lbox L o s s I o u Loss_{Iou} LossIou L 1 L_1 L1组成,公式如下:
在这里插入图片描述

3 辅助损失函数。 在训练过程中Transformer的decoder中加入了辅助损失函数,也就是将每层decoder的输出都通过参数共享的FFN映射成预测框然后计算Loss。

四、代码实现

 DETR的前向推理过程代码如下所示:

import torch
from torch import nn
from torchvision.models import resnet50class DETR(nn.Module):def __init__(self, num_classes, hidden_dim, nheads,num_encoder_layers, num_decoder_layers):super().__init__()# We take only convolutional layers from ResNet-50 modelself.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])self.conv = nn.Conv2d(2048, hidden_dim, 1)self.transformer = nn.Transformer(hidden_dim, nheads,num_encoder_layers, num_decoder_layers)self.linear_class = nn.Linear(hidden_dim, num_classes + 1)self.linear_bbox = nn.Linear(hidden_dim, 4)self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))def forward(self, inputs):x = self.backbone(inputs)h = self.conv(x)H, W = h.shape[-2:]pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),], dim=-1).flatten(0, 1).unsqueeze(1)h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),self.query_pos.unsqueeze(1))return self.linear_class(h), self.linear_bbox(h).sigmoid()

总结

  DETR是基于Transformer架构的端到端的目标检测模型,把目标检测看作一个直接的集合预测问题,他简化了传统目标检测模型繁杂的前处理和后处理过程,并且随着Transformer架构的增加模型的性能也有所增加。DETR相当于使用’object queries’替代了’anchor’,使用二分图匹配去掉了之前的NMS。同时DETR还有一些不足,因为他是Transformer架构的所以不好优化,对于小目标的物体检测不足等等。但由于DETR的简单性有效性后续出来了一大批工作对于DETR做出了改进。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/390614.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java重修笔记 第二十七天 匿名内部类

匿名内部类 1. 定义:无类名(底层自动分配类名“外部类名$1”),既是类也是对象,定义在外部类的局部位置,例如方法体和代码块中,通过new类或接口并在大括号里重写方法来实现。 2. 使用场景&…

c++网络编程实战——开发基于协议的文件传输模块(一)如何实现一个简单的tcp长连接

前言 在之前的几篇内容中我们已经介绍过基于ftp协议的文件传输模块,而这个系列我们所想实现的就是如何实现基于tcp进行的文件传输模块,话不多说,开坑开坑! 什么是tcp长连接 我们知道tcp在建立连接的时候会通过三次握手与四次挥手来建立tcp连接&#x…

大数据-62 Kafka 高级特性 主题 kafka-topics相关操作参数 KafkaAdminClient 偏移量管理

点一下关注吧!!!非常感谢!!持续更新!!! 目前已经更新到了: Hadoop(已更完)HDFS(已更完)MapReduce(已更完&am…

类加载机制

概述 所谓机制就是某种流程规范或运作模式。简单来说,将类文件加载到JVM中的过程,需要对这个过程进行限定和约束,这就是Java类加载的机制。 具体说来,对Java类加载机制的描述可以从三个方面: 按需加载 需要某一个类…

Web开发-html篇-上

HTML发展史 HTML的历史可以追溯到20世纪90年代初。当时,互联网尚处于起步阶段,Web浏览器也刚刚问世。HTML的创建者是蒂姆伯纳斯-李(Tim Berners-Lee),他在1991年首次提出了HTML的概念。HTML的初衷是为了方便不同计算机…

python常用库

目录 tqdm库介绍用法 argparse库介绍用法 tqdm库 介绍 封装一个可视化,可拓展的进度条,以了解项目运行的时长,了解项目进展情况。 传入第 用法 安装 pip install tqdm1直接使用 for i in tqdm(range(1000)):time.sleep(0.01)等价 for i…

DNS处理模块 dnspython

DNS处理模块 dnspython 标题介绍安装dnspython 模块常用方法介绍实践:DNS域名轮询业务监控 标题介绍 Dnspython 是 Python 的 DNS 工具包。它可用于查询、区域传输、动态更新、名称服务器测试和许多其他事情。 dnspython 模块提供了大量的 DNS 处理方法&#xff0c…

django集成pytest进行自动化单元测试实战

文章目录 一、引入pytest相关的包二、配置pytest1、将django的配置区分测试环境、开发环境和生产环境2、配置pytest 三、编写测试用例1、业务测试2、接口测试 四、进行测试 在Django项目中集成Pytest进行单元测试可以提高测试的灵活性和效率,相比于Django自带的测试…

PyQt5入门

Python中经常使用的GUI控件集有PyQt、Tkinter、wxPython、Kivy、PyGUI和Libavg。其中PyQt是Qt(c语言实现的)为Python专门提供的扩展 PyQt是一套Python的GUI开发框架,即图形用户界面开发框架.。而在Python中则使用PyQt这一工具包(PyQt5、PyQt5-tools、PyQt5-stubs&am…

卡码网--数组篇(二分法)

系列文章目录 文章目录 系列文章目录前言数组二分查找 前言 详情看:https://programmercarl.com/ 总结知识点用于复习 数组 概念: 数组是存放在连续内存空间上的相同类型数据的集合。 数组可以方便的通过下标索引的方式获取到下标对应的数据。 特点:…

安卓基本布局(下)

TableLayout 常用属性描述collapseColumns设置需要被隐藏的列的列号。shrinkColumns设置允许被伸缩的列的列号。stretchColumns设置允许被拉伸的列的列号。 <TableLayout xmlns:android"http://schemas.android.com/apk/res/android"android:id"id/TableL…

状体管理-装饰器

State 自己的状态 注意:不是状态变量的所有更改都会引起刷新。只有可以被框架观察到的修改才会引起UI刷新。 1、boolean、string、number类型时&#xff0c;可以观察到数值的变化。 2、class或者Object时&#xff0c;可以观察 自身的赋值 的变化&#xff0c;第一层属性赋值的变…

CC++:贪吃蛇小游戏教程

❀创作不易&#xff0c;关注作者不迷路❀&#x1f600;&#x1f600; 目录 &#x1f600;贪吃蛇简介 &#x1f603;贪吃蛇的实现 &#x1f40d;生成地图 &#x1f40d;生成蛇模块 ❀定义蛇的结构体 ❀初始化蛇的相关信息 ❀初始化食物的相关信息 &#x1f40d;光标定位和…

[Spring] SpringBoot统一功能处理与图书管理系统

&#x1f338;个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 &#x1f3f5;️热门专栏: &#x1f9ca; Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 &#x1f355; Collection与…

USB 2.0 规范摘录

文章目录 1、USB 体系简介2、USB 数据流模型四种传输类型 3、USB 物理规范和电气规范4、USB 协议层规范事务传输&#xff08;Transaction&#xff09;的流程 5、USB 框架6、USB 主机&#xff1a;硬件和软件7、USB HUB 规范数据的转发唤醒信号的转发USB HUB 的帧同步HUB Repeate…

前端常见场景、JS计算精度丢失问题(Decimal.js 介绍)

目录 一. Decimal.js 介绍 二. 常用方法 1. 创建 Decimal 实例 2.加法 add 或 plus 3.减法 sub 或 minus 4.乘法 times 或 mul 5.除法 div 或 dividedBy 6.取模 7.幂运算 8.平方根 9.保留小数位 toFixed方法(四舍五入) 三.项目应用 前端精度丢失问题通常由以下原因…

【Kubernetes】kubeadmu快速部署k8s集群

目录 一.组件部署 二.环境初始化 三.所有节点部署docker&#xff0c;以及指定版本的kubeadm 四.所有节点安装kubeadm&#xff0c;kubelet和kubectl 五.高可用配置 六.部署K8S集群 1.master01 节点操作 2.master02、master03节点 3.master01 节点 4.master02、master…

C语言 ——— 学习、使用 strcmp函数 并模拟实现

目录 strcmp函数的功能 学习strcmp函数​编辑 使用strcmp函数 模拟实现strcmp函数 strcmp函数的功能 strcmp函数的功能是字符串比较&#xff0c;两个字符串的对应位置的字符进行比较&#xff0c;直到字符不同或达到终止的 \0 字符为止 举例说明&#xff1a; 字符串1&am…

leetcode-二叉树oj题1(共三道)--c语言

目录 a. 二叉树的概念以及实现参照博客&#xff1a; 一、三道题的oj链接 二、每题讲解 1.单值二叉树 a. 题目&#xff1a; b. 题目所给代码 c. 思路 d. 代码&#xff1a; 2. 相同的树 a. 题目 b. 题目所给代码 c. 思路 d. 代码 3. 二叉树的前序遍历 a. 题目 b.…

前端-05-VSCode自定义代码片段console.log(js/ts配置)、代码段快捷提示放在首位

目录 配置VSCode自定义代码片段console.log()log代码段快捷提示放在首位 配置VSCode自定义代码片段console.log() 点击VSCode左下角设置图标&#xff0c;点击用户代码片段 点击用户代码片段后&#xff0c;VSCode上方出现弹窗如下图&#xff08;没有显示这两个文件的话搜索一下…