YOLO11改进|注意力机制篇|引入局部注意力HaloAttention

在这里插入图片描述

目录

    • 一、【HaloAttention】注意力机制
      • 1.1【HaloAttention】注意力介绍
      • 1.2【HaloAttention】核心代码
    • 二、添加【HaloAttention】注意力机制
      • 2.1STEP1
      • 2.2STEP2
      • 2.3STEP3
      • 2.4STEP4
    • 三、yaml文件与运行
      • 3.1yaml文件
      • 3.2运行成功截图

一、【HaloAttention】注意力机制

1.1【HaloAttention】注意力介绍

在这里插入图片描述

下图是【HaloAttention】的结构图,让我们简单分析一下运行过程和优势

处理过程

  • 图像分块:

  • 输入图像大小为 4×4×𝑐,其中 𝑐
    是通道数。该图像首先被分割为多个小块(如图所示被分为 4 个 2×2×𝑐的小块),每个块称为一个“block”。

  • Haloing 操作:

  • 在图像分块后,使用 haloing 操作扩展每个小块的边界。图中显示的是一个 halo 值为 1 的情况,即每个小块在其原有区域上扩展了 1 个像素的边界,形成了带有额外边界信息的邻域窗口。这一操作目的是为了在计算注意力时捕获块与块之间的上下文信息。

  • 邻域窗口计算:

  • Haloing 之后,每个小块拥有邻近区域的信息,即在扩展后的邻域窗口中包含了来自周围小块的部分信息。图中显示了每个小块及其周围邻域的窗口(如红色小块与其邻域的相关部分)。

  • 查询与注意力机制:

  • 在邻域窗口中应用 注意力机制。以每个小块作为查询(Query),与其扩展后的邻域窗口进行注意力计算,从中提取重要的上下文特征。注意力机制的引入使得每个小块不仅能够学习到自身的特征,还能从周围的块中获取相关的上下文信息,从而增强特征表达。

  • 输出:

  • 通过注意力机制的加权输出每个小块的结果,形成新的特征图。输出的特征图大小仍然是分块前的大小,但每个块内的特征已经经过上下文增强和融合。
    优势

  • 降低计算复杂度:

  • 通过将图像分割成小块并只在局部区域内应用注意力机制,减少了全局自注意力带来的高计算开销。这种方法可以大幅度降低计算复杂度,特别适合处理高分辨率图像或大规模数据集。

  • 局部上下文捕获:

  • Haloing 操作的引入使得每个块在计算注意力时能够感知到其邻域的上下文信息,克服了仅依赖自身区域的局限性。因此,它能够更好地捕捉局部细节和相关性,特别是在需要高精度定位的任务中(如图像分割或检测任务)。

  • 有效的特征增强:

  • 通过分块后的注意力机制,模型可以集中计算各个小块的注意力权重,并在局部范围内提升特征表达能力。这样可以避免全局注意力在大图像上计算时引入的冗余信息,同时仍能保证特征的有效整合。

  • 灵活性强:

  • 该方法可广泛应用于图像分类、目标检测、语义分割等任务中,并且可以根据实际需求调整分块大小和 halo 值,灵活适应不同的计算资源和任务要求。在这里插入图片描述

1.2【HaloAttention】核心代码

