基于transformer的解码decode目标检测框架(修改DETR源码)

提示:transformer结构的目标检测解码器,包含loss计算,附有源码

文章目录

  • 前言
  • 一、main函数代码解读
    • 1、整体结构认识
    • 2、main函数代码解读
    • 3、源码链接
  • 二、decode模块代码解读
    • 1、decoded的TransformerDec模块代码解读
    • 2、decoded的TransformerDecoder模块代码解读
    • 3、decoded的DecoderLayer模块代码解读
  • 三、decode模块训练demo代码解读
    • 1、解码数据输入格式
    • 2、解码训练demo代码解读
  • 四、decode模块预测demo代码解读
    • 1、预测数据输入格式
    • 2、解码预测demo代码解读
  • 五、losses模块代码解读
    • 1、matcher初始化
    • 2、二分匹配matcher代码解读
    • 3、num_classes参数解读
    • 4、losses的demo代码解读


前言

最近重温DETR模型,越发感觉detr模型结构精妙之处,不同于anchor base 与anchor free设计,直接利用100框给出预测结果,使用可学习learn query深度查找,使用二分匹配方式训练模型。为此,我基于detr源码提取解码decode、loss计算等系列模块,并重构、修改、整合一套解码与loss实现的框架,该框架可适用任何backbone特征提取接我框架,实现完整训练与预测,我也有相应demo指导使用我的框架。那么,接下来,我将完整介绍该框架源码。同时,我将此源码进行开源,并上传github中,供读者参考。


一、main函数代码解读

1、整体结构认识

在介绍main函数代码前,我先说下整体框架结构,该框架包含2个文件夹,一个losses文件夹,用于处理loss计算,一个是obj_det文件,用于transformer解码模块,该模块源码修改于detr模型,也包含main.py,该文件是整体解码与loss计算demo示意代码,如下图。

在这里插入图片描述

2、main函数代码解读

该代码实际是我随机创造了标签target数据与backbone特征提取数据及位置编码数据,使其能正常运行的demo,其代码如下:

import torch
from obj_det.transformer_obj import TransformerDec
from losses.matcher import HungarianMatcher
from losses.loss import SetCriterionif __name__ == '__main__':Model = TransformerDec(d_model=256, output_intermediate_dec=True, num_classes=4)num_classes = 4   #  类别+1matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)  # 二分匹配不同任务分配的权重losses = ['labels', 'boxes', 'cardinality']  # 计算loss的任务weight_dict = {'loss_ce': 1, 'loss_bbox': 5, 'loss_giou': 2}  # 为dert最后一个设置权重criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=0.1, losses=losses)# 下面使用iter,我构造了虚拟模型编码数据与数据加载标签数据src = torch.rand((391, 2, 256))pos_embed = torch.ones((391, 1, 256))# 创造真实target数据target1 = {'boxes':torch.rand((5,4)),'labels':torch.tensor([1,3,2,1,2])}target2 = {'boxes': torch.rand((3, 4)), 'labels': torch.tensor([1, 1, 2])}target = [target1, target2]res = Model(src, pos_embed)losses = criterion(res, target)print(losses)

如下图:

在这里插入图片描述

3、源码链接

源码链接:点击这里

二、decode模块代码解读

该模块主要是使用transform方式对backbone提取特征的解码,主要使用learn query等相关trike与transform解码方式内容。
我主要介绍TransformerDec、TransformerDecoder、DecoderLayer模块,为依次被包含关系,或说成后者是前者组成部分。

1、decoded的TransformerDec模块代码解读

该类大意是包含了learn query嵌入、解码transform模块调用、head头预测logit与boxes等内容,是实现解码与预测内容,该模块参数或解释已有注释,读者可自行查看,其代码如下:

