YOLOv5改进 | 注意力机制 | 迈向高质量像素级回归的极化自注意力【全网独家】

秋招面试专栏推荐 :深度学习算法工程师面试问题总结【百面算法工程师】——点击即可跳转   


💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡


专栏目录: 《YOLOv5入门 + 改进涨点》专栏介绍 & 专栏目录 |目前已有40+篇内容,内含各种Head检测头、损失函数Loss、Backbone、Neck、NMS等创新点改进


虽然深度卷积神经网络中的注意力机制已经变得流行,用于增强长距离依赖,但元素特定的注意力,如Nonlocal块,学习起来非常复杂且对噪声敏感,而且大多数简化的注意力混合方法试图在多种类型的任务之间达到最佳的折衷。为此研究人员提出了极化自注意力(PSA)块,它结合了两个关键设计,以实现高质量的像素级回归。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改并将修改后的完整代码放在文章的最后方便大家一键运行小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。

专栏地址YOLOv5改进+入门——持续更新各种有效涨点方法——点击即可跳转

目录

1.原理

2. 将极化自注意力添加到YOLOv5中

2.1 PSA代码实现

2.2 新增yaml文件

2.3 注册模块

2.4 执行程序

3. 完整代码分享

4. GFLOPs

5. 进阶

6. 总结


1.原理

论文地址:Polarized Self-Attention: Towards High-quality Pixel-wise Regression——点击即可跳转

官方代码:官方代码仓库——点击即可跳转

极化自注意力(Polarized Self-Attention,PSA)是一种针对像素级回归任务设计的注意力机制,旨在提高模型在关键点估计和语义分割等任务上的性能。其主要原理可以概括为以下几点: 1. 极化过滤

  • PSA 将注意力机制应用于输入特征图,将其分为两个分支:通道分支和空间分支。

  • 在通道分支中,PSA 沿着通道维度保持高分辨率,同时将特征图在空间维度上完全折叠,从而减少了计算量和内存占用。

  • 在空间分支中,PSA 沿着空间维度保持高分辨率,同时将特征图在通道维度上完全折叠。

  • 这种“极化”的设计方式有效地保留了高分辨率信息,避免了传统自注意力机制中由于池化或下采样造成的分辨率损失。 2. 非线性增强

  • PSA 在通道分支和空间分支中分别使用了 Softmax 和 Sigmoid 函数的组合来增强非线性。

  • Softmax 函数用于将通道分支中的特征图转化为概率分布,使其更适合表示高斯分布(例如关键点热图)。

  • Sigmoid 函数用于将空间分支中的特征图转化为二值分布,使其更适合表示二项分布(例如分割掩码)。

  • 这种非线性组合的设计能够更好地拟合像素级回归任务的输出分布,从而提高模型的预测精度。

3. 布局灵活

  • PSA 支持两种布局方式:并行和串行。

  • 在并行布局中,通道分支和空间分支的结果直接相加。

  • 在串行布局中,通道分支的结果首先通过空间分支,然后两者相加。

  • 实验表明,两种布局方式在性能上没有显著差异,这表明 PSA 的设计已经充分挖掘了通道和空间维度上的信息。

4. 优势

  • PSA 在不显著增加计算量和内存占用的情况下,能够显著提高像素级回归任务的性能。

  • PSA 的设计能够更好地拟合像素级回归任务的输出分布,从而提高模型的预测精度。

  • PSA 的布局方式灵活,可以根据具体任务进行调整。

总结

PSA 通过“极化”的设计方式,有效地保留了高分辨率信息,并通过非线性增强来拟合像素级回归任务的输出分布,从而显著提高了模型的性能。

2. 将极化自注意力添加到YOLOv5中

2.1 PSA代码实现

关键步骤一:将下面代码粘贴到/yolov5-6.1/models/common.py文件中

