基础论文学习(1)——ViT

Vision Transformer(ViT) 模型架构是在 ICLR 2021 上作为会议论文发表的一篇研究论文中介绍的,题为“An Image is Worth 16*16 Words: Transformers for Image Recognition at Scale”。它由Neil Houlsby,Alexey Dosovitskiy和Google研究大脑团队的另外10位作者开发和出版。 调整代码和预先训练的ViT模型可以在 谷歌研究团队的GitHub上找到。你可以在这里找到它们。ViT模型在ImageNet和ImageNet-21k数据集上进行了预训练。

在下文中,我们将重点介绍多年来开发的一些最重要的Vision Transformer。它们基于Transformer架构,该架构最初是在 2017 年为自然语言处理 (NLP) 提出的。

日期模型简述Vision Transformer?
2017Transformer仅基于注意力机制的模型。它在NLP任务上表现出色。
2018BERT预先训练的变压器模型开始主导NLP领域。
2020DETRDETR是一个简单而有效的高级视觉框架,它将对象检测视为直接集合预测问题。是的
2020GPT-3GPT-3 是具有 170B 参数的巨大变压器模型,向通用 NLP 模型迈出了重要的一步。
2020iGPT最初为 NLP 开发的转换器模型也可用于图像预训练。是的
2020ViT对视觉识别有效的纯变压器架构。是的
2020IPT/SETR/CLIP变压器已分别应用于低级视觉、分割和多模态任务。是的
2021以后ViT 变体有几种 ViT 变体,包括 DeiT、PVT、TNT、Swin 和 CSWin (2022)。是的

CNN和ViT之间的区别(ViT与CNN)

与卷积神经网络 (CNN) 相比,视觉转换器 (ViT) 取得了显著的结果,同时获得的用于预训练的计算资源要少得多。与卷积神经网络(CNN)相比,视觉转换器(ViT)显示出通常较弱的归纳偏差(inductive bias),导致在较小的数据集上进行训练时对模型正则化或数据增强(AugReg)的依赖性增加。 ViT 是基于变压器架构的可视化模型,最初是为基于文本的任务而设计的。ViT 模型将输入图像表示为一系列图像补丁,就像使用转换器对文本使用的一系列词嵌入一样,并直接预测图像的类标签。当在足够的数据上进行训练时,ViT 表现出非凡的性能,以 4 倍的计算资源打破了类似最先进的 CNN 的性能。

CNN使用像素阵列,而ViT将输入图像拆分为视觉标记。可视转换器将图像划分为固定大小的图块,正确嵌入每个图块,并将位置嵌入作为转换器编码器的输入。此外,ViT模型在计算效率和准确性方面比CNN高出近四倍。 ViT 中的自我注意层使得在整个图像中全局嵌入信息成为可能。该模型还学习训练数据,以编码图像补丁的相对位置以重建图像的结构。

1. 视觉变压器 ViT 架构

文献中已经提出了几种视觉变压器模型。视觉转换器架构的整体结构包括以下步骤:

  • Split an image into patches (fixed sizes)
  • Flatten the image patches
  • Create lower-dimensional linear embeddings from these flattened image patches
  • Add positional embeddings
  • Feed the sequence as an input to a state-of-the-art transformer encoder
  • Pre-train the ViT model with image labels (fully supervised on a huge dataset)
  • Fine-tune the downstream dataset for image classification

在这里插入图片描述
Vision Transformer(ViT)是一种使用self-attention来处理图像的架构。Vision Transformer架构由一系列变压器块组成。每个Transformer encoder模块由两个子层组成:Patch Embedding层、Multi-Head Self-Attention层、FeedForward层、MLP分类头。

Patch Embedding层,将图像划分为固定大小的Ptach,并将每个补丁映射(MLP)到高维矢量token表示。然后将这些tokens嵌入送入Transformer块进行进一步处理。

Multi-Head Self-Attention层根据图像中每个像素与所有其他像素的关系计算其注意力权重,多头注意力通过允许模型同时关注输入序列的不同部分来扩展此机制。而FeedForward层对自注意层的输出应用非线性变换。

ViT架构的最终输出是类预测,通过将最后一个Transformer模块的输出的CLS token,通过MLP分类头获得,分类头通常由单个全连接层组成。

唯一改变的是这些块的数量Layer。为此,为了进一步证明使用更多数据可以训练更大的ViT变体,提出了3个模型:
在这里插入图片描述
Layer就是encoder堆叠的层数,Heads是指多头注意力的头数,而MLP size是指多层感知器特征维度,但它实际上是一堆线性转换层,Hidden size D是embedding的token大小,在整个图层中保持固定。为什么要保持固定?这样我们就可以使用短的残差边跳跃连接。

