Vision Transformer (ViT) 基本原理

Vision Transformer (ViT) 基本原理

flyfish

Vision Transformer (ViT) 是一种基于 Transformer 架构的计算机视觉模型


一、ViT 的基本原理

ViT 的核心思想是将一张图像视为一组序列,将其嵌入到 Transformer 的输入中,通过自注意力机制捕获全局上下文信息,从而进行分类或其他视觉任务。

  • 传统卷积神经网络 (CNN): 使用卷积核逐层提取局部特征,通常关注图像的局部模式。
  • ViT 的创新点: 不使用卷积操作,而是将图像划分为小块 (patches),通过 Transformer 模型直接处理全局特征。

二、ViT 的核心架构

1. 输入处理

1.1 图像分块 (Patch Partitioning)
将输入图像 x ∈ R H × W × C x \in \mathbb{R}^{H \times W \times C} xRH×W×C H H H: 高, W W W: 宽, C C C: 通道数)分割为 N N N 个固定大小的图像块 (patch)。
每个 patch 的大小为 P × P P \times P P×P,则总的 patch 数量为:

N = H P × W P N = \frac{H}{P} \times \frac{W}{P} N=PH×PW

每个 patch 被展平成向量,形状为 R P 2 ⋅ C \mathbb{R}^{P^2 \cdot C} RP2C

1.2 Patch 嵌入 (Patch Embedding)
通过一个线性投影将每个 patch 转换为 D D D-维的向量:

z 0 = [ z 0 1 , z 0 2 , … , z 0 N ] 其中  z 0 i ∈ R D z_0 = [z_0^1, z_0^2, \dots, z_0^N] \quad \text{其中 } z_0^i \in \mathbb{R}^D z0=[z01,z02,,z0N]其中 z0iRD

公式:

z 0 i = W e ⋅ Flatten ( x i ) + b e z_0^i = W_e \cdot \text{Flatten}(x_i) + b_e z0i=WeFlatten(xi)+be

其中, W e W_e We 是嵌入权重, b e b_e be 是偏置。

1.3 加入位置编码 (Positional Encoding)
因为 Transformer 缺乏对序列顺序的感知,需要加入位置编码:

z 0 = z 0 + E p o s z_0 = z_0 + E_{pos} z0=z0+Epos

E p o s ∈ R N × D E_{pos} \in \mathbb{R}^{N \times D} EposRN×D 是位置编码矩阵。


2. Transformer 编码器 (Transformer Encoder)

Transformer 编码器由多层组成,每层包含两个主要模块:

2.1 多头自注意力机制 (Multi-Head Self-Attention, MHSA)

自注意力机制的核心思想是捕捉输入序列中每个元素与其他元素之间的关系。

计算步骤

  1. 对输入 z l − 1 ∈ R N × D z_{l-1} \in \mathbb{R}^{N \times D} zl1RN×D 进行线性变换得到 Q , K , V Q, K, V Q,K,V
    Q = z l − 1 W Q , K = z l − 1 W K , V = z l − 1 W V Q = z_{l-1}W_Q, \quad K = z_{l-1}W_K, \quad V = z_{l-1}W_V Q=zl1WQ,K=zl1WK,V=zl1WV

  2. 计算注意力权重:
    Attention ( Q , K , V ) = softmax ( Q K ⊤ D k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{D_k}}\right)V Attention(Q,K,V)=softmax(Dk QK)V

    其中, D k D_k Dk Q Q Q K K K 的维度,用于缩放防止梯度爆炸。

  3. 多头注意力的输出为:
    MHSA ( z l − 1 ) = [ head 1 , … , head h ] W O \text{MHSA}(z_{l-1}) = [\text{head}_1, \dots, \text{head}_h]W_O MHSA(zl1)=[head1,,headh]WO

    W Q , W K , W V , W O W_Q, W_K, W_V, W_O WQ,WK,WV,WO 是学习参数, h h h 是注意力头数。

2.2 前向全连接层 (Feed-Forward Network, FFN)

每个位置的特征通过一个两层全连接网络进行变换:

FFN ( x ) = ReLU ( x W 1 + b 1 ) W 2 + b 2 \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2 FFN(x)=ReLU(xW1+b1)W2+b2

2.3 残差连接与归一化

每个模块后有残差连接和 LayerNorm:

