DeiT:数据高效的图像Transformer及其工作原理详解
随着Transformer架构在自然语言处理(NLP)领域的巨大成功,研究者们开始探索其在计算机视觉领域的应用。Vision Transformer(ViT)是最早将Transformer直接应用于图像分类的模型之一,但其训练需要依赖大规模数据集(如JFT-300M)和强大的计算资源,这限制了其广泛应用。针对这一问题,Facebook AI和Sorbonne University的研究团队提出了DeiT(Data-efficient image Transformers),一种仅使用ImageNet数据集(约130万张图像)即可高效训练的图像Transformer模型。本文将详细介绍DeiT的原理,特别针对熟悉Transformer结构的深度学习研究者,深入探讨其架构设计、训练策略以及创新的蒸馏方法。
下文中图片来自于原论文:https://arxiv.org/pdf/2012.12877
一、DeiT的核心思想与背景
DeiT的目标是解决ViT的一个关键问题:Transformer在视觉任务中对数据量的依赖性。ViT的研究表明,如果仅使用ImageNet这样的中小规模数据集,Transformer模型的性能会显著低于卷积神经网络(CNN)。DeiT通过优化训练策略和引入特定的知识蒸馏方法,成功地在单台8-GPU机器上(训练时间2-3天)实现了与CNN竞争的性能,其参考模型(DeiT-B,86M参数)在ImageNet上达到了83.1%的top-1准确率(单裁剪),甚至在蒸馏后最高可达85.2%。
DeiT的核心贡献包括:
- 数据高效训练:通过强数据增强和正则化策略,使Transformer在ImageNet-only的场景下也能表现出色。
- 新型蒸馏方法:提出了一种专为Transformer设计的“蒸馏token”策略,利用教师模型(可以是CNN或Transformer)的知识进一步提升学生模型性能。
- 迁移能力:在下游任务(如CIFAR、iNaturalist等)上表现出与CNN相当的泛化能力。
接下来,我们将从架构、训练和蒸馏三个方面详细剖析DeiT的原理。
二、DeiT的架构设计
DeiT的架构直接继承自ViT,因此熟悉Transformer的研究者可以快速理解其结构。以下是其关键组件的详细说明:
1. 输入处理:图像分块与嵌入
与ViT相同,DeiT将输入图像(固定分辨率,如224×224)分割成固定大小的patch(通常为16×16像素)。对于224×224的图像,这会生成 ( 14 × 14 = 196 14 \times 14 = 196 14×14=196 ) 个patch。每个patch被展平并通过线性层投影到一个固定维度(例如DeiT-B中为768维),从而形成patch嵌入序列:
- 输入图像 ( X ∈ R H × W × 3 X \in \mathbb{R}^{H \times W \times 3} X∈RH×W×3 ) → ( N N N ) 个patch(( N = H P × W P N = \frac{H}{P} \times \frac{W}{P} N=PH×PW )),其中 ( P = 16 P=16 P=16 )。
- 线性投影:( X p ∈ R N × ( P 2 ⋅ 3 ) → Z 0 ∈ R N × D X_p \in \mathbb{R}^{N \times (P^2 \cdot 3)} \rightarrow Z_0 \in \mathbb{R}^{N \times D} Xp∈RN×(P2⋅3)→Z0∈RN×D )。
2. 位置编码(Positional Encoding)
由于Transformer缺乏CNN的空间归纳偏置(inductive bias),需要显式加入位置信息。DeiT沿用ViT的做法,在patch嵌入后添加位置编码(可以是固定的正弦编码或可学习的参数),以保留patch的空间关系:
- ( Z 0 = [ z p a t c h 1 , z p a t c h 2 , . . . , z p a t c h N ] + E p o s Z_0 = [z_{patch_1}, z_{patch_2}, ..., z_{patch_N}] + E_{pos} Z0=[zpatch1,zpatch2,...,zpatchN]+Epos ),其中 ( E p o s ∈ R N × D E_{pos} \in \mathbb{R}^{N \times D} Epos∈RN×D )。
一个改进点是,DeiT支持在不同分辨率下微调模型(例如从224×224训练到384×384微调)。此时,patch数量 ( N N N ) 会变化,DeiT通过插值(通常为双三次插值)调整位置编码的大小,确保模型适配性。
3. 类token(Class Token)
为了进行分类,DeiT在patch序列前添加一个可学习的类token(class token),其作用类似于NLP中BERT的[CLS] token。类token与patch token一起通过Transformer层处理,最终在最后一层通过线性分类器预测类别:
- 输入序列:( Z 0 = [ z c l a s s , z p a t c h 1 , . . . , z p a t c h N ] Z_0 = [z_{class}, z_{patch_1}, ..., z_{patch_N}] Z0=[zclass,zpatch1,...,zpatchN] )。
4. Transformer块
DeiT的Transformer块与标准结构一致,每个块包括:
- 多头自注意力(Multi-head Self-Attention, MSA):
- ( Attention ( Q , K , V ) = Softmax ( Q K T d ) V \text{Attention}(Q, K, V) = \text{Softmax}(\frac{QK^T}{\sqrt{d}})V Attention(Q,K,V)=Softmax(dQKT)V ),其中 ( Q , K , V Q, K, V Q,K,V ) 由输入序列线性变换生成。
- 多头机制通过 ( h h h ) 个并行注意力头增强表达能力(DeiT-B中 ( h = 12 h=12 h=12 ))。
- 前馈网络(Feed-Forward Network, FFN):
- 两层MLP,中间使用GeLU激活,第一层将维度扩展到 ( 4D ),第二层还原到 ( D D D )。
- 残差连接与层归一化(LayerNorm):
- ( Z ′ = MSA ( LayerNorm ( Z ) ) + Z Z' = \text{MSA}(\text{LayerNorm}(Z)) + Z Z′=MSA(LayerNorm(Z))+Z )。
- ( Z = FFN ( LayerNorm ( Z ′ ) ) + Z ′ Z = \text{FFN}(\text{LayerNorm}(Z')) + Z' Z=FFN(LayerNorm(Z′))+Z′ )。
DeiT-B由12个Transformer块组成,嵌入维度 ( D = 768 D=768 D=768 ),每头维度 ( d = D / h = 64 d = D/h = 64 d=D/h=64 )。
5. 输出层
最后一层的类token经过线性层投影到类别数(如ImageNet的1000类),输出logits用于分类。
三、数据高效训练策略
由于Transformer缺乏CNN的局部性偏置,其训练需要更多数据或更强的正则化。DeiT通过以下策略实现了数据高效性:
1. 强数据增强
DeiT大量借鉴CNN的增强技术,包括:
- Rand-Augment:随机选择增强操作(如旋转、剪切等),参数为9/0.5。
- Mixup(概率0.8):混合两张图像及其标签。
- CutMix(概率1.0):将一张图像的部分替换为另一张图像。
- 随机擦除(Random Erasing)(概率0.25):随机遮挡图像区域。
这些增强显著增加了数据的多样性,帮助Transformer在有限数据下学习鲁棒特征。
2. 正则化与优化
- Stochastic Depth(概率0.1):随机丢弃Transformer块,增强深层网络的训练稳定性。
- Label Smoothing(( ε = 0.1 \varepsilon=0.1 ε=0.1)):平滑标签分布,减少过拟合。
- 优化器:使用AdamW(学习率 ( 5 × 1 0 − 4 × batchsize 512 5 \times 10^{-4} \times \frac{\text{batchsize}}{512} 5×10−4×512batchsize )),权重衰减0.05,配合余弦学习率衰减和5个epoch的warmup。
3. 重复增强(Repeated Augmentation)
DeiT采用重复增强策略,即对同一张图像多次应用不同的增强变换(通常3次),增加训练时的样本多样性。这一策略显著提升了性能,尤其在300 epoch的训练中效果明显。
4. 分辨率调整
DeiT首先在224×224分辨率下预训练(约53小时,8-GPU),然后在更高分辨率(如384×384)微调(约20小时)。微调时通过插值调整位置编码,保持模型一致性。
四、创新的蒸馏方法:Distillation Token
DeiT的一个亮点是提出了专为Transformer设计的蒸馏策略,通过引入蒸馏token(distillation token)增强学生模型的学习。以下是其原理和实现细节:
1. 传统知识蒸馏回顾
传统知识蒸馏(Knowledge Distillation, KD, 具体可以参考笔者的另一篇博客:Hinton提出的知识蒸馏(Knowledge Distillation,简称KD):原理解释和代码实现)通过教师模型的软标签(soft labels)指导学生模型:
- 软蒸馏:学生模型最小化其softmax输出与教师softmax输出之间的KL散度:
- ( L global = ( 1 − λ ) L CE ( ψ ( Z s ) , y ) + λ τ 2 KL ( ψ ( Z s / τ ) , ψ ( Z t / τ ) ) \mathcal{L}_{\text{global}} = (1-\lambda) \mathcal{L}_{\text{CE}}(\psi(Z_s), y) + \lambda \tau^2 \text{KL}(\psi(Z_s/\tau), \psi(Z_t/\tau)) Lglobal=(1−λ)LCE(ψ(Zs),y)+λτ2KL(ψ(Zs/τ),ψ(Zt/τ)) ),
- 其中 ( Z s , Z t Z_s, Z_t Zs,Zt ) 为学生和教师的logits,( τ \tau τ) 为温度,( λ \lambda λ) 为平衡因子。
然而,DeiT发现软蒸馏对Transformer的效果不如预期,因此提出了两种改进:
- 硬蒸馏:直接使用教师的硬标签(( y t = argmax ( Z t ) y_t = \text{argmax}(Z_t) yt=argmax(Zt) ))作为监督信号:
- ( L global hardDistill = 1 2 L CE ( ψ ( Z s ) , y ) + 1 2 L CE ( ψ ( Z s ) , y t ) \mathcal{L}_{\text{global}}^{\text{hardDistill}} = \frac{1}{2} \mathcal{L}_{\text{CE}}(\psi(Z_s), y) + \frac{1}{2} \mathcal{L}_{\text{CE}}(\psi(Z_s), y_t) LglobalhardDistill=21LCE(ψ(Zs),y)+21LCE(ψ(Zs),yt) )。
硬蒸馏在Transformer上表现更好,因为它简单且无需调参,同时能适应数据增强带来的标签变化。
2. 蒸馏Token的设计
DeiT进一步提出了一种Transformer特有的蒸馏方法:
- 在输入序列中添加一个额外的蒸馏token,与类token并存:
- 输入序列变为:( Z 0 = [ z c l a s s , z d i s t i l l , z p a t c h 1 , . . . , z p a t c h N ] Z_0 = [z_{class}, z_{distill}, z_{patch_1}, ..., z_{patch_N}] Z0=[zclass,zdistill,zpatch1,...,zpatchN] )。
- 蒸馏token通过自注意力机制与patch token和类token交互,目标是重现教师模型的预测(硬标签 ( y t y_t yt ))。
- 在最后一层,蒸馏token通过独立的线性分类器输出预测,与类token的输出互补。
3. 训练与推理
- 训练时:损失函数结合真标签(作用于类token)和教师标签(作用于蒸馏token),两者的权重相等。
- 推理时:可以单独使用类token或蒸馏token的分类器,也可以融合两者(late fusion,softmax输出相加),融合方式通常效果最佳。
4. 为什么有效?
- 互补性:实验表明,类token和蒸馏token在训练后收敛到不同的向量(初始余弦相似度0.06,最后层0.93),表明它们捕获了不同信息。
- 教师偏置:当使用CNN(如RegNetY-16GF)作为教师时,蒸馏token能引入卷积的局部性偏置,使Transformer受益于CNN的归纳能力。
- 性能提升:相比传统硬蒸馏,DeiT的蒸馏token方法将准确率从83.0%提升至84.5%(DeiT-B,224分辨率)。
五、实验结果与分析
1. ImageNet性能
- DeiT-B(无蒸馏):83.1% top-1(384分辨率)。
- DeiT-B蒸馏(DeiT-B(\pi)):85.2% top-1,超越EfficientNet和JFT-300M预训练的ViT-B。
- 小模型:DeiT-S(22M参数)和DeiT-Ti(5M参数)分别达到81.2%和74.5%,适合资源受限场景。
2. 与CNN的对比
DeiT在吞吐量(images/sec)与准确率的权衡上接近EfficientNet,尤其在蒸馏后甚至超越,显示出Transformer的潜力。
3. 迁移学习
在CIFAR-10、Flowers-102等任务上,DeiT的top-1准确率(如99.1%、98.9%)与CNN相当,证明其泛化能力。
4. 教师选择的影响
使用CNN(如RegNetY-16GF,82.9%准确率)作为教师比Transformer更有效,可能是因为CNN的偏置对Transformer的训练更有指导意义。
六、总结与展望
DeiT通过优化训练策略和创新的蒸馏token方法,成功地将Transformer引入数据受限的视觉任务中,其性能已接近甚至超越经过多年优化的CNN。未来研究可以探索:
- 针对Transformer的专用数据增强方法。
- 更高效的架构设计,进一步降低计算复杂度。
- 在更大规模任务(如检测、分割)中的应用。
对于熟悉Transformer的研究者来说,DeiT提供了一个高效的起点,其开源代码(https://github.com/facebookresearch/deit)也便于复现和扩展实验。DeiT不仅是视觉Transformer的一个里程碑,也预示着Transformer可能成为计算机视觉的主流范式之一。
DeiT(Data-efficient image Transformers)的示例代码
以下是基于PyTorch实现的DeiT(Data-efficient image Transformers)的示例代码,包括训练代码和推理代码。由于DeiT的完整实现涉及较多细节(例如数据增强、蒸馏策略等),将提供一个简化的版本,重点展示其核心结构和逻辑。完整的实现可以参考官方代码库(https://github.com/facebookresearch/deit)。
前提条件
- PyTorch 1.7+
- torchvision
- timm(可选,用于预训练模型和增强)
- 数据集:这里以ImageNet为例,假设你已准备好数据加载器。
1. DeiT模型定义
首先定义DeiT的核心模型结构,基于ViT并添加蒸馏token。
import torch
import torch.nn as nn
import torch.nn.functional as Fclass DeiT(nn.Module):def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., drop_rate=0.1):super().__init__()self.num_patches = (img_size // patch_size) ** 2self.patch_size = patch_sizeself.embed_dim = embed_dim# Patch embeddingself.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)# Positional encodingself.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 2, embed_dim)) # +2 for cls and distill tokensself.pos_drop = nn.Dropout(p=drop_rate)# Class and distillation tokensself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))self.distill_token = nn.Parameter(torch.zeros(1, 1, embed_dim))# Transformer blocksself.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=int(embed_dim * mlp_ratio), dropout=drop_rate)for _ in range(depth)])# Layer normself.norm = nn.LayerNorm(embed_dim)# Classification headsself.head = nn.Linear(embed_dim, num_classes)self.head_distill = nn.Linear(embed_dim, num_classes)# Initialize weightsnn.init.trunc_normal_(self.pos_embed, std=0.02)nn.init.trunc_normal_(self.cls_token, std=0.02)nn.init.trunc_normal_(self.distill_token, std=0.02)def forward(self, x, return_both=False):B = x.shape[0]# Patch embeddingx = self.patch_embed(x).flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]# Add class and distillation tokenscls_tokens = self.cls_token.expand(B, -1, -1)distill_tokens = self.distill_token.expand(B, -1, -1)x = torch.cat((cls_tokens, distill_tokens, x), dim=1) # [B, num_patches + 2, embed_dim]# Add positional encodingx = x + self.pos_embedx = self.pos_drop(x)# Transformer blocksfor block in self.blocks:x = block(x)x = self.norm(x)# Extract class and distillation tokenscls_output = x[:, 0] # [B, embed_dim]distill_output = x[:, 1] # [B, embed_dim]# Classificationcls_logits = self.head(cls_output)distill_logits = self.head_distill(distill_output)if return_both:return cls_logits, distill_logitsreturn cls_logits # 默认返回cls token的输出# 示例模型实例化
model = DeiT(img_size=224, patch_size=16, num_classes=1000, embed_dim=768, depth=12, num_heads=12)
2. 训练代码
以下是训练DeiT的示例代码,包含蒸馏逻辑和常见数据增强。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm # 用于加载教师模型# 数据增强和加载器
train_transforms = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])val_transforms = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 假设ImageNet数据集路径为 './data/imagenet'
train_dataset = datasets.ImageFolder('./data/imagenet/train', transform=train_transforms)
val_dataset = datasets.ImageFolder('./data/imagenet/val', transform=val_transforms)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4)# 设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 学生模型
student = DeiT(img_size=224, patch_size=16, num_classes=1000, embed_dim=768, depth=12, num_heads=12)
student = student.to(device)# 教师模型(预训练的CNN,例如RegNetY-16GF)
teacher = timm.create_model('regnety_160', pretrained=True, num_classes=1000)
teacher = teacher.to(device)
teacher.eval() # 教师模型固定# 损失函数和优化器
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(student.parameters(), lr=5e-4, weight_decay=0.05)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300) # 300 epochs# 训练循环
def train_epoch(student, teacher, loader, optimizer, criterion, epoch):student.train()running_loss = 0.0for i, (images, labels) in enumerate(loader):images, labels = images.to(device), labels.to(device)optimizer.zero_grad()# 前向传播cls_logits, distill_logits = student(images, return_both=True)with torch.no_grad():teacher_logits = teacher(images)teacher_labels = teacher_logits.argmax(dim=1) # 硬标签# 损失:真标签(cls token)+ 教师硬标签(distill token)loss_cls = criterion(cls_logits, labels)loss_distill = criterion(distill_logits, teacher_labels)loss = 0.5 * loss_cls + 0.5 * loss_distill# 反向传播loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:print(f'[Epoch {epoch+1}, Batch {i+1}] Loss: {running_loss / 100:.3f}')running_loss = 0.0# 验证函数
def evaluate(model, loader):model.eval()correct, total = 0, 0with torch.no_grad():for images, labels in loader:images, labels = images.to(device), labels.to(device)outputs = model(images) # 使用cls token输出_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()return 100. * correct / total# 训练主循环
num_epochs = 300
for epoch in range(num_epochs):train_epoch(student, teacher, train_loader, optimizer, criterion, epoch)val_acc = evaluate(student, val_loader)print(f'Epoch {epoch+1}, Validation Accuracy: {val_acc:.2f}%')scheduler.step()# 保存模型
torch.save(student.state_dict(), 'deit_b.pth')
3. 推理代码
推理代码用于加载训练好的模型并对单张图像进行预测。
import torch
from PIL import Image
from torchvision import transforms# 加载模型
model = DeiT(img_size=224, patch_size=16, num_classes=1000, embed_dim=768, depth=12, num_heads=12)
model.load_state_dict(torch.load('deit_b.pth'))
model = model.to(device)
model.eval()# 图像预处理
def preprocess_image(image_path):image = Image.open(image_path).convert('RGB')transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])image = transform(image).unsqueeze(0) # [1, 3, 224, 224]return image.to(device)# 推理函数
def predict(image_path):image = preprocess_image(image_path)with torch.no_grad():cls_logits, distill_logits = model(image, return_both=True)# 融合预测(late fusion)probs = F.softmax(cls_logits, dim=1) + F.softmax(distill_logits, dim=1)_, predicted = probs.max(1)return predicted.item()# 示例推理
image_path = 'example.jpg'
pred_class = predict(image_path)
print(f'Predicted class: {pred_class}')
注意事项
- 完整实现:上述代码是简化版,未包含所有训练细节(如Rand-Augment、Repeated Augmentation等)。建议参考官方代码(
deit/main.py
)获取完整训练流程。 - 教师模型:这里使用预训练的RegNetY-16GF作为教师,你可以替换为其他模型(如EfficientNet)。
- 硬件需求:训练DeiT-B需要至少8GB显存的GPU,建议使用多GPU加速。
- 微调:若需在更高分辨率(如384×384)微调,需调整位置编码并重新定义数据加载器。
获取预训练模型
如果你不想从头训练,可以直接从timm
库加载预训练的DeiT模型:
import timm
model = timm.create_model('deit_base_patch16_224', pretrained=True)
希望这些代码能帮助你快速上手DeiT的实现!如需更深入的定制或优化,请参考官方文档和论文中的超参数设置。
后记
2025年3月22日16点15分于上海,在grok 3大模型辅助下完成。