ViT在大型数据集上进行预训练,然后微调为小数据集。唯一的修改是丢弃预测头(MLP 头)并附加一个新的D×K线性图层,其中 K 是小数据集的种类数。

1 图片分块和降维

因为transformer encoder的输入需要序列,所以最简单做法就是把图片切分为patch,然后拉成序列即可。 假设输入图片大小是256x256,打算分成64个patch,每个patch是32x32像素:

x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)

这个写法是采用了爱因斯坦表达式,具体是采用了einops库实现,内部集成了各种算子,rearrange就是其中一个,非常高效。p就是patch大小,假设输入是[b,3,256,256],则rearrange操作是先变成(b,3,8x32,8x32),最后变成(b,8x8,32x32x3)即(b,64,3072),将每张图片切分成64个小块,每个小块长度是32x32x3=3072,也就是说输入长度为64的图像序列,每个元素采用3072长度进行编码。

考虑到3072有点大,故作者先进行降维:

# 将3072变成dim,假设是1024
self.patch_to_embedding = nn.Linear(patch_dim, dim)
x = self.patch_to_embedding(x)

2 增加CLS token

仔细看论文上图,可以发现假设切成9个块,但是最终到transfomer输入是10个向量,额外追加了一个0和。为啥要追加?原因是我们现在没有解码器了,而是编码后直接就进行分类预测,那么该编码器就要负责一点点解码器功能,那就是:需要一个类似开启解码标志,非常类似于标准transformer解码器中输入的目标嵌入向量右移一位操作。试下如果没有额外输入,9个块输入9个编码向量输出,那么对于分类任务而言,我应该取哪个输出向量进行后续分类呢?选择任何一个都说不通,所以作者追加了一个可学习嵌入向量输入。那么额外的可学习嵌入向量为啥要设计为可学习,而不是类似nlp中采用固定的token代替?个人不负责任的猜测这应该就是图片领域和nlp领域的差别,nlp里面每个词其实都有具体含义,是离散的,但是图像领域没有这种真正意义上的离散token,有的只是一堆连续特征或者图像像素,如果不设置为可学习,那还真不知道应该设置为啥内容比较合适,全0和全1也说不通。 自此现在就是变成10个向量输出,输出也是10个编码向量,然后取第0个编码输出进行分类预测即可。从这个角度看可以认为编码器多了一点点解码器功能。具体做法超级简单,0就是位置编码向量,是可学习的patch嵌入向量。

# dim=1024
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# 变成(b,64,1024)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
# 额外追加token,变成b,65,1024
x = torch.cat((cls_tokens, x), dim=1)

3 位置编码PE

1-D 位置编码:例如3x3共9个patch,patch编码为1到9

2-D 位置编码:patch编码为11,12,13,21,22,23,31,32,33,即同时考虑X和Y轴的信息,每个轴的编码维度是D/2

即使应用了许多位置嵌入方案,也没有发现显着差异。这可能是由于变压器编码器在patch级别工作。捕获patch(空间信息)之间的顺序关系的学习嵌入并不是那么重要。理解 P x P 的patch块之间的关系比完整图像H x W 之间的关系相对容易。
在这里插入图片描述

这里做的比较简单,没有采用sincos编码,而是直接设置为可学习,效果差不多。相邻位置有相近的位置编码向量,整体呈现2d空间位置排布一样。将patch嵌入向量和位置编码向量相加即可作为编码器输入:

# num_patches=64,dim=1024,+1是因为多了一个cls开启解码标志
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)

对训练好的pos_embedding进行可视化,如下所示:
在这里插入图片描述

4 Transformer Encoder

作者采用的是没有任何改动的transformer,故没有啥说的。

self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)

假设输入是(b,65,1024),那么transformer输出也是(b,65,1024)

5 分类head

在编码器后接fc分类器head即可

self.mlp_head = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, mlp_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(mlp_dim, num_classes))# 65个输出里面只需要第0个输出进行后续分类即可
self.mlp_head(x[:, 0])

总体架构

到目前为止就全部写完了,是不是非常简单,外层整体流程为:
在这里插入图片描述