class TransformerDec(nn.Module):'''d_model=512, 使用多少维度表示,实际为编码输出表达维度nhead=8, 有多少个头num_queries=100, 目标查询数量,可学习querynum_decoder_layers=6, 解码循环层数dim_feedforward=2048, 类似FFN的2个nn.Linear变化dropout=0.1,activation="relu",normalize_before=False,解码结构使用2种方式,默认False使用post解码结构output_intermediate_dec=False, 若为True保存中间层解码结果(即:每个解码层结果保存),若False只保存最后一次结果,训练为True,推理为Falsenum_classes: num_classes数量与数据格式有关,若类别id=1表示第一类,则num_classes=实际类别数+1,若id=0表示第一个,则num_classes=实际类别数额外说明,coco类别id是1开始的,假如有三个类,名称为[dog,cat,pig],batch=2,那么参数num_classes=4,表示3个类+1个背景,模型输出src_logits=[2,100,5]会多出一个预测,target_classes设置为[2,100],其值为4(该值就是背景,而有类别值为123),那么target_classes中没有值为0,我理解模型不对0类做任何操作,是个无效值,模型只对1234进行loss计算,然4为背景会比较多,作者使用权重0.1避免其背景过度影响。forward return: 返回字典,包含{'pred_logits':[],  # 为列表,格式为[b,100,num_classes+2]'pred_boxes':[],  # 为列表,格式为[b,100,4]'aux_outputs'[{},...] # 为列表,元素为字典,每个字典为{'pred_logits':[],'pred_boxes':[]},格式与上相同}'''def __init__(self, d_model=512, nhead=8, num_queries=100, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False, output_intermediate_dec=False, num_classes=1):super().__init__()self.num_queries = num_queriesself.query_embed = nn.Embedding(num_queries, d_model)  # 与编码输出表达维度一致self.output_intermediate_dec = output_intermediate_decdecoder_layer = DecoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)decoder_norm = nn.LayerNorm(d_model)self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/178373.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu 20.04设置虚拟内存 (交换内存swap)解决内存不足

数据库服务器程序在运行起来之后,系统内存不足。 在系统监控中发现,当数据库服务程序启动后,占用了大量内存空间,导致系统的剩余的内存往往只有几十MB。 在ubuntu系统中,swap空间就是虚拟内存,所以考虑在磁…

【kubernetes】k8s对象☞pod

文章目录 1、什么是pod2、pod的使用2.1 用于管理pod的工作负载资源2.2 pod怎样管理多个容器2.3 pod 操作系统2.4 pod和控制器2.5 pod模板 3、pod的更新与替换3.1 资源共享和通信3.2 pod中的存储3.3 pod联网 4、容器的特权模式4.1 linux 特权容器4.2 windows特权容器 5、静态pod…

ES 8.x新特性一览(完整版)

一、看点 在 2022 年 2 月 11 日,Elasticsearch(ES)正式发布了 8.0 版本,而截止到 2023 年 10 月,历经一年半时间,ES官方已经连续发布了多个版本,最新版本为 8.10.4。这一系列的更新引入了众多引…

极智开发 | H100服务器的庐山真面目

欢迎关注我的公众号 [极智视界],获取我的更多经验分享 大家好,我是极智视界,本文分享一下 H100服务器的庐山真面目。 邀您加入我的知识星球「极智视界」,星球内有超多好玩的项目实战源码和资源下载,链接:https://t.zsxq.com/0aiNxERDq H100 是英伟达最强显卡,当然其实也…

【GitLab、GitLab Runner、Docker】GitLab CI/CD 应用

安装Gitlab开源版 官方文档-安装Gitlab 使用Docker安装 sudo docker run --detach \--hostname gitlab.example.com \--env GITLAB_OMNIBUS_CONFIG"external_url http://${ip}:9999/; gitlab_rails[gitlab_shell_ssh_port] 8822;" \--publish 443:443 --publish 99…

基于深度学习的人脸专注度检测计算系统 - opencv python cnn 计算机竞赛

文章目录 1 前言2 相关技术2.1CNN简介2.2 人脸识别算法2.3专注检测原理2.4 OpenCV 3 功能介绍3.1人脸录入功能3.2 人脸识别3.3 人脸专注度检测3.4 识别记录 4 最后 1 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 基于深度学习的人脸专注度…

自动驾驶算法(一):Dijkstra算法讲解与代码实现

目录 0 本节关键词:栅格地图、算法、路径规划 1 Dijkstra算法详解 2 Dijkstra代码详解 0 本节关键词:栅格地图、算法、路径规划 1 Dijkstra算法详解 用于图中寻找最短路径。节点是地点,边是权重。 从起点开始逐步扩展,每一步为一…

Python---字符串切片-----序列名称[开始位置下标 : 结束位置下标 : 步长]

字符串切片:是指对操作的对象截取其中一部分的操作。字符串、列表、元组都支持切片操作。 本文以字符串为例。 基本语法: 顾头不顾尾: ----------类似range() 范围,顾头不顾尾 相关链接Python----ran…

k8s调度约束

