DERT目标检测—End-to-End Object Detection with Transformers

DERT:使用Transformer的端到端目标检测

论文题目:End-to-End Object Detection with Transformers
官方代码:https://github.com/facebookresearch/detr

论文题目中包括的一个创新点End to End(端到端的方法)简单的理解就是没有使用到NMS等后处理操作来处理生成的多个重复的框

简单的端到端的目标检测系统

引言与概述

We present a new method that views object detection as a
direct set prediction problem.

将目标检测任务直接看成是集合预测的问题。

  1. 提出了一个全局的二分图匹配的损失函数:a set-based global loss that forces unique predictions via bipartite matching
  2. 结合了Transformer结构,在解码器的部分进行并行的出框。

在引言的部分论文中简单概括了之前的目标检测所用到的一些方法。

  • 双阶段目标检测的算法: FasterRcnn
  • 单阶段目标检测的算法:Yolo
  • 基于中心点进行生成的算法:CenterNet

在这里插入图片描述

在没有深入的学习目标检测网络具体的细节之前。对这一个过程进行一.个直观的信息描述,

  1. 首先经过一个CNN网络提取一部分的特征得到对应的特征图,并将得到的特征进行拉直处理。准备送入之后的Transformer结构中去。

  2. 将拉直之后的token送入编码器的结构部分,(endcode去进一步学习全局的特征信息。为我们decode出预测框的部分做铺垫。)使用endcode可以认为是将图片中的每一个点和其他的点之间就有交互信息了。就可以知道大概那一块是那个物体。

  3. 对同一个物体就出一个检测框的结果。通过query和我们的特征就可以确定要出多少检测框(论文中固定出框数为100)

  4. 最后一步就是我得出的这100个框,如何和我的Ground Truth框之间做一个关联匹配问题呢 计算loss 没有匹配的框则会标记为没有物体。

目标检测相关工作

Most modern object detection methods make predictions relative to some ini-tial guesses. Two-stage detectors [37,5] predict boxes w.r.t. proposals, whereas single-stage methods make predictions w.r.t. anchors [23] or a grid of possible object centers [53,46]. Recent work [52] demonstrate that the final performance of these systems heavily depends on the exact way these initial guesses are set.

在之前的目标检测的相关工作中作者就提到了,之前相关的目标检测的工作,取决于我们的先验猜测,双阶段的候选框proposals,单阶段的anchors于 centernet的中心点检测取决于,中心点选取的位置。

从而提出了一种新的方法,基于集合的方式来做这个目标检测的任务。

目标函数

DETR infers a fixed-size set of N predictions, in a single pass through the decoder, where N is set to be significantly larger than the typical number of objects in an image N=100

二分图匹配问题+匈牙利算法。找到一个唯一解使得最后可以完成最后的一个分配。(代价矩阵的构建就可以看成是,将100个预测框和Ground Truth框之间进行二分图匹配

σ ^ = arg ⁡ min ⁡ σ ∈ S N ∑ i N L match  ( y i , y ^ σ ( i ) ) \hat{\sigma}=\underset{\sigma \in \mathfrak{S}_{N}}{\arg \min } \sum_{i}^{N} \mathcal{L}_{\text {match }}\left(y_{i}, \hat{y}_{\sigma(i)}\right) σ^=σSNargminiNLmatch (yi,y^σ(i))

最后匹配完成之后就可以和之后的目标检测差不多的损失函数。

L Hungarian  ( y , y ^ ) = ∑ i = 1 N [ − log ⁡ p ^ σ ^ ( i ) ( c i ) + 1 { c i ≠ ∅ } L box  ( b i , b ^ σ ^ ( i ) ) ] \mathcal{L}_{\text {Hungarian }}(y, \hat{y})=\sum_{i=1}^{N}\left[-\log \hat{p}_{\hat{\sigma}(i)}\left(c_{i}\right)+\mathbb{1}_{\left\{c_{i} \neq \varnothing\right\}} \mathcal{L}_{\text {box }}\left(b_{i}, \hat{b}_{\hat{\sigma}}(i)\right)\right] LHungarian (y,y^)=i=1N[logp^σ^(i)(ci)+1{ci=}Lbox (bi,b^σ^(i))]

简单的看也就是分类损失加上回归损失的表达形式。

在这里插入图片描述
这一个更为详细的图里面就引出了另外的一个十分重要的概念object Queries(是一个可学习的参数在经过学习之后就可以确定出哪些查询会对应哪些目标,从而避免重复的操作。

连接一个分类头完成最终的结果的一个预测。论文中给出的简化的版本代码编写。

import torch
from torch import nn
from torchvision.models import resnet50
class DETR(nn.Module):
def __init__(self, num_classes, hidden_dim, nheads,
num_encoder_layers, num_decoder_layers):
super().__init__()
# We take only convolutional layers from ResNet-50 model
self.backbone=nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
self.conv = nn.Conv2d(2048, hidden_dim, 1)self.transformer = nn.Transformer(hidden_dim, nheads,num_encoder_layers, num_decoder_layers)self.linear_class = nn.Linear(hidden_dim, num_classes + 1)self.linear_bbox = nn.Linear(hidden_dim, 4)self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))def forward(self, inputs):x = self.backbone(inputs)h = self.conv(x)H, W = h.shape[-2:]
pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),], dim=-1).flatten(0, 1).unsqueeze(1)h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),self.query_pos.unsqueeze(1))return self.linear_class(h), self.linear_bbox(h).sigmoid()detr = DETR(num_classes=91, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6)detr.eval()inputs = torch.randn(1, 3, 800, 1200)logits, bboxes = detr(inputs)

