YOLOv5、YOLOv8改进:S2注意力机制

目录

1.简介

2.YOLOv5改进

2.1增加以下S2-MLPv2.yaml文件

2.2common.py配置

2.3yolo.py配置


1.简介

 S2-MLPv2注意力机制

请添加图片描述

 

最近,出现了基于 MLP 的视觉主干。与 CNN 和视觉Transformer相比,基于 MLP 的视觉架构具有较少的归纳偏差,在图像识别方面实现了有竞争力的性能。其中,spatial-shift MLP (S2-MLP),采用直接的空间位移操作,取得了比包括 MLP-mixer 和 ResMLP 在内的开创性工作更好的性能。使用具有金字塔结构的较小补丁,视觉置换器 (ViP) 和Global Filter Network (GFNet) 实现了比 S2-MLP 更好的性能。

在本文中,我们改进了 S2-MLP 视觉主干。我们沿通道维度扩展特征图,并将扩展后的特征图分成几个部分。我们对分割部分进行不同的空间移位操作。

本文对空间移位MLP (S2-MLP)模型进行了改进,提出了S2-MLPv2模型。将feature map进行扩展,并将扩展后的feature map分为三部分。它将每个部分单独移动,然后通过split-attention融合分开的特征图。同时,我们利用层次金字塔来提高其建模细粒度细节的能力,以获得更高的识别精度。在没有外部训练数据集的情况下,采用224×224的images,我们的s2-mlv2-medium模型在ImageNet1K数据集上取得了83.6%的top-1准确率,这是目前基于MLP的方法中最先进的性能。同时,与基于transformer的方法相比,我们的S2-MLPv2模型在不需要自我注意的情况下,参数更少,达到了相当的精度。

        与基于MLP的先驱作品如MLP-Mixer、ResMLP以及最近类似MLP的模型如Vision Permutator和GFNet相比,空间移位MLP的另一个重要优势是,空间移位MLP的形状对图像的输入尺度是不变的。因此,经过特定尺度图像预训练的空间移位MLP模型可以很好地应用于具有不同尺寸输入图像的下游任务。未来的工作将致力于不断提高空间移位MLP体系结构的图像识别精度。一个有希望且直接的方向是尝试更小尺寸的patch和更高级的四层金字塔,如CycleMLP和AS-MLP,以进一步减少FLOPs和缩短基于transformer模型之间的识别差距。
 

 S2注意力机制(S2 Attention Mechanism)是一种用于序列建模和注意力机制改进的方法,特别在自然语言处理(NLP)领域中得到广泛应用。它是对传统的自注意力机制(self-attention)进行改进,旨在提高序列中不同位置之间的关联性建模能力。

  1. 自注意力机制回顾: 自注意力机制是一种用于处理序列数据的方法,最早在Transformer模型中提出并广泛用于NLP任务中。在自注意力机制中,序列中的每个位置都可以与其他所有位置进行交互,以便捕获位置之间的关系。然而,这种全局的交互可能会导致计算复杂度的增加,并且可能过于强调距离较近的位置。

  2. S2注意力机制的改进: S2注意力机制引入了一种分段结构,将序列分为不同的段(segments)。每个段内的位置之间可以进行交互,但不同段之间的交互被限制。这种分段结构在捕获长距离依赖关系时更加高效,因为不同段之间的关联性通常较弱。

  3. 注意力计算: 在S2注意力中,注意力权重的计算仍然涉及对查询(query)、键(key)和值(value)的操作。不同之处在于,每个段的注意力计算是独立的,而不同段之间的注意力权重设为固定值(通常为0)。

  4. 优势与应用: S2注意力机制的主要优势是在捕获序列中的长距离依赖关系时表现更加高效。这在处理长文本或长序列时特别有用,可以减少计算成本,同时提高建模性能。S2注意力机制在机器翻译、文本生成、命名实体识别等NLP任务中都有应用,以更好地处理长文本的关系建模。

2.YOLOv5改进

2.1增加以下S2-MLPv2.yaml文件

# parameters
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, C3, [128]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, SPPF, [1024, 5]],  # 9]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 6], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 4], 1, Concat, [1]],  # cat backbone P3[-1, 3, C3, [256, False]],  # 17 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 14], 1, Concat, [1]],  # cat head P4[-1, 3, C3, [512, False]],  # 20 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 10], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [1024, False]],  # 23 (P5/32-large)[-1, 1, S2Attention, [1024]], #修改[[17, 20, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

2.2common.py配置

./models/common.py文件增加以下模块