import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange# helpersdef pair(t):return t if isinstance(t, tuple) else (t, t)# classesclass FeedForward(nn.Module):  # LN + FC + GELU + Dropout + FC + Dropoutdef __init__(self, dim, hidden_dim, dropout = 0.):super().__init__()self.net = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)class Attention(nn.Module):# LN(x)  ->  qkv  ->  Softmax(q*k/dk)*v  ->  FCdef __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):super().__init__()inner_dim = dim_head *  headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.norm = nn.LayerNorm(dim)self.attend = nn.Softmax(dim = -1)self.dropout = nn.Dropout(dropout)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):x = self.norm(x)qkv = self.to_qkv(x).chunk(3, dim = -1)q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)attn = self.dropout(attn)out = torch.matmul(attn, v)out = rearrange(out, 'b h n d -> b n (h d)')return self.to_out(out)class Transformer(nn.Module):# depth=12层Attention + FeedForward -> LNdef __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):super().__init__()self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),FeedForward(dim, mlp_dim, dropout = dropout)]))self.norm = nn.LayerNorm(dim)def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn self.norm(x)class ViT(nn.Module):# to_patch_embedding + cat_cls_token + add_pos_embedding + transformer + mlp_headdef __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):super().__init__()image_height, image_width = pair(image_size)patch_height, patch_width = pair(patch_size)assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'num_patches = (image_height // patch_height) * (image_width // patch_width)patch_dim = channels * patch_height * patch_widthassert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'self.to_patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),nn.LayerNorm(patch_dim),nn.Linear(patch_dim, dim),nn.LayerNorm(dim),)self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))self.cls_token = nn.Parameter(torch.randn(1, 1, dim))self.dropout = nn.Dropout(emb_dropout)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)self.pool = poolself.to_latent = nn.Identity()  # 创建一个恒等映射层,用于不做任何改变地传递输入数据self.mlp_head = nn.Linear(dim, num_classes)def forward(self, img):x = self.to_patch_embedding(img)b, n, _ = x.shapecls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)x = torch.cat((cls_tokens, x), dim=1)x += self.pos_embedding[:, :(n + 1)]x = self.dropout(x)x = self.transformer(x)x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]x = self.to_latent(x)return self.mlp_head(x)

2. 实验分析

作者得出的结论是:cv领域应用transformer需要大量数据进行预训练,在同等数据量的情况下性能不如cnn。一旦数据量上来了,对应的训练时间也会加长很多,那么就可以轻松超越cnn。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/102246.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

常用获取威胁情报数据+信息溯源的平台

获取威胁情报数据 奇安信威胁分析平台深信服威胁情报中心360安全大脑腾讯哈勃分析系统绿盟威胁情报中心安全星图平台安天威胁情报中心VenusEye威胁情报中心VirustotalIBM X-Force威胁情报ThreatBookAlienVaultVirusScan多引擎在线扫描RiskIQThreatMiner 需要注意的是&#xff0…

php 系列题目,包含查看后端源代码

一、弱类型比较问题 原则: 1.字符串和数字比较,字符串回被转换成数字。 "admin" 0(true) admin被转换成数字,由于admin是字符串,转换失败,变成0 int(admin)0,所以比较结果是ture 2.混合字符串转…

【轴承故障诊断】用于轴承故障诊断的集中时频分析研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

Go语言基础之函数

函数 Go语言中支持函数、匿名函数和闭包,并且函数在Go语言中属于“一等公民”。 函数定义 Go语言中定义函数使用func关键字,具体格式如下: func 函数名(参数)(返回值){函数体 }其中: 函数名:由字母、数字、下划线…

借助frp的xtcp+danted代理打通两边局域网p2p方式访问

最终效果 实现C内网所有设备借助c1内网代理访问B内网所有服务器 配置公网服务端A frps 配置frps.ini [common] # 绑定frp穿透使用的端口 bind_port 7000 # 使用token认证 authentication_method token token xxxx./frps -c frps.ini启动 配置service自启(可选) /etc/…

分布式核心知识以及常见微服务框架

分布式中的远程调用 在微服务架构中,通常存在多个服务之间的远程调用的需求。远程调用通常包含两个部分:序列化和通信协议。常见的序列化协议包括json、xml、 hession、 protobuf、thrift、text、 bytes等,目前主流的远程调用技术有基于HTTP…

C#小轮子:MiniExcel,快速操作Excel

文章目录 前言环境安装功能测试普通读写读新建Excel表格完全一致测试:成功大小写测试:严格大小写别名读测试:成功 写普通写别名写内容追加更新模板写 其它功能xlsx和CSV互转 前言 Excel的操作是我们最常用的操作,Excel相当于一个…

Unity 之NavMeshAgent 组件(导航和路径寻找的组件)

文章目录 **作用**:**属性和方法**:**用途**:**注意事项**: NavMeshAgent 是Unity引擎中用于导航和路径寻找的组件。它可以使游戏对象在场景中自动找到可行走的路径,并在避免障碍物的情况下移动到目标位置。 以下是关于…

从零玩转系列之微信支付实战PC端装修我的订单页面 | 技术创作特训营第一期

