一、本文介绍
本文给大家带来的改进机制是EMAttention注意力机制,它的核心思想是,重塑部分通道到批次维度,并将通道维度分组为多个子特征,以保留每个通道的信息并减少计算开销。EMA模块通过编码全局信息来重新校准每个并行分支中的通道权重,并通过跨维度交互来捕获像素级别的关系。本文改进是基于ResNet18、ResNet34、ResNet50、ResNet101,文章中均以提供,本专栏的改进内容全网独一份深度改进RT-DETR非那种无效Neck部分改进,同时本文的改进也支持主干上的即插即用,本文内容也支持PP-HGNetV2版本的修改。
专栏目录: RT-DETR改进有效系列目录 | 包含卷积、主干、RepC3、注意力机制、Neck上百种创新机制
专栏链接:RT-DETR剑指论文专栏,持续复现各种顶会内容——论文收割机RT-DETR
目录
一、本文介绍
二、EMAttention的框架原理
三、EMAttention的核心代码
四、 手把手教你添加EMAttention(注意看此处)
4.1 修改Basicclock/Bottleneck的教程
4.1.1 修改一
4.1.2 修改二
4.2 修改主干上即插即用的教程
4.2.1 修改一(如果修改了4.1教程此步无需修改)
4.2.2 修改二
4.2.3 修改三
4.2.4 修改四
五、EMAttention的yaml文件
5.1 替换ResNet的yaml文件1(ResNet18版本)
5.2 替换ResNet的yaml文件1(ResNet50版本)
5.3 即插即用的yaml文件(HGNetV2版本)
六、成功运行记录
6.1 ResNet18运行成功记录截图
6.2 ResNet50运行成功记录截图
6.3 HGNetv2运行成功记录截图
七、全文总结
二、EMAttention的框架原理
官方论文地址: 官方论文地址
官方代码地址: 官方代码地址
主要原理是一个新型的高效多尺度注意力(EMA)这个模块通过重塑部分通道到批次维度,并将通道维度分组为多个子特征,以保留每个通道的信息并减少计算开销。EMA模块通过编码全局信息来重新校准每个并行分支中的通道权重,并通过跨维度交互来捕获像素级别的关系。
提出的创新点主要包括:
1. 高效多尺度注意力(EMA)模:这是一种新型的注意力机制,专为计算机视觉任务设计,旨在同时减少计算开销和保留每个通道的关键信息。
2. 通道和批次维度的重组:EMA通过重新组织通道维度和批次维度,提高了模型处理特征的能力。
3. 跨维度交互:模块利用跨维度的交互来捕捉像素级别的关系,这在传统的注意力模型中较为少见。
4. 全局信息编码和通道权重校准:EMA模块在并行分支中编码全局信息,用于通道权重的重新校准,增强了特征表示的能力。
这张图片是文章中提出的高效多尺度注意力(EMA)模块的示意图。"g"表示输入通道被分成的组数。"X Avg Pool"和"Y Avg Pool"分别代表一维水平和垂直的全局池化操作。在EMA模块中,输入首先被分组,然后通过不同的分支进行处理:一个分支进行一维全局池化,另一个通过3x3的卷积进行特征提取。两个分支的输出特征之后通过sigmoid函数和归一化操作进行调制,最终通过跨维度交互模块合并,以捕捉像素级的成对关系。经过最终的sigmoid调节后,输出特征映射以增强或减弱原始输入特征,从而得到最终输出。
三、EMAttention的核心代码
使用方法看章节四。
import torch
from torch import nnclass EMA(nn.Module):def __init__(self, channels, factor=32):super(EMA, self).__init__()self.groups = factorassert channels // self.groups > 0self.softmax = nn.Softmax(-1)self.agp = nn.AdaptiveAvgPool2d((1, 1))self.pool_h = nn.AdaptiveAvgPool2d((None, 1))self.pool_w = nn.AdaptiveAvgPool2d((1, None))self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)def forward(self, x):b, c, h, w = x.size()group_x = x.reshape(b * self.groups, -1, h, w) # b*g,c//g,h,wx_h = self.pool_h(group_x)x_w = self.pool_w(group_x).permute(0, 1, 3, 2)hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))x_h, x_w = torch.split(hw, [h, w], dim=2)x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())x2 = self.conv3x3(group_x)x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))x12 = x2.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hwx21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))x22 = x1.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hwweights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)return (group_x * weights.sigmoid()).reshape(b, c, h, w)
四、 手把手教你添加EMAttention(注意看此处)
修改教程分两种,一种是替换修改ResNet中的Basicclock/Bottleneck模块的,一种是在主干上即插即用的修改教程,如果你只需要一种那么修改对应的就行,互相之间并不影响,需要注意的是即插即用的需要修改ResNet改进才行,链接如下:
ResNet文章地址:【RT-DETR改进涨点】ResNet18、34、50、101等多个版本移植到ultralytics仓库(RT-DETR官方一比一移植)
4.1 修改Basicclock/Bottleneck的教程
4.1.1 修改一
第一还是建立文件,我们找到如下ultralytics/nn/modules文件夹下建立一个目录名字呢就是'Addmodules'文件夹(用群内的文件的话已经有了无需新建)!然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。
4.1.2 修改二
第二步此处需要注意,因为我这里默认大家修改了ResNet系列的模型了,同级目录下应该有一个ResNet.py的文件夹,我们这里需要找到我们'ultralytics/nn/Addmodules/ResNet.py'创建的ResNet的文件夹(默认大家已经创建了!!!)
我们只需要修改上面的两步即可,后面复制yaml文件进行运行即可了,修改方法大家只要仔细看是非常简单的。
4.2 修改主干上即插即用的教程
4.2.1 修改一(如果修改了4.1教程此步无需修改)
第一还是建立文件,我们找到如下ultralytics/nn/modules文件夹下建立一个目录名字呢就是'Addmodules'文件夹(用群内的文件的话已经有了无需新建)!然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。
4.2.2 修改二
第二步我们在该目录下创建一个新的py文件名字为'__init__.py'(用群内的文件的话已经有了无需新建),然后在其内部导入我们的检测头如下图所示。
4.2.3 修改三
第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块(用群内的文件的话已经有了无需重新导入直接开始第四步即可)!
从今天开始以后的教程就都统一成这个样子了,因为我默认大家用了我群内的文件来进行修改!!
4.2.4 修改四
按照我的添加在parse_model里添加即可。
elif m in {EMA}:c2 = ch[f]args = [c2, *args]
到此就修改完成了,大家可以复制下面的yaml文件运行。
五、EMAttention的yaml文件
5.1 替换ResNet的yaml文件1(ResNet18版本)
需要修改如下的ResNet主干才可以运行本文的改进机制 !
ResNet文章地址:【RT-DETR改进涨点】ResNet18、34、50、101等多个版本移植到ultralytics仓库(RT-DETR官方一比一移植)
# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'# [depth, width, max_channels]l: [1.00, 1.00, 1024]backbone:# [from, repeats, module, args]- [-1, 1, ConvNormLayer, [32, 3, 2, 1, 'relu']] # 0-P1- [-1, 1, ConvNormLayer, [32, 3, 1, 1, 'relu']] # 1- [-1, 1, ConvNormLayer, [64, 3, 1, 1, 'relu']] # 2- [-1, 1, nn.MaxPool2d, [3, 2, 1]] # 3-P2- [-1, 2, Blocks, [64, BasicBlock_EMA, 2, True]] # 4- [-1, 2, Blocks, [128, BasicBlock_EMA, 3, True]] # 5-P3- [-1, 2, Blocks, [256, BasicBlock_EMA, 4, True]] # 6-P4- [-1, 2, Blocks, [512, BasicBlock_EMA, 5, True]] # 7-P5head:- [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 8 input_proj.2- [-1, 1, AIFI, [1024, 8]]- [-1, 1, Conv, [256, 1, 1]] # 10, Y5, lateral_convs.0- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 11- [6, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 12 input_proj.1- [[-2, -1], 1, Concat, [1]]- [-1, 3, RepC3, [256, 0.5]] # 14, fpn_blocks.0- [-1, 1, Conv, [256, 1, 1]] # 15, Y4, lateral_convs.1- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 16- [5, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 17 input_proj.0- [[-2, -1], 1, Concat, [1]] # 18 cat backbone P4- [-1, 3, RepC3, [256, 0.5]] # X3 (19), fpn_blocks.1- [-1, 1, Conv, [256, 3, 2]] # 20, downsample_convs.0- [[-1, 15], 1, Concat, [1]] # 21 cat Y4- [-1, 3, RepC3, [256, 0.5]] # F4 (22), pan_blocks.0- [-1, 1, Conv, [256, 3, 2]] # 23, downsample_convs.1- [[-1, 10], 1, Concat, [1]] # 24 cat Y5- [-1, 3, RepC3, [256, 0.5]] # F5 (25), pan_blocks.1- [[19, 22, 25], 1, RTDETRDecoder, [nc, 256, 300, 4, 8, 3]] # Detect(P3, P4, P5)
5.2 替换ResNet的yaml文件1(ResNet50版本)
需要修改如下的ResNet主干才可以运行本文的改进机制 !
ResNet文章地址:【RT-DETR改进涨点】ResNet18、34、50、101等多个版本移植到ultralytics仓库(RT-DETR官方一比一移植)
# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'# [depth, width, max_channels]l: [1.00, 1.00, 1024]backbone:# [from, repeats, module, args]- [-1, 1, ConvNormLayer, [32, 3, 2, 1, 'relu']] # 0-P1- [-1, 1, ConvNormLayer, [32, 3, 1, 1, 'relu']] # 1- [-1, 1, ConvNormLayer, [64, 3, 1, 1, 'relu']] # 2- [-1, 1, nn.MaxPool2d, [3, 2, 1]] # 3-P2- [-1, 3, Blocks, [64, BottleNeck_EMA, 2, True]] # 4- [-1, 4, Blocks, [128, BottleNeck_EMA, 3, True]] # 5-P3- [-1, 6, Blocks, [256, BottleNeck_EMA, 4, True]] # 6-P4- [-1, 3, Blocks, [512, BottleNeck_EMA, 5, True]] # 7-P5head:- [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 8 input_proj.2- [-1, 1, AIFI, [1024, 8]] # 9- [-1, 1, Conv, [256, 1, 1]] # 10, Y5, lateral_convs.0- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 11- [6, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 12 input_proj.1- [[-2, -1], 1, Concat, [1]] # 13- [-1, 3, RepC3, [256]] # 14, fpn_blocks.0- [-1, 1, Conv, [256, 1, 1]] # 15, Y4, lateral_convs.1- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 16- [5, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 17 input_proj.0- [[-2, -1], 1, Concat, [1]] # 18 cat backbone P4- [-1, 3, RepC3, [256]] # X3 (19), fpn_blocks.1- [-1, 1, Conv, [256, 3, 2]] # 20, downsample_convs.0- [[-1, 15], 1, Concat, [1]] # 21 cat Y4- [-1, 3, RepC3, [256]] # F4 (22), pan_blocks.0- [-1, 1, Conv, [256, 3, 2]] # 23, downsample_convs.1- [[-1, 10], 1, Concat, [1]] # 24 cat Y5- [-1, 3, RepC3, [256]] # F5 (25), pan_blocks.1- [[19, 22, 25], 1, RTDETRDecoder, [nc, 256, 300, 4, 8, 6]] # Detect(P3, P4, P5)
5.3 即插即用的yaml文件(HGNetV2版本)
此版本为HGNetV2-l的yaml文件!
# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'# [depth, width, max_channels]l: [1.00, 1.00, 1024]backbone:# [from, repeats, module, args]- [-1, 1, HGStem, [32, 48]] # 0-P2/4- [-1, 6, HGBlock, [48, 128, 3]] # stage 1- [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8- [-1, 6, HGBlock, [96, 512, 3]] # stage 2- [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16- [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut- [-1, 6, HGBlock, [192, 1024, 5, True, True]]- [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3- [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32- [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4head:- [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2- [-1, 1, AIFI, [1024, 8]]- [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1- [[-2, -1], 1, Concat, [1]]- [-1, 3, RepC3, [256]] # 16, fpn_blocks.0- [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0- [[-2, -1], 1, Concat, [1]] # cat backbone P4- [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1- [-1, 1, EMA, []] # 22- [-1, 1, Conv, [384, 3, 2]] # 23, downsample_convs.0- [[-1, 17], 1, Concat, [1]] # cat Y4- [-1, 3, RepC3, [256]] # F4 (25), pan_blocks.0- [-1, 1, EMA, []] # 26- [-1, 1, Conv, [384, 3, 2]] # 27, downsample_convs.1- [[-1, 12], 1, Concat, [1]] # cat Y5- [-1, 3, RepC3, [256]] # F5 (29), pan_blocks.1- [-1, 1, EMA, []] # 30- [[22, 26, 30], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
六、成功运行记录
6.1 ResNet18运行成功记录截图
6.2 ResNet50运行成功记录截图
6.3 HGNetv2运行成功记录截图
七、全文总结
到此本文的正式分享内容就结束了,在这里给大家推荐我的RT-DETR改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~
专栏链接:RT-DETR剑指论文专栏,持续复现各种顶会内容——论文收割机RT-DETR