z l ′ = LayerNorm ( z l − 1 + MHSA ( z l − 1 ) ) z_l' = \text{LayerNorm}(z_{l-1} + \text{MHSA}(z_{l-1})) zl=LayerNorm(zl1+MHSA(zl1))
z l = LayerNorm ( z l ′ + FFN ( z l ′ ) ) z_l = \text{LayerNorm}(z_l' + \text{FFN}(z_l')) zl=LayerNorm(zl+FFN(zl))


3. 分类头 (Classification Head)

通过一个 learnable 的分类标记 z cls z_{\text{cls}} zcls 获取全局特征,最后用全连接层输出分类结果:

y = softmax ( W h e a d ⋅ z cls + b h e a d ) y = \text{softmax}(W_{head} \cdot z_{\text{cls}} + b_{head}) y=softmax(Wheadzcls+bhead)


三、与传统 Transformer 的主要区别

  1. 输入类型

    • 传统 Transformer: 输入为一维序列(如词向量)。
    • ViT: 输入为二维图像分块。
  2. 位置编码

    • 传统 Transformer: 通常使用正弦或可学习的位置编码。
    • ViT: 通常直接添加 learnable 的二维位置编码。
  3. 任务目标

    • 传统 Transformer: 多用于 NLP 任务(如翻译、文本分类)。
    • ViT: 多用于视觉任务(如图像分类、目标检测)。

四、ViT 中的注意力机制如何运作

ViT 的注意力机制核心是 自注意力 (Self-Attention),通过计算每个 patch 之间的相关性,捕获全局信息。

1. 自注意力权重的计算

α i j = exp ⁡ ( q i ⋅ k j D k ) ∑ j = 1 N exp ⁡ ( q i ⋅ k j D k ) \alpha_{ij} = \frac{\exp\left(\frac{q_i \cdot k_j}{\sqrt{D_k}}\right)}{\sum_{j=1}^N \exp\left(\frac{q_i \cdot k_j}{\sqrt{D_k}}\right)} αij=j=1Nexp(Dk qikj)exp(Dk qikj)

其中:

  • α i j \alpha_{ij} αij 表示 patch i i i j j j 的注意力权重。
  • q i q_i qi k j k_j kj 分别是 query 和 key 向量。
2. 输出特征的加权求和

Attention ( Q , K , V ) = ∑ j = 1 N α i j v j \text{Attention}(Q, K, V) = \sum_{j=1}^N \alpha_{ij} v_j Attention(Q,K,V)=j=1Nαijvj


五、示例计算流程

假设输入图像大小为 224 × 224 × 3 224 \times 224 \times 3 224×224×3,每个 patch 大小为 16 × 16 16 \times 16 16×16,则:

  • 总 patch 数 N = 224 16 × 224 16 = 196 N = \frac{224}{16} \times \frac{224}{16} = 196 N=16224×16224=196
  • 每个 patch 被展平为向量,形状为 16 × 16 × 3 = 768 16 \times 16 \times 3 = 768 16×16×3=768
  • 经过线性投影后,转换为 D D D-维(如 D = 768 D=768 D=768)。

经过多层 Transformer 编码器后,分类标记 z cls z_{\text{cls}} zcls 被用于分类。

ViT 是计算机视觉领域的重要进步,通过将 Transformer 应用于图像任务,突破了传统 CNN 的局限,在大规模数据集上表现出色。

Vision Transformer (ViT) 整体流程

图像处理流程:
图像切分 Patch → Patch 序列化 → Patch + Position Embedding → Transformer Encoder → MLP Head → 分类结果

  1. 输入预处理:将图像划分为固定大小的 Patch,转化为 Token 序列,加入 [CLS] Token 和位置编码,形成符合 Transformer 的输入格式。
  2. Transformer Encoder:通过自注意力机制捕获全局特征,堆叠多个编码器以提取深层次特征。
  3. 分类头:利用 [CLS] Token 表示全局特征,通过 MLP 进行分类。

详细些就是
在这里插入图片描述