import torch
from torch import nn, einsum
import torch.nn.functional as Ffrom einops import rearrange, repeatdef to(x):return {"device": x.device, "dtype": x.dtype}def pair(x):return (x, x) if not isinstance(x, tuple) else xdef expand_dim(t, dim, k):t = t.unsqueeze(dim=dim)expand_shape = [-1] * len(t.shape)expand_shape[dim] = kreturn t.expand(*expand_shape)def rel_to_abs(x):b, l, m = x.shaper = (m + 1) // 2col_pad = torch.zeros((b, l, 1), **to(x))x = torch.cat((x, col_pad), dim=2)flat_x = rearrange(x, "b l c -> b (l c)")flat_pad = torch.zeros((b, m - l), **to(x))flat_x_padded = torch.cat((flat_x, flat_pad), dim=1)final_x = flat_x_padded.reshape(b, l + 1, m)final_x = final_x[:, :l, -r:]return final_xdef relative_logits_1d(q, rel_k):b, h, w, _ = q.shaper = (rel_k.shape[0] + 1) // 2logits = einsum("b x y d, r d -> b x y r", q, rel_k)logits = rearrange(logits, "b x y r -> (b x) y r")logits = rel_to_abs(logits)logits = logits.reshape(b, h, w, r)logits = expand_dim(logits, dim=2, k=r)return logitsclass RelPosEmb(nn.Module):def __init__(self, block_size, rel_size, dim_head):super().__init__()height = width = rel_sizescale = dim_head**-0.5self.block_size = block_sizeself.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)def forward(self, q):block = self.block_sizeq = rearrange(q, "b (x y) c -> b x y c", x=block)rel_logits_w = relative_logits_1d(q, self.rel_width)rel_logits_w = rearrange(rel_logits_w, "b x i y j-> b (x y) (i j)")q = rearrange(q, "b x y d -> b y x d")rel_logits_h = relative_logits_1d(q, self.rel_height)rel_logits_h = rearrange(rel_logits_h, "b x i y j -> b (y x) (j i)")return rel_logits_w + rel_logits_hclass HaloAttention(nn.Module):def __init__(self, dim, block_size, halo_size, dim_head=64, heads=8):super().__init__()assert halo_size > 0, "halo size must be greater than 0"self.dim = dimself.heads = headsself.scale = dim_head**-0.5self.block_size = block_sizeself.halo_size = halo_sizeinner_dim = dim_head * headsself.rel_pos_emb = RelPosEmb(block_size=block_size,rel_size=block_size + (halo_size * 2),dim_head=dim_head,)self.to_q = nn.Linear(dim, inner_dim, bias=False)self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)self.to_out = nn.Linear(inner_dim, dim)def forward(self, x):b, c, h, w, block, halo, heads, device = (*x.shape,self.block_size,self.halo_size,self.heads,x.device,)assert (h % block == 0 and w % block == 0), "fmap dimensions must be divisible by the block size"assert (c == self.dim), f"channels for input ({c}) does not equal to the correct dimension ({self.dim})"# get block neighborhoods, and prepare a halo-ed version (blocks with padding) for deriving key valuesq_inp = rearrange(x, "b c (h p1) (w p2) -> (b h w) (p1 p2) c", p1=block, p2=block)kv_inp = F.unfold(x, kernel_size=block + halo * 2, stride=block, padding=halo)kv_inp = rearrange(kv_inp, "b (c j) i -> (b i) j c", c=c)# derive queries, keys, valuesq = self.to_q(q_inp)k, v = self.to_kv(kv_inp).chunk(2, dim=-1)# split headsq, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=heads), (q, k, v))# scaleq *= self.scale# attentionsim = einsum("b i d, b j d -> b i j", q, k)# add relative positional biassim += self.rel_pos_emb(q)# mask out padding (in the paper, they claim to not need masks, but what about padding?)mask = torch.ones(1, 1, h, w, device=device)mask = F.unfold(mask, kernel_size=block + (halo * 2), stride=block, padding=halo)mask = repeat(mask, "() j i -> (b i h) () j", b=b, h=heads)mask = mask.bool()max_neg_value = -torch.finfo(sim.dtype).maxsim.masked_fill_(mask, max_neg_value)# attentionattn = sim.softmax(dim=-1)# aggregateout = einsum("b i j, b j d -> b i d", attn, v)# merge and combine headsout = rearrange(out, "(b h) n d -> b n (h d)", h=heads)out = self.to_out(out)# merge blocks back to original feature mapout = rearrange(out,"(b h w) (p1 p2) c -> b c (h p1) (w p2)",b=b,h=(h // block),w=(w // block),p1=block,p2=block,)return outif __name__ == "__main__":input = torch.rand(3, 32, 64, 64).cuda()model = HaloAttention(dim=32,block_size=2,halo_size=1,).cuda()output = model(input)print(input.size(), output.size())

二、添加【HaloAttention】注意力机制

2.1STEP1

首先找到ultralytics/nn文件路径下新建一个Add-module的python文件包【这里注意一定是python文件包,新建后会自动生成_init_.py】,如果已经跟着我的教程建立过一次了可以省略此步骤,随后新建一个HaloAttention.py文件并将上文中提到的注意力机制的代码全部粘贴到此文件中,如下图所示在这里插入图片描述

2.2STEP2

在STEP1中新建的_init_.py文件中导入增加改进模块的代码包如下图所示在这里插入图片描述

2.3STEP3

找到ultralytics/nn文件夹中的task.py文件,在其中按照下图添加在这里插入图片描述

2.4STEP4