class PSA(nn.Module):
​def __init__(self, channel=512):super().__init__()self.ch_wv = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1))self.ch_wq = nn.Conv2d(channel, 1, kernel_size=(1, 1))self.softmax_channel = nn.Softmax(1)self.softmax_spatial = nn.Softmax(-1)self.ch_wz = nn.Conv2d(channel // 2, channel, kernel_size=(1, 1))self.ln = nn.LayerNorm(channel)self.sigmoid = nn.Sigmoid()self.sp_wv = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1))self.sp_wq = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1))self.agp = nn.AdaptiveAvgPool2d((1, 1))
​def forward(self, x):b, c, h, w = x.size()
​# Channel-only Self-Attentionchannel_wv = self.ch_wv(x)  # bs,c//2,h,wchannel_wq = self.ch_wq(x)  # bs,1,h,wchannel_wv = channel_wv.reshape(b, c // 2, -1)  # bs,c//2,h*wchannel_wq = channel_wq.reshape(b, -1, 1)  # bs,h*w,1channel_wq = self.softmax_channel(channel_wq)channel_wz = torch.matmul(channel_wv, channel_wq).unsqueeze(-1)  # bs,c//2,1,1channel_weight = self.sigmoid(self.ln(self.ch_wz(channel_wz).reshape(b, c, 1).permute(0, 2, 1))).permute(0, 2,1).reshape(b, c, 1, 1)  # bs,c,1,1channel_out = channel_weight * x
​# Spatial-only Self-Attentionspatial_wv = self.sp_wv(channel_out)  # bs,c//2,h,wspatial_wq = self.sp_wq(channel_out)  # bs,c//2,h,wspatial_wq = self.agp(spatial_wq)  # bs,c//2,1,1spatial_wv = spatial_wv.reshape(b, c // 2, -1)  # bs,c//2,h*wspatial_wq = spatial_wq.permute(0, 2, 3, 1).reshape(b, 1, c // 2)  # bs,1,c//2spatial_wq = self.softmax_spatial(spatial_wq)spatial_wz = torch.matmul(spatial_wq, spatial_wv)  # bs,1,h*wspatial_weight = self.sigmoid(spatial_wz.reshape(b, 1, h, w))  # bs,1,h,wspatial_out = spatial_weight * channel_outreturn spatial_out

PSA (Polarized Self-Attention) 是一种针对像素级回归任务设计的注意力机制,其主要流程可以分为以下几个步骤:

1. 特征提取

  • 首先,使用卷积神经网络(CNN)对输入图片进行特征提取,得到特征图。特征图包含了图片中每个像素点的特征信息。

2. 极化自注意力

  • 将特征图输入到 PSA 模块进行特征增强。

  • PSA 模块包含两个分支:通道分支和空间分支。

    • 通道分支:沿着通道维度保持高分辨率,同时将特征图在空间维度上完全折叠。

    • 空间分支:沿着空间维度保持高分辨率,同时将特征图在通道维度上完全折叠。

  • 在通道分支和空间分支中,分别使用 Softmax 和 Sigmoid 函数的组合来增强非线性,从而更好地拟合像素级回归任务的输出分布。

3. 输出

  • 将 PSA 模块处理后的特征图输入到解码器中,进行进一步的预测。

  • 解码器会根据具体的任务类型(例如关键点估计或语义分割)生成相应的输出,例如关键点热图或分割掩码。

4. 优化

  • 使用损失函数来评估模型预测结果与真实标签之间的差距。

  • 通过反向传播算法来更新模型参数,从而优化模型的性能。

总结: PSA 的主要流程可以概括为:特征提取 -> 极化自注意力 -> 输出 -> 优化。通过引入 PSA 模块,可以有效地提高模型在像素级回归任务上的性能。

2.2 新增yaml文件

关键步骤二在下/yolov5-6.1/models下新建文件 yolov5_PSA.yaml并将下面代码复制进去

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters
nc: 80  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, C3, [128]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, SPPF, [1024, 5]],  # 9]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 6], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 4], 1, Concat, [1]],  # cat backbone P3[-1, 3, C3, [256, False]],  # 17 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 14], 1, Concat, [1]],  # cat head P4[-1, 3, C3, [512, False]],  # 20 (P4/16-medium)[-1, 1, PSA, [ 512 ]],[-1, 1, Conv, [512, 3, 2]],[[-1, 10], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [1024, False]],  # 23 (P5/32-large)[-1, 1, PSA, [ 1024 ]],[[17, 21, 25], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

温馨提示:本文只是对yolov5l基础上添加模块,如果要对yolov8n/l/m/x进行添加则只需要指定对应的depth_multiple 和 width_multiple。


# YOLOv5n
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple# YOLOv5s
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple# YOLOv5l 
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple# YOLOv5m
depth_multiple: 0.67  # model depth multiple
width_multiple: 0.75  # layer channel multiple# YOLOv5x
depth_multiple: 1.33  # model depth multiple
width_multiple: 1.25  # layer channel multiple

2.3 注册模块

关键步骤三在yolo.py中注册 注册”PSA",

elif m is PSA :c1, c2 = ch[f], args[0]if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)c2 = make_divisible(c2 * gw, 8)args = [c1, *args[1:]]

2.4 执行程序