首先,要对图像进行切分操作,将完整的图像切割成一个个的小部分,也就是所谓的 “Patch”。
接着,把这些切分好的 “Patch” 按照一定顺序进行序列化处理,使其能够以有序的形式来参与后续的流程。
然后,为序列化后的 “Patch” 添加上位置嵌入信息(Position Embedding),通过这样的方式让模型能够知晓每个 “Patch” 在图像中的位置情况,便于后续准确地处理。
之后,把带有位置嵌入的 “Patch” 送入到 Transformer Encoder 当中,Transformer Encoder 会运用其自身的机制对输入的内容进行特征提取等相关处理,进一步挖掘和分析其中蕴含的特征信息。
再之后,经过 Transformer Encoder 处理后的结果会传递到多层感知机头部(MLP Head),由多层感知机对这些特征做进一步的整合、变换等操作。
最终,经过前面一系列的处理后,多层感知机头部输出相应的分类结果,以此来判断图像属于哪一类别的内容。


1. 图像切分 Patch

给定一张 RGB 图像,尺寸为 ( 224 , 224 , 3 ) (224, 224, 3) (224,224,3),假设 Patch 的大小为 16 × 16 16 \times 16 16×16

  • Patch 数量
    N = 22 4 2 1 6 2 = 196 N = \frac{224^2}{16^2} = 196 N=1622242=196

  • 每个 Patch 的尺寸
    每个 Patch 为 ( 16 , 16 , 3 ) (16, 16, 3) (16,16,3),展开后为一维向量,维度大小为:
    16 × 16 × 3 = 768 16 \times 16 \times 3 = 768 16×16×3=768

  • 线性嵌入矩阵
    使用一个线性投影矩阵将每个 Patch 映射到 D D D-维嵌入空间,这里 D = 768 D = 768 D=768。嵌入后的 Patch 表示为 ( 196 , 768 ) (196, 768) (196,768)

处理到这里,已经将原始的视觉问题转化为 NLP 问题:输入序列由一系列一维的 Token 表示。


2. [CLS] Token

ViT 借鉴了 BERT 中的 [CLS] 特殊 Token,用于表示整个图像的全局信息:

  • [CLS] Token 的位置信息始终固定为 0,但通过注意力机制,它能够与所有的 Patch Token 交互,从而获取全局特征。
  • 在 [CLS] Token 的基础上,仅需根据其输出的嵌入来进行最终分类任务。

加入 [CLS] Token 后,输入序列长度由 196 196 196 变为 197 197 197,输入维度变为:
( 197 , 768 ) (197, 768) (197,768)

实验对比:ViT 对比了两种获取全局特征的方式:[CLS] Token 和 Global Average Pooling (GAP)。实验表明二者效果相当,但为了与原始 Transformer 保持一致,ViT 选择了 [CLS] Token。


3. 位置编码 (Positional Embedding)

Transformer 缺乏对序列位置信息的直接感知,因此需要加入位置编码:

  • 对每个 Patch 编号(1~196),通过映射表生成一个 768 维的位置向量作为位置编码。
  • 位置编码与 Patch 的嵌入直接相加,得到最终送入 Transformer 的向量,维度仍为 ( 197 , 768 ) (197, 768) (197,768)

实验对比
ViT 探索了多种位置编码方式,包括:

  1. 一维位置编码 (1D Positional Embedding)
  2. 二维位置编码 (2D Positional Embedding)
  3. 相对位置编码 (Relative Positional Embedding)

结果表明三者均能很好地学习位置信息,这可能是因为 ViT 处理的是 Patch-Level 而非 Pixel-Level 的特征,Patch 数量较少(196 个),学习位置关系较为简单。


4. Transformer 编码器 (Transformer Encoder)

经过上述数据预处理,得到输入矩阵:
Input:  ( 197 , 768 ) \text{Input: } (197, 768) Input: (197,768)

Transformer Encoder 的组成:

  • 多头自注意力机制 (Multi-Head Self-Attention, MHSA)

    1. 输入序列 ( 197 , 768 ) (197, 768) (197,768) 通过线性变换生成 Q , K , V Q, K, V Q,K,V 三组矩阵,维度分别为 ( 197 , 768 ) (197, 768) (197,768)
    2. 假设有 H = 12 H=12 H=12 个注意力头,每个头的维度为 768 12 = 64 \frac{768}{12} = 64 12768=64。通过线性投影,降维得到 ( 197 , 64 ) (197, 64) (197,64) 的表示。
    3. 每个头独立计算注意力,输出 H H H ( 197 , 64 ) (197, 64) (197,64) 矩阵,最后拼接成 ( 197 , 768 ) (197, 768) (197,768)
  • 前向全连接网络 (MLP)
    通常包含两层全连接层,伴随激活函数:
    ( 197 , 768 ) → 升维 ( 197 , 3072 ) → 降维 ( 197 , 768 ) (197, 768) \xrightarrow{\text{升维}} (197, 3072) \xrightarrow{\text{降维}} (197, 768) (197,768)升维 (197,3072)降维 (197,768)

  • 残差连接和归一化
    每个模块后通过残差连接和 LayerNorm,使输入输出维度一致,便于堆叠多个 Transformer Block。

