25_Vision Transformer原理详解

1.1 简介

Vision Transformer (ViT) 是一种将Transformer架构从自然语言处理(NLP)领域扩展到计算机视觉(CV)领域的革命性模型,由Google的研究人员在2020年提出。ViT的核心在于证明了Transformer架构不仅在处理序列数据(如文本)方面非常有效,在处理图像数据时也能取得卓越性能,挑战了卷积神经网络(CNNs)在视觉任务中的主导地位。

ViT的基本思想

传统的CNN通过局部感受野和池化操作逐步提取图像特征,而ViT则采取了一种完全不同的思路。它首先将输入图像划分为多个固定大小的 patches(例如16x16像素),然后对每个patch进行线性映射(Flattening + Linear Projection),将其转换成一个向量。这些向量连同一个特殊的分类token([CLS])一起作为Transformer的输入序列。这样,图像就被转换成了一个序列数据,从而可以直接应用自注意力机制(Self-Attention)进行特征提取。

ViT模型结构

  1. Patch Embedding: 输入图像被分割成多个非重叠的patches,每个patch被展平并经过一个线性层转换成一个固定长度的向量,这个过程称为嵌入。通常还会加入位置编码(Positional Encoding),以保留patch之间的空间信息。

  2. Transformer Encoder: 这是ViT的核心部分,由多层Transformer编码器组成。每层包括一个多头自注意力模块(Multi-Head Self-Attention, MHSA)和一个前馈神经网络(Feed Forward Network, FFN),两侧通常还会有LayerNorm层,并可能伴有残差连接(Residual Connections)。MHSA允许不同patch间的交互,FFN则进一步加工这些特征。

  3. MLP Head: 在Transformer编码器的输出之后,通常会添加一个多层感知机(MLP)用于最终的任务输出,比如在图像分类任务中预测类别。

  4. Classification Token ([CLS]): 在输入序列的开始处添加一个特殊token,Transformer的输出中这个token的表示将用于分类任务,即包含了整个图像的全局特征。

关键创新点

  • 直接应用自注意力机制于视觉任务:打破了CNN在视觉领域的垄断,展示了自注意力机制在图像处理中的潜力。
  • 灵活的序列化处理图像:通过将图像视为一系列向量的序列,使得模型能够更好地理解全局上下文。
  • 简化模型结构:相比复杂的CNN架构,ViT的结构更为简洁,易于理解和调整。

训练与优化

ViT的成功依赖于大规模数据集(如ImageNet)的预训练以及特定的训练策略,包括大数据量、长时间训练以及数据增强技术(如裁剪、翻转等)。为了防止过拟合,还采用了正则化技术,如Dropout或Stochastic Depth。

总结

Vision Transformer展示了Transformer架构在计算机视觉领域的强大潜力,推动了后续一系列视觉Transformer模型的发展,如DeiT(Data-efficient Image Transformers)、Swin Transformer等,这些模型在效率和准确性上进行了进一步的优化。ViT的成功标志着深度学习架构在视觉任务处理上的一次重要革新,为未来的研究开辟了新的方向。

1.2 模型结构

Vision Transformer (ViT) 的工作流程可以分为几个关键步骤,从原始图像输入到最终的分类或预测输出。以下是其详细的工作流程:

1. 图像分割成Patch(Patchification)

  • 输入图像:首先,模型接收一张图像作为输入。
  • 分割:将图像划分为多个大小一致的矩形区域,这些区域被称为“patches”。标准做法是使用16×16像素的大小,但这个尺寸可以根据需求调整。例如,对于一个224×224像素的图像,分割后会得到14×14=196个这样的patches。

2. Patch嵌入(Patch Embedding)

  • 展平和线性映射:每个patch被展平成一个向量,然后通过一个全连接层(也称为嵌入层)转换为一个更高维度的向量(比如D=768维),这一步是为了将像素信息转化为特征表示。
  • 位置编码(Positional Encoding):为了保持序列中patch的位置信息,每个嵌入向量还会加上一个位置编码,这可以是绝对位置编码或者学习得到的编码。

3. 添加分类Token

  • 在所有patch的嵌入序列之前,会添加一个特殊的分类Token(例如一个学习到的向量),这个Token的作用是在Transformer的输出中代表整个图像的全局信息,对于分类任务至关重要。