定位到ultralytics/nn文件夹中的task.py文件中的def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)函数添加如图代码,【如果不好定位可以直接ctrl+f搜索定位】

在这里插入图片描述

三、yaml文件与运行

3.1yaml文件

以下是添加【HaloAttention】注意力机制在Backbone中的yaml文件,大家可以注释自行调节,效果以自己的数据集结果为准

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'# [depth, width, max_channels]n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPss: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPsm: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPsl: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPsx: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs# YOLO11n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128,3,2]] # 1-P2/4- [-1, 2, C3k2, [256, False, 0.25]]- [-1, 1, Conv, [256,3,2]] # 3-P3/8- [-1, 2, C3k2, [512, False, 0.25]]- [-1, 1, Conv, [512,3,2]] # 5-P4/16- [-1, 2, C3k2, [512, True]]- [-1, 1, Conv, [1024,3,2]] # 7-P5/32- [-1, 2, C3k2, [1024, True]]- [-1, 1, HaloAttention, [2, 1]]- [-1, 1, SPPF, [1024, 5]] # 9- [-1, 2, C2PSA, [1024]] # 10# YOLO11n head
head:- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 6], 1, Concat, [1]] # cat backbone P4- [-1, 2, C3k2, [512, False]] # 13- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 4], 1, Concat, [1]] # cat backbone P3- [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 14], 1, Concat, [1]] # cat head P4- [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 11], 1, Concat, [1]] # cat head P5- [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)- [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

以上添加位置仅供参考,具体添加位置以及模块效果以自己的数据集结果为准

3.2运行成功截图

在这里插入图片描述

OK 以上就是添加【HaloAttention】注意力机制的全部过程了,后续将持续更新尽情期待

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/444033.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用FastAPI做人工智能后端服务器时,接口内的操作不是异步操作的解决方案

在做AI模型推理的接口时,这时候接口是非异步的,但是uvicorn运行FastAPI时就会出现阻塞所有请求。 这时候需要解决这个问题: api.py: import asyncio from fastapi import FastAPI from fastapi.responses import StreamingResp…

嵌入式开发:STM32 硬件 CRC 使用

测试平台:STM32G474系列 STM32硬件的CRC不占用MCU的资源,计算速度快。由于硬件CRC需要配置一些选项,配置不对就会导致计算结果错误,导致使用上没有软件计算CRC方便。但硬件CRC更快的速度在一些有时间资源要求的场合还是非…

从 Reno TCP 到 Scalable TCP,HighSpeed TCP

前文 Scalable TCP 如何优化长肥管道 介绍了 Scalable TCP,但联系另一个类似的算法 HighSpeed TCP(简称 HSTCP),就会看到一个类似从 Reno TCP 经 BIC 到 CUBIC 的路线,但采用了不同的策略。 Reno TCP 经 BIC 到 CUBIC 路线的核心在于 “在长…

4反馈、LC、石英、RC振荡器

1什么是振荡器? 我们看看振荡器在无线通信中扮演什么角色? 1)无线通信的波是指电磁波‌。 2‌)电磁波的频率高于100KHz才能在空气中传播。‌ 3)空气中的高频电磁波的相位和振幅可以排列组合包含信息。 4)无…

DBMS-3.4 SQL(4)——存储过程和函数触发器

本文章的素材与知识来自李国良老师和王珊老师。 存储过程和函数 一.存储过程 1.语法 2.示例 (1) 使用DELIMITER更换终止符后用于编写存储过程语句后,在下次执行SQL语句时记得再使用DELIMITER将终止符再换回分号。 使用DELIMITER更换终止符…

Ubuntu 22.04.4 LTS更换下载源

方法1:使用图形界面更换下载源 1. 打开软件和更新应用 2. 在Ubuntu 软件标签中,点击“下载自”旁边的下拉菜单,选择“其他” 3. 点击“选择最佳服务器”来自动选择最快的服务器 4. 选择服务器 5. 确定并关闭窗口,系统会提示您重新…

ElasticSearch备考 -- Multi match

一、题目 索引task有3个字段a、b、c,写一个查询去匹配这三个字段为mom,其中b的字段评分比a、c字段大一倍,将他们的分数相加作为最后的总分数 二、思考 通过题目要求对多个字段进行匹配查询,可以考虑multi match、bool query操作。…

【C++第十八章】Map和Set

