ViT和SwinTransformer详解

ViT是Google brain发表于ICLR'21上的工作,开创性将transformer用在vision领域,且图像识别性能超CNN,至今引用3.8w+;原文:https://arxiv.org/pdf/2010.11929

SwinTransformer是微软亚洲研究院发表于ICCV'21上,获best paper,在多个视觉任务上获sota,打破CNN垄断vision backbone的现状,至今引用1.8w+;原文:https://openaccess.thecvf.com/content/ICCV2021/papers/Liu_Swin_Transformer_Hierarchical_Vision_Transformer_Using_Shifted_Windows_ICCV_2021_paper.pdf

建议读原文,这些文章优雅、简洁、深刻。

下面按照三部分进行,分别是Attention介绍、ViT详解、SwinTransformer详解。与常规文章讲解不同,我会多采用QA进行展开。

1. Attention介绍

这涉及到NeurIPS发表的“Attention is all you need”,这篇文章引用已经12w+,理解注意力机制是学习transformer的核心。

Q: general attention和self attention区别?

A: 相同点是均需要计算qkv,不同之处,self attention的input只有x,而general attention的input除了有x(映射得到kv),还有q(查询query)。

Self attention layer介绍

步骤:1. 输入x,通过映射矩阵Wq,Wk,Wv,得到qkv(D维)

           2. q和k进行对齐操作,如:q0会分别与不同的k进行点乘操作,得e0(e矩阵第一列)

           3. 注意力机制:softmax操作。如,从e0得a0(a矩阵第一列),为0~1之间的注意力权重

           4. 输出:v和注意力权重a的加权和,如:y0为a0和所有v的注意力加权和

=》不同于CNN的局部特性,此处的自注意力很好地体现了全局特性。

仔细观察,可以发现self-attention layer具备permutation invariant的性质(置换不变)

现实中,不管是语言token还是vision patch token,位置不同,显然我们应该得到不同的内容向量y才是合理的。

因此,有必要加入位置编码,将位置信息考虑进来进行自注意力学习。

对于每个输入xj,给出位置编码pj。使用位置编码函数pos,pj=pos(j),将位置j映射到D维向量(因为x是D维)。对于pos函数的选取此处不详细展开。

Multi-head self attention layer介绍

多头自注意力层就是transformer里核心模块。

Q:为什么要multi-head?

A:本质是为提取更好的特征,类似于CNN中卷积核也是多组,以得到多个特征谱。不同的是,CNN中卷积核小,计算量小,特征谱数量都是几十、几百。这儿的Multi-head不会很多,一般不超过10。

2. ViT详解

这篇文章的writing也可以当作范本,反复学习。

Q:标题两个keys,一个是an image is worth 16x16 words, 另一个是at scale,分别突出了什么?

A:前者突出将图像按照文字的处理方式,把一张图表示成了16x16 tokens。另一个关键点at scale,则与transformer的优势关联起来,也暗含了transformer要获的良好性能的前提。

Q:Transformer的天然优势是什么?

A:主要是excellent scalability,当模型和训练集增加时,并没有saturating performance。可以处理超大规模的训练数据。另一个是self-attention带来的computational efficiency,很多计算可以高度并行。

Q:CNN的天然优势是什么?

A:主要是inductive bias,在卷积的过程中,我们使用了translation equaivariance(平移不变性)、locality(局部性)来保留2D相邻结构。这些使得CNN在少量训练数据时候也能获得很好的性能。

Q:什么时候Transformer会比CNN更好?

A:通常,小训练数据集时候,convolutional inductive bias会很有用。当,数据集规模足够大的时候,最终large scaling training会比inductive bias表现好。这是合理的,因为Transformer学习中没有inductive bias,其特征时只能从大规模数据中学习。

Conclusion中提及的几个有前瞻性的点,现在均已经实现:)

1)self-supervised vs. supervised learning,之间的gap已经去掉;

2)scaling law,随着scaling提升,模型性能提升,现在已经是大模型发展遵循的发展规律;