4. Transformer Encoder 层处理

  • 多头自注意力(Multi-Head Self-Attention, MHSA):每个Encoder层的核心是MHSA模块,它允许模型并行地考虑不同patch之间的相互依赖关系。MHSA通过将输入分成多个“头”并计算不同头间的注意力权重来实现这一点。
  • 层归一化(Layer Normalization, LN):在MHSA前后使用LN来稳定学习过程,提高训练稳定性。
  • 残差连接(Residual Connection):MHSA的输出与输入相加,通过残差连接保留原始信息,有助于解决深度网络中的梯度消失问题。
  • 前馈神经网络(Feed Forward Network, FFN):每个Encoder层还包括一个FFN,通常由两个线性层和一个激活函数(如ReLU)组成,用于进一步的特征变换和非线性处理。
  • Dropout:为了正则化,可能在某些部分应用Dropout以减少过拟合的风险。

5. 输出和分类

  • 分类Token:经过多层Transformer Encoder处理后,最开始添加的分类Token的特征向量被提取出来,该向量包含了关于整个图像的综合信息。
  • 分类头:这个向量接着通过一个分类头,通常是一个简单的全连接层(MLP),用于将特征映射到特定的类别数上,产生最终的分类预测。

6. 训练与优化

  • 损失函数:模型的输出会与真实标签对比,通常使用交叉熵损失函数来计算预测误差。
  • 反向传播与更新:通过反向传播算法计算损失函数关于模型参数的梯度,并使用优化器(如Adam)来更新这些参数,以最小化损失。

总结

ViT的工作流程从图像的分割开始,经过嵌入、位置编码、Transformer编码器层的处理,最终通过分类头输出类别预测。其核心在于利用Transformer的自注意力机制来高效地捕获全局的视觉特征,这一机制使其在图像识别和分类任务中表现出色,尤其是在大数据集上。

首先我们输入一张图片,然后分成一张一张的patches,然后将每个patch输入至embedding层,也就是Linear Projection of Flattened Patches。然后我们就会得到一个个的向量,称为token。紧接着我们会在所有生成的token前加上一个新的,专门用于我们分类的class token,(这里增加一个class token是参考的bert网络),这些patch维度都是相同的。

为了标记位置信息,又加了class token即向量前面的0123456789。然后输入至编码器。编码器的详细结构如右图所示,重复堆叠L次。

然后因为我们要做分类,所以只提取针对我们class token所对应的输出,再通过MLPhead得到最终分类的结果。

1.3 Embedding层

在Vision Transformer (ViT) 中,Embedding层扮演着将图像数据转换为适合Transformer处理的序列化表示的关键角色。这一过程主要涉及以下几个关键步骤:

1. 图像分割(Patch Extraction)

首先,原始图像被分割成多个相同大小的小块,称为patches。每个patch代表了图像的一个局部区域。常见的做法是使用16×16像素的大小,但这个尺寸可以根据具体任务和模型需求调整。例如,对于一个224×224像素的图像,分割后可能会得到14×14=196个这样的patches。

2. Patch Flatten and Linear Projection

  • 展平(Flatten):每一个patch被展平成一维向量,即将其16×16×C(其中C为通道数,例如3对于RGB图像)的形状转换为一个长度为16×16×C的一维数组。
  • 线性映射(Linear Projection):展平后的patch向量通过一个全连接层(也称作线性层)进行映射,目的是将patch的维度变换到一个预定的维度,通常记为D(如D=768)。这一步骤确保了所有patch转换为具有相同维度的向量,为Transformer的输入做好准备。

3. 位置编码(Positional Encoding)

  • 为了保持图像的空间结构信息,每个经过线性映射的patch向量还会加上一个位置编码。位置编码可以是固定的(如正弦/余弦函数),也可以是学习得到的,它的目的是为序列中的每个位置提供唯一的标识。这样,Transformer模型就能区分不同位置的patch,即使它们的视觉内容相似。

4. 分类Token的添加

  • 在所有的patch嵌入前面,通常会添加一个特殊的分类Token(如CLS),该Token的目的是在模型处理过程中聚合所有patch的信息,用于最终的分类任务。这个Token也会经过相同的线性映射处理,以保持维度一致性。

5. 输出形式

经过上述步骤,所有patch的嵌入向量(包括分类Token的嵌入向量)被组织成一个序列,形成Transformer模型的输入。这个序列的形状将是(N+1)×D,其中𝑁N是patch的数量,D是嵌入维度,加1是因为额外的分类Token。

总结