经过一个 Transformer Block 后,输入维度从 ( 197 , 768 ) (197, 768) (197,768) 转换为相同的 ( 197 , 768 ) (197, 768) (197,768),可以堆叠多个 Block(例如 ViT-B 中堆叠 12 层)。


5. 分类头 (MLP Head)

最终通过 Transformer 编码器得到的 [CLS] Token 表示 z c l s z_{cls} zcls,维度为 ( 1 , 768 ) (1, 768) (1,768)
通过全连接层生成最终的分类结果:
Output: softmax ( W ⋅ z c l s + b ) \text{Output: } \text{softmax}(W \cdot z_{cls} + b) Output: softmax(Wzcls+b)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/485712.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

工业异常检测-CVPR2024-新的3D异常数据合成办法和自监督网络IMRNet

论文:https://arxiv.org/pdf/2311.14897v3.pdf 项目:https://github.com/chopper-233/anomaly-shapenet 这篇论文主要关注的是3D异常检测和定位,这是一个在工业质量检查中至关重要的任务。作者们提出了一种新的方法来合成3D异常数据&#x…

三款电容麦的对比

纸面参数 第一款麦克风 灵敏度: -36 dB 2 dB(0 dB1V/Pa at 1 kHz) 灵敏度较低,需要更高的增益来拾取同样的音量。频率响应: 40 Hz - 18 kHz 响应范围较窄,尤其在高频区域。等效噪音级: ≤18 dB(A计权) 噪…

easyexcel 导出日期格式化

1.旧版本 在新的版本中formate已经被打上废弃标记。那么不推荐使用这种方式。 2.推荐方式 推荐使用另外一种方式【 Converter 】代码如下,例如需要格式化到毫秒【yyyy-MM-dd HH:mm:ss SSS】级别 创建一个公共Converter import com.alibaba.excel.converters.Conv…

PPT怎样做的更加精美

目录 PPT怎样做的更加精美 3D的GIF图片 3维空间图​编辑 结果有明显的对比 阅读高质量文献,采用他们的图 PPT怎样做的更加精美 3D的GIF图片 3维空间图 结果有明显的对比

插入排序⁻⁻⁻⁻直接插入排序希尔排序

引言 所谓的排序,就是使一串记录按照其中的某个或某些关键字的大小,递增或递减的排列起来的操作。 常见的排序算法有: 今天我们主要学习插入排序的直接插入排序和希尔排序。 直接插入排序 什么是直接插入排序? 直接插入排序其…

鸿蒙UI开发——亮/暗色模式适配

1、概 述 系统存在深浅色两种显示模式,为了给用户更好的使用体验,应用最好适配暗色和亮色两种模式。从应用与系统配置关联的角度来看,适配暗色和亮色模式可以分为下面两种情况: 应用跟随系统的深浅色模式; 应用主动设…

推荐在线Sql运行

SQL Fiddle 1、网址:SQL Fiddle - Online SQL Compiler for learning & practiceDiscover our free online SQL editor enhanced with AI to chat, explain, and generate code. Support SQL Server, MySQL, MariaDB, PostgreSQL, and SQLite.http://www.sqlfi…

在Ubuntu-22.04 [WSL2]中配置Docker

文章目录 0. 进入Ubuntu-22.041. 更新系统软件包2. 安装Docker相关依赖包3. 添加Docker官方GPG密钥4. 添加Docker软件源5. 安装Docker Engine5.1 更新软件包列表5.2 安装Docker相关软件包 6. 验证Docker安装是否成功6.1 查看Docker版本信息6.2 启动Docker6.3 配置镜像加速器6.4…

AI大模型ollama结合Open-webui

