笔记01----Transformer高效语义分割解码器模块DEPICT(即插即用)

学习笔记01----即插即用的解码器模块DEPICT

    • 前言
    • 源码下载
    • DEPICT实现
    • 实验

前言

文 章 标 题:《Rethinking Decoders for Transformer-based Semantic Segmentation: Compression is All You Need》
当前的 Transformer-based 方法(如 DETR 和其变体)取得了显著进展。但这些解码器(decoder)的设计更多是基于经验,缺乏理论解释,难以确定性能瓶颈并进行进一步改进。
该论文将语义分割任务建模为“从主空间到子空间的信息压缩”问题,强调从高维图像特征中提取类别相关的紧凑表示。
提出 DEPICT 解码器:

  • 基于 自注意力(MSSA) 和 交叉注意力(MSCA) 设计简单高效的解码器。
  • MSSA 构建主子空间,去除冗余,优化图像特征。
  • MSCA 动态提取类别相关特征,生成类别嵌入的低维表示。

源码下载

源代码地址:https://github.com/QishuaiWen/DEPICT

DEPICT实现

在这里插入图片描述
DEPICT流程:
1. 图像特征输入: 通过vit的主干网络对图像进行特征提取。这些特征中可能包含很多不重要的信息,比如背景噪声。我们的目标是提取出与分类相关的特征。
2.sa模式—自注意力模块(MSSA): 通过自注意力机制(Multi-head Subspace Self-Attention, MSSA),捕捉图像块之间的全局关系,去掉不相关信息,优化出更加紧凑的主要特征(主子空间)。它的具体操作是将 类别嵌入向量图像特征进行 拼接操作 输入 MSSA模块进行特征优化。
3.ca模式—交叉注意力模块(MSCA):类别嵌入(这是一个可学习的特征向量)作为查询,图像特征作为键和值,通过交叉注意力(Multi-head Subspace Cross-Attention, MSCA)提取每个类别的相关特征,生成类别嵌入的低维表示。它的具体操作是将 类别嵌入向量 作为 查询向量 通过MSCA进行特征优化。
类别嵌入向量是一个可学习的参数,是从 主空间中提取 出的,与类别强相关的特征子集,是图像特征的降维。
4.生成分割掩码:用点积操作比较图像特征和类别嵌入,生成每块图像属于每个类别的概率。