Embedding层在ViT中起着桥梁作用,它将图像从像素空间转换到特征空间,使得Transformer能够以序列化的方式理解和处理图像信息。通过展平、线性映射、位置编码以及分类Token的加入,ViT能够捕捉到图像的局部特征、保持空间信息,并为后续的自注意力机制提供合适的输入形式。

/16代表分割的每张图片大小是16x16的。叠加位置编码时是对应元素直接相加的。

那么不加位置编码会出现什么情况呢?论文也做了实验。

不加位置编码的话准确率是61.3,加了位置编码变为了64.2。其他的位置编码效果差别都不大。

下图为,训练得到的位置编码,它的每个位置上与其他位置上的一个余弦相似度。

这里的patches大小是32x32的 ,224/32=7,所以是7x7大小。

1.4 Encoder层

在Vision Transformer (ViT) 中,Encoder层是模型的核心部分,负责对经过Embedding层处理过的序列化图像特征进行深层次的处理和特征提取。Encoder层由多个相同的Transformer Encoder Block堆叠而成,每个Block包含以下几个关键组件:

1. Layer Normalization (LN)层归一化

  • 在许多实现中,每个Encoder Block的开始会应用Layer Normalization (LN),而不是原始Transformer架构中的Post-Norm(即在自注意力和前馈网络之后)。LN通过对每个特征维度进行标准化处理,帮助稳定训练过程,加速收敛,并且在视觉任务中通常优于Batch Normalization (BN)。

2. Multi-Head Self-Attention (MHSA)

  • Query, Key, Value计算:MHSA是Transformer的核心,它通过计算输入序列中不同位置(或patch)之间的关系来捕获全局依赖性。每个位置的输入被线性映射为Query、Key和Value矩阵,分别用于衡量位置间的关系、查找相关性和聚合信息。
  • 多头注意力:MHSA将Query、Key、Value分成多个头(Head),每个头独立计算注意力,这样可以让模型在不同的表示子空间中并行地学习不同类型的依赖关系,之后再将结果合并。这样做增加了模型的表达能力,使其能更好地处理复杂的数据结构。
  • 注意力分数与加权求和:基于Query和Key的点积计算注意力分数,通过softmax函数归一化后,用以加权求和Value,从而获得每个位置的上下文依赖表示。

3. Add & Normalize with Residual Connection

  • 自注意力模块的输出会与输入(即LN的输出)相加,形成残差连接(Residual Connection),这有助于梯度流动并防止信息丢失。之后,再次应用Layer Normalization以稳定特征表示。

4. Feed-Forward Network (FFN)

  • 每个Encoder Block还包括一个两层的全连接网络,即FFN。它通常包含两个线性层,中间夹着一个非线性激活函数(如ReLU或GELU),用于对注意力模块输出的特征进行进一步的非线性变换和特征提炼。
  • 第一层线性变换通常会增加特征维度,第二层则将特征维度转换回原始维度,保持输入输出维度一致以便于残差连接。

5. Dropout

  • 在某些情况下,为了增强模型的泛化能力,会在自注意力输出或FFN输出上应用Dropout,随机丢弃一部分神经元的输出,以减少过拟合的风险。

6. 重复堆叠

  • 上述过程(LN → MHSA → Add & Normalize → FFN → Add & Normalize)在一个Encoder Block内完成,整个Encoder层由这样的Block重复堆叠L次(L是超参数),每增加一层,模型的表达能力都会增强,能够学习更复杂的特征交互。

综上所述,ViT的Encoder层通过多头自注意力机制捕捉全局依赖,结合前馈网络进行特征变换,通过残差连接和Layer Normalization保证信息流通和训练稳定性,最终输出经过深层次特征提取的序列化特征表示,为图像分类或其他下游任务提供强大的特征支持。

1.5 MLPHead层

在Vision Transformer (ViT) 模型中,MLP Head(也称为分类头或预测头)是模型的最后一部分,其主要任务是将Transformer Encoder输出的特征向量转换为最终的分类预测或回归输出。具体工作流程如下:

1. 提取分类Token

  • 经过一系列Transformer Encoder层处理后,序列中的第一个Token(通常是初始化时添加的分类Token,如𝐶𝐿𝑆CLS)被认为汇总了整个输入图像的全局信息。因此,在进入MLP Head之前,模型会从Encoder的输出序列中提取这个分类Token的特征表示。