一、前言 欢迎来到本期的博客!本篇文章是 PC 端的结尾了,前面经历过九个章节到本章节刚刚好十章节感谢观看我的文章,那么接下来我们将要编写的是我的订单页面. GGBOM! 本篇完毕后将是 UniApp 的篇章感受移动端的诱惑 💗 本次为前端知识点如果不懂前段可以…

信号波形解读

can波形解读 实际波形 标准帧 发送数据 仲裁段 0x1AA 数据长度为8字节 内容为:0x41, 0x20, 0x38, 0x41, 0x00, 0x16, 0x00, 0x00 波特率 111K

关于stm32推挽带有上下拉电阻的思考、IO口驱动能力是什么

1、发现推挽带有上下拉电阻 1.1、stm32手册 记忆中推挽是不需要上下拉的,没关注过,但是我真的理解上下拉吗,下图来自stm32f4的中文版和英文版的数据手册,没有翻译错,就是“推挽带有上下拉的能力”。 1.2、查找相关信…

基于决策树(Decision Tree)的乳腺癌诊断

决策树(DecisionTree)学习是以实例为基础的归纳学习算法。算法从--组无序、无规则的事例中推理出决策树表示形式的分类规则,决策树也能表示为多个If-Then规则。一般在决策树中采用“自顶向下、分而治之”的递归方式,将搜索空间分为若千个互不相交的子集,在决策树的内部节点(非叶…

DDD 架构分层,MQ消息要放到那一层处理?

作者:小傅哥 博客:https://bugstack.cn 沉淀、分享、成长,让自己和他人都能有所收获!😄 本文的宗旨在于通过简单干净实践的方式教会读者,使用 Docker 配置 RocketMQ 并在基于 DDD 分层结构的 SpringBoot 工…

【Java 动态数据统计图】动态数据统计思路案例(动态,排序,数组)一(112)

需求&#xff1a;&#xff1a; 有一个List<Map<String.Object>>,存储了某年某月的数据&#xff0c; 数据是根据用户查询条件进行显示的&#xff1b;所以查询的数据是动态的&#xff1b;需按月份统计每个年月数据出现的次数&#xff0c;并且按照月份排序&#xff1…

unity 之 Input.GetMouseButtonDown 的使用

文章目录 Input.GetMouseButtonDown Input.GetMouseButtonDown 当涉及到处理鼠标输入的时候&#xff0c;Input.GetMouseButtonDown 是一个常用的函数。它可以用来检测鼠标按键是否在特定帧被按下。下面我会详细介绍这个函数&#xff0c;并举两个例子说明如何使用它。 函数签名…

AI在日常生活中的应用:从语音助手到自动驾驶

文章目录 AI的定义和发展AI在日常生活中的应用1. **智能语音助手**2. **智能家居**3. **智能医疗**4. **自动驾驶** 代码示例&#xff1a;使用Python实现基于机器学习的图片分类AI的未来前景结论 &#x1f389;欢迎来到AIGC人工智能专栏~探索AI在日常生活中的应用 ☆* o(≧▽≦…

Python Opencv实践 - 直方图显示

import cv2 as cv import numpy as np import matplotlib.pyplot as pltimg cv.imread("../SampleImages/pomeranian.png", cv.IMREAD_COLOR) print(img.shape)#图像直方图计算 #cv.calcHist(images, channels, mask, histSize, ranges, hist, accumulate) #images&…

深度学习|CNN卷积神经网络

CNN卷积神经网络 解决的问题人类的视觉原理原理卷积层——提取特征池化层——数据降维全连接层——输出结果 应用图像处理自然语言处理 解决的问题 在CNN没有出现前&#xff0c;图像对人工智能来说非常难处理。 主要原因&#xff1a; 图像要处理的数据量太大了。图像由像素组…

Docker数据管理

目录 一、数据卷 二、数据卷容器 三、容器互联 管理 Docker容器中数据主要有两种方式&#xff1a; 数据卷&#xff08;Data Volumes&#xff09;数据卷容器&#xff08;DataVolumes Containers&#xff09; 一、数据卷 数据卷是一个供容器使用的特殊目录&#xff0c;位于容…

百度云BOS云存储的图片如何在访问时,同时进行格式转换、缩放等处理

前言 之前做了一个图片格式转换和压缩的服务&#xff0c;结果太占内存。后来查到在访问图片链接时&#xff0c;支持进行图片压缩和格式转换&#xff0c;本来想着先格式转换、压缩图片再上传到BOS&#xff0c;现在变成了上传后&#xff0c;访问时进行压缩和格式转换。想了想&am…