3)transformer在segmentation、detection上的发展,现在已经横扫这些视觉任务。

# Transformer Encoder (depth x)class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):super().__init__()self.norm = nn.LayerNorm(dim)self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([Attention(dim, heads=heads, dim_head=dim_head,dropout=dropout),FeedForward(dim,mlp_dim,dropout=dropout)]))def forward(self, x):for att, ff in self.layers:x = attn(x)+xx = ff(x)+xreturn self.norm(x)
# Multi-Head Attention
# 与self-attention layer中的operation保持一致class Attention(nn.Module):def __init__(self, dim, heads=8, dim_head=64, dropot=0.):super().__init__()inner_dim = dim_head * headsproject_out = not(heads==1 and dim_head==dim)self.heads = headsself.scale = dim_head ** -0.5self.norm = nn.LayerNorm(dim)self.attend = nn.Softmax(dim=-1)    self.dropout = nn.Dropout(dropout)self.to_qkv = nn.Linear(dim, inner_dim*3, bias=False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):x = self.norm(x)qkv = self.to_qkv(x).chunk(3, dim=-1)q,k,v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) # batch(b), sequence length(n), heads(h), dim(d) dots = torch.matmul(q, k.transpose(-1,-2))*self.scaleattn = self.attend(dots)attn = self.dropout(attn)out = torch.matmul(attn, v)out = rearrange(out, 'b h n d -> b n (h d)')return self.to_out(out)
# FeedForward (Transformer Encoder第二个部分)class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout=0.):super().__init__()self.net = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)
# ViT framework
# 包括输入图像的处理方式以及具体的任务class ViT(nn.Module):def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0.):super().__init__()image_height, image_width = pair(image_size)patch_height, patch_width = pair(patch_size)assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'num_pathes = (image_height//patch_height)*(image_width//patch_width) # sequence lengthpatch_dim = channels*patch_height*patch_widthassert pool in {'cls','mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' #输入序列的第一个位置会添加一个特殊的标记,称为 [CLS] 标记self.to_patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),nn.LayerNorm(patch_dim),nn.Linear(patch_dim, dim),nn.LayerNorm(dim),)self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))self.cls_token = nn.Parameter(torch.randn(1,1,dim))self.dropout = nn.Dropout(emb_dropout)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)self.pool = poolself.to_latent = nn.Identity()self.mlp_head = nn.Linear(dim, num_classes) # classification taskdef forward(self, img):x = self.to_patch_embedding(img)b,n,_ = x.shapecls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)x = torch.cat((cls_tokens, x), dim=1)x += self.pos_embedding[:,:(n+1)]x = self.dropout(x)x = self.transformer(x)x = x.mean(dim=1) if self.pool=='mean' else x[:,0]x = self.to_latent(x)return self.mlp_head(x)  

此处,pos_embedding是随机给的,transformer的输出后pool只能选cls或者mean中之一,然后进行MLP对任务的预测。

这里没有涉及到transformer decoder设计。

3. SwinTransformer详解

Q:ViT不好吗,SwinTransformer主要解决哪些关键问题?

A:如果图像分辨率变大,按照patch的size进行切分,这时候图像块的数量会增加,相应的计算复杂度quadratic增加,除此外切分的patch也相对较大(下采样倍数高),特征提取信息不准。不能很好处理高分辨率图。除此外,ViT这种固定的图像块切分方法对于不同大小的视觉实体而言不是很合理,当物体远小于或者大于patch时候,很难有效提取特征。不同物体的尺寸和比例差异很大,不像单词的长度相对固定。不能很好处理大小变化的视觉物体

个人感觉SwinTransformer中窗口概念,类似与CNN中卷积核,窗口shift类似于CNN中stride,不同这个shift(向右、向下)更加灵活。不同在于,CNN针对局部图像感受野直接去求W,而SwinTransformer则是利用Self-attention更高级的方式去求局部图像的特征。

很多技巧都是用于减少(分窗口、窗口移动)运算量。

Q:SwinTransformer主要贡献?

