YOLOv8改进 - 注意力篇 - 引入SCAM注意力机制

一、本文介绍

作为入门性篇章,这里介绍了SCAM注意力在YOLOv8中的使用。包含SCAM原理分析,SCAM的代码、SCAM的使用方法、以及添加以后的yaml文件及运行记录。

二、SCAM原理分析

SCAM官方论文地址:SCAM文章

SCAM官方代码地址:SCAM代码

SCAM注意力机制(空间上下文感知模块):

空间上下文感知模块(SCAM)在FEM和FFM之后,特征映射已经考虑了局部上下文信息,并且能够很好地表示小对象特征。在此阶段对小目标和背景之间的全局关系进行建模比在主干阶段更有效。利用全局上下文信息来表示像素之间的跨空间关系,可以抑制无用背景,增强目标和背景的区分能力。受GCNet和SCP的启发,SCAM由三个分支组成。第一个部分使用GAP和GMP整合全球信息。第二个分支使用1 × 1卷积生成特征映射的线性变换结果,该特征映射在图4中称为value。第三个分支使用1 × 1卷积来简化查询和键的倍数。这个卷积在图4中称为QK。随后,将第一分支和第三分支分别与第二分支矩阵相乘。得到的两个分支分别表示跨通道和空间的上下文信息。最后,利用广播Hadamard积在这两个分支上得到了SCAM的输出。

相关代码:

SCAM注意力的代码,如下。

class SCAM(nn.Module):def __init__(self, in_channels, reduction=1):super(SCAM, self).__init__()self.in_channels = in_channelsself.inter_channels = in_channelsself.k = Conv(in_channels, 1, 1, 1)self.v = Conv(in_channels, self.inter_channels, 1, 1)self.m = Conv_withoutBN(self.inter_channels, in_channels, 1, 1)self.m2 = Conv(2, 1, 1, 1)self.avg_pool = nn.AdaptiveAvgPool2d(1)  # GAPself.max_pool = nn.AdaptiveMaxPool2d(1)  # GMPdef forward(self, x):n, c, h, w = x.size(0), x.size(1), x.size(2), x.size(3)# avg max: [N, C, 1, 1]avg = self.avg_pool(x).softmax(1).view(n, 1, 1, c)max = self.max_pool(x).softmax(1).view(n, 1, 1, c)# k: [N, 1, HW, 1]k = self.k(x).view(n, 1, -1, 1).softmax(2)# v: [N, 1, C, HW]v = self.v(x).view(n, 1, c, -1)# y: [N, C, 1, 1]y = torch.matmul(v, k).view(n, c, 1, 1)# y2:[N, 1, H, W]y_avg = torch.matmul(avg, v).view(n, 1, h, w)y_max = torch.matmul(max, v).view(n, 1, h, w)# y_cat:[N, 2, H, W]y_cat = torch.cat((y_avg, y_max), 1)y = self.m(y) * self.m2(y_cat).sigmoid()return x + y

四、YOLOv8中SCAM使用方法

1.YOLOv8中添加SCAM模块:

首先在ultralytics/nn/modules/conv.py最后添加SCAM模块的代码。

2.在conv.py的开头__all__ = 内添加SCAM模块的类别名:

3.在同级文件夹下的__init__.py内添加SCAM的相关内容:(分别是from .conv import SCAM ;以及在__all__内添加SCAM)

4.在ultralytics/nn/tasks.py进行LSKA注意力机制的注册,以及在YOLOv8的yaml配置文件中添加SCAM即可。

首先打开task.py文件,按住Ctrl+F,输入parse_model进行搜索。找到parse_model函数。在其最后一个else前面添加以下注册代码:

        elif m is SCAM:c2 = ch[f]args = [c2]

然后,就是新建一个名为YOLOv8_SCAM.yaml的配置文件:(路径:ultralytics/cfg/models/v8/YOLOv8_SCAM.yaml)

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call CPAM-yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SCAM, [1024]]#11代表卷积核大小,可以填写7、11、23、35、41、53- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 13], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 10], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 21 (P5/32-large)- [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

其中参数中nc,由自己的数据集决定。本文测试,采用的coco8数据集,有80个类别。

在根目录新建一个train.py文件,内容如下

from ultralytics import YOLOwith warnings.catch_warnings():warnings.simplefilter("ignore")
# 加载一个模型model = YOLO('ultralytics/cfg/models/v8/YOLOv8_SCAM.yaml')  # 从YAML建立一个新模型
# 训练模型results = model.train(data='ultralytics/cfg/datasets/coco8.yaml', epochs=1,imgsz=640,optimizer="SGD")

训练输出:​

​​

五、总结

以上就是SCAM的原理及使用方式,但具体SCAM注意力机制的具体位置放哪里,效果更好。需要根据不同的数据集做相应的实验验证。希望本文能够帮助你入门YOLO中注意力机制的使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/435899.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Google Protocol Buffers快速入门指南

声明:未经作者允许,禁止转载。 概念 Portocol Buffer是谷歌提出来的一种序列化结构数据的机制,它的可扩展性特别强,支持C、C#、Java、Go和Python等主流编程语言。使用Portocol Buffer时,仅需要定义好数据的结构化方式…

Mysql梳理10——使用SQL99实现7中JOIN操作

10 使用SQL99实现7中JOIN操作 10.1 使用SQL99实现7中JOIN操作 本案例的数据库文件分享: 通过百度网盘分享的文件:atguigudb.sql 链接:https://pan.baidu.com/s/1iEAJIl0ne3Y07kHd8diMag?pwd2233 提取码:2233 # 正中图 SEL…

每日OJ题_牛客_添加逗号_模拟_C++_Java