import torch
import torch.nn as nn
from einops import rearrange
from timm.models.layers import trunc_normal_
from dec_blocks import Transformer
from segm.model.utils import init_weights
class MaskTransformer(nn.Module):def __init__(self,n_cls,#类别数量patch_size,# 图像分块大小n_layers,  # Transformer 的层数n_heads,  # 多头注意力中的头数d_model,  # 特征的嵌入维度dropout,  # dropout 概率mode='ca',  # 模式选择:'ca' (交叉注意力) 或 'sa' (自注意力)):super().__init__()self.patch_size = patch_sizeself.n_cls = n_clsself.mode = mode# cls_emb 是类别嵌入矩阵,初始化为随机值,形状为 (1, n_cls, d_model)。# 在 DEPICT 中,类别嵌入对应于主子空间的基向量 Pself.cls_emb = nn.Parameter(torch.randn(1, n_cls, d_model))if mode == 'sa':# 提取图像主特征self.net = Transformer(d_model, n_layers, n_heads, 100, dropout)self.decoder_norm = nn.LayerNorm(d_model)elif mode == 'ca':# 用于优化图像特征的主特征self.snet = Transformer(d_model, n_layers, n_heads, 100, dropout)# 用于进一步提取类别嵌入self.cnet = Transformer(d_model, 3, n_heads, 50, dropout)self.snorm = nn.LayerNorm(d_model)self.cnorm = nn.LayerNorm(d_model)else:raise ValueError(f"Provided mode: {mode} is not valid.")self.mask_norm = nn.LayerNorm(n_cls)self.apply(init_weights)trunc_normal_(self.cls_emb, std=0.02)@torch.jit.ignoredef no_weight_decay(self):return {"cls_emb"}def forward(self, x, im_size=None):H, W = im_sizeGS = H // self.patch_size# 扩张维度从(1, n_cls, d_model)到(batch_size,n_cls,d_model)cls_emb = self.cls_emb.expand(x.size(0), -1, -1)if self.mode == 'sa':# 拼接图像特征和类别嵌入# (batch_size,num_patches,d_model)x = torch.cat((x, cls_emb), 1)# 通过 Transformer 网络x = self.net(x)# 归一化处理x = self.decoder_norm(x)# patches优化后的图像特征。# cls_seg_feat:更新后的类别嵌入patches, cls_seg_feat = x[:, :-self.n_cls], x[:, -self.n_cls:]else:# 优化图像特征x = self.snet(x)# 归一化处理x = self.snorm(x)# 通过交叉注意力提取类别嵌入cls_emb = self.cnet(x, query=cls_emb)# 归一化cls_emb = self.cnorm(cls_emb)# patches优化后的图像特征。# cls_seg_feat:更新后的类别嵌入patches, cls_seg_feat = x, cls_emb#  向量标准化patches = patches / patches.norm(dim=-1, keepdim=True)cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True)# 点积操作:生成掩码# patches:形状为 (batch_size, num_patches, d_model)。# cls_seg_feat:形状为 (batch_size, n_cls, d_model)# 转为 (batch_size, d_model, n_cls),方便点积运算。# 输出 masks 的形状为 (batch_size, num_patches, n_cls),表示每个 patch 属于每个类别的得分。masks = patches @ cls_seg_feat.transpose(1, 2)# 标准化为了简化训练masks = self.mask_norm(masks)# 重排掩码形状masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS))return masks

调用测试代码

def main():# 配置参数n_cls = 10           # 类别数,例如分割任务有 10 个类别patch_size = 16       # 图像分块大小n_layers = 4          # Transformer 层数n_heads = 8           # 多头注意力头数d_model = 128         # 特征嵌入维度dropout = 0.1         # dropout 比例mode = 'ca'           # 模式选择:'ca' 或 'sa'# 初始化 MaskTransformermodel = MaskTransformer(n_cls=n_cls,patch_size=patch_size,n_layers=n_layers,n_heads=n_heads,d_model=d_model,dropout=dropout,mode=mode)# 测试输入batch_size = 2        # 批次大小image_size = 128      # 图像尺寸(假设输入为 128x128)num_patches = (image_size // patch_size) ** 2  # 分块后有多少个 patch# 生成随机的图像特征输入 (batch_size, num_patches, d_model)x = torch.randn(batch_size, num_patches, d_model)# 设置 im_sizeim_size = (image_size, image_size)# 运行模型masks = model(x, im_size=im_size)# 输出形状print("Output masks shape:", masks.shape)

实验

ADE20KcityscapePascalContext数据集
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/474217.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

A037-基于Spring Boot的二手物品交易的设计与实现

🙊作者简介:在校研究生,拥有计算机专业的研究生开发团队,分享技术代码帮助学生学习,独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取,记得注明来意哦~🌹 赠送计算机毕业设计600…

EEG+EMG学习系列 (1) :一个基于小波的自动睡眠评分模型

EEGEMG学习系列:一个基于小波的自动睡眠评分模型 0. 引言1. 主要贡献2. 提出的方法2.1 工作框图2.1 正交小波滤波器组2.2 小波分解2.3 特征提取 3. 结果4. 总结欢迎来稿 论文地址:https://www.mdpi.com/1660-4601/19/12/7176 论文题目:An Automated Wave…

自动化运维-检测Linux服务器CPU、内存、负载、IO读写、机房带宽和服务器类型等信息脚本

前言:以上脚本为今年8月1号发布的,当时是没有任何问题,但现在脚本里网络速度测试py文件获取不了了,测速这块功能目前无法实现,后面我会抽时间来研究,大家如果有建议也可以分享下。 脚本内容: #…

H.265流媒体播放器EasyPlayer.js网页直播/点播播放器WebGL: CONTEXT_LOST_WEBGL错误引发的原因

EasyPlayer无插件直播流媒体音视频播放器属于一款高效、精炼、稳定且免费的流媒体播放器,可支持多种流媒体协议播放,无须安装任何插件,起播快、延迟低、兼容性强,使用非常便捷。 EasyPlayer.js能够同时支持HTTP、HTTP-FLV、HLS&a…

OCRSpace申请free api流程

0.OCRSpace概述 OCR.Space是一款功能强大的在线光学字符识别(OCR)工具。 格式与语言支持广泛:支持多种图片格式,如 JPG、PNG、GIF、PDF 等作为输入。在语言方面,它支持英语、中文、法语、德语等20多种语言的文字识别…

Linux Kernel Programming 2

目录 书写内核框架 起手我们需要理解的是:用户态和内核态 库和系统调用 API 内核空间组件 探索 LKM(Linux Kernel Module体系) LKM 框架 内核源代码树中的内核模块 modinfo 动手!写年轻人的第一个内核模块程序 先试试看&…

机器学习基础04

目录 1.朴素贝叶斯-分类 1.1贝叶斯分类理论 1.2条件概率 1.3全概率公式 1.4贝叶斯推断 1.5朴素贝叶斯推断 1.6拉普拉斯平滑系数 1.7API 2.决策树-分类 2.1决策树 2.2基于信息增益的决策树建立 2.2.1信息熵 2.2.2信息增益 2.2.3信息增益决策树建立步骤 2.3基于基…

ChatGPT学术专用版,一键润色纠错+中英互译+批量翻译PDF

ChatGPT academic项目是由中科院团队基于ChatGPT专属定制。论文润色、语法检查、中英互译、代码解释等可一键搞定,堪称科研神器。 功能介绍 我们以3.5版本为例,ChatGPT学术版总共分为五个区域:输入控制区、输出对话区、基础功能区、函数插件…

fpga 同步fifo

FIFO 基础知识 FIFO(First In First Out,即先入先出),是一种数据缓存器,用来实现数据先入先出 的读写方式。在 FPGA 或者 ASIC 中使用到的 FIFO 一般指的是对数据的存储具有先入先出 特性的缓存器,常被用于…

模式:每个服务一个数据库

Pattern: Database per service。 背景 如用微服务架构模式开发一个在线商店应用程序。大多数服务需要在某种数据库中持久化数据。如,订单服务存储订单信息,而客户服务存储客户信息。 问题 微服务应用程序中的数据库架构是什么? 驱动力…

Java 全栈知识体系

包含: Java 基础, Java 部分源码, JVM, Spring, Spring Boot, Spring Cloud, 数据库原理, MySQL, ElasticSearch, MongoDB, Docker, k8s, CI&CD, Linux, DevOps, 分布式, 中间件, 开发工具, Git, IDE, 源码阅读,读书笔记, 开源项目...

WebRTC实现双端音视频聊天(Vue3 + SpringBoot)

目录 概述 相关概念 双端连接整体实现步骤概述 文章代码实现注意点 STUN和TURN服务器的搭建 开发过程描述 后端开发流程 前端开发流程 效果演示 Gitee源码地址 概述 文章描述使用WebRTC技术实现一对一音视频通话。 由于设备摄像头限制(一台电脑作测试无法…

机器学习3

六、朴素贝叶斯分类 背景知识:第三大点的第4点:概率 基础定义_数学概率中事件的定义-CSDN博客 1、条件概率 𝑃(𝐴|𝐵)𝑃(𝐴∩𝐵)/𝑃(𝐵) :A事件在…

SpringBoot Data Redis连接Redis-Cluster集群

使用SpringBoot Data Redis无法连接Redis-Cluster集群 最近在研究系统高并发下的缓存架构,因此自己在自己买的云服务器上搭建好Redis 5.0 版本的集群后,使用springboot的 RedisTemplate连接是发现总是访问不到集群节点。上网百度了发现没有好的解决办法&…

网页作业9

<!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>服务中心</title><style>* {margin:…

基于yolov8、yolov5的行人检测识别系统(含UI界面、训练好的模型、Python代码、数据集)

摘要&#xff1a;行人检测在交通管理、智能监控和公共安全中起着至关重要的作用&#xff0c;不仅能帮助相关部门实时监控人群动态&#xff0c;还为自动化监控系统提供了可靠的数据支撑。本文介绍了一款基于YOLOv8、YOLOv5等深度学习框架的行人检测模型&#xff0c;该模型使用了…

递归(3)----力扣40组合数2,力扣473火柴拼正方形

给定一个候选人编号的集合 candidates 和一个目标数 target &#xff0c;找出 candidates 中所有可以使数字和为 target 的组合。 candidates 中的每个数字在每个组合中只能使用 一次 。 注意&#xff1a;解集不能包含重复的组合。 示例 1: 输入: candidates [10,1,2,7,6,1…

1Panel 推送 SSL 证书到阿里云、腾讯云

本文首发于 Anyeの小站&#xff0c;点击链接 访问原文体验更佳 前言 都用 CDN 了还在乎那点 1 年证书钱么&#xff1f; 开句玩笑话&#xff0c;按照 Apple 的说法&#xff0c;证书有效期不该超过 45 天。那么证书有效期的缩短意味着要更频繁地更新证书。对于我这样的“裸奔”…

通过shell脚本分析部署nginx网络服务

通过shell脚本分析部署nginx网络服务 1.接收用户部署的服务名称 [rootlocalhost xzy]# vim 1.sh [rootlocalhost xzy]# chmod x 1.sh [rootlocalhost xzy]# ./1.sh2.判断服务是否安装 已安装&#xff1b;自定义网站配置路径为/www&#xff1b;并创建共享目录和网页文件&…

tcp 超时计时器

在 TCP&#xff08;传输控制协议&#xff09;中有以下四种重要的计时器&#xff1a; 重传计时器&#xff08;Retransmission Timer&#xff09; 作用&#xff1a;用于处理数据包丢失的情况。当发送方发送一个数据段后&#xff0c;就会启动重传计时器。如果在计时器超时之前没有…