(即插即用模块-Attention部分) 二十、(2021) GAA 门控轴向注意力

在这里插入图片描述

文章目录

  • 1、Gated Axial-Attention
  • 2、代码实现

paper:Medical Transformer: Gated Axial-Attention for Medical Image Segmentation

Code:https://github.com/jeya-maria-jose/Medical-Transformer


1、Gated Axial-Attention

论文首先分析了 ViTs 在训练小规模数据集时的弊端以及指出了 ViTs 的计算复杂度偏高。为此,论文提出了一种门控轴向注意力(Gated Axial-Attention),其通过在自注意力模块中引入额外的门控机制来扩展现有的体系结构。在分析了位置偏差难以学习、相对位置编码不够准确等问题后,通过将可控制的影响位置偏差施加在编码的非本地上下文来实现改进。Gated Axial-Attention的 核心思想是Gate门控机制,通过引入 Gate 控制机制来控制位置编码对 Self-Attention 的影响程度。

对于一个输入特征 X,Gated Axial-Attention的实现过程:

  1. 输入特征图: 将输入图像提取特征图,并进行通道维度上的线性变换,得到 Query、Key 和 Value 向量。

  2. Axial-Attention

    在高度方向上进行 1D Self-Attention,计算像素之间的依赖关系。

    在宽度方向上进行 1D Self-Attention,计算像素之间的依赖关系。

  3. Positional Encoding:计算相对位置编码,将像素位置信息融入到 Query、Key 和 Value 向量中。

  4. Gate 控制机制:通过可学习的 Gate 参数,控制相对位置编码对 Self-Attention 的影响程度。

  5. 输出特征图: 将经过 Self-Attention 和 Gate 控制的特征图进行线性变换,得到最终输出特征图。


Gated Axial-Attention 结构图:
在这里插入图片描述