A:第一,层级的特征谱方式使计算复杂度对于图像尺寸而言是linear而不是quadratic,可以处理高分辨率的图像。第二、shifted window很好解决了视觉物体大小变化的特点。

Q:SwinTransformer主要设计思想?

A:全局的注意力机制只在小范围内做,然后在不同层级上提特征(W-MSA,提出窗口的概念,窗口内进行多头注意力机制)。此外,利用shifted window将各个窗口之间的信息进行通信,完美达到捕获全局的上下文信息的优势(SW-MSA,此处就是滑动窗口的多头注意力机制)。这两部分就是Swin Transformer blocks的主要组成部分。

对比:

1)SwinTransformer有很多窗口(红色框),且在不同的层级上,窗口的划分是不同的。ViT将整图作为一个窗口,一直进行全局注意力机制计算。

2)SwinTransformer先进行4x下采样,将4*4个pixels作为一个小patch,在划定的窗口内进行注意力计算,然后是8x下采样,最后是16x下采样。ViT直接下采样16倍,后面保持相同的下采样规律。

SwinTransformer Blocks

1. W-MSA介绍(窗口间不涉及信息传递)

这个提出的目的是在窗口内进行kqv的求解,既能减少计算复杂度,也能使用更小的patch size,使下采样倍数不用很大。

Q:具体减少了多少运算量?

A:运算量主要分为三部分:1)to kqv,2)qk对齐,3)与v加权和。Att(Q,K,V)=Softmax(\frac{QK^{T}}{\sqrt{d}})V

1) X^{hw\times C}通过矩阵运算W^{C\times C}生成Q^{hw\times C },K^{hw\times C},V^{hw\times C}.总运算量为3hwC^{2}

2) qk对齐,运算量为(hw)^{2}C,得A^{hw\times hw}

3) 与v加权,得B^{hw\times C},运算量为(hw)^{2}C

4) 多头注意力机制,多了一个融合矩阵W,B^{hw\times C}\cdot W^{C\times C}=O^{hw\times C},计算量hwC^{2}

总计,4hwC^{2}+2(hw)^{2}C 公式一

假设W-MSA的窗口长和宽为M,代入上面公式为,

4M^{2}C^{2}+2M^{4}C

\frac{h}{M}\times \frac{w}{M}窗口,所以为,4hwC^{2}+2M^{2}hwC 公式二

缺点:减少了运算量,但窗口之间由于没有任何通信,导致确实全局感受野。

2.SW-MSA介绍(窗口间进行信息传递)

两层之间发生了窗口的移动(Shift),偏移的量是:往右、往下偏移M/2个像素。移动后,划分出的第二列3个窗口能够完成相邻窗口的信息交流。

缺点:原来4个窗口,移动后变成9个窗口,且大小不一。总之,移动后窗口的数量增多,从\frac{h}{M}\times \frac{w}{M}变成(\frac{h}{M}+1)\times (\frac{w}{M}+1),有些窗口会变小。

解决办法:

naive方案,把所有变小的窗口pad后,计算attention时候把pad数值掩膜。但这样,存在很多没必要的运算。

Efficient batch computation approach by cyclic-shifting toward the top-left direction

主要思想:将移动模式作为flag,只对有相邻关系的子窗口计算,不相邻的,减去100,使得softmax计算后概率接近0。

示意图传递了跟之前的窗口类似的计算,对于不相关的信息加上了mask,softmax后得到的概率接近0,使其达到mask的作用。最后再通过reverse cyclic shift移动回去。

SwinTransformer Framework

整体架构的思想:4个阶段,每个阶段构建不同大小的特征图,不断缩小分辨率,类似CNN逐渐增大感受野。

  • Patch partition, 本质就是矩阵的reshape,以4*4为一个图像块,对输入图片进行分块,然后在channel方向上进行拼接
  • Linear embedding, 经过线性变换,通道数从48变成C
  • Patch merging, 本质就是降采样,只是比pooling的方式来得更复杂一些,有学习的参数