在train.py中,将cfg的参数路径设置为yolov5_PSA.yaml的路径

建议大家写绝对路径,确保一定能找到

🚀运行程序,如果出现下面的内容则说明添加成功🚀

3. 完整代码分享

https://pan.baidu.com/s/1evAM-6V-IPEyOOglheWNMw?pwd=we72

提取码: we72 

4. GFLOPs

关于GFLOPs的计算方式可以查看:百面算法工程师 | 卷积基础知识——Convolution

未改进的GFLOPs

img

改进后的GFLOPs

5. 进阶

和损失函数可能有意外的收获,这非常有趣,快去试试吧

损失函数相关改进:YOLOv5改进 | 损失函数 | EIoU、SIoU、WIoU、DIoU、FocusIoU等多种损失函数——点击即可跳转

6. 总结

PSA 是一种针对像素级回归任务设计的注意力机制,其核心思想是通过“极化”的方式,在保持高分辨率信息的同时,有效地增强特征表示,从而提高模型的预测精度。具体而言,PSA 将注意力机制应用于输入特征图,将其分为通道分支和空间分支。在通道分支中,PSA 沿着通道维度保持高分辨率,同时将特征图在空间维度上完全折叠;在空间分支中,PSA 沿着空间维度保持高分辨率,同时将特征图在通道维度上完全折叠。这种“极化”的设计方式有效地保留了高分辨率信息,避免了传统自注意力机制中由于池化或下采样造成的分辨率损失。此外,PSA 在通道分支和空间分支中分别使用了 Softmax 和 Sigmoid 函数的组合来增强非线性,从而更好地拟合像素级回归任务的输出分布。这种非线性组合的设计能够更好地拟合像素级回归任务的输出分布,例如高斯分布(例如关键点热图)或二项分布(例如分割掩码),从而提高模型的预测精度。最后,PSA 支持两种布局方式:并行和串行,两种布局方式在性能上没有显著差异,这表明 PSA 的设计已经充分挖掘了通道和空间维度上的信息。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/364163.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux高并发服务器开发(二)系统调用函数

文章目录 1 系统调用2 errno3 虚拟内存空间4 文件描述符5 常用文件IO函数6 阻塞和非阻塞7 lseek 偏移函数8 文件操作函数之stat函数9 文件描述符复制 dup10 fcnlt函数 修改文件属性11 目录相关操作12 时间相关函数 1 系统调用 根据系统调用,获取驱动信息、CPU的信息…

数据平台发展史-从数据仓库数据湖到数据湖仓

做数据的同学经常听到一些数据相关的术语,常见的包括数据仓库,逻辑数据仓库,数据湖,数据湖仓/湖仓一体,数据网格 data mesh,数据编织 data fabric等. 笔者在这里回顾了下数据平台的发展史,也介绍和对比了下…

Nvidia显卡GeForce Experience录屏操作流程

