[yolo系列:YOLOV7改进-添加CoordConv,SAConv.]

文章目录

    • 概要
    • CoordConv
    • SAConv

概要

CoordConv(Coordinate Convolution)和SAConv(Spatial Attention Convolution)是两种用于神经网络中的特殊卷积操作,用于处理图像数据或其他多维数据。以下是它们的简要介绍:
CoordConv(Coordinate Convolution)

CoordConv 是由Uber AI Labs的研究人员提出的一种卷积操作,用于处理图像中的坐标信息。在传统的卷积操作中,卷积核在图像上滑动并执行卷积操作,但是它们对于图像中的位置信息是不敏感的。CoordConv 的目标是使卷积操作变得位置敏感,它在输入特征图中加入了位置信息作为额外的通道。这个位置信息可以是像素的坐标,也可以是归一化的坐标值,具体取决于应用的场景。

通过将坐标信息与输入特征图拼接在一起,CoordConv 能够帮助神经网络更好地学习到输入数据中的空间关系,从而提高模型的性能。它在需要考虑输入数据的空间位置信息时,特别有用。
SAConv(Spatial Attention Convolution)

SAConv 是一种引入了空间注意力机制的卷积操作。传统的卷积操作在所有位置都应用相同的卷积核,而SAConv 具有可学习的空间注意力权重,这意味着它能够动态地调整不同位置的卷积核权重。

SAConv 的关键思想是,在进行卷积操作之前,先计算每个位置的空间注意力权重。这些权重由神经网络学习得出,然后被用来加权输入特征图的不同位置,从而生成具有位置敏感性的特征表示。这种机制使得神经网络在处理输入数据时能够更加关注重要的区域,从而提高了模型的感知能力和性能。

总的来说,CoordConv 和 SAConv 都是为了增强神经网络对输入数据的空间信息处理能力而提出的方法。CoordConv 引入了位置信息通道,使得网络对位置信息更敏感,而 SAConv 引入了空间注意力机制,使得网络能够动态地调整卷积核的权重,提高了对不同位置信息的关注度。这两种方法在特定的任务和场景下都能够带来性能的提升。

CoordConv

common.py添加如下

class AddCoords(nn.Module):def __init__(self, with_r=False):super().__init__()self.with_r = with_rdef forward(self, input_tensor):"""Args:input_tensor: shape(batch, channel, x_dim, y_dim)"""batch_size, _, x_dim, y_dim = input_tensor.size()xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)xx_channel = xx_channel.float() / (x_dim - 1)yy_channel = yy_channel.float() / (y_dim - 1)xx_channel = xx_channel * 2 - 1yy_channel = yy_channel * 2 - 1xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)ret = torch.cat([input_tensor,xx_channel.type_as(input_tensor),yy_channel.type_as(input_tensor)], dim=1)if self.with_r:rr = torch.sqrt(torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2))ret = torch.cat([ret, rr], dim=1)return retclass CoordConv(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, with_r=False):super().__init__()self.addcoords = AddCoords(with_r=with_r)in_channels += 2if with_r:in_channels += 1self.conv = Conv(in_channels, out_channels, k=kernel_size, s=stride)def forward(self, x):x = self.addcoords(x)x = self.conv(x)return x

在yolo.py

在这里插入图片描述