2. 全连接层(FC Layer)

  • 提取出的分类Token特征向量会通过一个或多个全连接(Fully Connected,FC)层。第一个FC层通常用于将Transformer的高维特征空间映射到一个中间维度,这有助于模型进行更灵活的特征变换和非线性处理。随后,第二个FC层(如果有)将中间维度的特征映射到最终的输出维度,如对于图像分类任务,就是映射到类别数。

3. 激活函数

  • 在FC层之间,通常会插入激活函数(如ReLU或GELU),以引入非线性,使得模型能够学习和表达更复杂的决策边界。

4. 归一化和Dropout(可选)

  • 在某些实现中,MLP Head之前或之后可能会应用Layer Normalization或Batch Normalization,以进一步稳定训练过程。同时,为了提高模型的泛化能力,可能会在FC层之后应用Dropout,随机丢弃一部分神经元的输出,减少模型对训练数据的过拟合。

5. 输出层

  • 最终的FC层输出通常对应于模型的预测概率分布。对于分类任务,这通常通过Softmax函数来实现,将输出转换为各分类的概率分布。对于回归任务,则可能直接输出预测值或使用其他适合的激活函数(如线性激活)。

6. 损失计算与优化

  • MLP Head的输出会与真实的标签进行比较,计算损失(如交叉熵损失用于分类任务)。这个损失值用于指导反向传播,优化模型参数,以减小预测与实际之间的差距。

综上所述,MLP Head在ViT模型中起到了从高层次特征到具体任务输出的桥梁作用,通过一系列精心设计的线性变换和非线性操作,将Transformer编码得到的特征转化为任务所需的输出形式,是模型进行分类或回归预测的关键组件。

1.6 不同的模型参数

1.7 混合模型

R50指的是ResNet50,和VIT混合。

1.10 模型效果

/16代表分割的每张图片大小是16x16的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/378059.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

算法 —— 高精度(模拟)

