Yolov8有效涨点,添加多种注意力机制,修改损失函数提高目标检测准确率

目录

简介

CBAM注意力机制原理及代码实现

原理

 代码实现

 GAM注意力机制

原理

代码实现

修改损失函数

YAML文件

完整代码


🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

http://t.csdnimg.cn/sVHxvicon-default.png?t=N7T8http://t.csdnimg.cn/sVHxv

简介

Ultralytics 推出了最新版本的 YOLO 模型。注意力机制是提高模型性能最热门的方法之一。

本次将介绍几种常见的注意力机制,这些注意力机制在大多数的数据集上均能有效的提升目标检测的精度/召回率/准确率。

CBAM注意力机制原理及代码实现
原理
CBAM注意力机制结构图

CBAM(Convolutional Block Attention Module)是一种用于卷积神经网络(CNN)的注意力机制,它能够增强网络对输入特征的关注度,提高网络性能。CBAM 主要包含两个子模块:通道注意力模块(Channel Attention Module)和空间注意力模块(Spatial Attention Module)。

以下是CBAM注意力机制的基本原理:

1. 通道注意力模块(Channel Attention Module):
输入:经过卷积层的特征图。
处理步骤:
对每个通道进行全局平均池化,得到通道的全局平均值。
通过两个全连接层,将全局平均值映射为两个权重向量(一个用于缩放,一个用于偏置)。
将这两个权重向量与原始特征图相乘,以加权调整每个通道的重要性。

2. 空间注意力模块(Spatial Attention Module):**
输入:通道注意力模块的输出。
处理步骤:
     对每个通道的特征图进行分别的最大池化和平均池化,得到两个空间特征图。
     将这两个空间特征图相加,通过一个卷积层产生一个权重图。
     将原始特征图与权重图相乘,以加权调整每个空间位置的重要性。

3. 整合:
   将通道注意力模块和空间注意力模块的输出相乘,得到最终的注意力增强特征图。
   将这个注意力增强的特征图传递给网络的下一层进行进一步处理。

CBAM的关键优势在于它能够同时考虑通道和空间信息,有助于网络更好地理解和利用输入特征。这种注意力机制有助于提高网络在视觉任务上的性能,使其能够更有针对性地关注重要的特征。

 代码实现

路径:"./ultralytics/nn/modules/conv.py"