# yolov7 head
head:[[-1, 1, SPPCSPC, [512]], # 51[-1, 1, CoordConv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[37, 1, CoordConv, [256, 1, 1]], # route backbone P4[[-1, -2], 1, Concat, [1]],[-1, 1, Conv, [256, 1, 1]],[-2, 1, Conv, [256, 1, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],[-1, 1, Conv, [256, 1, 1]], # 63[-1, 1, CoordConv, [128, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[24, 1, CoordConv, [128, 1, 1]], # route backbone P3[[-1, -2], 1, Concat, [1]],[-1, 1, Conv, [128, 1, 1]],[-2, 1, Conv, [128, 1, 1]],[-1, 1, Conv, [64, 3, 1]],[-1, 1, Conv, [64, 3, 1]],[-1, 1, Conv, [64, 3, 1]],[-1, 1, Conv, [64, 3, 1]],[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],[-1, 1, Conv, [128, 1, 1]], # 75[-1, 1, MP, []],[-1, 1, Conv, [128, 1, 1]],[-3, 1, Conv, [128, 1, 1]],[-1, 1, Conv, [128, 3, 2]],[[-1, -3, 63], 1, Concat, [1]],[-1, 1, Conv, [256, 1, 1]],[-2, 1, Conv, [256, 1, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],[-1, 1, Conv, [256, 1, 1]], # 88[-1, 1, MP, []],[-1, 1, Conv, [256, 1, 1]],[-3, 1, Conv, [256, 1, 1]],[-1, 1, Conv, [256, 3, 2]],[[-1, -3, 51], 1, Concat, [1]],[-1, 1, Conv, [512, 1, 1]],[-2, 1, Conv, [512, 1, 1]],[-1, 1, Conv, [256, 3, 1]],[-1, 1, Conv, [256, 3, 1]],[-1, 1, Conv, [256, 3, 1]],[-1, 1, Conv, [256, 3, 1]],[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],[-1, 1, Conv, [512, 1, 1]], # 101[75, 1, CoordConv, [256, 3, 1]],[88, 1, CoordConv, [512, 3, 1]],[101, 1, CoordConv, [1024, 3, 1]],[[102,103,104], 1, IDetect, [nc, anchors]],   # Detect(P3, P4, P5)]

SAConv

在common.py添加

class ConvAWS2d(nn.Conv2d):def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True):super().__init__(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias)self.register_buffer('weight_gamma', torch.ones(self.out_channels, 1, 1, 1))self.register_buffer('weight_beta', torch.zeros(self.out_channels, 1, 1, 1))def _get_weight(self, weight):weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,keepdim=True).mean(dim=3, keepdim=True)weight = weight - weight_meanstd = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)weight = weight / stdweight = self.weight_gamma * weight + self.weight_betareturn weightdef forward(self, x):weight = self._get_weight(self.weight)return super()._conv_forward(x, weight, None)def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,missing_keys, unexpected_keys, error_msgs):self.weight_gamma.data.fill_(-1)super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,missing_keys, unexpected_keys, error_msgs)if self.weight_gamma.data.mean() > 0:returnweight = self.weight.dataweight_mean = weight.data.mean(dim=1, keepdim=True).mean(dim=2,keepdim=True).mean(dim=3, keepdim=True)self.weight_beta.data.copy_(weight_mean)std = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)self.weight_gamma.data.copy_(std)class SAConv2d(ConvAWS2d):def __init__(self,in_channels,out_channels,kernel_size,s=1,p=None,g=1,d=1,act=True,bias=True):super().__init__(in_channels,out_channels,kernel_size,stride=s,padding=autopad(kernel_size, p),dilation=d,groups=g,bias=bias)self.switch = torch.nn.Conv2d(self.in_channels,1,kernel_size=1,stride=s,bias=True)self.switch.weight.data.fill_(0)self.switch.bias.data.fill_(1)self.weight_diff = torch.nn.Parameter(torch.Tensor(self.weight.size()))self.weight_diff.data.zero_()self.pre_context = torch.nn.Conv2d(self.in_channels,self.in_channels,kernel_size=1,bias=True)self.pre_context.weight.data.fill_(0)self.pre_context.bias.data.fill_(0)self.post_context = torch.nn.Conv2d(self.out_channels,self.out_channels,kernel_size=1,bias=True)self.post_context.weight.data.fill_(0)self.post_context.bias.data.fill_(0)self.bn = nn.BatchNorm2d(out_channels)self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())def forward(self, x):# pre-contextavg_x = torch.nn.functional.adaptive_avg_pool2d(x, output_size=1)avg_x = self.pre_context(avg_x)avg_x = avg_x.expand_as(x)x = x + avg_x# switchavg_x = torch.nn.functional.pad(x, pad=(2, 2, 2, 2), mode="reflect")avg_x = torch.nn.functional.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)switch = self.switch(avg_x)# sacweight = self._get_weight(self.weight)out_s = super()._conv_forward(x, weight, None)ori_p = self.paddingori_d = self.dilationself.padding = tuple(3 * p for p in self.padding)self.dilation = tuple(3 * d for d in self.dilation)weight = weight + self.weight_diffout_l = super()._conv_forward(x, weight, None)out = switch * out_s + (1 - switch) * out_lself.padding = ori_pself.dilation = ori_d# post-contextavg_x = torch.nn.functional.adaptive_avg_pool2d(out, output_size=1)avg_x = self.post_context(avg_x)avg_x = avg_x.expand_as(out)out = out + avg_xreturn self.act(self.bn(out))

