YOLOv7改进:结合CotNet Transformer结构

 

1.简介

京东AI研究院提出的一种新的注意力结构。将CoT Block代替了ResNet结构中的3x3卷积,在分类检测分割等任务效果都出类拔萃

 

论文:Contextual Transformer Networks for Visual Recognition
论文地址:https://arxiv.org/abs/2107.12292

 

有自注意力的Transformer引发了自然语言处理领域的革命,最近还激发了Transformer式架构设计的出现,并在众多计算机视觉任务中取得了具有竞争力的结果。

大多数现有设计直接在2D特征图上使用自注意力来获得基于每个空间位置的独立查询和键对的注意力矩阵,但未充分利用相邻键之间的丰富上下文。在今天分享的工作中,研究者设计了一个新颖的Transformer风格的模块,即Contextual Transformer (CoT)块,用于视觉识别。这种设计充分利用输入键之间的上下文信息来指导动态注意力矩阵的学习,从而增强视觉表示能力。从技术上讲,CoT块首先通过3×3卷积对输入键进行上下文编码,从而产生输入的静态上下文表示。

 

上图a是传统的self-attention仅利用孤立的查询-键对来测量注意力矩阵,但未充分利用键之间的丰富上下文。 b就是CoT块

研究者进一步将编码的键与输入查询连接起来,通过两个连续的1×1卷积来学习动态多头注意力矩阵。学习到的注意力矩阵乘以输入值以实现输入的动态上下文表示。静态和动态上下文表示的融合最终作为输出。CoT块很吸引人,因为它可以轻松替换ResNet架构中的每个3 × 3卷积,产生一个名为Contextual Transformer Networks (CoTNet)的Transformer式主干。通过对广泛应用(例如图像识别、对象检测和实例分割)的大量实验,验证了CoTNet作为更强大的主干的优越性

Attention注意力机制与self-attention自注意力机制

为什么要注意力机制?

在Attention诞生之前,已经有CNN和RNN及其变体模型了,那为什么还要引入attention机制?主要有两个方面的原因,如下:

(1)计算能力的限制:当要记住很多“信息“,模型就要变得更复杂,然而目前计算能力依然是限制神经网络发展的瓶颈。

(2)优化算法的限制:LSTM只能在一定程度上缓解RNN中的长距离依赖问题,且信息“记忆”能力并不高。

什么是注意力机制

在介绍什么是注意力机制之前,先让大家看一张图片。当大家看到下面图片,会首先看到什么内容?当过载信息映入眼帘时,我们的大脑会把注意力放在主要的信息上,这就是大脑的注意力机制。

同样,当我们读一句话时,大脑也会首先记住重要的词汇,这样就可以把注意力机制应用到自然语言处理任务中,于是人们就通过借助人脑处理信息过载的方式,提出了Attention机制。

self attention是注意力机制中的一种,也是transformer中的重要组成部分。自注意力机制是注意力机制的变体,其减少了对外部信息的依赖,更擅长捕捉数据或特征的内部相关性。自注意力机制在文本中的应用,主要是通过计算单词间的互相影响,来解决长距离依赖问题。

传统的自注意力很好地触发了不同空间位置的特征交互,具体取决于输入本身。然而,在传统的自注意力机制中,所有成对的查询键关系都是通过孤立的查询键对独立学习的,而无需探索其间的丰富上下文。这严重限制了自注意力学习在2D特征图上进行视觉表示学习的能力。

为了缓解这个问题,研究者构建了一个新的Transformer风格的构建块,即上图 (b)中的 Contextual Transformer (CoT) 块,它将上下文信息挖掘和自注意力学习集成到一个统一的架构中。

2.YOLOv7改进

 2.1yaml配置文件修改

