Mamba-yolo|结合Mamba注意力机制的视觉检测

一、本文介绍

 PDF地址:https://arxiv.org/pdf/2405.16605v1

代码地址:GitHub - LeapLabTHU/MLLA: Official repository of MLLA

Demystify Mamba in Vision: A Linear AttentionPerspective一文中引入Baseline Mamba,指明Mamba在处理各种高分辨率图像的视觉任务有着很好的效率。发现了强大的Mamba和线性注意力Transformer( linear attention Transformer)非常相似,然后就分析了两者之间的异同。将Mamba模型重述为linear attention Transformer的变体,并且主要有六大差异,分别是:input gate, forget gate,shortcut, no attention normalization, single-head, and modified block design。作者对每个设计都细致的分析了优缺点,评估了性能,最终发现forget gate和block design是Mamba这么给力的主要贡献点。基于以上发现,作者提出了一个类似mamba的线性注意力模型,Mamba-Like Linear Attention (MLLA) ,相当于取其精华,去其糟粕,把mamba两个最为关键的优点设计结合到线性注意力模型当中,具有可并行计算和快速推理的特点。本文将结合YOlOV8检测模型通过添加MLLA模块提升检测精度。

二、宏观架构设计

线性注意 Transformer 模型通常采用图 (a) 中的设计,它由线性注意力模块和 MLP 模块组成。相比之下,Mamba 通过结合 H3和 Gated Attention这两个设计来改进,得到如图 (b) 所示的架构。改进的 Mamba Block 集成了多种操作,例如选择性 SSM、深度卷积、线性映射、激活函数、门控机制等,并且往往比传统的 Transformer 设计更有效。

MLLA (Mamba-Like Linear Attention)的则是通过将Mamba模型的一些核心设计融入线性注意力机制,从而提升模型的性能。具体来说,MLLA主要整合了Mamba中的"忘记门”(forget gate9)和模块设计(block design)这两个关键因素,这些因素被认为是Mamba成功的主要原因。
以下是对MLLA原理的详细分析:
1.忘记门(Forget Gate)
1.忘记门提供了局部偏差和位置信息。所有的忘记门元素严格限制在0到1之间,这意味着模型在接收到当前输入后会持续衰减失前的隐藏状态。这种特性确保了模型对输入序列的顺序敏感。
2.忘记门的局部偏差和位置信息对于图像处理任务来说非常重要,尽管引入忘记门会导致计算需要采用递归的形式,从而降低并行计算的效率。
2.模块设计(Block Design)
1.Mamba的模块设计在保持相似的浮点运算次数(FLOPS)的同时,通过替换注意力子模块为线性注意力来提升性能。结果表明,采用这种模块设计能够显著提高模型的表现。
3.线性注意力的改进:
1.线性注意力被重新设计以整合忘记门和模块设计,这种改进后的模型被称为MLLA。实验结果显示,MLLA在图像分类和高分辨率密集预测任务中均优于各种视觉Mamba模型
4.并行计算和快速推理速度:
1.MLLA通过使用位置编码(ROPE)来替代忘记门,从而在保持并行计算和快速推理速度的同时,提供必要的位置信息。这使得MLLA在处理非自回归的视觉任务时更加有效

结合yolov8改进

核心代码
 