AI大模型Ollama结合Open-webui 作者:行癫(盗版必究) 一:认识 Ollama 1.什么是Ollama ​ Ollama是一个开源的 LLM(大型语言模型)服务工具,用于简化在本地运行大语言模型,降低使用大语言模型的门槛,使得大模型的开发者、研究人员和爱好者能够在本地环境快速实验、管理和…

使用ensp搭建内外互通,使用路由跨不同vlan通信。

1.网络拓扑图 2.规则 (1)允许 (自己)ping通内外网,内外网随便一个pc就可以. (2) 允许(电信)ping通内外网,内外网随便一个pc就可以 (时间问题不做…

gRPC 快速入门 — SpringBoot 实现(1)

目录 一、什么是 RPC 框架 ? 二、什么是 gRPC 框架 ? 三、传统 RPC 与 gRPC 对比 四、gRPC 的优势和适用场景 五、gRPC 在分布式系统中应用场景 六、什么是 Protocol Buffers(ProtoBuf)? 特点 使用场景 简单的…

Python实现BBS论坛自动签到【steamtools论坛】

一、知识点分析 1.requests模块介绍 ‌requests模块是Python中用于发送HTTP请求的一个库,它封装了urllib3库,提供了更加便捷的API接口。‌ 通过使用requests模块,用户可以模拟浏览器的请求,发送HTTP请求到指定的URL,并获取响应内容。与urllib相比,requests模块的API更加…

Probabilistic Face Embeddings 论文阅读

Probabilistic Face Embeddings 论文阅读 Abstract1. Introduction2. Related Work3. Limitations of Deterministic Embeddings4. Probabilistic Face Embeddings4.1. Matching with PFEs4.2. Fusion with PFEs4.3. Learning 5. Experiments5.1. Experiments on Different Bas…

重磅升级:OpenAI o1模型上手实测,从芯片架构分析到象棋残局判断的全能表现

引言 昨日,在圣诞节系列发布会的第一天,OpenAI终于给我们带来了令人振奋的更新,这些更新有望塑造AI互动的未来。备受期待的OpenAI o1正式版的推出,标志着ChatGPT体验的重大进化,宣告了AI驱动应用新时代的开始。o1现已可…

1.使用docker 部署redis Cluster模式 集群3主3从

1.使用docker 部署redis Cluster模式 集群3主3从 1.1 先安装docker 启动docker服务,拉取redis镜像 3主3从我们要在docker启动6个容器docker run --name redis-node-1 --net host --privilegedtrue -v /data/redis/share/redis-node-1:/data redis:6.0.8 --cluster-…

如何通过 Windows 自带的启动管理功能优化电脑启动程序

在日常使用电脑的过程中,您可能注意到开机后某些程序会自动运行。这些程序被称为“自启动”或“启动项”,它们可以在系统启动时自动加载并开始运行,有时甚至在后台默默工作。虽然一些启动项可能是必要的(如杀毒软件)&a…

记一次跑前端老项目的问题

记一次跑前端老项目的问题 一、前言二、过程1、下载依赖2、启动项目3、打包 一、前言 在一次跑前端老项目的时候,遇到了一些坑,这里记录一下。 二、过程 1、下载依赖 使用 npm install下载很久,然后给我报了个错 core-js2.6.12: core-js…

在米尔FPGA开发板上实现Tiny YOLO V4,助力AIoT应用

学习如何在 MYIR 的 ZU3EG FPGA 开发板上部署 Tiny YOLO v4,对比 FPGA、GPU、CPU 的性能,助力 AIoT 边缘计算应用。 一、 为什么选择 FPGA:应对 7nm 制程与 AI 限制 在全球半导体制程限制和高端 GPU 受限的大环境下,FPGA 成为了中…

Python爬虫之selenium库驱动浏览器

目录 一、简介 二、使用selenium库前的准备 1、了解selenium库驱动浏览器的原理 (1)、WebDriver 协议 (2)、 浏览器驱动(Browser Driver) (3)、 Selenium 客户端库 &#xff0…

从零开始学TiDB(2)深入了解TiDB Server模块

TiDB Server 架构 TiDB Server 的主要功能: 一条SQL的执行流程: 1.将整个SQL语句解析成一个个的token,生成一个树形结构。 2.编译模块 1.首先需要做一个合法性验证,比如表存不存在等。 2.做逻辑优化:依据关系型代数等…