# YOLOv7 🚀, GPL-3.0 license
# parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 1.0  # layer channel multiple# anchors
anchors:- [12,16, 19,36, 40,28]  # P3/8- [36,75, 76,55, 72,146]  # P4/16- [142,110, 192,243, 459,401]  # P5/32# yolov7 backbone by yoloair
backbone:# [from, number, module, args][[-1, 1, Conv, [32, 3, 1]],  # 0[-1, 1, Conv, [64, 3, 2]],  # 1-P1/2[-1, 1, Conv, [64, 3, 1]],[-1, 1, Conv, [128, 3, 2]],  # 3-P2/4 [-1, 1, C3HB, [128]], [-1, 1, Conv, [256, 3, 2]], [-1, 1, MP, []],[-1, 1, Conv, [128, 1, 1]],[-3, 1, Conv, [128, 1, 1]],[-1, 1, Conv, [128, 3, 2]],[[-1, -3], 1, Concat, [1]],  # 16-P3/8[-1, 1, Conv, [128, 1, 1]],[-2, 1, Conv, [128, 1, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[[-1, -3, -5, -6], 1, Concat, [1]],[-1, 1, Conv, [512, 1, 1]],[-1, 1, MP, []],[-1, 1, Conv, [256, 1, 1]],[-3, 1, Conv, [256, 1, 1]],[-1, 1, Conv, [256, 3, 2]],[[-1, -3], 1, Concat, [1]],[-1, 1, Conv, [256, 1, 1]],[-2, 1, Conv, [256, 1, 1]],[-1, 1, Conv, [256, 3, 1]],[-1, 1, Conv, [256, 3, 1]],[-1, 1, Conv, [256, 3, 1]],[-1, 1, Conv, [256, 3, 1]],[[-1, -3, -5, -6], 1, Concat, [1]],[-1, 1, Conv, [1024, 1, 1]],          [-1, 1, MP, []],[-1, 1, Conv, [512, 1, 1]],[-3, 1, Conv, [512, 1, 1]],[-1, 1, Conv, [512, 3, 2]],[[-1, -3], 1, Concat, [1]],[-1, 1, C3HB, [1024]],[-1, 1, Conv, [256, 3, 1]],]# yolov7 head by yoloair
head:[[-1, 1, SPPCSPC, [512]],[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[31, 1, Conv, [256, 1, 1]],[[-1, -2], 1, Concat, [1]],[-1, 1, CoT3, [128]],[-1, 1, Conv, [128, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[18, 1, Conv, [128, 1, 1]],[[-1, -2], 1, Concat, [1]],[-1, 1, CoT3, [128]],[-1, 1, MP, []],[-1, 1, Conv, [128, 1, 1]],[-3, 1, Conv, [128, 1, 1]],[-1, 1, Conv, [128, 3, 2]],[[-1, -3, 44], 1, Concat, [1]],[-1, 1, CoT3, [256]], [-1, 1, MP, []],[-1, 1, Conv, [256, 1, 1]],[-3, 1, Conv, [256, 1, 1]],[-1, 1, Conv, [256, 3, 2]], [[-1, -3, 39], 1, Concat, [1]],[-1, 3, CoT3, [512]],# 检测头 -----------------------------[49, 1, RepConv, [256, 3, 1]],[55, 1, RepConv, [512, 3, 1]],[61, 1, RepConv, [1024, 3, 1]],[[62,63,64], 1, IDetect, [nc, anchors]],   # Detect(P3, P4, P5)]

2.1common.py配置

./models/common.py文件增加以下模块

class CoT3(nn.Module):def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansionsuper().__init__()c_ = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, c_, 1, 1)self.cv2 = Conv(c1, c_, 1, 1)self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)self.m = nn.Sequential(*[CoTBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])# self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])def forward(self, x):return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))class CoTBottleneck(nn.Module):def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansionsuper(CoTBottleneck, self).__init__()c_ = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, c_, 1, 1)self.cv2 = CoT(c_, 3)self.add = shortcut and c1 == c2def forward(self, x):return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))class CoT(nn.Module):# Contextual Transformer Networks https://arxiv.org/abs/2107.12292def __init__(self, dim=512,kernel_size=3):super().__init__()self.dim=dimself.kernel_size=kernel_sizeself.key_embed=nn.Sequential(nn.Conv2d(dim,dim,kernel_size=kernel_size,padding=kernel_size//2,groups=4,bias=False),nn.BatchNorm2d(dim),nn.ReLU())self.value_embed=nn.Sequential(nn.Conv2d(dim,dim,1,bias=False),nn.BatchNorm2d(dim))factor=4self.attention_embed=nn.Sequential(nn.Conv2d(2*dim,2*dim//factor,1,bias=False),nn.BatchNorm2d(2*dim//factor),nn.ReLU(),nn.Conv2d(2*dim//factor,kernel_size*kernel_size*dim,1))def forward(self, x):bs,c,h,w=x.shapek1=self.key_embed(x) #bs,c,h,wv=self.value_embed(x).view(bs,c,-1) #bs,c,h,wy=torch.cat([k1,x],dim=1) #bs,2c,h,watt=self.attention_embed(y) #bs,c*k*k,h,watt=att.reshape(bs,c,self.kernel_size*self.kernel_size,h,w)att=att.mean(2,keepdim=False).view(bs,c,-1) #bs,c,h*wk2=F.softmax(att,dim=-1)*vk2=k2.view(bs,c,h,w)return k1+k2

2.3yolo.py配置修改

然后找到./models/yolo.py文件下里的parse_model函数,将加入的模块名CoT3加入进去
在 models/yolo.py文件夹下

定位到parse_model函数中
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):内部
对应位置 下方只需要增加 CoT3模块

 修改完成!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/145732.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

技术Leader对下管理的法宝-SMART

SMART方法论 源于国外管理大师的《管理的实践》是管理者能够更加明确员工高效工作的利器,科学、规范的对员工绩效制定考核目标和考核标准5个单词缩写 Specific:目标要具体Measurable:目标成果要可衡量(量化) Attainable:目标要可实现,避免过高/过低Rel…

模块化CSS

1、什么是模块化CSS 模块化CSS是一种将CSS样式表的规则和样式定义封装到模块或组件级别的方法,以便于更好地管理、维护和组织样式代码。这种方法通过将样式与特定的HTML元素或组件相关联,提供了一种更具可维护性、可复用性和隔离性的方式来处理样式。简单…

SpringBoot读取配置的方式

在 Spring Boot 应用中,我们通常需要一些配置信息来指导应用的运行。这些配置信息可以包括如下内容:端口号、数据库连接信息、日志配置、缓存配置、认证配置、等等。Spring Boot 提供了多种方式来读取这些配置信息。读取配置的目的是为了在程序中使用这些…

前端JavaScript入门到精通,javascript核心进阶ES6语法、API、js高级等基础知识和实战 —— Web APIs(三)

思维导图 全选案例 大按钮控制小按钮 小按钮控制大按钮 css伪类选择器checked <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><…

Apache Commons Pool2 池化技术

对象池是一种设计模式&#xff0c;用于管理和重用对象&#xff0c;以提高性能和资源利用率。对象池的概念在许多应用程序中都有广泛应用&#xff0c;特别是在需要频繁创建和销毁对象的情况下&#xff0c;例如数据库连接、线程、HTTP连接等 对象池通过预先创建一组对象并将它们存…

B. Comparison String

题目&#xff1a; 样例&#xff1a; 输入 4 4 <<>> 4 >><< 5 >>>>> 7 <><><><输出 3 3 6 2 思路&#xff1a; 由题意&#xff0c;条件是 又因为要使用尽可能少的数字&#xff0c;这是一道贪心题&#xff0c;所以…

最大内切圆算法计算裂缝宽度

本文这里是对CSDN上另一位博主的代码进行了整理&#xff1a; 基于opencv的裂缝宽度检测算法&#xff08;计算轮廓最大内切圆算法&#xff09; 我觉得这位博主应该是上传了一个代码草稿&#xff0c;我对其进行了重新整理&#xff0c;并添加了详细的注释。 import cv2 import …

ICCV 2023|Occ2Net,一种基于3D 占据估计的有效且稳健的带有遮挡区域的图像匹配方法...

本文为大家介绍一篇入选ICCV 2023的论文&#xff0c;《Occ2Net: Robust Image Matching Based on 3D Occupancy Estimation for Occluded Regions》&#xff0c; 一种基于3D 占据估计的有效且稳健的带有遮挡区域的图像匹配方法。 论文链接&#xff1a;https://arxiv.org/abs/23…

[红明谷CTF 2021]write_shell %09绕过过滤空格 ``执行

目录 1.正常短标签 2.短标签配合内联执行 看看代码 <?php error_reporting(0); highlight_file(__FILE__); function check($input){if(preg_match("/| |_|php|;|~|\\^|\\|eval|{|}/i",$input)){ 过滤了 木马类型的东西// if(preg_match("/| |_||php/&quo…

Java应用生产Full GC或者OOM问题如何定位

1 引言 生产应用服务频繁Full GC却无法释放内存&#xff0c;甚至可能OOM&#xff0c;这种情况很有可能是内存泄露或者堆内存分配不足&#xff0c;此时需要dump堆信息来定位问题&#xff0c;查看是哪些地方内存泄漏。 Dump文件也称为内存转储文件或内存快照文件&#xff0c;是…

【Axure】常见元件、常用交互效果

产品结构图 以微信为例&#xff0c;根据页面层级制作 动态面板 动态面板的作用&#xff1a;动态面板是一个可视区域&#xff0c;如果要把控件放进去&#xff0c;要全部放进去&#xff0c;放多少显示多少 文本框 隐藏边框方法&#xff1a;先拉一个矩形&#xff0c;去掉部分…

基于Java的汽车票网上预订系统设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09;有保障的售后福利 代码参考源码获取 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作…

在windows的ubuntu LTS中安装及使用EZ-InSAR进行InSAR数据处理

EZ-InSAR&#xff08;曾被称为MIESAR&#xff0c;即Matlab界面用于易于使用的合成孔径雷达干涉测量&#xff09;是一个用MATLAB编写的工具箱&#xff0c;用于通过易于使用的图形用户界面&#xff08;GUI&#xff09;进行干涉合成孔径雷达&#xff08;InSAR&#xff09;数据处理…

vue-cli项目打包体积太大,服务器网速也拉胯(100kb/s),客户打开网站需要等十几秒!!! 尝试cdn优化方案

一、首先用插件webpack-bundle-analyzer查看自己各个包的体积 插件用法参考之前博客 vue-cli项目中&#xff0c;使用webpack-bundle-analyzer进行模块分析&#xff0c;查看各个模块的体积&#xff0c;方便后期代码优化 二、发现有几个插件体积较大&#xff0c;有改成CDN引用的…

亚信科技AntDB数据库 高并发、低延迟、无死锁,深入了解AntDB-M元数据锁的实现

AntDB-M在架构上分为两层&#xff0c;服务层和存储引擎层。元数据的并发管理集中在服务层&#xff0c;数据的存储访问在存储引擎层。为了保证DDL操作与DML操作之间的一致性&#xff0c;引入了元数据锁&#xff08;MDL&#xff09;。 AntDB-M提供了丰富的元数据锁功能&#xff0…

【Axure高保真原型】3D圆柱图_中继器版

今天和大家分享3D圆柱图_中继器版的原型模板&#xff0c;图表在中继器表格里填写具体的数据&#xff0c;调整坐标系后&#xff0c;就可以根据表格数据自动生成对应高度的圆柱图&#xff0c;鼠标移入时&#xff0c;可以查看对应圆柱体的数据……具体效果可以打开下方原型地址体验…

一文带你搞懂Redis持久化

Redis持久化 Redis的数据是存储在内存的&#xff0c;当程序崩溃或者服务器宕机&#xff0c;那么内存里的数据就会丢失。所以避免数据丢失的情况&#xff0c;需要将数据保存到其他的存储设备中。 Redis提供两种方式来持久化&#xff0c;分别是 RDB(Redis Database)&#xff1a…

【LeetCode热题100】--19.删除链表的倒数第N个结点

19.删除链表的倒数第N个结点 注意&#xff1a;先声明一个头结点&#xff0c;它的next指针指向链表的头节点&#xff0c;这样就不需要对头节点进行特殊的判断 /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListN…

【Java每日一题】——第十六题:将数组元素逆序并遍历输出。(2023.09.30)

&#x1f578;️Hollow&#xff0c;各位小伙伴&#xff0c;今天我们要做的是第十五题。 &#x1f3af;问题&#xff1a; 设有数组如下&#xff1a;int[] arr{11,34,47,19,5,87,63,88}; 测试结果如下&#xff1a; &#x1f3af; 答案&#xff1a; int a[]new int [10];Random …

华为云HECS云服务器docker环境下安装nginx

前提&#xff1a;有一台华为云服务器。 华为云HECS云服务器&#xff0c;安装docker环境&#xff0c;查看如下文章。 华为云HECS安装docker-CSDN博客 一、拉取镜像 下载最新版Nginx镜像 (其实此命令就等同于 : docker pull nginx:latest ) docker pull nginx查看镜像 dock…