import torch
import torch.nn as nn__all__ = ['MLLAttention']class Mlp(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass ConvLayer(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, dilation=1, groups=1,bias=True, dropout=0, norm=nn.BatchNorm2d, act_func=nn.ReLU):super(ConvLayer, self).__init__()self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else Noneself.conv = nn.Conv2d(in_channels,out_channels,kernel_size=(kernel_size, kernel_size),stride=(stride, stride),padding=(padding, padding),dilation=(dilation, dilation),groups=groups,bias=bias,)self.norm = norm(num_features=out_channels) if norm else Noneself.act = act_func() if act_func else Nonedef forward(self, x: torch.Tensor) -> torch.Tensor:if self.dropout is not None:x = self.dropout(x)x = self.conv(x)if self.norm:x = self.norm(x)if self.act:x = self.act(x)return xclass RoPE(torch.nn.Module):r"""Rotary Positional Embedding."""def __init__(self, base=10000):super(RoPE, self).__init__()self.base = basedef generate_rotations(self, x):# 获取输入张量的形状*channel_dims, feature_dim = x.shape[1:-1][0], x.shape[-1]k_max = feature_dim // (2 * len(channel_dims))assert feature_dim % k_max == 0, "Feature dimension must be divisible by 2 * k_max"# 生成角度theta_ks = 1 / (self.base ** (torch.arange(k_max, dtype=x.dtype, device=x.device) / k_max))angles = torch.cat([t.unsqueeze(-1) * theta_ks for t intorch.meshgrid([torch.arange(d, dtype=x.dtype, device=x.device) for d in channel_dims],indexing='ij')], dim=-1)# 计算旋转矩阵的实部和虚部rotations_re = torch.cos(angles).unsqueeze(dim=-1)rotations_im = torch.sin(angles).unsqueeze(dim=-1)rotations = torch.cat([rotations_re, rotations_im], dim=-1)return rotationsdef forward(self, x):# 生成旋转矩阵rotations = self.generate_rotations(x)# 将 x 转换为复数形式x_complex = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2))# 应用旋转矩阵pe_x = torch.view_as_complex(rotations) * x_complex# 将结果转换回实数形式并展平最后两个维度return torch.view_as_real(pe_x).flatten(-2)class MLLAttention(nn.Module):r""" Linear Attention with LePE and RoPE.Args:dim (int): Number of input channels.num_heads (int): Number of attention heads.qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True"""def __init__(self, dim=3, input_resolution=[160, 160], num_heads=4, qkv_bias=True, **kwargs):super().__init__()self.dim = dimself.input_resolution = input_resolutionself.num_heads = num_headsself.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)self.elu = nn.ELU()self.lepe = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)self.rope = RoPE()def forward(self, x):"""Args:x: input features with shape of (B, N, C)"""x = x.reshape((x.size(0), x.size(2) * x.size(3), x.size(1)))b, n, c = x.shapeh = int(n ** 0.5)w = int(n ** 0.5)# self.rope = RoPE(shape=(h, w, self.dim))num_heads = self.num_headshead_dim = c // num_headsqk = self.qk(x).reshape(b, n, 2, c).permute(2, 0, 1, 3)q, k, v = qk[0], qk[1], x# q, k, v: b, n, cq = self.elu(q) + 1.0k = self.elu(k) + 1.0q_rope = self.rope(q.reshape(b, h, w, c)).reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)k_rope = self.rope(k.reshape(b, h, w, c)).reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)z = 1 / (q @ k.mean(dim=-2, keepdim=True).transpose(-2, -1) + 1e-6)kv = (k_rope.transpose(-2, -1) * (n ** -0.5)) @ (v * (n ** -0.5))x = q_rope @ kv * zx = x.transpose(1, 2).reshape(b, n, c)v = v.transpose(1, 2).reshape(b, h, w, c).permute(0, 3, 1, 2)x = x + self.lepe(v).permute(0, 2, 3, 1).reshape(b, n, c)x = x.transpose(2, 1).reshape((b, c, h, w))return xdef extra_repr(self) -> str:return f'dim={self.dim}, num_heads={self.num_heads}'if __name__ == "__main__":# Generating Sample imageimage_size = (1, 64, 160, 160)image = torch.rand(*image_size)# Modelmodel = MLLAttention(64)out = model(image)print(out.size())

修改一

第一还是建立文件,我们找到如下ultralvtics/n文件夹下建立一个目录名字呢就是'Addmodules文件夹(用群内的文件的话已经有了无需新建)!然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。

修改二

第二步我们在该目录下创建一个新的py文件名字为'  __init__ .py,然后在其内部导入我们的检测头如
下图所示。

修改三 

第三步我门中到如下文件uitralytics/nn/tasks.py进行导入和注册我们的模块

修改四

按照我的添加在parse model里添加即可。

修改5

修改6 配置yolov8-MLLA.yaml文件

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
 
# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOP
 
# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9
 
# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12
 
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)
 
 
  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)
 
 
  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)
  - [-1, 1, MLLAttention, []]  # 22 (P5/32-large) # 添加在大目标检测层后!
 
  - [[15, 18, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

7. 训练代码

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLOif __name__ == '__main__':model = YOLO('yolov8-MLLA.yaml')# 如何切换模型版本, 上面的ymal文件可以改为 yolov8s.yaml就是使用的v8s,# 类似某个改进的yaml文件名称为yolov8-XXX.yaml那么如果想使用其它版本就把上面的名称改为yolov8l-XXX.yaml即可(改的是上面YOLO中间的名字不是配置文件的)!# model.load('yolov8n.pt') # 是否加载预训练权重,科研不建议大家加载否则很难提升精度model.train(data=r"C:\Users\Administrator\PycharmProjects\yolov5-master\yolov5-master\Construction Site Safety.v30-raw-images_latestversion.yolov8\data.yaml",# 如果大家任务是其它的'ultralytics/cfg/default.yaml'找到这里修改task可以改成detect, segment, classify, posecache=False,imgsz=640,epochs=150,single_cls=False,  # 是否是单类别检测batch=16,close_mosaic=0,workers=0,device='0',optimizer='SGD', # using SGD# resume='runs/train/exp21/weights/last.pt', # 如过想续训就设置last.pt的地址amp=True,  # 如果出现训练损失为Nan可以关闭ampproject='runs/train',name='exp',)

8.开启训练

专栏推荐

专栏将持续收集整理市场上深度学习的相关项目,旨在为准备从事深度学习工作或相关科研活动的伙伴,储备、提升更多的实际开发经验,每个项目实例都可作为实际开发项目写入简历,且都附带完整的代码与数据集。可通过百度云盘进行获取,实现开箱即用

正在跟新中~

深度学习落地实战_机 _ 长的博客-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/383897.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

与Bug较量:Codigger之软件项目体检Software Project HealthCheck来帮忙

在软件工程师的世界里,与 Java 小程序中的 Bug 作战是一场永不停歇的战役。每一个隐藏在代码深处的 Bug 都像是一个狡猾的敌人,时刻准备着给我们的项目带来麻烦。 最近,我就陷入了这样一场与 Java 小程序 Bug 的激烈较量中。这个小程序原本应…

数据结构 | LinkedList与链表

前言 ArrayList底层使用连续的空间,任意位置(尤其是0位置下标)插入或删除元素时,需要将该位置后序元素 整体 往前或往后搬移,故时间复杂度为O(N). 优点(给定一个下标,可以快速查找到对应的元素,时间复杂度为O(1))增容需要申请新空间,拷贝数据,释放旧空间,会有不小的消耗.增容一…

【Redis进阶】集群

1. 集群分片算法 1.1 集群概述 首先对于"集群"这个概念是存在不同理解的: 广义的"集群":表示由多台主机构成的分布式系统,称为"集群"狭义的"集群":指的是redis提供的一种集群模式&…

uniapp中@click或者@tap多层嵌套的问题解决方法

我们在开发页面的过程中。例如要设计一个九宫格的相册,并且加上删除上传图片和点击图片后预览图片大图的功能例如下图的演示功能。 点击图片后显示大图预览图片,点击x号后要删除掉当前的图片,那么我们设计的时候如果我们代码写成如下的格式 …

谷粒商城实战笔记-55-商品服务-API-三级分类-修改-拖拽数据收集

文章目录 一,拖拽后结点的parentCid的更新二,拖拽后结点的父节点下所有结点的sort排序属性的变化更新排序的逻辑代码分析 三,拖拽后结点及其子节点catLevel的变化判断是否需要更新 catLevel获取拖动后的新节点 更新 catLevel完整代码 这一节的…

基于JSP、java、Tomcat、mysql三层交互的项目实战--校园交易网(2)登录,注册功能实现

技术支持:JAVA、JSP 服务器:TOMCAT 7.0.86 编程软件:IntelliJ IDEA 2021.1.3 x64 登陆页面如下 在这个页面中我们实现了一个登录页面和一个注册页面的Jsp文件,和两个java 的服务层文件 分别是web包下的denglu.jsp和zhuce.jsp以…

“微软蓝屏”事件,给IT行业带来的宝贵经验和教训

“微软蓝屏”事件是指2024年7月19日发生的一次全球性技术故障,主要涉及微软视窗(Windows)操作系统及其相关应用和服务。 以下是对该事件的详细解析: 一、事件概述 发生时间:2024年7月19日事件影响:全球多个…

2023河南萌新联赛第(二)场 南阳理工学院

A. 国际旅行Ⅰ 题目&#xff1a; 思路&#xff1a; 因为题意上每个国家可以相互到达&#xff0c;所以只需要排序&#xff0c;输出第k小的值就可以了。 AC代码&#xff1a; #include<bits/stdc.h> #define int long long #define IOS ios::sync_with_stdio(0);cin.tie…

基于k8s快速搭建docker镜像服务的demo

基于k8s快速搭建docker镜像服务的demo 一、环境准备 如标题&#xff0c;你需要环境中有和2个平台&#xff0c;并且服务器上也已经安装好docker服务 接下来我来构建一个docker镜像&#xff0c;然后使用harbork8s来快速部署服务demo 二、部署概述 使用docker构建镜像&#x…

HTML常见标签——超链接a标签

一、a标签简介 二、a标签属性 href属性 target属性 三、a标签的作用 利用a标签进行页面跳转 利用a标签返回页面顶部以及跳转页面指定区域 利用a标签实现文件下载 一、a标签简介 <a>标签用于做跳转、导航&#xff0c;是双标签&#xff0c;记作<a></a>&#…

zyx青岛实训day14 7/25

Git 一种分布式版本控制系统&#xff0c;用于跟踪和管理代码的变更 一&#xff0e;Git的主要功能&#xff1a; 二&#xff0e;准备git机器 修改静态ip&#xff0c;主机名 三&#xff0e;git仓库的建立&#xff1a; 1.安装git [rootgit ~]# yum -y install git 2.创建一个…

ROS2入门到精通—— 2-8 ROS2实战:机器人安全通过狭窄区域的方案

0 前言 室内机器人需要具备适应性和灵活性&#xff0c;以便在狭窄的空间中进行安全、高效的导航。本文提供一些让机器人在狭窄区域安全通过的思路&#xff0c;希望帮助读者根据实际开发适当调整和扩展 1 Voronoi图 Voronoi图&#xff1a;根据给定的一组“种子点”&#xff0…

在英特尔 Gaudi 2 上加速蛋白质语言模型 ProtST

引言 蛋白质语言模型 (Protein Language Models, PLM) 已成为蛋白质结构与功能预测及设计的有力工具。在 2023 年国际机器学习会议 (ICML) 上&#xff0c;MILA 和英特尔实验室联合发布了ProtST模型&#xff0c;该模型是个可基于文本提示设计蛋白质的多模态模型。此后&#xff0…

项目一缓存商品

文章目录 概要整体架构流程技术细节小结 概要 因为商品是经常被浏览的,所以数据库的访问量就问大大增加,造成负载过大影响性能,所以我们需要把商品缓存到redis当中,因为redis是存在内存中的,所以效率会比MySQL的快. 整体架构流程 技术细节 我们在缓存时需要保持数据的一致性所…

如何实现Web服务只允许特定客户端访问

如何实现Web服务只允许特定客户端访问 需求来源 为了满足B/S系统给客户演示的需要&#xff0c;需要部署一套系统允许公网能够访问&#xff0c;便于业务人员到客户哪里进行系统演示&#xff0c;但是目前网络安全非常重要&#xff0c;希望能防止暴力破解或者端口扫描等黑客攻击…

C#/WinFrom TCP通信+ 网线插拔检测+客服端异常掉线检测

Winfor Tcp通信(服务端) 今天给大家讲一下C# 关于Tcp 通信部分&#xff0c;这一块的教程网上一大堆&#xff0c;不过关于掉网&#xff0c;异常断开连接的这部分到是到是没有多少说明&#xff0c;有方法 不过基本上最多的两种方式&#xff08;1.设置一个超时时间&#xff0c;2.…

【Python面试题收录】Python编程基础练习题②(数据类型+文件操作+时间操作)

本文所有代码打包在Gitee仓库中https://gitee.com/wx114/Python-Interview-Questions 一、数据类型 第一题 编写一个函数&#xff0c;实现&#xff1a;先去除左右空白符&#xff0c;自动检测输入的数据类型&#xff0c;如果是整数就转换成二进制形式并返回出结果&#xff1b…

第二讲:NJ网络配置

Ethernet/IP网络拓扑结构 一. NJ EtherNet/IP 1、网络端口位置 NJ的CPU上面有两个RJ45的网络接口,其中一个是EtherNet/IP网络端口(另一个是EtherCAT的网络端口) 2、网络作用 如图所示,EtherNet/IP网络既可以做控制器与控制器之间的通信,也可以实现与上位机系统的对接通…

安防视频监控EasyCVR视频汇聚平台修改配置后无法启动的原因排查与解决

安防视频监控/视频集中存储/云存储/磁盘阵列EasyCVR平台基于云边端一体化架构&#xff0c;兼容性强、支持多协议接入&#xff0c;包括国标GB/T 28181协议、部标JT808、GA/T 1400协议、RTMP、RTSP/Onvif协议、海康Ehome、海康SDK、大华SDK、华为SDK、宇视SDK、乐橙SDK、萤石云SD…

RV1126 Linux 系统,接外设,时好时坏(二)排查问题的常用命令

在 RV1126 Linux 系统中,排查外设连接问题时,可以使用多种命令来诊断和调试。以下是一些常用的命令和工具: 1. 查看系统日志 dmesg: 显示内核环形缓冲区的消息,通常包含设备初始化、驱动加载和错误等信息。 dmesg | grep <设备名或相关关键字>journalctl: 查看系统…