然后在yolo.py里面添加
在这里插入图片描述
在这里插入图片描述
和可变形卷积加法一样,但是不建议加太多,也是只替换3x3卷积上面。比普通卷积复杂度高,不建议加太多,推理速度变慢,尽量少用,提高精度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/169022.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【RNA structures】RNA-seq 分析: RNA转录的重构和前沿测序技术

文章目录 RNA转录重建1 先简单介绍一下测序相关技术2 Map to Genome Methods2.1 Step1 Mapping reads to the genome2.2 Step2 Deal with spliced reads2.3 Step 3 Resolve individual transcripts and their expression levels 3 Align-de-novo approaches3.1 Step 1: Generat…

2023年中国调速器产量、销量及市场规模分析[图]

调速器行业是指生产、销售和维修各种调速器设备的行业。调速器是一种能够改变机械传动系统输出转速的装置,通过调整输入和输出的转速比来实现转速调节的功能。 调速器行业分类 资料来源:共研产业咨询(共研网) 随着工业自动化程度…

C语言代码把时间戳字符串转换成日期时间格式以及修正bug的测试方法

时间戳是一种用来表示日期和时间的数字格式,在不同的编程语言里时间戳的长度和单位都不一样: C:以秒为单位,目前的时间戳是10位数。 Python:以秒为单位并且有精确到7位小数的毫秒,目前的时间戳整数部分是…

基于springboot小区团购管理系统

基于springboot小区团购管理系统的设计与实现 摘要 小区团购管理系统是一款基于Spring Boot框架的Web应用,为小区居民提供了一个方便的平台,以协调和管理各种团购活动。该系统的主要目标是促进小区居民之间的互助合作,通过集中采购来降低商品…

Ubuntu 22.04 中安装 fcitx5

Ubuntu 22.04 中安装 fcitx5 可以按照以下步骤进行: 添加 fcitx5 的 PPA 首先,添加 fcitx5 的官方 PPA: sudo add-apt-repository ppa:fcitx-team/fcitx5更新软件包列表 sudo apt update安装 fcitx5 sudo apt install fcitx5 fcitx5-conf…

【JavaEE初阶】 CAS详解

文章目录 🌲什么是 CAS🚩CAS伪代码 🎋CAS 是怎么实现的🌳CAS的应用🚩实现原子类🚩实现自旋锁 🎄CAS 的 ABA 问题🚩什么是 ABA 问题🚩ABA 问题引来的 BUG🚩解决…

Mac安装nginx(Homebrew)

文章目录 nginx 安装nginx 反向代理nginx 反向代理配置nginx 负载均衡配置 nginx 安装 查看需要安装 nginx 的信息 brew info nginxDocroot 默认为 /usr/local/var/www 在 /opt/homebrew/etc/nginx/nginx.conf 配置文件中默认端口被配置为8080,从而使 nginx 运行…

常用Win32 API的简单介绍

目录 前言: 控制控制台程序窗口的指令: system函数: COORD函数: GetStdHandle函数: GetConsoleCursorInfo函数: CONSOLE_CURSOR_INFO函数: SetConsoleCursorInfo函数: SetC…

Vue 实战项目(智慧商城项目): 完整的订单购物管理功能 内涵资源代码 基于Vant组件库 Vuex态管理 基于企业级项目开发规范