网络模型结构

在这里插入图片描述

主干网络与预处理的部分

根据论文官方的代码对模型的结构进行说明:

  1. 输入时一个800 x 1066的三通道图片,将其输入到主干网络提取器ResNet50中进行特征的提取 得到的特征图大小是 25x34(下采样了32倍) 将通道数拓展为2048。

  2. 将得到的特征图经过一个1x1的卷积层输入的通道数是2048 输出的通道数是 256得到了**[ 25 34 256]的结构**

  3. 将最后的两个维度进行一个展平的操作步骤得到了 [850 ,256]的结构

在这里插入图片描述

其中的850就是我们后面使用的Transformer中token的个数,256即为特征向量的长度。

Transformer结构部分

论文中也给出了一个改进之后的Transformer结构。结构之前的Transfomer结构给出类比的结果。

在这里插入图片描述
在标准的Transformer中位置编码只作用在输入的位置处,并且只作用一次。而在DERT的Transformer中位置编码是在每一个编码器,和解码器的部分都需要操作一次的。

学习这个网络模型的难点就在需要注意,模型之间的连线来确定好各个Q K V是通过哪些变量的计算来生成的(结合源码)

损失函数

  1. 从100个预测框中,找出和真实标注框所匹配的N个框(图中对应的是两个框),也就是说我们在训练集样本中标注了几个框,就需要在那100个得到的预测框中筛选出几个框(N)来进行匹配

在这里插入图片描述

我们需要做的任务就是向代价矩阵中进行填值使得匹配的结果最为合适

− 1 { c i ≠ ∅ } p ^ σ ( i ) ( c i ) + 1 { c i ≠ ∅ } L box  ( b i , b ^ σ ( i ) ) . -\mathbb{1}_{\left\{c_{i} \neq \varnothing\right\}} \hat{p}_{\sigma(i)}\left(c_{i}\right)+\mathbb{1}_{\left\{c_{i} \neq \varnothing\right\}} \mathcal{L}_{\text {box }}\left(b_{i}, \hat{b}_{\sigma(i)}\right) . 1{ci=}p^σ(i)(ci)+1{ci=}Lbox (bi,b^σ(i)).
在这里插入图片描述
我们首先看公式的前半部分:即为对应的类别损失:Class Cost

− 1 { c i ≠ ∅ } p ^ σ ( i ) ( c i ) -\mathbb{1}_{\left\{c_{i} \neq \varnothing\right\}} \hat{p}_{\sigma(i)}\left(c_{i}\right) 1{ci=}p^σ(i)(ci)