目录 加法高精度 两个正整数相加 两个正小数相加 两正数相加 减法高精度 两个正整数相减 两个正小数相减 两正数相减 加减法总结 乘法高精度 两个正整数相乘 两个正小数相乘 乘法总结 加法高精度 题目来源洛谷:P1601 AB Problem(高精&#x…

老物件线上3D回忆展拓宽了艺术作品的展示空间和时间-深圳华锐视点

在数字技术的浪潮下,3D线上画展为艺术家们开启了一个全新的展示与销售平台。这一创新形式不仅拓宽了艺术作品的展示空间,还为广大观众带来了前所未有的观赏体验。 3D线上画展制作以其独特的互动性,让艺术不再是单一的视觉享受。在这里&#x…

220V降5V芯片输出电压电流封装选型WT

220V降5V芯片输出电压电流封装选型WT 220V降5V恒压推荐:非隔离芯片选型及其应用方案 在考虑220V转低压应用方案时,以下非隔离芯片型号及其封装形式提供了不同的电压电流输出能力: 1. WT5101A(SOT23-3封装)适用于将2…

勒索防御第一关 亚信安全AE防毒墙全面升级 勒索检出率提升150%

亚信安全信舷AE高性能防毒墙完成能力升级,全面完善勒索边界“全生命周期”防御体系,筑造边界勒索防御第一关! 勒索之殇,银狐当先 当前勒索病毒卷携着AI技术,融合“数字化”的运营模式,形成了肆虐全球的网…

SpringBoot使用RedisTemplate、StringRedisTemplate操作Redis

前言 RedisTemplate 是 Spring Boot 访问 Redis 的核心组件,底层通过 RedisConnectionFactory 对多种 Redis 驱动进行集成,上层通过 XXXOperations 提供丰富的 API ,并结合 Spring4 基于泛型的 bean 注入,极大的提供了便利&#x…

【自学安全防御】二、防火墙NAT智能选路综合实验

任务要求: (衔接上一个实验所以从第七点开始,但与上一个实验关系不大) 7,办公区设备可以通过电信链路和移动链路上网(多对多的NAT,并且需要保留一个公网IP不能用来转换) 8,分公司设备可以通过总…

网络安全防御【防火墙安全策略用户认证综合实验】

目录 一、实验拓扑图 二、实验要求 三、实验思路 四、实验步骤 1、打开ensp防火墙的web服务(带内管理的工作模式) 2、在FW1的web网页中网络相关配置 3、交换机LSW6(总公司)的相关配置: 4、路由器相关接口配置&a…

connect by prior 递归查询

connect by prior 以公司组织架构举例,共四个层级,总公司,分公司,中心支公司,支公司 总公司level_code为1 下一层级的parent_id为上一层级的id,建立关联关系 SELECT id, name, LEVEL FROM org_info a STA…

海事无人机解决方案

海事巡察 海事巡察现状 巡查效率低下,存在视野盲区,耗时长,人力成本高。 海事的职能 统一管理水上交通安全和防治船舶污染。 管理通航秩序、通航环境。负责水域的划定和监督管理,维护水 上交通秩序;核定船舶靠泊安…

使用自制Qt工具配合mitmproxy进行网络调试

在软件开发和网络调试过程中,抓包工具是不可或缺的。传统的抓包工具如Fiddler或Charles Proxy通常需要设置系统代理,这会抓到其他应用程序的网络连接,需要设置繁琐的过滤,导致不必要的干扰。为了解决这个问题,我们可以…

linux中关于环境变量的常用的设置方法

一. linux中设置环境变量的方式 1.使用/etc/environment, 是一个全局的环境变量设置文件,它会影响到所有用户和所有进程。当你需要设置一个全局的环境变量时,应该使用这个文件。这个文件的格式是 KEYvalue,每行一个环境变量。 2. 使用/etc/…

【unity笔记】常见问题收集

一 . Unity Build GI data 卡住问题 解决: 参考官方文档,GI(Global Illumination) data 指的是全局照明信息。 在Unity的Edit->Preference中,可以编辑GI缓存路径和分配GI缓存大小。 调出Window->Rendering->Lighting窗口,取消勾选…

【Caffeine】⭐️SpringBoot 项目整合 Caffeine 实现本地缓存

目录 🍸前言 🍻一、Caffeine 🍺二、项目实践 2.1 环境准备 2.2 项目搭建 2.3 接口测试 ​💞️三、章末 🍸前言 小伙伴们大家好,缓存是提升系统性能的一个不可或缺的工具,通过缓存可以避免大…

基于SpringBoot+VueJS+微信小程序技术的图书森林共享小程序设计与实现:7000字论文+源代码参考

博主介绍:硕士研究生,专注于信息化技术领域开发与管理,会使用java、标准c/c等开发语言,以及毕业项目实战✌ 从事基于java BS架构、CS架构、c/c 编程工作近16年,拥有近12年的管理工作经验,拥有较丰富的技术架…

Java面试八股之Redis哨兵机制

Redis哨兵机制 Redis Sentinel(哨兵)模式是一种高可用解决方案,用于监控和自动故障转移Redis主从集群。以下是对哨兵模式详细过程的描述: 1. 初始化与配置 部署哨兵节点:在不同的服务器上部署一个或多个Redis Sentin…

leetcode 周赛(406)全AC留念

纪念第一次 leetcode 周赛(406)全AC 1.(100352. 交换后字典序最小的字符串) 题目描述: 给你一个仅由数字组成的字符串 s,在最多交换一次 相邻 且具有相同 奇偶性 的数字后,返回可以得到的 字典序最小的字符串 。 如…

ubantu22.04安装OceanBase 数据库

1、管理员启动cmd,运行 sudo bash -c "$(curl -s https://obbusiness-private.oss-cn-shanghai.aliyuncs.com/download-center/opensource/service/installer.sh)" 2、提示如下代表安装完成 3、修改数据库配置文件的密码 sudo vim /etc/oceanbase.cnf 然后保存退…

初学SpringMVC之 JSON 篇

JSON(JavaScript Object Notation,JS 对象标记)是一种轻量级的数据交换格式 采用完全独立于编程语言的文本格式来存储和表示数据 JSON 键值对是用来保存 JavaScript 对象的一种方式 比如:{"name": "张三"}…

Redis实战—附近商铺、用户签到、UV统计

本博客为个人学习笔记,学习网站与详细见:黑马程序员Redis入门到实战 P88 - P95 目录 附近商铺 数据导入 功能实现 用户签到 签到功能 连续签到统计 UV统计 附近商铺 利用Redis中的GEO数据结构实现附近商铺功能,常见命令如下图所示。…

牛客TOP101:合并两个排序的链表

文章目录 1. 题目描述2. 解题思路3. 代码实现 1. 题目描述 2. 解题思路 与正常的合并两个有序数组思路一样,这里可以定义一个头节点(虚拟节点),可以方便我们一开始进行连接。用两个指针标记两个链表的结点,进行循环比较…