List-Watch Kubernetes 是通过 List-Watch的机制进行每个组件的协作,保持数据同步的,每个组件之间的设计实现了解耦。 List-Watch机制 工作机制:用户通过 kubectl请求给 APIServer 来建立一个 Pod。APIServer会将Pod相关元信息存入 etcd 中…

注册中心ZK、nameServer、eureka、Nacos介绍与对比

前言 注册中心的由来 微服务架构是存在着很多跨服务调用,每个服务都存在着多个节点,如果有多个提供者和消费者,当提供者增加/减少或者消费者增加/减少,双方都需要感知发现。所以诞生了注册中心这个中间件。 市面上有很多注册中心,如 Zookeeper、NameServer、Eureka、Na…

【Tomcat Servlet】如何在idea上部署一个maven项目?

目录 1.创建项目 2.引入依赖 3.创建目录 4.编写代码 5.打包程序 6.部署项目 7.验证程序 什么是Tomcat和Servlet? 以idea2019为例: 1.创建项目 1.1 首先创建maven项目 1.2 项目名称 2.引入依赖 2.1 网址输入mvnrepository.com进入maven中央仓库->地址…

2.4G合封芯片 XL2422,集成M0核MCU,高性能 低功耗

XL2422芯片是一款高性能低功耗的SOC集成无线收发芯片,集成M0核MCU,工作在2.400~2.483GHz世界通用ISM频段。该芯片集成了射频接收器、射频发射器、频率综合器、GFSK调制器、GFSK解调器等功能模块,并且支持一对多线网和带ACK的通信模式。发射输…

【Windows-软件-OS】(01)Windows操作系统配置环境变量,快速上手

前言 "Windows"操作系统配置环境变量,快速上手; 实操 【实操一】 环境 Windows 11 专业版(22621.2428); 图片 (1) (2) (3) &#x…

HTTP和HTTPS本质区别——SSL证书

HTTP和HTTPS是两种广泛使用的协议,尽管它们看起来很相似,但是它们在网站数据传输的安全性上有着本质上的区别。 HTTP是明文传输协议,意味着通过HTTP发送的数据是未经加密的,容易受到拦截、窃听和篡改的风险。而HTTPS通过使用SSL或…

【vtk学习笔记4】基本数据类型

一、可视化数据的基本特点 可视化数据有以下特点: 离散型 计算机处理的数据是对无限、连续的空间进行采样,生成的有限采样点数据。在某些离散点上有精确的值,但点与点之间值不可知,只有通过插值方式获取数据具有规则或不规则的结…

由QTableView/QTableWidget显示进度条和按钮,理解qt代理delegate用法

背景: 我的最初应用场景,就是要在表格上用进度条显示数据,以及放一个按钮。 qt-creator中有自带的delegate示例可以参考,但终归自己动手还是需要理解细节,否则不能随心所欲。 自认没那个天赋,于是记录下…

基于springboot实现疫情防控期间外出务工人员信息管理系统项目【项目源码+论文说明】计算机毕业设计

基于springboot疫情防控期间外出务工人员信息管理系统 摘要 网络的广泛应用给生活带来了十分的便利。所以把疫情防控期间某村外出务工人员信息管理与现在网络相结合,利用java技术建设疫情防控期间某村外出务工人员信息管理系统,实现疫情防控期间某村外出…

Linux编译器vim的使用

文章目录 vim基本概念vim的常用三种模式vim三种模式的相互转换 vim命令模式下的命令集移动光标删除文字剪切/删除复制替换撤销和恢复跳转至指定行 vim底行模式下的命令集 vim基本概念 vim是Linux下的一个多模式的编译器 简单来说就是写代码的工具 不提供编译调试等功能 有语法…

0基础学习PyFlink——时间滚动窗口(Tumbling Time Windows)

大纲 mapreduce完整代码参考资料 在《0基础学习PyFlink——个数滚动窗口(Tumbling Count Windows)》一文中,我们发现如果窗口内元素个数没有达到窗口大小时,计算个数的函数是不会被调用的。如下图中红色部分 那么有没有办法让上图中(B,2&…

Prometheus接入AlterManager配置钉钉告警(基于K8S环境部署)

文章目录 一、钉钉群创建报警机器人二、安装Webhook-dingtalk插件三、配置Webhook-dingtalk插件对接钉钉群四、配置AlterManager告警发送至Webhook-dingtalk五、Prometheus接入AlterManager配置六、部署PrometheusAlterManager(放到一个Pod中)七、测试告警 注意:请基…