关于relative position bias这里不展开,因为其在图像分类上提高了,但是在目标检测任务上降低了,具体理解可以参考[4]。

Application:下一篇会介绍SwinIR,揭开该方法如何在底层视觉的图像修复上施展魔法。

参考:

[1] cs231n课件

[2] vit-pytorch/vit_pytorch/vit.py at main · lucidrains/vit-pytorch · GitHub

[3] Swin Transformer:屠榜各大CV任务的视觉Transformer模型 (high-level介绍)

[4] Swin Transformer 详解(detail-level理解,很不错)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/392938.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

双回路校园智能电表是什么?什么叫双回路校园智能电表?

在智慧校园的建设浪潮中,双回路校园智能电表作为一种创新的能源计量与管理解决方案,正逐渐成为校园电力系统改造与升级的关键要素。本文旨在深入探讨双回路校园智能电表的概念、工作原理、核心优势及其在校园能源管理中的应用实践。 一、定义与工作原理…

Harbor 仓库一键安装

文章目录 一、场景说明二、脚本职责三、参数说明四、操作示例五、注意事项 一、场景说明 本自动化脚本旨在为提高研发、测试、运维快速部署应用环境而编写。 脚本遵循拿来即用的原则快速完成 CentOS 系统各应用环境部署工作。 统一研发、测试、生产环境的部署模式、部署结构、…

一文理清生产管理的“4管”和“8理”!

一提到生产管理,很多人的第一反应可能是车间里忙碌的身影、流水线上飞速运转的机器,还有一张张密密麻麻的生产计划表。但实际上,生产管理远不止于此。 “科学管理之父”弗雷德里克温斯洛泰勒认为:管理就是确切地知道你要别人干什…

CompletableFuture详解

CompletableFuture详解 学习链接:https://juejin.cn/post/7124124854747398175?searchId20240806151438B643DF2AAD2FC5E6F11E 一、CompletableFuture简介 在JAVA8开始引入了全新的CompletableFuture类,它是Future接口的一个实现类。也就是在Future接口的基础上&a…

计算机网络复习总结

第一章 计算机网络体系结构 1、计算机网络的概念组成和功能 (1)什么是计算机网络? 计算机网络是将一个分散的、具有独立功能的计算机系统,通过通信设备与线路连接起来,由功能完善的软件实现资源共享和信息传递的系统。…

UI自动化常见精华面试题整理

selenium的运行原理 1、基于Web端如何做自动化测试,谈谈你的思路和方向? Web端的自动化测试,基本就是模拟手工测试人员来做功能测试。用机器的自动执行代替人的操作。web端呈现的产品有两个方向可以做自动化测试:接口层和界面操作…

虚拟机Windows10系统安装QEMU

文章目录 1. QEMU安装1.1 安装准备1.1.1 安装平台1.1.2 软件下载 1.2 安装QEMU1.2.1 找到下载的QEMU软件,双击开始安装1.2.2 设置语言1.2.3 安装向导,点击 Next1.2.4 点击“I Agree”1.2.5 点击Next1.2.6 设置软件安装位置1.2.7 点击 finish1.2.8 编辑系…

odoo from样式更新