Map和Set map和set的介绍 容器分为两种,序列式容器和关联式容器,序列式容器因为底层是线性序列的数据结构,存储的是元素本身,而关联式容器中不单是为了存储数据,还要进行查找,所以存储的是键值对&#xff…

网络编程(17)——asio多线程模型IOThreadPool

十七、day17 之前我们介绍了IOServicePool的方式,一个IOServicePool开启n个线程和n个iocontext,每个线程内独立运行iocontext, 各个iocontext监听各自绑定的socket是否就绪,如果就绪就在各自线程里触发回调函数。为避免线程安全问题&#xf…

腾讯云SDK点播播放数据

点播播放质量监控提供点播播放全链路的数据统计、质量监控及可视化分析服务。支持实时数据上报、数据聚合、多维筛选和精细化定向分析,可帮助企业实时掌控大盘运营状况、了解用户习惯和行为特征,有效指导运营决策、驱动业务增长。 注意事项 点播播放质…

Python 工具库每日推荐 【Pandas】

文章目录 引言Python数据处理库的重要性今日推荐:Pandas工具库主要功能:使用场景:安装与配置快速上手示例代码代码解释实际应用案例案例:销售数据分析案例分析高级特性数据合并和连接时间序列处理数据透视表扩展阅读与资源优缺点分析优点:缺点:总结【 已更新完 TypeScrip…

基于 CSS Grid 的简易拖拉拽 Vue3 组件,从代码到NPM发布(1)- 拖拉拽交互

基于特定的应用场景,需要在页面中以网格的方式,实现目标组件在网格中可以进行拖拉拽、修改大小等交互。本章开始分享如何一步步从代码设计,最后到如何在 NPM 上发布。 请大家动动小手,给我一个免费的 Star 吧~ 大家如果发现了 Bug…

探索未来:mosquitto-python,AI领域的新宠

文章目录 探索未来:mosquitto-python,AI领域的新宠背景:为何选择mosquitto-python?库简介:mosquitto-python是什么?安装指南:如何安装mosquitto-python?函数用法:5个简单…

代码随想录算法训练营第四十六天 | 647. 回文子串,516.最长回文子序列

四十六天打卡,今天用动态规划解决回文问题,回文问题需要用二维dp解决 647.回文子串 题目链接 解题思路 没做出来,布尔类型的dp[i][j]:表示区间范围[i,j] (注意是左闭右闭)的子串是否是回文子串&#xff0…

深入理解Transformer的笔记记录(精简版本)---- Transformer

自注意力机制开启大规模预训练时代 1 从机器翻译模型举例 1.1把编码器和解码器联合起来看待的话,则整个流程就是(如下图从左至右所示): 1.首先,从编码器输入的句子会先经过一个自注意力层(即self-attention),它会帮助编码器在对每个单词编码时关注输入句子中的的其他单…

【JavaEE】——回显服务器的实现

阿华代码,不是逆风,就是我疯 你们的点赞收藏是我前进最大的动力!! 希望本文内容能够帮助到你!! 目录 一:引入 1:基本概念 二:UDP socket API使用 1:socke…

2-118 基于matlab的六面体建模和掉落仿真

基于matlab的六面体建模和掉落仿真,将对象建模为刚体来模拟将立方体扔到地面上。同时考虑地面摩擦力、刚度和阻尼所施加的力,在三个维度上跟踪平移运动和旋转运动。程序已调通,可直接运行。 下载源程序请点链接:2-118 基于matla…

基于SpringBoot“花开富贵”花园管理系统【附源码】

效果如下: 系统注册页面 系统首页界面 植物信息详细页面 后台登录界面 管理员主界面 植物分类管理界面 植物信息管理界面 园艺记录管理界面 研究背景 随着城市化进程的加快和人们生活质量的提升,越来越多的人开始追求与自然和谐共生的生活方式&#xf…

使用激光跟踪仪提升码垛机器人精度

标题1.背景 码垛机器人是一种用于工业自动化的机器人,专门设计用来将物品按照一定的顺序和结构堆叠起来,通常用于仓库、物流中心和生产线上,它们可以自动执行重复的、高强度的搬运和堆垛任务。 图1 码垛机器人 传统调整码垛机器人的方法&a…

通信工程学习:什么是DIP数据集成点

DIP:数据集成点 DIP数据集成点(Data Integration Point),简称DIP,是物联网技术(IoT)和机器到机器(M2M)通信中的一个重要组成部分。DIP在数据集成和传输过程中扮演着关键角…