在 Vision Transformer (ViT) 中,CLS token(分类 token)是一个特殊的标记,主要用于在模型的分类任务中提取全局图像的特征表示。以下是 CLS token 的作用及其工作原理的详细解释:
-
CLS Token 的引入
• CLS Token 是一种借鉴自 NLP 领域(Transformer 模型中)的设计。在自然语言处理中,CLS token 通常用于表示整个输入序列的全局语义信息。
• 在 ViT 中,CLS token 的目的是为图像输入提供一个全局的代表符号。 -
CLS Token 的工作原理
输入阶段:
1. 图像切分为 Patch:
• 输入图像被分割成固定大小的 patches(如 ),每个 patch 被展平成一个向量。
• 这些 patches 被嵌入为固定长度的特征向量,称为 patch embeddings。
2. 添加 CLS Token:
• 在 patch embeddings 序列的最前面添加一个特殊的 CLS token。
• CLS token 通常是一个可以学习的向量(随机初始化),其维度与 patch embedding 相同。
3. 位置编码:
• 对每个 patch 和 CLS token 加入位置编码,帮助模型感知输入的位置信息。
Transformer 阶段:
1.自注意力机制:
• CLS token 会参与所有层的自注意力计算,与每个 patch embedding 交互。
• 在每一层中,CLS token 会不断汇聚来自所有 patch 的全局信息。
2. 最终输出:
• 在最后一层 Transformer 中,CLS token 的特征向量被提取出来,作为整个输入图像的全局表征。
• 这个表征被传递到分类头(如全连接层)进行分类任务。
- CLS Token 的作用
(1) 全局特征聚合
• CLS token 通过自注意力机制,从所有图像 patches 中收集全局上下文信息。
• 它的最终表示可以被视为整个图像的全局特征。
(2) 任务输出
• 对于分类任务,CLS token 的特征被直接用作分类头的输入。
• 它替代了 CNN 中常用的全局池化操作(如 GAP,全局平均池化)。
(3) 灵活扩展
• 在非分类任务中,CLS token 的概念也可以扩展为其他全局任务,例如图像生成或回归任务中的全局控制标记。
-
CLS Token 与其他机制的对比
- 与 CNN 的全局池化 (GAP):
• CNN 通常通过 GAP 将特征图转为全局特征。
• ViT 使用 CLS token,则通过注意力机制直接学习到全局特征,而无需显式池化操作。 - 与无 CLS Token 的方法(Mean Pooling):
• ViT 的变种(如 DeiT)中,也有方法省略 CLS token,直接对所有 patch embeddings 进行平均池化(mean pooling)。
• CLS token 提供了一个更加灵活和可学习的全局表征。
- 与 CNN 的全局池化 (GAP):
-
为什么 CLS Token 有效?
• Transformer 的核心是自注意力机制,CLS token 在所有注意力头中与其他 patches 交互,能够有效地聚合整个输入的上下文信息。
• 由于 CLS token 是一个可学习的参数,它能够在训练过程中优化以适应具体任务。 -
示例:ViT 的架构中的 CLS Token
以下是一个包含 CLS token 的简化示例代码(使用 PyTorch):
import torch
import torch.nn as nnclass VisionTransformer(nn.Module):def __init__(self, num_patches, embed_dim, num_classes):super().__init__()self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # 可学习的 CLS tokenself.patch_embeddings = nn.Linear(16*16*3, embed_dim) # 假设 patch 大小为 16x16x3self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))self.transformer = nn.Transformer(embed_dim, nhead=8, num_encoder_layers=12)self.classifier = nn.Linear(embed_dim, num_classes)def forward(self, x):# 假设 x: [batch_size, num_patches, 16*16*3]batch_size = x.size(0)patch_embeddings = self.patch_embeddings(x) # [batch_size, num_patches, embed_dim]# 添加 CLS tokencls_tokens = self.cls_token.expand(batch_size, -1, -1) # [batch_size, 1, embed_dim]embeddings = torch.cat((cls_tokens, patch_embeddings), dim=1) # [batch_size, num_patches + 1, embed_dim]# 添加位置编码embeddings += self.position_embeddings# 输入 Transformertransformer_output = self.transformer(embeddings) # [batch_size, num_patches + 1, embed_dim]# 提取 CLS token 特征cls_output = transformer_output[:, 0, :] # [batch_size, embed_dim]# 分类logits = self.classifier(cls_output) # [batch_size, num_classes]return logits
总结
• CLS token 是 ViT 中的一个关键设计,用于作为图像的全局表征。
• 它通过自注意力机制从所有 patch 中提取全局上下文信息,并作为分类任务的输入。
• CLS token 的引入使 Transformer 模型能够在视觉任务中充分利用其灵活性和全局建模能力。