2、代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import mathdef conv1x1(in_planes, out_planes, stride=1):"""1x1 卷积"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)class qkv_transform(nn.Conv1d):"""Conv1d for qkv_transform"""class AxialAttention(nn.Module):def __init__(self, in_planes, out_planes, groups=8, kernel_size=56,stride=1, bias=False, width=False):assert (in_planes % groups == 0) and (out_planes % groups == 0)super(AxialAttention, self).__init__()self.in_planes = in_planesself.out_planes = out_planesself.groups = groupsself.group_planes = out_planes // groupsself.kernel_size = kernel_sizeself.stride = strideself.bias = biasself.width = width# Multi-head self attentionself.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1,padding=0, bias=False)self.bn_qkv = nn.BatchNorm1d(out_planes * 2)self.bn_similarity = nn.BatchNorm2d(groups * 3)self.bn_output = nn.BatchNorm1d(out_planes * 2)# Position embeddingself.relative = nn.Parameter(torch.randn(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True)query_index = torch.arange(kernel_size).unsqueeze(0)key_index = torch.arange(kernel_size).unsqueeze(1)relative_index = key_index - query_index + kernel_size - 1self.register_buffer('flatten_index', relative_index.view(-1))if stride > 1:self.pooling = nn.AvgPool2d(stride, stride=stride)self.reset_parameters()def forward(self, x):# pdb.set_trace()if self.width:x = x.permute(0, 2, 1, 3)else:x = x.permute(0, 3, 1, 2)  # N, W, C, HN, W, C, H = x.shapex = x.contiguous().view(N * W, C, H)# Transformationsqkv = self.bn_qkv(self.qkv_transform(x))q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H),[self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)# Calculate position embeddingall_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2,self.kernel_size,self.kernel_size)q_embedding, k_embedding, v_embedding = torch.split(all_embeddings,[self.group_planes // 2, self.group_planes // 2,self.group_planes], dim=0)qr = torch.einsum('bgci,cij->bgij', q, q_embedding)kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)qk = torch.einsum('bgci, bgcj->bgij', q, k)stacked_similarity = torch.cat([qk, qr, kr], dim=1)stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)# stacked_similarity = self.bn_qr(qr) + self.bn_kr(kr) + self.bn_qk(qk)# (N, groups, H, H, W)similarity = F.softmax(stacked_similarity, dim=3)sv = torch.einsum('bgij,bgcj->bgci', similarity, v)sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)if self.width:output = output.permute(0, 2, 1, 3)else:output = output.permute(0, 2, 3, 1)if self.stride > 1:output = self.pooling(output)return outputdef reset_parameters(self):self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))# nn.init.uniform_(self.relative, -0.1, 0.1)nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes))if __name__ == '__main__':x = torch.randn(4, 512, 7, 7).cuda()# kernel_size 要跟 h,w 相同model = AxialAttention(512, 512, kernel_size=7).cuda()out = model(x)print(out.shape)

本文只是对论文中的即插即用模块做了整合,对论文中的一些地方难免有遗漏之处,如果想对这些模块有更详细的了解,还是要去读一下原论文,肯定会有更多收获。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/479647.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[C++ 核心编程]笔记 4.1 封装

4.1.1 封装的意义 封装是C面向对象三大特性之一 封装的意义: 将属性和行为作为一个整体,表现生活中的事物将属性和行为加以权限控制 封装意义一: 在设计类的时候,属性和行为写在一起,表现事物 语法: class 类名{ 访问权限: 属性 /行为 }…

韩顺平 一周学会Linux | Linux 实操篇-组管理和权限管理

一、Linux 组 1. 组基本介绍 在linux 中的每个用户必须属于一个组,不能独立于组外。在linux 中每个文件有所有者、所在组、其它组的概念。 2. 文件/目录 所有者 一般为文件的创建者,谁创建了该文件,就自然的成为该文件的所有者。 1) 查看文件所有者&…

FBX福币交易所创业板指放量大涨2.73% 谷子经济概念持续爆发

查查配分析11月27日电 27日,A股三大指数探底回升,沪指涨逾1%,创业板指涨近3%。全市成交额较上个交易日放量至1.49万亿元。 截至收盘,上证指数涨1.53%,报3309.78点;深证成指涨2.25%,报10566.10点;创业板指涨2.73%,报2208.78点。 FBX福币凭借用户友好的界面和对透明度的承诺,迅速…

前端性能优化之任务管理/调度

浏览器的一帧 前面我们提到如何使用requestAnimationFrame来检测是否产生了卡顿。除此之外,如果你也处理过简单的异步任务管理(闲时执行等),还可以使用requestIdleCallback来检测卡顿。其实,requestAnimationFrame和requestIdleCallback都会在浏览器的每一帧中被执行到,…

Ubuntu20.04安装kalibr

文章目录 环境配置安装wxPython下载编译测试报错1问题描述问题分析问题解决 参考 环境配置 Ubuntu20.04,python3.8.10,boost自带的1.71 sudo apt update sudo apt-get install python3-setuptools python3-rosinstall ipython3 libeigen3-dev libboost…

QUAD-MxFE平台

QUAD-MxFE平台 16Tx/16Rx直接L/S/C频段采样相控阵/雷达/电子战/卫星通信开发平台 概览 优势和特点 四通道MxFE数字化处理卡 使用MxFE的多通道、宽带系统开发平台 与Xilinx VCU118评估板(不包括)搭配使用 16个RF接收(Rx)通道(32个数字Rx通道…

互联网视频推拉流EasyDSS视频直播点播平台视频转码有哪些技术特点和应用?

视频转码本质上是一个先解码再编码的过程。在转码过程中,原始视频码流首先被解码成原始图像数据,然后再根据目标编码标准、分辨率、帧率、码率等参数重新进行编码。这样,转换前后的码流可能遵循相同的视频编码标准,也可能不遵循。…

开源加密库mbedtls及其Windows编译库

目录 1 项目简介 2 功能特性 3 性能优势 4 平台兼容性 5 应用场景 6 特点 7 Windows编译 8 编译静态库及其测试示例下载 1 项目简介 Mbed TLS是一个由ARM Maintained的开源项目,它提供了一个轻量级的加密库,适用于嵌入式系统和物联网设备。这个项…

GESP C++等级考试 二级真题(2024年9月)

若需要在线模拟考试,可进入题库中心,在线备考,检验掌握程度: https://www.hixinao.com/tidan/exam-157.html?time1732669362&sid172&index1

upload-labs 靶场(11~21)

免责声明 本博客文章仅供教育和研究目的使用。本文中提到的所有信息和技术均基于公开来源和合法获取的知识。本文不鼓励或支持任何非法活动,包括但不限于未经授权访问计算机系统、网络或数据。 作者对于读者使用本文中的信息所导致的任何直接或间接后果不承担任何…

嵌入式硬件实战基础篇(四)多路直流稳压电源

设计一个多路直流稳压电源 要求设计制作一个多路输出直流稳压电源,可将220 V / 5 0HZ交流电转换为5路直流稳压输出。具体要求: 输出直流电压 12V, 5V;和一路输出3- 15V连续可调直流稳压电源: 输出电流Iom500mA; 稳压系数 Sr≤0.05;

【人工智能】深入解析GPT、BERT与Transformer模型|从原理到应用的完整教程

在当今人工智能迅猛发展的时代,自然语言处理(NLP)领域涌现出许多强大的模型,其中GPT、BERT与Transformer无疑是最受关注的三大巨头。这些模型不仅在学术界引起了广泛讨论,也在工业界得到了广泛应用。那么,G…

【计算机视觉+MATLAB】自动检测并可视化圆形目标:通过 imfindcircles 和 viscircles 函数

引言 自动检测图像中的圆形或圆形对象,并可视化检测到的圆形。 函数详解 imfindcircles imfindcircles是MATLAB中的一个函数,用于在图像中检测并找出圆形区域。 基本语法: [centers, radii] imfindcircles(A, radiusRange) [centers, r…

17. C++模板(template)1(泛型编程,函数模板,类模板)

⭐本篇重点:泛型编程,函数模板,类模板 ⭐本篇代码:c学习/07.函数模板 橘子真甜/c-learning-of-yzc - 码云 - 开源中国 (gitee.com) 目录 一. 泛型编程 二. 函数模板 2.1 函数模板的格式 2.2 函数模板的简单使用 2.3 函数模板…

学习threejs,设置envMap环境贴图创建反光效果

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:threejs gis工程师 文章目录 一、🍀前言1.1 ☘️THREE.CubeTextureLoader 立…

v-for产生 You may have an infinite update loop in a component render function

参考文章&#xff1a; 报错解析 [Vue warn]: You may have an infinite update loop in a component render function. 另外一个解决方法 例如: MyList 是一个数组&#xff0c;我希望将排序后的结果返回进行for循环&#xff0c;因此设计了一个myMethon函数 <div v-for"…

spring boot框架漏洞复现

spring - java开源框架有五种 Spring MVC、SpringBoot、SpringFramework、SpringSecurity、SpringCloud spring boot版本 版本1: 直接就在根下 / 版本2:根下的必须目录 /actuator/ 端口:9093 spring boot搭建 1:直接下载源码打包 2:运行编译好的jar包:actuator-testb…

【Linux】线程的互斥和同步

【Linux】线程的互斥和同步 线程间的互斥 临界资源&#xff1a;多线程执行共享的资源就叫做临界资源临界区&#xff1a;每个线程内部&#xff0c;访问临界资源的代码&#xff0c;就叫做临界区互斥&#xff1a;任何时刻&#xff0c;互斥保证有且只有一个执行流进入临界区&#…

集合Queue、Deque、LinkedList、ArrayDeque、PriorityQueue详解

1、 Queue与Deque的区别 在研究java集合源码的时候&#xff0c;发现了一个很少用但是很有趣的点&#xff1a;Queue以及Deque&#xff1b; 平常在写leetcode经常用LinkedList向上转型Deque作为栈或者队列使用&#xff0c;但是一直都不知道Queue的作用&#xff0c;于是就直接官方…

亮相全国集群智能与协同控制大会,卓翼飞思无人智能科研方案成焦点

无人集群智能协同技术是人工智能发展的必然趋势&#xff0c;也是我国新一代人工智能的核心研究领域。为加强集群智能与协同控制需求牵引和对接、技术交流和互动&#xff0c;11月23-25日&#xff0c;由中国指挥与控制学会主办的第八届全国集群智能与协同控制大会在贵阳市隆重召开…