class ChannelAttention(nn.Module):"""Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""def __init__(self, channels: int) -> None:super().__init__()self.pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)self.act = nn.Sigmoid()def forward(self, x: torch.Tensor) -> torch.Tensor:return x * self.act(self.fc(self.pool(x)))class SpatialAttention(nn.Module):"""Spatial-attention module."""def __init__(self, kernel_size=7):"""Initialize Spatial-attention module with kernel size argument."""super().__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.act = nn.Sigmoid()def forward(self, x):"""Apply channel and spatial attention on input for feature recalibration."""return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))class CBAM(nn.Module):"""Convolutional Block Attention Module."""def __init__(self, c1, kernel_size=7):  # ch_in, kernelssuper().__init__()self.channel_attention = ChannelAttention(c1)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):"""Applies the forward pass through C1 module."""return self.spatial_attention(self.channel_attention(x))

添加完代码以后需要在"./ultralytics/nn/tasks.py"进行注册

 GAM注意力机制
原理

目标的设计是一种减少信息缩减并放大全局维度交互特征的机制。我们采用 CBAM 的顺序通道空间注意力机制并重新设计子模块。整个过程如图所示。

GAM结构图


通道注意力机制
通道注意力子模块使用 3D 排列来保留三个维度的信息。然后,它使用两层 MLP(多层感知器)放大跨维度通道空间依赖性。 (MLP是一种编码器-解码器结构,其缩减比为r,与BAM相同。)通道注意子模块如图所示。 

通道注意力子模块


空间注意力机制
在空间注意力子模块中,为了关注空间信息,我们使用两个卷积层进行空间信息融合。我们还使用与 BAM 相同的通道注意子模块的缩减率 r。同时,最大池化会减少信息并产生负面影响。我们删除池化以进一步保留特征图。因此,空间注意力模块有时会显着增加参数的数量。为了防止参数显着增加,我们在 ResNet50 中采用带有通道洗牌的组卷积。没有组卷积的空间注意力子模块如图所示。 

空间注意力子模块
代码实现

代码添加在 ./ultralytics/nn/modules/conv.py 中,同样需要在task.py中注册

class GAM_Attention(nn.Module):def __init__(self, c1, c2, group=True, rate=4):super(GAM_Attention, self).__init__()self.channel_attention = nn.Sequential(nn.Linear(c1, int(c1 / rate)),nn.ReLU(inplace=True),nn.Linear(int(c1 / rate), c1))self.spatial_attention = nn.Sequential(nn.Conv2d(c1, c1 // rate, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(c1, int(c1 / rate),kernel_size=7,padding=3),nn.BatchNorm2d(int(c1 / rate)),nn.ReLU(inplace=True),nn.Conv2d(c1 // rate, c2, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(int(c1 / rate), c2,kernel_size=7,padding=3),nn.BatchNorm2d(c2))def forward(self, x):b, c, h, w = x.shapex_permute = x.permute(0, 2, 3, 1).view(b, -1, c)x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)x_channel_att = x_att_permute.permute(0, 3, 1, 2)# x_channel_att=channel_shuffle(x_channel_att,4) #last shufflex = x * x_channel_attx_spatial_att = self.spatial_attention(x).sigmoid()x_spatial_att = channel_shuffle(x_spatial_att, 4)  # last shuffleout = x * x_spatial_att# out=channel_shuffle(out,4) #last shufflereturn out
修改损失函数

WIoU是一种新型的损失函数,代码实现

def WIoU(cls, pred, target, self=None):self = self if self else cls(pred, target)dist = torch.exp(self.l2_center / self.l2_box.detach())return self._scaled_loss(dist * self.iou)

 这个其实就是修改了loss.py中的BboxLoss,在本段代码的第十二行,将type改成了WIoU

class BboxLoss(nn.Module):def __init__(self, reg_max, use_dfl=False):"""Initialize the BboxLoss module with regularization maximum and DFL settings."""super().__init__()self.reg_max = reg_maxself.use_dfl = use_dfldef forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):"""IoU loss."""weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)loss,iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False,type_='WIoU')loss_iou=loss.sum()/target_scores_sum# DFL lossif self.use_dfl:target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weightloss_dfl = loss_dfl.sum() / target_scores_sumelse:loss_dfl = torch.tensor(0.0).to(pred_dist.device)return loss_iou, loss_dfl
YAML文件
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 9  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, GAM_Attention, [512,512]]- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 16 (P3/8-small)- [-1, 1, GAM_Attention, [256,256]]- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 20 (P4/16-medium)- [-1, 1, GAM_Attention, [512,512]]- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 24 (P5/32-large)- [-1, 1, GAM_Attention, [1024,1024]]- [[17, 21, 25], 1, Detect, [nc]]  # Detect(P3, P4, P5)

在head部分,可以将GAM_attention改成不同的注意力机制,来改变网络结构,从而提升目标检测 的精度

完整代码

链接: https://pan.baidu.com/s/1IDnEZxpcaEgBowlTxX2iNA?pwd=vdrs 提取码: vdrs 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/269795.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微信小程序开发系列(十七)·事件传参·mark-自定义数据

目录 步骤一:按钮的创建 步骤二:按钮属性配置 步骤三:添加点击事件 步骤四:参数传递 步骤五:打印数据 步骤六:获取数据 步骤七:父进程验证 总结:data-*自定义数据和mark-自定…

什么是物联网?物联网如何工作?

物联网到底是什么? 物联网(Internet of Things,IoT)的概念最早于1999年被提出,官方解释为“万物相连的互联网”,是在互联网基础上延伸和扩展,将各种信息传感设备与网络结合起来而形成的一个巨大网络,可以实…

抖音视频评论批量采集软件|视频下载工具

《轻松搞定!视频评论批量采集软件,助您高效工作》 在短视频这个充满活力和创意的平台上,了解用户评论是了解市场和观众心声的重要途径之一。为了帮助您快速获取大量视频评论数据,我们推出了一款操作便捷、功能强大的软件&#xff…

gitlab的安装

1、下载rpm 安装包 (1)直接命令下载 wget https://mirrors.tuna.tsinghua.edu.cn/gitlab-ce/yum/el7/gitlab-ce-11.6.10-ce.0.el7.x86_64.rpm(2)直接去服务器上下载包 Index of /gitlab-ce/yum/el7/ | 清华大学开源软件镜像站 | Tsinghua Open Source…

【组合递归】【StringBuilder】Leetcode 17. 电话号码的字母组合

【组合递归】【StringBuilde】Leetcode 17. 电话号码的字母组合 StringBulider常用方法!!!!!!!!!!!!!!17. 电…

[计算机网络]--五种IO模型和select

前言 作者:小蜗牛向前冲 名言:我可以接受失败,但我不能接受放弃 如果觉的博主的文章还不错的话,还请点赞,收藏,关注👀支持博主。如果发现有问题的地方欢迎❀大家在评论区指正 目录 一、五种IO…

如何使用 ArcGIS Pro 制作三维地形图

伴随硬件性能的提高和软件算法的优化,三维地图的应用场景会越来越多,这里为大家介绍一下在ArcGIS Pro怎么制作三维地形图,希望能对你有所帮助。 数据来源 教程所使用的数据是从水经微图中下载的DEM和影像数据,除了DEM和影像数据…

华为Web举例:私网用户通过三元组NAT访问Internet

Web举例:私网用户通过三元组NAT访问Internet 介绍私网用户通过三元组NAT访问Internet的配置举例。 组网需求 某公司在网络边界处部署了FW作为安全网关。为了使私网中10.1.1.0/24网段的用户可以正常访问Internet,需要在FW上配置源NAT策略。除了公网接口…

Go 与 Rust:导航编程语言景观

在当今构建软件时,开发者在编程语言上有着丰富的选择。两种脱颖而出的语言是 Go 和 Rust - 都很强大但却截然不同。本文将从各种因素比较这两种语言,以帮助您确定哪种更适合您的需求。 我们将权衡它们在并发、安全性、速度、互操作性等方面的方法。我们将…

深度学习预测分析API:金融领域的Game Changer

🚀 引言 在这个AI遍地开花的时代,谁能成为金融领域的真正Game Changer?那必然是是深度学习预测分析API。如大脑般高效运转的系统不仅颠覆了传统操作,更是以无与伦比的速度和精度赋予了金融数据以全新的生命。 💼 广泛…

ARM系统控制和管理接口System Control and Management Interface

本文档描述了一个可扩展的独立于操作系统的软件接口,用于执行各种系统控制和管理任务,包括电源和性能管理。 本文档描述了系统控制和管理接口(SCMI),它是一组操作系统无关的软件接口,用于系统管理。SCMI 是…

Docker之自定义镜像上传阿里云

目录 一、制作jdk镜像 1. alpine Linux简介 2. 通过alpine进行制作镜像 1. 制作jdk2.0 2. 制作jdk3.0 二、镜像上传阿里云及下载 1. 前期准备 2. push (推) 镜像 一、制作jdk镜像 1. alpine Linux简介 Alpine Linux是一个轻量级的Linux发行版,专注于安全、…

硬盘删除的文件如何恢复?分享硬盘数据恢复方法

随着信息时代的飞速发展,硬盘作为我们储存数据的主要场所其重要性日益凸显。但硬盘数据的丢失或误删也成为了许多用户头疼的问题。当您发现重要的文件从硬盘中消失时不必过于焦虑。本文将为您介绍五种高效且原创的数据恢复策略,帮助您找回那些珍贵的文件…

PRewrite: Prompt Rewriting with Reinforcement Learning

PRewrite: Prompt Rewriting with Reinforcement Learning 基本信息 2024-01谷歌团队提交到arXiv 博客贡献人 徐宁 作者 Weize Kong,Spurthi Amba Hombaiah,Mingyang Zhang 摘要 工程化的启发式编写对于LLM(大型语言模型)应…

基于SSM的学科竞赛管理系统。Javaee项目。ssm项目。

演示视频: 基于SSM的学科竞赛管理系统。Javaee项目。ssm项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结构,通过Spring SpringMvcMybatisVueLayuiElemen…

持续更新 | 与您分享 Flutter 2024 年路线图

作者 / Michael Thomsen Flutter 是一个拥有繁荣社区的开源项目,我们致力于确保我们的计划公开透明,并将毫无隐瞒地分享从问题到设计规范的所有内容。我们了解到许多开发者对 Flutter 的功能路线图很感兴趣。我们往往会在一年中不断更改并调整这些计划&a…

蓝桥杯练习系统(算法训练)ALGO-986 藏匿的刺客

资源限制 内存限制:256.0MB C/C时间限制:1.0s Java时间限制:3.0s Python时间限制:5.0s 问题描述 强大的kAc建立了强大的帝国,但人民深受其学霸及23文化的压迫,于是勇敢的鹏决心反抗。   kAc帝国防…

alpine创建lnmp环境alpine安装nginx+php5.6+mysql

前言 制作lnmp环境,你可以在alpine基础镜像中安装相关的服务,也可以直接使用Dockerfile创建自己需要的环境镜像。 注意:提前确认自己的alpine版本,本次创建基于alpine3.6进行创建,官方在一些版本中删除了php5 1、拉取…

HTML实体字符列表,知识点详解

css盒模型 1,css盒模型基本概念? 2,标准模型和IE模型的区别:计算高度和宽度的不同,怎么不同,高度宽度是怎么计算的? 3,css如何设置这两种模型? 4,js如何设置…

EPSON RA8000CE (RTC模块)压电侠

RA8000CE是一个集成了32.768 kHz数字温度补偿晶体振荡器(DTCXO)的RTC模块。它包括各种功能,如具有闰年校正的秒到年时钟/日历,时间警报,唤醒计时器,时间更新中断,时钟输出和时间戳功能,可以在外部或内部事件…