本文这次分享的是三重注意力机制Triplet Attention。现在注意力机制在计算机视觉任务中被广泛研究和应用,如 Squeeze-and-Excitation Networks (SENet)、Convolutional Block Attention Module (CBAM) 等。然而,这些方法存在一些局限性,例如需要大量可学习参数,或者在计算通道注意力时没有考虑跨维度交互等。Triplet Attention可以有效的解决这些问题,其强调在计算注意力权重时捕获跨维度交互的重要性,以提供丰富的特征表示。
1. 三重注意力机制Triplet Attention
这张图展示了一个名为 “Triplet Attention” 的神经网络模块的结构。该模块的输入是一个形状为CxHxW的张量,整个模块由三个并行的分支组成,每个分支都执行相似的操作:
- Z - Pool 操作:每个分支首先进行 Z - Pool 操作。Z - Pool 是一种池化操作,它将输入张量在通道维度上进行平均池化和最大池化,并将结果在通道维度上进行拼接。
- 卷积操作(Conv):接着进行卷积操作,卷积核的大小未在图中明确标出,但卷积操作后张量的形状仍然保持为 。
- Sigmoid 激活函数:卷积操作之后,应用 Sigmoid 激活函数,将输出值压缩到 0 到 1 之间。
2.Triplet Attention结构
1. 模块 (d) 首先对输入张量进行维度置换(Permute),改变张量的维度顺序。
2. 然后进行 Z - Pool 操作,提取通道维度的特征。 Z - Pool 操作是一种在神经网络中用于特征提取的池化操作。它结合了平均池化(Average Pooling)和最大池化(Maximum Pooling)的优点。 在 Z - Pool 操作中,对于输入张量的某一维度(通常是通道维度),会同时进行平均池化和最大池化操作,然后将这两个池化结果在该维度上进行拼。
3. 利用 7x7 卷积、批归一化(Batch Norm)和 Sigmoid 函数生成注意力权重 。
4. 再次进行维度置换后,通过 1x1 卷积将权重应用到原始输入张量,达到对输入特征重新加权的目的。
3. 接下来,我们将详细介绍如何将Triplet_Attention集成到 YOLOv8 模型中。
这是我的GitHub代码:tgf123/YOLOv8_improve (github.com)
这是改进讲解:YOLOv8模型改进 第二十讲 添加三重注意力机制Triplet Attention 提升小目标、遮挡目标_哔哩哔哩_bilibili
3.1 如何添加
1. 首先,在我上传的代码中yolov8_improve中找到Triplet_Attention.py代码部分,它包含两个部分一个是Triplet_Attention.py的核心代码,一个是yolov8模型的配置文件。
2. 然后我们在ultralytics文件夹下面创建一个新的文件夹,名字叫做change_models, 然后再这个文件夹下面创建Triplet_Attention.py文件,然后将iRMB的核心代码放入其中
3. 在 task.py文件中导入Triplet_Attention
from ultralytics.change_models.triplet_attention import C2f_TripletAttention,TripletAttention
4. 然后将 Triplet_Attention添加到下面当中
第一个改进修改的地方
第二个改进修改的地方
5. 最后将配置文件复制到下面文件夹下
6. 运行代码跑通
from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorldif __name__=="__main__":# 使用自己的YOLOv8.yamy文件搭建模型并加载预训练权重训练模型model = YOLO(r"D:\bilibili\model\YOLOV8_new\ultralytics-main\ultralytics\cfg\models\v8\yolov8_Triplet_Attention.yaml")\.load(r'D:\bilibili\model\YOLOV8_new\ultralytics-main\yolov8n.pt') # build from YAML and transfer weightsresults = model.train(data=r'D:\bilibili\model\ultralytics-main\ultralytics\cfg\datasets\VOC_my.yaml',epochs=100,imgsz=640,batch=8,cache = False,# single_cls = False, # 是否是单类别检测# workers = 0,# resume='',# amp = False)
from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorldif __name__=="__main__":# 使用自己的YOLOv8.yamy文件搭建模型并加载预训练权重训练模型model = YOLO(r"D:\bilibili\model\YOLOV8_new\ultralytics-main\ultralytics\cfg\models\v8\yolov8_irmb.yaml")\.load(r'D:\bilibili\model\YOLOV8_new\ultralytics-main\yolov8n.pt') # build from YAML and transfer weightsresults = model.train(data=r'D:\bilibili\model\ultralytics-main\ultralytics\cfg\datasets\VOC_my.yaml',epochs=100, imgsz=640, batch=8)