import numpy as np
import torch
from torch import nn
from torch.nn import init# https://arxiv.org/abs/2108.01072
def spatial_shift1(x):b,w,h,c = x.size()x[:,1:,:,:c//4] = x[:,:w-1,:,:c//4]x[:,:w-1,:,c//4:c//2] = x[:,1:,:,c//4:c//2]x[:,:,1:,c//2:c*3//4] = x[:,:,:h-1,c//2:c*3//4]x[:,:,:h-1,3*c//4:] = x[:,:,1:,3*c//4:]return xdef spatial_shift2(x):b,w,h,c = x.size()x[:,:,1:,:c//4] = x[:,:,:h-1,:c//4]x[:,:,:h-1,c//4:c//2] = x[:,:,1:,c//4:c//2]x[:,1:,:,c//2:c*3//4] = x[:,:w-1,:,c//2:c*3//4]x[:,:w-1,:,3*c//4:] = x[:,1:,:,3*c//4:]return xclass SplitAttention(nn.Module):def __init__(self,channel=512,k=3):super().__init__()self.channel=channelself.k=kself.mlp1=nn.Linear(channel,channel,bias=False)self.gelu=nn.GELU()self.mlp2=nn.Linear(channel,channel*k,bias=False)self.softmax=nn.Softmax(1)def forward(self,x_all):b,k,h,w,c=x_all.shapex_all=x_all.reshape(b,k,-1,c) a=torch.sum(torch.sum(x_all,1),1) hat_a=self.mlp2(self.gelu(self.mlp1(a))) hat_a=hat_a.reshape(b,self.k,c) bar_a=self.softmax(hat_a) attention=bar_a.unsqueeze(-2) out=attention*x_all out=torch.sum(out,1).reshape(b,h,w,c)return outclass S2Attention(nn.Module):def __init__(self, channels=512 ):super().__init__()self.mlp1 = nn.Linear(channels,channels*3)self.mlp2 = nn.Linear(channels,channels)self.split_attention = SplitAttention()def forward(self, x):b,c,w,h = x.size()x=x.permute(0,2,3,1)x = self.mlp1(x)x1 = spatial_shift1(x[:,:,:,:c])x2 = spatial_shift2(x[:,:,:,c:c*2])x3 = x[:,:,:,c*2:]x_all=torch.stack([x1,x2,x3],1)a = self.split_attention(x_all)x = self.mlp2(a)x=x.permute(0,3,1,2)return x

2.3yolo.py配置

在 models/yolo.py文件夹下

定位到parse_model函数中,新增以下代码

elif m is S2Attention:c1, c2 = ch[f], args[0]if c2 != no:c2 = make_divisible(c2 * gw, 8)

以上就修改完成了

又遇到的问题欢迎评论区留言讨论

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/97954.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

中国剩余定理及扩展

目录 中国剩余定理解释 中国剩余定理扩展——求解模数不互质情况下的线性方程组: 代码实现: 互质: 非互质: 中国剩余定理解释 在《孙子算经》中有这样一个问题:“今有物不知其数,三三数之剩二&#x…

go es实例

go es实例 1、下载第三方库 go get github.com/olivere/elastic下载过程中出现如下报错: 解决方案: 2、示例 import package mainimport ("context""encoding/json""fmt""reflect""time""…

【前端】快速掌握HTML+CSS核心知识点

文章目录 1.HTML核心基础知识1.1.编写第一个HTML网页1.2.超链接a标签和路径1.3.图像img标签的用法1.4.表格table标签用法1.5.列表ul、ol、dl标签用法1.6.表单form标签用法1.7.区块标签和行内标签用法 2.CSS核心基础知识2.1.CSS标签选择器viewport布局2.2.CSS样式的几种写法2.3.…

【Linux取经路】解析环境变量,提升系统控制力

文章目录 一、进程优先级1.1 什么是优先级?1.2 为什么会有优先级?1.3 小结 二、Linux系统中的优先级2.1 查看进程优先级2.2 PRI and NI2.3 修改进程优先级2.4 进程优先级的实现原理2.5 一些名词解释 三、环境变量3.1 基本概念3.2 PATH:Linux系…

k8s 常见面试题

前段时间在这个视频中分享了 https://github.com/bregman-arie/devops-exercises 这个知识仓库。 这次继续分享里面的内容,本次主要以 k8s 相关的问题为主。 k8s 是什么,为什么企业选择使用它 k8s 是一个开源应用,给用户提供了管理、部署、扩…

Learning to Super-resolve Dynamic Scenes for Neuromorphic Spike Camera论文笔记

摘要 脉冲相机使用了“integrate and fire”机制来生成连续的脉冲流,以极高的时间分辨率来记录动态光照强度。但是极高的时间分辨率导致了受限的空间分辨率,致使重建出的图像无法很好保留原始场景的细节。为了解决这个问题,这篇文章提出了Sp…

idea2023 springboot2.7.5+mybatisplus3.5.2+jsp 初学单表增删改查

创建项目 修改pom.xml 为2.7.5 引入mybatisplus 2.1 修改pom.xml <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><version>3.5.2</version></dependency><!--mysq…

【STM32 学习】电源解析(VCC、VDD、VREF+、VBAT)

VCC电源电压GND电源供电负电压&#xff08;通常接地&#xff09;VDD模块工作正电压VSS模块工作负电压VREFADC参考正电压VREF-ADC参考负电压VBAT电池或其他电源供电VDDA模拟供电正电压VSSA模拟供电负电压 一、VCC&#xff08;供电电压&#xff09; VCC是指芯片的电源电压&#…

MNIST手写数字数据集+7000张图片下载

MNIST手写数字图像数据集是一个经典的用于图像分类任务的数据集&#xff0c;其中包含了大量的手写数字图像样本 数据集点击下载&#xff1a; MNIST手写数字数据集7000张图片.rar

函数栈帧理解

本文是从汇编角度来展示的函数调用&#xff0c;而且是在vs2013下根据调试展开的探究&#xff0c;其它平台在一些指令上会有点不同&#xff0c;指令不多&#xff0c;简单记忆一下即可&#xff0c;在我前些年的学习中&#xff0c;学的这几句汇编指令对我调试找错误起了不小的作用…

【令牌桶算法与漏桶算法】

&#x1f4a7; 令牌桶算法与漏桶算法 \color{#FF1493}{令牌桶算法与漏桶算法} 令牌桶算法与漏桶算法&#x1f4a7; &#x1f337; 仰望天空&#xff0c;妳我亦是行人.✨ &#x1f984; 个人主页——微风撞见云的博客&#x1f390; &#x1f433; 《数据结构与算法》专…

【前端|JS实战第1篇】使用JS来实现属于自己的贪吃蛇游戏!

前言 贪吃蛇游戏是经典的小游戏&#xff0c;也是学习前端JS的一个很好的练习项目。在本教程中&#xff0c;我们将使用 JavaScript 来逐步构建一个贪吃蛇游戏。我们会从创建游戏区域开始&#xff0c;逐步添加蛇的移动、食物的生成以及游戏逻辑等功能。 &#x1f680; 作者简介&a…

韦东山-电子量产工具项目:业务系统

代码结构 所有代码都已通过测试跑通&#xff0c;其中代码结构如下&#xff1a; 一、include文件夹 1.1 common.h #ifndef _COMMON_H #define _COMMON_Htypedef struct Region {int iLeftUpX; //区域左上方的坐标int iLeftUpY; //区域左下方的坐标int iWidth; //区域宽…

java八股文面试[java基础]——String StringBuilder StringBuffer

String类型定义&#xff1a; final String 不可以继承 final char [] 不可以修改 String不可变的好处&#xff1a; hash值只需要算一次&#xff0c;当String作为map的key时&#xff0c; 不需要考虑hash改变 天然的线程安全 知识来源&#xff1a; 【基础】String、StringB…

Python web实战之细说 Django 的单元测试

关键词&#xff1a; Python Web 开发、Django、单元测试、测试驱动开发、TDD、测试框架、持续集成、自动化测试 大家好&#xff0c;今天&#xff0c;我将带领大家进入 Python Web 开发的新世界&#xff0c;深入探讨 Django 的单元测试。通过本文的实战案例和详细讲解&#xff…

ubuntu安装Microsoft Edge并设置为中文

1、下载 edge.deb 版本并安装 sudo dpkg -i microsoft-edg.deb 2. 设置默认中文显示 如果是通过.deb方式安装的&#xff1a; 打开默认安装路径下的microsoft-edge-dev文件&#xff0c;在文件最开头加上: export LANGUAGEZH-CN.UTF-8 &#xff0c;保存退出。 cd /opt/micr…

PHP8的字符串操作3-PHP8知识详解

今天继续分享字符串的操作&#xff0c;前面说到了字符串的去除空格和特殊字符&#xff0c;获取字符串的长度&#xff0c;截取字符串、检索字符串。 今天继续分享字符串的其他操作。如&#xff1a;替换字符串、分割和合成字符串。 5、替换字符串 替换字符串就是对指定字符串中…

【算法系列篇】滑动窗口

文章目录 前言什么是滑动窗口1.长度最小的子数组1.1 题目要求1.2 做题思路 1.3 Java代码实现2.无重复字符的最长子串2.1 题目要求2.2 做题思路2.3 Java代码实现 3.最大连续1的个数 III3.1 题目要求3.2 做题思路3.3 Java代码实现 4.将x减到0的最小操作数4.1 题目要求4.2 做题思路…

【仿写框架之仿写Tomact】四、封装HttpRequest对象(属性映射http请求报文)、HttpResponse对象(属性映射http响应报文)

文章目录 1、创建HttpRequest对象2、创建HttpResponse对象 1、创建HttpRequest对象 HttpRequest对象中的属性与HTTP协议中的内容对应&#xff0c;用于后序servlet从request中获取请求中的参数。 参照http请求报文&#xff1a; import java.io.BufferedReader; import java…

配置使用Gitee账号认证登录Grafana

三方社会化身份源 集成gitee第三方登录 第三方登录的原理 所谓第三方登录&#xff0c;实质就是 OAuth 授权。用户想要登录 A 网站&#xff0c;A 网站让用户提供第三方网站的数据&#xff0c;证明自己的身份。获取第三方网站的身份数据&#xff0c;就需要 OAuth 授权。 举例来…