目录 牛客_添加逗号_模拟 题目解析 C代码1 C代码2 Java代码 牛客_添加逗号_模拟 添加逗号_牛客题霸_牛客网 题目解析 读取输入:读取一行字符串。分割字符串:使用空格将字符串分割成单词数组。拼接字符串:将单词数组中的每个单词用逗号…

Oracle控制文件全部丢失如何使用RMAN智能恢复?

1.手动删除所有控制文件模拟故障产生 2.此时启动数据库发现控制文件丢失 3.登录rman 4.列出故障 list failure; 5.让RMAN列举恢复建议 advise failure; 6.使用RMAN智能修复 repair failure;

Java常用三类定时器快速入手指南

文章目录 Java常用三类定时器快速入手指南一、序言二,Timer相关1、概念2、Timer类3、TimerTask类4、ScheduleExecutorService接口 三,Scheduled相关1、配置1.1 SpringMVC配置1.2 SpringBoot配置(1)单线程(2&#xff09…

cpp,git,unity学习

c#中的? 1. 空值类型(Nullable Types) ? 可以用于值类型(例如 int、bool 等),使它们可以接受 null。通常,值类型不能为 null,但是通过 ? 可以表示它们是可空的。 int? number null; // …

桥接(桥梁)模式

简介 桥接模式(Bridge Pattern)又叫作桥梁模式、接口(Interface)模式或柄体(Handle and Body)模式,指将抽象部分与具体实现部分分离,使它们都可以独立地变化,属于结构型…

【AAOS】CarService -- Android汽车服务

概述 Android Automative OS理解为Android OS + Android Automative Service,而CarService就是提供汽车相关功能的最主要模块。 CarService与Android OS的关系:CarService运行于独立的进程中,其作为原有Android服务的补充,在汽车设备上运行。CarService在整体车载通信中起…

JAVA线程基础二——锁的概述之乐观锁与悲观锁

乐观锁与悲观锁 乐观锁和悲观锁是在数据库中引入的名词,但是在并发包锁里面也引入了类似的思想,所以这里还是有必要讲解下。 悲观锁指对数据被外界修改持保守态度,认为数据很容易就会被其他线程修改,所以在数据被处理前先对数据进行加锁&…

python 如何引用变量

在字符串中引入变量有三种方法: 1、 连字符 name zhangsan print(my name is name) 结果为 my name is zhangsan 2、% 字符 name zhangsan age 25 price 4500.225 print(my name is %s%(name)) print(i am %d%(age) years old) print(my price is %f%(pric…

项目管理专业资质认证ICB 3中关于项目经理素质的标准

项目管理专业资质认证ICB 3中关于项目经理素质的标准,的确很全面,下面摘录之:

15年408计算机网络

第一题: 解析: 接收方使用POP3向邮件服务器读取邮件,使用的TCP连接,TCP向上层提供的是面向连接的,可靠的数据传输服务。 第二题: 解析:物理层-不归零编码和曼彻斯特编码 编码1:电平在…

LabVIEW项目编码器选择

在LabVIEW项目中,选择增量式(Incremental Encoder)和绝对式(Absolute Encoder)编码器取决于项目的具体需求。增量式编码器和绝对式编码器在工作原理、应用场景、精度和成本等方面存在显著差异。以下从多方面详细阐述两…

又一年国庆至,“打工人”在欢呼,OTA们在雀跃

国庆“黄金周”倒计时最后一天,旅游出行即将迎来新一轮高峰。 安信国际指出,国庆期间,出游人次的增长确定性高于人均消费的增长。预计国内旅游收入7,000-7,500亿元,较2019年同期增8%-15%;预计国内旅游人次8.5-9.0亿人…

今日指数项目A股大盘数据采集

1、A股大盘数据采集 1.1 A股大盘数据采集准备 1.1.1 配置ID生成器bean A股大盘数据采集入库时,主键ID保证唯一,所以在stock_job工程配置ID生成器: Configuration public class CommonConfig {/*** 配置基于雪花算法生成全局唯一id* 参与…

springboot+大数据+基于协同过滤算法的校园食堂订餐系统【内含源码+文档+部署教程】

博主介绍:✌全网粉丝10W,前互联网大厂软件研发、集结硕博英豪成立工作室。专注于计算机相关专业毕业设计项目实战6年之久,选择我们就是选择放心、选择安心毕业✌ 🍅由于篇幅限制,想要获取完整文章或者源码,或者代做&am…

Git常用方法——详解

一、下载安装git git官网: Git - Downloads (git-scm.com) 下载安装Git(超详细超简单)_git下载-CSDN博客 二、克隆下载至本地 1、复制HTTPS链接 在gitee或者gitLab或者gitHub上复制HTTPS链接 2、打开Open Git Bash here 在本地想要新建文…

闯关训练一:Linux基础

闯关任务:完成SSH连接与端口映射并运行hello_world.py 1.创建开发机 2.SSH连接 3. VS-Code 连接 选择 Linux 平台 ,输入密码 ,选择进入文件夹 4.端口映射 按照下文安装Docs pip install gradio 运行server.py import gradio as grdef …

工业制造场景中的设备管理深度解析

在工业制造的广阔领域中,设备管理涵盖多个关键方面,对企业的高效生产和稳定运营起着举足轻重的作用。 一、设备运行管理 1.设备状态监测 实时监控设备的运行状态是确保生产顺利进行的重要环节。通过传感器和数据采集系统等先进技术,获取设备…

新书速览|Stable Diffusion-ComfyUI AI绘画工作流解析

《Stable Diffusion-ComfyUI AI绘画工作流解析》 本书内容 《Stable Diffusion-ComfyUI AI绘画工作流解析》从零开始,详尽系统地讲解从本地部署ComfyUI、下载安装自定义节点,到搭建各种工作流程的全过程。同时,辅以3D形象转绘、艺术二维码和证…