.xodoo_form {.o_form_sheet {padding-bottom: 0 !important;border-style: solid !important;border-color: white;}.o_inner_group {/* 线框的样式 *//*--line-box-border: 1px solid #666;*//*box-shadow: 0 1px 0 #e6e6e6;*/margin: 0;}.grid {display: grid;gap: 0;}.row …

FFmpeg源码:av_reduce函数分析

AVRational结构体和其相关的函数分析: FFmpeg有理数相关的源码:AVRational结构体和其相关的函数分析 FFmpeg源码:av_reduce函数分析 一、av_reduce函数的声明 av_reduce函数声明在FFmpeg源码(本文演示用的FFmpeg源码版本为7.0…

【计算机操作系统】同步与互斥的基本概念

同步与互斥的基本概念 进程同步的概念 知识点回顾:进程具有异步性,异步性是指,各并发执行的进程以各自独立的、不可预知的速度向前推进 并发性带来了异步性,有时需要通过进程同步解决这种异步问题,有的进程之间需要…

JVM的面试考点

JVM内存划分 1.堆,整个内存区域中,内存最大的区域,放的都是new出来的对象,new类名这一部分存放在堆中, 而这个scanner是一个临时变量,这个scanner的地址存放在栈上,scanner里面存放的值是new类名这个对象的首地址 2.栈,分为JVM虚拟机栈(Java代码),和本地方法栈(C),这个栈包含了…

如何提前预防网络威胁

一、引言 随着信息技术的迅猛进步,网络安全议题愈发凸显,成为社会各界不可忽视的重大挑战。近年来,一系列网络安全事件的爆发,如同惊雷般震撼着个人、企业及国家的安全防线,揭示了信息安全保护的紧迫性与复杂性。每一…

2024年第五届“华数杯”全国大学生数学建模竞赛C题-老外游中国(代码+成品论文+讲解)

目录 💕一、问题重述💕 🐸问题 1🐸 🐸问题 2🐸 🐸问题 3🐸 🐸问题 4🐸 🐸问题 5🐸 💕二、解题思路💕 …

图解RocketMQ之如何实现顺序消息

大家好,我是苍何。 顺序消息是业务中常用的功能之一,而 RocketMQ 默认发送的事普通无序的消息,那该如何发送顺序消息呢? 要保证消息的顺序,要从生产端到 broker 消息存储,再到消费消息都要保证链路的顺序…

【C++】二维数组 数组名

二维数组名用途 1、查看所占内存空间 2、查看二维数组首地址 针对第一种用途&#xff0c;还可以计算数组有多少行、多少列、多少元素 针对第二种用途&#xff0c;数组元素、行数、列数都是连续的&#xff0c;且相差地址是有规律的 下面是一个实例 #include<iostream&g…

Spring源码解析(29)之AOP动态代理对象创建过程分析

一、前言 在上一节中我们已经介绍了在createBean过程中去执行AspectJAutoProxyCreator的after方法&#xff0c;然后去获取当前bean适配的advisor&#xff0c;如果还不熟悉的可以去看下之前的博客&#xff0c;接下来我们分析Spring AOP是如何创建代理对象的&#xff0c;在此之前…

【目标检测类】YOLOv5网络模型结构基本原理讲解

1. 基本概念 YOLOv5模型结构主要包括以下组成部分&#xff1a;‌ 输入端&#xff1a;‌YOLOv5的输入端采用了多种技术来增强模型的性能&#xff0c;‌包括Mosaic数据增强、‌自适应锚框计算、‌以及自适应图片缩放。‌这些技术有助于提高模型的泛化能力和适应不同尺寸的输入图…

MySQL基础操作全攻略:增删改查实用指南(中)

本节目标&#xff1a; NOT NULL - 指示某列不能存储 NULL 值。 UNIQUE - 保证某列的每行必须有唯一的值。 DEFAULT - 规定没有给列赋值时的默认值。 PRIMARY KEY - NOT NULL 和 UNIQUE 的结合。确保某列&#xff08;或两个列多个列的结合&#xff09;有唯一标 识&am…

【C++】模拟实现stack

&#x1f984;个人主页:修修修也 &#x1f38f;所属专栏:实战项目集 ⚙️操作环境:Visual Studio 2022 ​ 目录 一.了解项目功能 &#x1f4cc;了解stack官方标准 &#x1f4cc;了解模拟实现stack 二.逐步实现项目功能模块及其逻辑详解 &#x1f4cc;实现stack成员变量 &…

【Linux】进程间通信(管道通信、共享内存通信)

一.什么是进程间通信 进程间通信这五个字很好理解&#xff0c;就是进程和进程之间通信。 那么为什么要有进程间通信呢&#xff1f; 1.数据传输&#xff1a;一个进程需要将它的数据发送给另一个进程 2.资源共享&#xff1a;多个进程之间共享同样的资源 3.通知事件&#xff1a;一…