安装软件 首先我们从英伟达官网下载GeForce Experience程序,安装在电脑中GeForce Experience(简称 GFE)自动更新驱动并优化游戏设置 | NVIDIA 登录软件 安装完成后登录 开启录屏功能 登录后点击右上角的设置(小齿轮图标&#x…

(单机版)神魔大陆|v0.51.0|冰火荣耀

前言 今天给大家带来一款单机游戏的架设:神魔大陆v0.51.0:冰火荣耀。 如今市面上的资源参差不齐,大部分的都不能运行,本人亲自测试,运行视频如下: (单机版)神魔大陆 下面我将详细的教程交给大家,请耐心阅…

C语言力扣刷题1——最长回文字串[双指针]

力扣算题1——最长回文字串[双指针] 一、博客声明二、题目描述三、解题思路1、思路说明2、知识补充a、malloc动态内存分配b、free释放内存c、strlen求字符数组长度d、strncpy函数 四、解题代码(附注释) 一、博客声明 找工作逃不过刷题,为了更…

JS在线加密简述

JS在线加密,是指:在线进行JS代码混淆加密。通过混淆、压缩、加密等手段,使得JS源代码难以阅读和理解。从而可以有效防止代码被盗用或抄袭,保护开发者的知识产权和劳动成果。常用的JS在线加密网站有:JShaman、JS-Obfusc…

【ONLYOFFICE 8.1】的安装与使用——功能全面的 PDF 编辑器、幻灯片版式、优化电子表格的协作

🔥 个人主页:空白诗 文章目录 一、引言二、ONLYOFFICE 简介三、安装1. Windows/Mac 安装2. 文档开发者版安装安装前准备使用 Docker 安装使用 Linux 发行版安装配置 ONLYOFFICE 文档开发者版集成和开发 四、使用1. 功能全面的 PDF 编辑器PDF 查看和导航P…

llama.cpp

https://github.com/echonoshy/cgft-llm 【大模型量化】- Llama.cpp轻量化模型部署及量化_哔哩哔哩_bilibili github.com/ggerganov/llama.cpp cd ~/code/llama.cpp/build_cuda/bin ./quantize --allow-requantize /root/autodl-tmp/models/Llama3-8B-Chinese-Chat-GGUF/Llama…

文心一言4.0免费使用

领取&安装链接:Baidu Comate 领取季卡 视频教程:免费使用文心一言4.0大模型_哔哩哔哩_bilibili 有图有真相 原理:百度comate使用文心一言最新的4.0模型。百度comate目前免费使用,可以借助comate达到免费使用4.0模型目的。 …

网页如何快速被收录?

其实就是要要吸引搜索引擎爬虫更快地抓取你的网页,想让爬虫爬取网页,首要做的自然是创建并提交站点地图。站点地图是搜索引擎了解你网站结构的重要工具。它可以帮助爬虫更快地发现和抓取你网站上的所有重要页面。通过Google Search Console提交站点地图&…

p6spy 组件打印完整的 SQL 语句、执行耗时

一、前言 我们来配置一下 Mybatis Plus 打印 SQL 功能(包括执行耗时),一方面可以了解到每个操作都具体执行的什么 SQL 语句, 另一方面通过打印执行耗时,也可以提前发现一些慢 SQL,提前做好优化&#xff0c…

多线程引发的安全问题

前言👀~ 上一章我们介绍了线程的一些基础知识点,例如创建线程、查看线程、中断线程、等待线程等知识点,今天我们讲解多线程下引发的安全问题 线程安全(最复杂也最重要) 产生线程安全问题的原因 锁(重要…

高精度除法的实现

高精度除法与高精度加法的定义、前置过程都是大致相同的,如果想了解具体内容,可以移步至我的这篇博客:高精度加法计算的实现 在这里就不再详细讲解,只讲解主体过程qwq 主体过程 高精度除法的原理和小学学习的竖式除法是一样的。 …

python中lxml库的使用简介

目录 1.ElementTree 类 2.Element 类 3.ElementTree 类或 Element 类的查找方法 为方便开发人员在程序中使用 XPath 的路径表达式提取节点对应的内容, Python 提供了 第三方库 lxml 。开发人员通过 lxml 库可以轻松地对 HTM…

使用obdumper对oceanbase进行备份,指定2881端口

1.安装obdumper (1)下载软件 OceanBase分布式数据库-海量数据 笔笔算数https://www.oceanbase.com/softwarecenter (2)安装软件 参考:https://www.oceanbase.com/docs/common-oceanbase-dumper-loader-100000000062…

qt实现打开pdf(阅读器)功能用什么库比较合适

关于这个问题,网上搜一下,可以看到非常多的相关博客和例子,可以先看看这个总结性的博客(https://zhuanlan.zhihu.com/p/480973072) 该博客讲得比较清楚了,这里我再补充一下吧(qt官方也给出了一些…

MyBatis Plus条件构造器使用

1Wrapper: 条件构造抽象类,最顶端父类 1.1 AbstractWrapper: 用于查询条件封装,生成 sql 的 where 条件 1.2 QueryWrapper: Entity 对象封装操作类,不是用lambda语法 1.3 UpdateWrapper: Update…

动手学深度学习(Pytorch版)代码实践 -卷积神经网络-29残差网络ResNet

29残差网络ResNet import torch from torch import nn from torch.nn import functional as F import liliPytorch as lp import matplotlib.pyplot as plt# 定义一个继承自nn.Module的残差块类 class Residual(nn.Module):def __init__(self, input_channels, num_chan…

ROS2创建自定义接口

ROS2提供了四种通信方式: 话题-Topics 服务-Services 动作-Action 参数-Parameters 查看系统自定义接口命令 使用ros2 interface package sensor_msgs命令可以查看某一个接口包下所有的接口 除了参数之外,话题、服务和动作(Action)都支持自定义接口&am…

微服务实战系列之云原生

前言 话说博主的微服务实战系列从去年走到今天,已过去了半年多了。本系列,博主主要围绕微服务实践过程中的主要组件或工具展开介绍。其中基本覆盖了我们项目或产品研发过程中,经常使用的中间件或第三方工具。至此,该系列也该朝着…