在这里插入图片描述

  • 首先要提取出GT中的坐标框对应的类别信息(第一张图有两个框,第二张图中有四个框。值为类别编号

  • 对应两个图片给出的200个预测框的值(2N)我们将其进行拼接,计算出包含真实类别的概率值。

在这里插入图片描述
在计算的时候Cost class这个张量需要加符号用来计算损失函数的值

  • 第二部分我们对应的是边界框回归的一个损失。

1 { c i ≠ ∅ } L box  ( b i , b ^ σ ( i ) ) \mathbb{1}_{\left\{c_{i} \neq \varnothing\right\}} \mathcal{L}_{\text {box }}\left(b_{i}, \hat{b}_{\sigma(i)}\right) 1{ci=}Lbox (bi,b^σ(i))

论文中关于回归损失函数的描述信息为:
Bounding box loss. The second part of the matching cost and the Hungarian
loss is Lbox(·) that scores the bounding boxes. Unlike many detectors that do box
predictions as a ∆ w.r.t. some initial guesses, we make box predictions directly.
While such approach simplify the implementation it poses an issue with relative
scaling of the loss. The most commonly-used 1 loss will have different scales for
small and large boxes even if their relative errors are similar. To mitigate this
issue we use a linear combination of the 1 loss and the generalized IoU loss [38]
Liou(·, ·) that is scale-invariant. Overall, our box loss is Lbox(bi, ˆbσ(i)) defined as
λiouLiou(bi, ˆbσ(i)) + λL1||bi − ˆbσ(i)||1 where λiou, λL1 ∈ R are hyperparameters.
These two losses are normalized by the number of objects inside the batch.

总结一下:也就如果和之前一样使用常规的L1损失来作为回归损失,可能会导致,大小检测框的相对计算一致。因此在这个基础上引出了GIOU损失与L1损失相结合的最终回归损失部分。

λ iou  L iou  ( b i , b ^ σ ( i ) ) + λ L 1 ∥ b i − b ^ σ ( i ) ∥ 1 \lambda_{\text {iou }} \mathcal{L}_{\text {iou }}\left(b_{i}, \hat{b}_{\sigma(i)}\right)+\lambda_{\mathrm{L} 1}\left\|b_{i}-\hat{b}_{\sigma(i)}\right\|_{1} λiou Liou (bi,b^σ(i))+λL1 bib^σ(i) 1

在这里插入图片描述

match操作

− 1 { c i ≠ ∅ } p ^ σ ( i ) ( c i ) + 1 { c i ≠ ∅ } L box  ( b i , b ^ σ ( i ) ) -\mathbb{1}_{\left\{c_{i} \neq \varnothing\right\}} \hat{p}_{\sigma(i)}\left(c_{i}\right)+\mathbb{1}_{\left\{c_{i} \neq \varnothing\right\}} \mathcal{L}_{\text {box }}\left(b_{i}, \hat{b}_{\sigma(i)}\right) 1{ci=}p^σ(i)(ci)+1{ci=}Lbox (bi,b^σ(i))

L box  ( ⋅ ) = λ iou  L iou  ( b i , b ^ σ ( i ) ) + λ L1  ∥ b i − b ^ σ ( i ) ∥ 1 \mathcal{L}_{\text {box }}(\cdot)=\lambda_{\text {iou }} \mathcal{L}_{\text {iou }}\left(b_{i}, \hat{b}_{\sigma(i)}\right)+\lambda_{\text {L1 }}\left\|b_{i}-\hat{b}_{\sigma(i)}\right\|_{1} Lbox ()=λiou Liou (bi,b^σ(i))+λL1  bib^σ(i) 1

  • 我们对应代码部分实际的计算步骤就是:cost = -cost_class + 5 × cost_bbor - 2 × cost_GIoUs

  • 把计算得到的结果填写入矩阵之中,就可以得到两个图片总的代价矩阵,我们在使用split操作将其分开得到两个代价矩阵的结果。

(分别进行匈牙利匹配)

在这里插入图片描述

计算损失并反向传播

在这个地方论文中提出了一个新的损失函数。—匈牙利损失函数。使用筛选出的预测框与真实标注框计算损失。

L Hungarian  ( y , y ^ ) = ∑ i = 1 N [ − log ⁡ p ^ σ ^ ( i ) ( c i ) + 1 { c i ≠ ∅ } L box  ( b i , b ^ σ ^ ( i ) ) ] \mathcal{L}_{\text {Hungarian }}(y, \hat{y})=\sum_{i=1}^{N}\left[-\log \hat{p}_{\hat{\sigma}(i)}\left(c_{i}\right)+\mathbb{1}_{\left\{c_{i} \neq \varnothing\right\}} \mathcal{L}_{\text {box }}\left(b_{i}, \hat{b}_{\hat{\sigma}}(i)\right)\right] LHungarian (y,y^)=i=1N[logp^σ^(i)(ci)+1{ci=}Lbox (bi,b^σ^(i))]

和之前代价矩阵计算所用的那个函数其实差不多(类别损失+坐标损失)。区别主要在于一下几点。

  1. 这里在计算类别损失的时候我们是使用N也就是100个预测框来参与运算。而不是只计算标注类别的损失。
  2. 加了 log也就是使用交叉熵损失函数(计算平均值)
  3. 中间层的输出也是参与了损失计算的。(主网络损失+网络中间层的损失)

在这里插入图片描述

用预测结果与真实的结果计算交叉熵损失(所有的框 92代表背景)

在回归损失中,公式也给出了只使用真实的标注框不含背景。

1 { c i ≠ ∅ } L box  ( b i , b ^ σ ^ ( i ) ) \mathbb{1}_{\left\{c_{i} \neq \varnothing\right\}} \mathcal{L}_{\text {box }}\left(b_{i}, \hat{b}_{\hat{\sigma}}(i)\right) 1{ci=}Lbox (bi,b^σ^(i))

最后就可以得到最终的结果了:结合反向传播对整个网络进行训练和优化

L Hungarian  ( y , y ^ ) = ∑ i = 1 N [ − log ⁡ p ^ σ ^ ( i ) ( c i ) + 1 { c i ≠ ∅ } L box  ( b i , b ^ σ ^ ( i ) ) ] \mathcal{L}_{\text {Hungarian }}(y, \hat{y})=\sum_{i=1}^{N}\left[-\log \hat{p}_{\hat{\sigma}(i)}\left(c_{i}\right)+\mathbb{1}_{\left\{c_{i} \neq \varnothing\right\}} \mathcal{L}_{\text {box }}\left(b_{i}, \hat{b}_{\hat{\sigma}}(i)\right)\right] LHungarian (y,y^)=i=1N[logp^σ^(i)(ci)+1{ci=}Lbox (bi,b^σ^(i))]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/431009.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MISC - 第四天(OOK编码,audacity音频工具,摩斯电码,D盾,盲文识别,vmdk文件压缩)

前言 各位师傅大家好,我是qmx_07,今天继续讲解MISC知识点 FLAG 附件是一张图片,尝试binwalk无果 使用StegSolve工具Data Extract查看时 发现PK字段,是大多数压缩包的文件头点击Save Bin保存zip文件 解压缩失败使用修复软件:htt…

代码随想录Day17 图论-2

103. 水流问题 本题思路很简单 要求我们找到可以满足到达两个边界的单元格的坐标 有一个优化的思路就是 我们从边界的节点向中间遍历 然后用两个数组表示 一个是第一组边界的数组 一个是第二边界的数组 如果两个数组都遍历到了某一个单元格 就说明该单元格时满足题目要求的 #…

OpenLayers 开源的Web GIS引擎 - 添加地图控件地图控件

中心点按钮、地图放大缩小滑块、全图和比例尺控件 直接上代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.…

如何选择适合的干式电抗器?

干式电抗器是电力系统中重要的电气设备&#xff0c;主要用于限制电网中的短路电流&#xff0c;提高电力系统的稳定性和可靠性。选择适合的干式电抗器对于保障电力系统的正常运行具有重要意义。以下是选择适合的干式电抗器的一些建议&#xff1a; 1. 根据电力系统的需求选择合适…

大模型算法岗常见面试题100道(值得收藏)非常详细收藏我这一篇就够了

大模型应该是目前当之无愧的最有影响力的AI技术&#xff0c;它正在革新各个行业&#xff0c;包括自然语言处理、机器翻译、内容创作和客户服务等等&#xff0c;正在成为未来商业环境的重要组成部分。 截至目前大模型已经超过200个&#xff0c;在大模型纵横的时代&#xff0c;不…

【设计模式】创建型模式(四):建造者模式

《设计模式之创建型模式》系列&#xff0c;共包含以下文章&#xff1a; 创建型模式&#xff08;一&#xff09;&#xff1a;工厂模式创建型模式&#xff08;二&#xff09;&#xff1a;抽象工厂模式创建型模式&#xff08;三&#xff09;&#xff1a;单例模式创建型模式&#…

Android14请求动态申请存储权限

Android14请求动态申请存储权限 Android14和Android15存储权限有增加多了选择部分&#xff0c;还是全部。一个小小的存储权限真的被它玩出了花来。本来Android13就将存储权限进行了3个细分&#xff0c;是图片&#xff0c;音频还是视频文件。 步骤一&#xff1a;AndroidManife…

峟思:山洪灾害监测预警系统全面解析

在自然灾害频发的今天&#xff0c;山洪灾害以其突发性强、破坏力大而备受关注。为了有效预防和减少山洪灾害带来的损失&#xff0c;山洪灾害监测预警系统应运而生。本文将详细介绍该系统的主要组成部分、关键传感器及其工作机制&#xff0c;以期为防灾减灾工作提供有力支持。 山…

项目小总结

这段时间主要把大概的开发流程了解完毕 修改了&#xff0c;并画了几个界面 一.界面 修改为 博客主页 个人中心 二.前后端分离开发 写前端时 就可以假设拿到这些数据了 const blogData2 {blog:{id:1,title: "如何编程飞人",author_id: 1,content: "这是一篇…

最新版C/C++通过CLion2024进行Linux远程开发保姆级教学

目前来说&#xff0c;对Linux远程开发支持相对比较好的也就是Clion和VSCode了&#xff0c;这两个其实对于C和C语言开发都很友好&#xff0c;大可不必过于纠结使用那个&#xff0c;至于VS和QtCreator&#xff0c;前者太过重量级了&#xff0c;后者更是不用说&#xff0c;主要用于…

【论文阅读】Grounding Language with Visual Affordances over Unstructured Data

Abstract 最近的研究表明&#xff0c;大型语言模型&#xff08;llms&#xff09;可以应用于将自然语言应用于各种各样的机器人技能。然而&#xff0c;在实践中&#xff0c;学习多任务、语言条件机器人技能通常需要大规模的数据收集和频繁的人为干预来重置环境或帮助纠正当前的…

vue node node-sass sass-loader 版本 对应 与 兼容

警告&#xff1a; LibSass 和 Node Sass 已弃用。虽然它们将继续无限期地接收维护版本&#xff0c;但没有计划添加其他功能或与任何新的 CSS 或 Sass 功能兼容。仍在使用它的项目应该转移到 Dart Sass。 sass Sass是一种预处理器脚本语言&#xff0c;可以解释或编译成…

Java—反射机制详解

介绍反射 反射的基本概念 反射&#xff08;Reflection&#xff09;是Java语言中的一种机制&#xff0c;它允许程序在运行时检查和操作类、接口、字段和方法等类的内部结构。通过反射&#xff0c;你可以在运行时获取类的信息&#xff0c;包括类的构造器、字段、方法等&#xf…

在 Windows 上运行 Vue 项目时解决 ‘NODE_OPTIONS‘ 错误

在 Windows 上运行 Vue 项目时解决 ‘NODE_OPTIONS’ 错误 在 Windows 系统上启动 Vue 项目时&#xff0c;遭遇报错。具体报错信息如下&#xff1a; ‘NODE_OPTIONS‘ 不是内部或外部命令&#xff0c;也不是可运行的程序或批处理文件。这个错误通常意味着 Windows 系统无法识…

机器翻译之创建Seq2Seq的编码器、解码器

1.创建编码器、解码器的基类 1.1创建编码器的基类 from torch import nn#构建编码器的基类 class Encoder(nn.Module): #继承父类nn.Moduledef __init__(self, **kwargs): #**kwargs&#xff1a;不定常的关键字参数super().__init__(**kwargs)def forward(self, X, *args…

基于SpringBoot+Vue+MySQL的美食点餐管理系统

系统展示 用户前台界面 管理员后台界面 系统背景 在数字化快速发展的今天&#xff0c;餐饮行业也迎来了转型升级的重要机遇。传统餐饮管理方式面临效率低下、顾客体验不佳等问题。为此&#xff0c;开发一款基于SpringBootVueMySQL架构的美食点餐管理系统显得尤为重要。该系统旨…

【可图(Kolors)部署与使用】大规模文本到图像生成模型部署与使用教程

✨ Blog’s 主页: 白乐天_ξ( ✿&#xff1e;◡❛) &#x1f308; 个人Motto&#xff1a;他强任他强&#xff0c;清风拂山冈&#xff01; &#x1f4ab; 欢迎来到我的学习笔记&#xff01; 1.Kolors 简介 1.1.什么是Kolors&#xff1f; 开发团队 Kolors 是由快手 Kolors 团队…

网页护眼宝——全方位解析 Chrome Dark Reader 插件

网页护眼宝——全方位解析 Chrome Dark Reader 插件 1. 基本介绍&#xff1a;Chrome 插件的力量与 Dark Reader 的独特之处 随着现代浏览器的功能越来越强大&#xff0c;Chrome 插件为用户提供了极大的定制化能力。从广告屏蔽、性能优化到页面翻译&#xff0c;Chrome 插件几乎…

视频监控相关笔记

一、QT 之 QTreeWidget 树形控件 Qt编程指南&#xff0c;Qt新手教程&#xff0c;Qt Programming Guide 一个树形结构的节点中的图表文本 、附带数据的添加&#xff1a; QTreeWidgetItem* TourTreeWnd::InsertNode(NetNodeInfo node, QTreeWidgetItem* parent_item) { // …

C++: unordered系列关联式容器

目录 1. unordered系列关联式容器1.1 unordered_map1.2 unordered_set 2. 哈希概念3. 哈希冲突4. 闭散列5. 开散列 博客主页: 酷酷学 感谢关注!!! 正文开始 1. unordered系列关联式容器 在C98中&#xff0c;STL提供了底层为红黑树结构的一系列关联式容器&#xff0c;在查询时…