鹏鹏老师的实战开发项目 智慧商城项目 接口文档:安全问题(需要私信即可) 演示地址:跳转项目地址 01. 项目功能演示 1.明确功能模块 启动准备好的代码,演示移动端面经内容,明确功能模块 在这里插入图…

DevExpress WinForms甘特图组件 - 轻松集成项目管理功能到应用

DevExpress WinForms Gantt(甘特图)控件允许您在下一个WinForms桌面应用程序中快速合并项目规划和任务调度功能。 DevExpress WinForms有180组件和UI库,能为Windows Forms平台创建具有影响力的业务解决方案。同时能完美构建流畅、美观且易于…

【超参数研究02】使用随机搜索优化超参数

一、说明 在神经网络训练中,超参数也是需要优化的,然而在超参数较多(大于3个)后,如果用穷举的,或是通过经验约摸实现就显得费时费力,无论如何,这是需要研究、规范、整合的要点&#…

Banana Pi BPI-M4 Berry 采用全志H618芯片,板载2G RAM和8G eMMC

BPI-M4 Berry 开发板作为一款强大的单板计算机(SBC),充分挖掘了全志 H618 系统级芯片(SoC)的功能,为开发人员提供了令人印象深刻的性能和丰富的特性。与树莓派 4b 类似,BPI-M4 Berry 能够展现与…

网站页脚展示备案号并在新标签页中打开超链接

备案时,我们就注意到,备案成功后需要在网站首页底部展示“备案号”,并将备案号链接至https://beian.miit.gov.cn。 这里我使用了WrodPress中的主题,主题自定义中有提供对页脚文本的编辑,支持用css标签定义样式。若是自…

MySQL MVCC机制探秘:数据一致性与并发处理的完美结合,助你成为数据库高手

一、前言 在分析 MVCC 的原理之前,我们先回顾一下 MySQL 的一些内容以及关于 MVCC 的一些简单介绍。(注:下面没有特别说明默认 MySQL 的引擎为 InnoDB ) 1.1 数据库的并发场景 数据库并发场景有三种,分别是: 读-读…

基于springboot实现广场舞团平台系统项目【项目源码+论文说明】计算机毕业设计

基于SPRINGBOOT实现广场舞团平台系统演示 摘要 随着信息技术和网络技术的飞速发展,人类已进入全新信息化时代,传统管理技术已无法高效,便捷地管理信息。为了迎合时代需求,优化管理效率,各种各样的管理系统应运而生&am…

算法通关村第十一关青铜挑战——移位运算详解

大家好,我是怒码少年小码。 计算机到底是怎么处理数字的? 数字在计算机中的表示 机器数 一个数在计算机中的二进制表示形式,叫做这个数的机器数。 机器数是带符号的,在计算机用一个数的最高位存放符号,正数为0&am…

Unity之ShaderGraph如何实现全息投影效果

前言 今天我们来实现一个全息投影的效果,如下所示: 主要节点 Position:提供对网格顶点或片段的Position 的访问,具体取决于节点所属图形部分的有效着色器阶段。使用Space下拉参数选择输出值的坐标空间。 Time:提…

C++入门(3):引用,内联函数

一、引用 1.1 引用特性 引用必须初始化 一个变量可以有多个引用 引用一旦引用一个实体,就不能引用其他实体 int main() {int a 10, C 20;int& b a;b c; // 赋值?还是b变成c的别名?return 0; }1.2 常引用 引用权限可以平移或缩小…

ubuntu双系统安装以及启动时卡死解决办法

目录 一.简介 二.安装 如何安装Ubuntu20.04(详细图文教程-CSDN博客 Ubuntu22.04(非虚拟机)安装教程(2023最新最详细)-CSDN博客 三.ubuntu双系统启动时卡死解决办法(在ubuntu16.04和18.04测试无误) 问题…

vue实现响应式改变scss样式

需求:侧边导航栏点击收起,再次点击展开,但是我这个项目的位置是在左侧菜单栏所以需要自定义 效果图: 实现步骤: 1:定义一个变量(因为我这里会存储菜单栏的状态所以需要存储状态,一…