一、本文介绍
本文给大家带来的改进是Triplet Attention三重注意力机制。这个机制,它通过三个不同的视角来分析输入的数据,就好比三个人从不同的角度来观察同一幅画,然后共同决定哪些部分最值得注意。三重注意力机制的主要思想是在网络中引入了一种新的注意力模块,这个模块包含三个分支,分别关注图像的不同维度。比如说,一个分支可能专注于图像的宽度,另一个分支专注于高度,第三个分支则聚焦于图像的深度,即色彩和纹理等特征。这样一来,网络就能够更全面地理解图像内容,就像是得到了一副三维眼镜,能够看到图片的立体效果一样。本文改进是基于ResNet18、ResNet34、ResNet50、ResNet101,文章中均以提供,本专栏的改进内容全网独一份深度改进RT-DETR非那种无效Neck部分改进,同时本文的改进也支持主干上的即插即用,本文内容也支持PP-HGNetV2版本的修改。
专栏目录: RT-DETR改进有效系列目录 | 包含卷积、主干、RepC3、注意力机制、Neck上百种创新机制
专栏链接:RT-DETR剑指论文专栏,持续复现各种顶会内容——论文收割机RT-DETR
目录
一、本文介绍
二、Triplet Attention机制原理
2.1 Triplet Attention的基本原理
2.2 Triplet Attention和其它简单注意力机制的对比
2.3 Triplet Attention的实现流程
三、Triplet Attention的完整代码
四、手把手教你添加Triplet Attention
4.1 修改Basicclock/Bottleneck的教程
4.1.1 修改一
4.1.2 修改二
4.2 修改主干上即插即用的教程
4.2.1 修改一(如果修改了4.1教程此步无需修改)
4.2.2 修改二
4.2.3 修改三
4.2.4 修改四
五、Triplet Attention的yaml文件
5.1 替换ResNet的yaml文件1(ResNet18版本)
5.2 替换ResNet的yaml文件1(ResNet50版本)
5.3 即插即用的yaml文件(HGNetV2版本)
六、成功运行记录
6.1 ResNet18运行成功记录截图
6.2 ResNet50运行成功记录截图
6.3 HGNetv2运行成功记录截图
七、全文总结
二、Triplet Attention机制原理
论文地址:官方论文地址
代码地址:官方代码地址
2.1 Triplet Attention的基本原理
三重注意力(Triplet Attention)的基本原理是利用三支结构捕获输入数据的跨维度交互,从而计算注意力权重。这个方法能够构建输入通道或空间位置之间的相互依赖性,而且计算代价小。三重注意力由三个分支组成,每个分支负责捕获空间维度H或W与通道维度C之间的交互特征。通过对每个分支中的输入张量进行排列变换,然后通过Z池操作和一个大小为k×k的卷积层,生成注意力权重。这些权重是通过一个S形激活层生成的,然后应用于排列变换后的输入张量,再变换回原来的输入形状
三重注意力(Triplet Attention)的主要改进点包括:
-
跨维度的注意力权重计算: 通过一个创新的三支结构捕获通道、高度、宽度三个维度之间的交互关系来计算注意力权重。
-
旋转操作和残差变换: 通过旋转输入张量和应用残差变换来建立不同维度间的依赖,这是三重注意力机制中的关键步骤。
-
维度间依赖性的重要性: 强调在计算注意力权重时,捕获跨维度依赖性的重要性,这是三重注意力的核心直觉和设计理念。
下面的图片是三重注意力的一个抽象表示图,展示了三个分支如何捕获跨维度交互。图中的每个子图表示三重注意力中的一个分支:
1. 分支(a): 这个分支直接处理输入张量,没有进行旋转,然后通过残差变换来提取特征。
2. 分支(b): 这个分支首先沿着宽度(W)和通道(C)的维度旋转输入张量,然后进行残差变换。
3. 分支(c): 这个分支沿着高度(H)和通道(C)的维度旋转输入张量,之后同样进行残差变换。
总结:通过这样的设计,三重注意力模型能够有效地捕获输入张量中的空间和通道维度之间的交互关系。这种方法使模型能够构建通道与空间位置之间的相互依赖性,提高模型对特征的理解能力。
2.2 Triplet Attention和其它简单注意力机制的对比
下面的图片是论文中三重注意力机制和其它注意力机制的一个对比大家有兴趣可以看看,横向扩展以下自己的知识库。
这张图片是一幅对比不同注意力模块的图示,其中包括:
1.Squeeze Excitation (SE) Module:
这个模块使用全局平均池化 (Global Avg Pool) 生成通道描述符,接着通过两个全连接层(1x1 Conv),中间使用ReLU激活函数,最后通过Sigmoid函数生成每个通道的权重。
2. Convolutional Block Attention Module (CBAM):
首先使用全局平均池化和全局最大池化(GAP + GMP)结合,再通过一个卷积层和ReLU激活函数,最后经过另一个卷积层和Sigmoid函数生成注意力权重。
3. Global Context (GC) Module:
从一个1x1卷积层开始,经过Softmax函数进行归一化,接着进行另一个1x1卷积,然后使用LayerNorm和最终的1x1卷积,通过广播加法结合原始特征图。
4. Triplet Attention (我们的方法):
分为三个分支,每个分支进行不同的处理:通道池化后的7x7卷积,Z池化,再接一个7x7卷积,然后是批量归一化和Sigmoid函数。每个分支都有一个Permute操作来调整维度。最后,三个分支的结果通过平均池化聚合起来生成最终的注意力权重。
每种模块都设计用于处理特征图(C x H x W),其中C是通道数,H是高度,W是宽度。这些模块通过不同方式计算注意力权重,增强网络对特征的重要部分的关注度,从而在各种视觉任务中提高性能。图片中的符号⊗代表矩阵乘法,⊕代表广播元素级加法。
2.3 Triplet Attention的实现流程
下面的图片是三重注意力(Triplet Attention)的具体实现流程图。图中详细展示了三个分支如何处理输入张量,并最终合成三重注意力。下面是对这个过程的描述:
-
上部分支: 负责计算通道维度C和空间维度W的注意力权重。这个分支对输入张量进行Z池化(Z-Pool)操作,然后通过一个卷积层(Conv),接着用Sigmoid函数生成注意力权重。
-
中部分支: 负责捕获通道维度C与空间维度H和W之间的依赖性。这个分支首先进行相同的Z池化和卷积操作,然后同样通过Sigmoid函数生成注意力权重。
-
下部分支: 用于捕获空间维度之间的依赖性。这个分支保持输入的身份(Identity,即不改变输入),执行Z池化和卷积操作,之后也通过Sigmoid函数生成注意力权重。
每个分支在生成注意力权重后,会对输入进行排列(Permutation),然后将三个分支的输出进行平均聚合(Avg),最终得到三重注意力输出。
这种结构通过不同的旋转和排列操作,能够综合不同维度上的信息,更好地捕获数据的内在特征,同时这种方法在计算上是高效的,并且可以作为一个模块加入到现有的网络架构中,增强网络对复杂数据结构的理解和处理能力。
三、Triplet Attention的完整代码
大家复制代码的时候需要注意这是一种无参数的注意力机制,所以在看第四章添加教程的时候需要按照无参数的注意力机制进行添加。
import torch
import torch.nn as nn__all__ = ['TripletAttention', 'BottleNeck_TripletAttention', 'BasicBlock_TripletAttention']class BasicConv(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,bn=True, bias=False):super(BasicConv, self).__init__()self.out_channels = out_planesself.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,dilation=dilation, groups=groups, bias=bias)self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else Noneself.relu = nn.ReLU() if relu else Nonedef forward(self, x):x = self.conv(x)if self.bn is not None:x = self.bn(x)if self.relu is not None:x = self.relu(x)return xclass ZPool(nn.Module):def forward(self, x):return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)class AttentionGate(nn.Module):def __init__(self):super(AttentionGate, self).__init__()kernel_size = 7self.compress = ZPool()self.conv = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False)def forward(self, x):x_compress = self.compress(x)x_out = self.conv(x_compress)scale = torch.sigmoid_(x_out)return x * scaleclass TripletAttention(nn.Module):def __init__(self, no_spatial=False):super(TripletAttention, self).__init__()self.cw = AttentionGate()self.hc = AttentionGate()self.no_spatial = no_spatialif not no_spatial:self.hw = AttentionGate()def forward(self, x):x_perm1 = x.permute(0, 2, 1, 3).contiguous()x_out1 = self.cw(x_perm1)x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()x_perm2 = x.permute(0, 3, 2, 1).contiguous()x_out2 = self.hc(x_perm2)x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()if not self.no_spatial:x_out = self.hw(x)x_out = 1 / 3 * (x_out + x_out11 + x_out21)else:x_out = 1 / 2 * (x_out11 + x_out21)return x_outfrom collections import OrderedDict
import torch.nn.functional as Fclass ConvNormLayer(nn.Module):def __init__(self,ch_in,ch_out,filter_size,stride,groups=1,act=None):super(ConvNormLayer, self).__init__()self.act = actself.conv = nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=filter_size,stride=stride,padding=(filter_size - 1) // 2,groups=groups)self.norm = nn.BatchNorm2d(ch_out)def forward(self, inputs):out = self.conv(inputs)out = self.norm(out)if self.act:out = getattr(F, self.act)(out)return outclass BasicBlock_TripletAttention(nn.Module):expansion = 1def __init__(self,ch_in,ch_out,stride,shortcut,act='relu',variant='b',att=False):super(BasicBlock_TripletAttention, self).__init__()self.shortcut = shortcutif not shortcut:if variant == 'd' and stride == 2:self.short = nn.Sequential()self.short.add_sublayer('pool',nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True))self.short.add_sublayer('conv',ConvNormLayer(ch_in=ch_in,ch_out=ch_out,filter_size=1,stride=1))else:self.short = ConvNormLayer(ch_in=ch_in,ch_out=ch_out,filter_size=1,stride=stride)self.branch2a = ConvNormLayer(ch_in=ch_in,ch_out=ch_out,filter_size=3,stride=stride,act='relu')self.branch2b = ConvNormLayer(ch_in=ch_out,ch_out=ch_out,filter_size=3,stride=1,act=None)self.att = attif self.att:self.se = TripletAttention(ch_out)def forward(self, inputs):out = self.branch2a(inputs)out = self.branch2b(out)if self.att:out = self.se(out)if self.shortcut:short = inputselse:short = self.short(inputs)out = out + shortout = F.relu(out)return outclass BottleNeck_TripletAttention(nn.Module):expansion = 4def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='d', att=False):super().__init__()if variant == 'a':stride1, stride2 = stride, 1else:stride1, stride2 = 1, stridewidth = ch_outself.branch2a = ConvNormLayer(ch_in, width, 1, stride1, act=act)self.branch2b = ConvNormLayer(width, width, 3, stride2, act=act)self.branch2c = ConvNormLayer(width, ch_out * self.expansion, 1, 1)self.shortcut = shortcutif not shortcut:if variant == 'd' and stride == 2:self.short = nn.Sequential(OrderedDict([('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)),('conv', ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1))]))else:self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride)self.att = attif self.att:self.se = TripletAttention(ch_out * 4)def forward(self, x):out = self.branch2a(x)out = self.branch2b(out)out = self.branch2c(out)if self.att:out = self.se(out)if self.shortcut:short = xelse:short = self.short(x)out = out + shortout = F.relu(out)return out
四、手把手教你添加Triplet Attention
修改教程分两种,一种是替换修改ResNet中的Basicclock/Bottleneck模块的,一种是在主干上即插即用的修改教程,如果你只需要一种那么修改对应的就行,互相之间并不影响,需要注意的是即插即用的需要修改ResNet改进才行,链接如下:
ResNet文章地址:【RT-DETR改进涨点】ResNet18、34、50、101等多个版本移植到ultralytics仓库(RT-DETR官方一比一移植)
4.1 修改Basicclock/Bottleneck的教程
4.1.1 修改一
第一还是建立文件,我们找到如下ultralytics/nn/modules文件夹下建立一个目录名字呢就是'Addmodules'文件夹(用群内的文件的话已经有了无需新建)!然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。
4.1.2 修改二
第二步此处需要注意,因为我这里默认大家修改了ResNet系列的模型了,同级目录下应该有一个ResNet.py的文件夹,我们这里需要找到我们'ultralytics/nn/Addmodules/ResNet.py'创建的ResNet的文件夹(默认大家已经创建了!!!)
我们只需要修改上面的两步即可,后面复制yaml文件进行运行即可了,修改方法大家只要仔细看是非常简单的。
4.2 修改主干上即插即用的教程
4.2.1 修改一(如果修改了4.1教程此步无需修改)
第一还是建立文件,我们找到如下ultralytics/nn/modules文件夹下建立一个目录名字呢就是'Addmodules'文件夹(用群内的文件的话已经有了无需新建)!然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。
4.2.2 修改二
第二步我们在该目录下创建一个新的py文件名字为'__init__.py'(用群内的文件的话已经有了无需新建),然后在其内部导入我们的检测头如下图所示。
4.2.3 修改三
第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块(用群内的文件的话已经有了无需重新导入直接开始第四步即可)!
从今天开始以后的教程就都统一成这个样子了,因为我默认大家用了我群内的文件来进行修改!!
4.2.4 修改四
按照我的添加在parse_model里添加即可。
elif m in {TripletAttention}:c2 = ch[f]args = [c2, *args]
到此就修改完成了,大家可以复制下面的yaml文件运行。
五、Triplet Attention的yaml文件
5.1 替换ResNet的yaml文件1(ResNet18版本)
需要修改如下的ResNet主干才可以运行本文的改进机制 !
ResNet文章地址:【RT-DETR改进涨点】ResNet18、34、50、101等多个版本移植到ultralytics仓库(RT-DETR官方一比一移植)
# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'# [depth, width, max_channels]l: [1.00, 1.00, 1024]backbone:# [from, repeats, module, args]- [-1, 1, ConvNormLayer, [32, 3, 2, 1, 'relu']] # 0-P1- [-1, 1, ConvNormLayer, [32, 3, 1, 1, 'relu']] # 1- [-1, 1, ConvNormLayer, [64, 3, 1, 1, 'relu']] # 2- [-1, 1, nn.MaxPool2d, [3, 2, 1]] # 3-P2- [-1, 2, Blocks, [64, BasicBlock_TripletAttention, 2, True]] # 4- [-1, 2, Blocks, [128, BasicBlock_TripletAttention, 3, True]] # 5-P3- [-1, 2, Blocks, [256, BasicBlock_TripletAttention, 4, True]] # 6-P4- [-1, 2, Blocks, [512, BasicBlock_TripletAttention, 5, True]] # 7-P5head:- [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 8 input_proj.2- [-1, 1, AIFI, [1024, 8]]- [-1, 1, Conv, [256, 1, 1]] # 10, Y5, lateral_convs.0- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 11- [6, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 12 input_proj.1- [[-2, -1], 1, Concat, [1]]- [-1, 3, RepC3, [256, 0.5]] # 14, fpn_blocks.0- [-1, 1, Conv, [256, 1, 1]] # 15, Y4, lateral_convs.1- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 16- [5, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 17 input_proj.0- [[-2, -1], 1, Concat, [1]] # 18 cat backbone P4- [-1, 3, RepC3, [256, 0.5]] # X3 (19), fpn_blocks.1- [-1, 1, Conv, [256, 3, 2]] # 20, downsample_convs.0- [[-1, 15], 1, Concat, [1]] # 21 cat Y4- [-1, 3, RepC3, [256, 0.5]] # F4 (22), pan_blocks.0- [-1, 1, Conv, [256, 3, 2]] # 23, downsample_convs.1- [[-1, 10], 1, Concat, [1]] # 24 cat Y5- [-1, 3, RepC3, [256, 0.5]] # F5 (25), pan_blocks.1- [[19, 22, 25], 1, RTDETRDecoder, [nc, 256, 300, 4, 8, 3]] # Detect(P3, P4, P5)
5.2 替换ResNet的yaml文件1(ResNet50版本)
需要修改如下的ResNet主干才可以运行本文的改进机制 !
ResNet文章地址:【RT-DETR改进涨点】ResNet18、34、50、101等多个版本移植到ultralytics仓库(RT-DETR官方一比一移植)
# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'# [depth, width, max_channels]l: [1.00, 1.00, 1024]backbone:# [from, repeats, module, args]- [-1, 1, ConvNormLayer, [32, 3, 2, 1, 'relu']] # 0-P1- [-1, 1, ConvNormLayer, [32, 3, 1, 1, 'relu']] # 1- [-1, 1, ConvNormLayer, [64, 3, 1, 1, 'relu']] # 2- [-1, 1, nn.MaxPool2d, [3, 2, 1]] # 3-P2- [-1, 3, Blocks, [64, BottleNeck_TripletAttention, 2, True]] # 4- [-1, 4, Blocks, [128, BottleNeck_TripletAttention, 3, True]] # 5-P3- [-1, 6, Blocks, [256, BottleNeck_TripletAttention, 4, True]] # 6-P4- [-1, 3, Blocks, [512, BottleNeck_TripletAttention, 5, True]] # 7-P5head:- [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 8 input_proj.2- [-1, 1, AIFI, [1024, 8]] # 9- [-1, 1, Conv, [256, 1, 1]] # 10, Y5, lateral_convs.0- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 11- [6, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 12 input_proj.1- [[-2, -1], 1, Concat, [1]] # 13- [-1, 3, RepC3, [256]] # 14, fpn_blocks.0- [-1, 1, Conv, [256, 1, 1]] # 15, Y4, lateral_convs.1- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 16- [5, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 17 input_proj.0- [[-2, -1], 1, Concat, [1]] # 18 cat backbone P4- [-1, 3, RepC3, [256]] # X3 (19), fpn_blocks.1- [-1, 1, Conv, [256, 3, 2]] # 20, downsample_convs.0- [[-1, 15], 1, Concat, [1]] # 21 cat Y4- [-1, 3, RepC3, [256]] # F4 (22), pan_blocks.0- [-1, 1, Conv, [256, 3, 2]] # 23, downsample_convs.1- [[-1, 10], 1, Concat, [1]] # 24 cat Y5- [-1, 3, RepC3, [256]] # F5 (25), pan_blocks.1- [[19, 22, 25], 1, RTDETRDecoder, [nc, 256, 300, 4, 8, 6]] # Detect(P3, P4, P5)
5.3 即插即用的yaml文件(HGNetV2版本)
此版本为HGNetV2-l的yaml文件!
# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'# [depth, width, max_channels]l: [1.00, 1.00, 1024]backbone:# [from, repeats, module, args]- [-1, 1, HGStem, [32, 48]] # 0-P2/4- [-1, 6, HGBlock, [48, 128, 3]] # stage 1- [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8- [-1, 6, HGBlock, [96, 512, 3]] # stage 2- [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16- [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut- [-1, 6, HGBlock, [192, 1024, 5, True, True]]- [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3- [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32- [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4head:- [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2- [-1, 1, AIFI, [1024, 8]]- [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1- [[-2, -1], 1, Concat, [1]]- [-1, 3, RepC3, [256]] # 16, fpn_blocks.0- [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0- [[-2, -1], 1, Concat, [1]] # cat backbone P4- [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1- [-1, 1, TripletAttention, []] # 22- [-1, 1, Conv, [384, 3, 2]] # 23, downsample_convs.0- [[-1, 17], 1, Concat, [1]] # cat Y4- [-1, 3, RepC3, [256]] # F4 (25), pan_blocks.0- [-1, 1, TripletAttention, []] # 26- [-1, 1, Conv, [384, 3, 2]] # 27, downsample_convs.1- [[-1, 12], 1, Concat, [1]] # cat Y5- [-1, 3, RepC3, [256]] # F5 (29), pan_blocks.1- [-1, 1, TripletAttention, []] # 30- [[22, 26, 30], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
六、成功运行记录
6.1 ResNet18运行成功记录截图
6.2 ResNet50运行成功记录截图
6.3 HGNetv2运行成功记录截图
七、全文总结
到此本文的正式分享内容就结束了,在这里给大家推荐我的RT-DETR改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~
专栏链接:RT-DETR剑指论文专栏,持续复现各种顶会内容——论文收割机RT-DETR