当CV遇上transformer(二)MAE模型及源码分析

当CV遇上transformer(二)MAE模型

  • 2020年10月,Dosovitskiy首次将纯Transformer的网络结构应用于图像分类任务中(ViT),并取得了当时最优的分类效果,其研究成果是Transformer完全替代标准卷积的首次尝试。
  • 大神何恺明在2021年11月基于(ViT)架构,提出了用于CV领域的自监督学习模型MAE(Masked Autoencoders)。
  • MAE想法很简单,以一定比例随机 mask 掉图片中的一些图像块(patch),然后重建这些部分的像素值。MAE基于两个核心设计(如下),使得能够高效(加快训练速度,作者在原文中表示至少快3倍)有效地(提高准确性)训练大型模型:
    • 非对称的(asymmetric)编码器-解码器架构,编码器仅仅对可见的patches进行编码,不对mask tokens进行任何处理,轻量级的解码器将编码器的输出(latent representation)和mask tokens作为输入,重构image;
    • 使用较高的mask比例(如75%)。
  • 基于ViT模型,作者在原文中也提出了三个模型(Base Large Huge)。在下游任务中,MAE展现了很强的迁移性能。其中MAE-ViT-Huge模型在ImageNet-1K数据集上达到了87.8%的top-1准确率,可扩展性极强(scalable)。
  • 今天我们来了解下MAE模型。
    • 论文链接:https://arxiv.org/pdf/2111.06377
    • 官方源码:https://github.com/facebookresearch/mae

1 MAE模型架构

1.1 研究背景

  • 在NLP领域,自监督预训练使用十分广泛。我们知道在BERT中,以一定比例 mask 掉输入文本中的一些部分,让模型去预测这批被 mask 掉的内容。这样,利用数据本身就可以作为监督(模型要预测的目标来源于数据本身,并非人工构造),无需复杂的人工标注。同时,使用大量的数据让拥有大规模参数量的模型能够学到通用的知识,从而拥有良好的泛化能力。
  • 但是在CV领域,大多数预训练还是采用监督方式。那么为什么自监督在CV领域的发展要滞后于NLP呢?
  • 作者解释如下:
    • 架构(architecture)差异
      • CV 和 NLP 的网络架构不一致,CV在过去一直被 CNN 统治,它基于方正的局部窗口来操作,不方便集成 mask token 以及 position embedding 这类带有指示性的可学习因子。
      • 不过,ViT(Vision Transformer) 已经在CV领域取得不错的效果,现在看来应该可以解决了。
    • 信息密度(information density)不同
      • NLP和CV的信息密度(information density)差异巨大;
      • NLP是强语义的,高信息密度的(highly semantic and information-dense),在NLP中即使只mask一个token,对模型来说可能都是很难的任务,因此模型可以通过学习获得复杂的语言理解能力(sophisticated language understanding);
      • 但是对CV来说,信息是高度冗余的,缺失一个patch,可能并不会让模型产生多少困惑,模型可以通过周围的像素信息进行推断。所以MAE做的一件事就是mask很高比例的patches,制造高难度的学习任务,方法简单但是极其有效。
    • 解码的目标不一致
      • NLP 解码输出的是对应被 mask 掉的词语,本身包含了丰富的语义信息。因此,NLP 的解码器可以很简单,比如 BERT,严格来说它并没有解码器,最后用 MLP 也可以搞定。因为来自编码器的特征也是高度语义的,与需要解码的目标之间的 gap 较小;
      • 而 CV 要重建的是被 mask 掉的图像块(像素值),是低语义的。因此CV 的解码器设计则需要“谨慎”考虑了,因为它要将来自编码器的高级语义特征解码至低级语义层级

1.2 MAE模型架构

在这里插入图片描述

  • MAE模型在预训练时需要encoder和decoder,预训练后抛弃decoder,只使用encoder做下游任务。

  • mask策略解析。

    • 与ViT一样,首先将图片切成一个个不重叠的patches
    • 采样策略很简单直接:在不替换的情况下,按照均匀分布**(uniform distribution)**对patches进行随机采样,采到的样本保留,剩下的全部mask掉。被 mask 掉的 patches 占所有 patches 的大部分(例如75%),优势如下:
      • patch 在图像中是服从均匀分布来采样的,这样能够避免 patch 的位置大多都分布在靠近图像中心的区域;
      • 采用高掩码比例(mask 掉图中大部分 patches)能够防止模型轻易地根据邻近的可见 patches 推断出这些掩码块;
      • 造就了稀疏的编码器输入,因为 Encoder 只处理可见的 patches,于是能够以更低的代价训练较大规模的 Encoder,因为计算量和内存占用都减少了。
      • mask 策略是至关重要的一个部分,因为其决定了预训练代理任务是否具有足够的挑战性,从而影响着 Encoder 学到的潜在特征表示 以及 Decoder 重建效果的质量。
  • MAE编码器解析。

    • Encoder 仅处理可见(un-masked)的 patches
    • 源码中Encoder 用的是 ViT模型,即对每一块做线性的投影,再加上位置信息。被盖住的patch就不会进去了。
    • 由于 un-masked patches 占所有 patches 的少数,计算消耗和空间需求都减少了,因此可以训练很大的 Encoder。
  • MAE解码器解析。

    • 解码器输入需要所有的patches,包括unmasked的patches以及masked的patches(没有进入编码器),
    • 对于masked的patches,解码器通过同一个向量来表示,这个向量通过学习得到
    • 解码器输入也需要加入位置信息,不然就无法区分它对应的到底是哪一个masked的patch。
    • 解码器主要只在预训练的时候使用,当将模型用于做一些别的任务的时候,只需要用编码器对一个图片进行编码就可以了。
    • 解码器的架构比较小,计算开销不到编码器的1/10。
  • 任务目标:重建像素值。MAE 预训练任务的目标是重建像素值,并且仅仅是 masked patches 的像素值,也就是仅对 masked 的部分计算 mse loss

2 MAE部分实验

2.1 Masking ratio

  • fine-tuning(微调)是在迁移学习中,将预训练模型的所有层都解冻,并使用新的数据集进行端到端的微调。通常,所有层的权重都被更新。
  • linear probing(线性探测)是在迁移学习中,只更新预训练模型的最后一层(通常是分类器层),而不更新其余层的权重。这意味着预训练模型的所有层在微调过程中都保持冻结状态。
  • 由下图实验结果,无论是在 fine-tune 还是 linear probe 的中,mask 比例逐渐升高(但不过分)时,模型性能都会更好,在源码中作者选择75%的masking比例。

在这里插入图片描述

2.2 消融实验

  • Decoder 的设计

    • 下图中(a)和(b)展示了不同的 Decoder 深度(Transformer 层数)和宽度(通道数)对于 fine-tune 和 linear probe 在 ImageNet-1K 下游任务中的表现。可以发现,Decoder 的深度和宽度对于 linear probe 有较为明显的影响,但对于 fine-tune 的影响却不那么突出。
    • 原因是**预训练任务(图像重建)与下游任务(图像识别)之间存在着 gap。**fine-tune 时由于能够调整 Encoder 去适配图像识别任务,因此预训练对其影响程度就相对没那么大了。
  • Mask token

    • 下图中©中,作者比较了Encoder 仅使用unmasked tokens以及全部的tokens效果,可以发现如果Encoder 仅使用unmasked tokens不仅效果好,训练速度也快3倍。
  • 重建目标的比较

    • MAE 的重建目标是 masked patches 的像素值。
    • 下图中(d)中发现,如果预测的是归一化的像素值,那么效果会更好。
  • 数据增强的影响

    • 数据增强能提升精度

    • 下图中(e)中,不做随机缩放(fixed size)和随机缩放(rand size)的效果其实差不多,而采用色彩扰动(color jit)却反而比简单的 crop, fixed size效果差

    • 原因可能是MAE 对图像进行 mask 的做法本身就已经是一种数据增强手段了,因此不需要过份的额外数据增强就能取得较好的效果

    • 值得注意的是,源码中作者在预训练时候做了弱数据增强,但在微调时做了强数据增强。

  • Mask取样策略的比较

    • 下图更加直观显示Mask几种取样策略效果:

    • 在这里插入图片描述

    • 在下图中(f)中,也能发现采用均匀分布的随机采样效果最好

在这里插入图片描述

3 Mae Model代码分析

这里,我们只分析下models_mae.py中模型部分的代码。

  • 官方源码:https://github.com/facebookresearch/mae

3.1 下载预训练模型

  • 我们先下载作者预训练好的模型,按照下面的代码(依据mae/demo/mae_visualize.ipynb改造),执行mae的前向推理过程,方便我们进行调试。

  • 预训练模型有base、large、huge三种模型,这里下载base模型。

    • 和Vit模型参数一致,主要是Layers、Hidden_size、Heads的不同。
    # models_mae.pydef mae_vit_base_patch16_dec512d8b(**kwargs):model = MaskedAutoencoderViT(patch_size=16, embed_dim=768, depth=12, num_heads=12,decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return modeldef mae_vit_large_patch16_dec512d8b(**kwargs):model = MaskedAutoencoderViT(patch_size=16, embed_dim=1024, depth=24, num_heads=16,decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return modeldef mae_vit_huge_patch14_dec512d8b(**kwargs):model = MaskedAutoencoderViT(patch_size=14, embed_dim=1280, depth=32, num_heads=16,decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return model# set recommended archs
    mae_vit_base_patch16  = mae_vit_base_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
    mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
    mae_vit_huge_patch14  = mae_vit_huge_patch14_dec512d8b  # decoder: 512 dim, 8 blocks
    
  • base模型下载连接:https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth

  • 需要注意的是mae源码中使用了timm库,当前版本不支持qk_scale参数,可以删掉此参数(如下),源码中其实也是设置为None,可以放心删除。

# models_mae.py    # 堆叠Transformer Block
self.blocks = nn.ModuleList([# 删除qk_scale参数# Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)for i in range(depth)])
......
# 堆叠Transformer Block
self.decoder_blocks = nn.ModuleList([# 删除qk_scale参数# Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)for i in range(decoder_depth)])
  • 我们下载作者在readme中的图像,然后运行下面代码,可以获取下面图像,后面我们可以运行此代码,就可以去models_mae.py中愉快的Debug了。

在这里插入图片描述

import sys
import osimport torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Imagemodule_path = r'/root/autodl-tmp/transformers-code/huggingface/AIGC/mae/models_mae.py'
# 将模块路径添加到系统路径
sys.path.append(os.path.dirname(module_path))
import models_maedef show_image(image, title=''):# image is [H, W, 3]assert image.shape[2] == 3plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())plt.title(title, fontsize=16)plt.axis('off')returndef prepare_model(chkpt_dir, arch='mae_vit_base_patch16'):# build modelmodel = getattr(models_mae, arch)()# load modelcheckpoint = torch.load(chkpt_dir, map_location='cpu')msg = model.load_state_dict(checkpoint['model'], strict=False)print(msg)return modeldef run_one_image(img, model):x = torch.tensor(img)# make it a batch-likex = x.unsqueeze(dim=0)x = torch.einsum('nhwc->nchw', x)# run MAEloss, y, mask = model(x.float(), mask_ratio=0.75)y = model.unpatchify(y)y = torch.einsum('nchw->nhwc', y).detach().cpu()# visualize the maskmask = mask.detach()mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0] ** 2 * 3)  # (N, H*W, p*p*3)mask = model.unpatchify(mask)  # 1 is removing, 0 is keepingmask = torch.einsum('nchw->nhwc', mask).detach().cpu()x = torch.einsum('nchw->nhwc', x)# masked imageim_masked = x * (1 - mask)# MAE reconstruction pasted with visible patchesim_paste = x * (1 - mask) + y * mask# make the plt figure largerplt.rcParams['figure.figsize'] = [24, 24]plt.subplot(1, 4, 1)show_image(x[0], "original")plt.subplot(1, 4, 2)show_image(im_masked[0], "masked")plt.subplot(1, 4, 3)show_image(y[0], "reconstruction")plt.subplot(1, 4, 4)show_image(im_paste[0], "reconstruction + visible")plt.show()# plt.savefig('fox_r.jpg')if __name__ == '__main__':imagenet_mean = np.array([0.485, 0.456, 0.406])imagenet_std = np.array([0.229, 0.224, 0.225])# 1、加载图像# 图像地址:https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpgimg = Image.open('fox.jpg')img = img.resize((224, 224))img = np.array(img) / 255.assert img.shape == (224, 224, 3)# 2、标准化img = img - imagenet_meanimg = img / imagenet_std# 3、加载作者在ImageNet数据集上训练好的模型(mae_vit_base_patch16)chkpt_dir = r'/root/autodl-fs/models/mae/mae_visualize_vit_base.pth'model_mae = prepare_model(chkpt_dir, 'mae_vit_base_patch16')print('Model loaded.')# 4、还原torch.manual_seed(2)print('MAE with pixel reconstruction:')run_one_image(img, model_mae)

3.2 MAE的预训练过程概述

  1. 将图像划分成 patches:(B,C,H,W)->(B,N,PxPxC);
  2. 对各个 patch 进行 embedding(实质是通过全连接层),生成 tokens,并加入位置信息(position embeddings):(B,N,PxPxC)->(B,N,dim);
  3. 根据预设的掩码比例(paper 中提倡的是 75%),使用服从均匀分布的随机采样策略采样一部分 tokens 送给 Encoder,另一部分扔掉(mask 掉)
  4. 将 Encoder 编码后的 tokens 与 加入位置信息后的 masked tokens 按照原先在 patch 形态时对应的次序拼在一起,然后喂给 Decoder 。Encoder 编码后的 token 的维度与 Decoder 要求的输入维度不一致,需要先经过 linear projection 将维度映射到符合 Decoder 的要求;
  5. Decoder 解码后取出 masked tokens 对应的部分送入到全连接层,对 masked patches 的像素值进行预测,最后将预测结果与 masked patches 进行比较,计算 MSE loss。
	# models_mae.pydef forward(self, imgs, mask_ratio=0.75):latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)pred = self.forward_decoder(latent, ids_restore)  # [N, L, p*p*3]loss = self.forward_loss(imgs, pred, mask)return loss, pred, mask
  • models_mae.py中前向传播forward函数如上所示。
  • 前向传播forward主要包括:forward_encoder、forward_decoder以及forward_loss。

3.3 初始化

class MaskedAutoencoderViT(nn.Module):""" Masked Autoencoder with VisionTransformer backboneLayers Hidden_size MLP_size HeadsViT-Base :     12       768      768*4    12ViT-Large:     24       1024     1024*4   16  (MAE默认)ViT-Huge :     32       1280     1280*4   16"""def __init__(self, img_size=224, patch_size=16, in_chans=3,embed_dim=1024, depth=24, num_heads=16,decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):super().__init__()# --------------------------------------------------------------------------# MAE encoder specificsself.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)# patch数量 = (img_size/patch_size)^2 = 14 * 14 = 196num_patches = self.patch_embed.num_patches# 参考在ViT中,在一系列输入序列中插入一个专门用于分类的标志位(Class Token)self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))# encoder中的位置编码,使用2d的sincos绝对位置编码。由于加了cls_token,因此num_patches需要加1self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embedding# 堆叠Transformer Blockself.blocks = nn.ModuleList([# Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)for i in range(depth)])# 层归一化self.norm = norm_layer(embed_dim)# --------------------------------------------------------------------------# --------------------------------------------------------------------------# MAE decoder specifics# 由于Encoder 编码后的 token 的维度与 Decoder 要求的输入维度不一致,先经过 linear projection 将维度映射到符合Decoder的要求# 构建线性映射层,将1024维的embed_dim 转换为 512维的decoder_embed_dimself.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)# 被mask住的块 用一个共享的、可训练的向量进行表示self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))# decoder中的位置编码,使用2d的sincos绝对位置编码self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)  # fixed sin-cos embedding# 堆叠Transformer Blockself.decoder_blocks = nn.ModuleList([# Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)for i in range(decoder_depth)])self.decoder_norm = norm_layer(decoder_embed_dim)# 解码后取出 masked tokens 对应的部分送入到全连接层self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch# --------------------------------------------------------------------------# 是否对每个patch中的数据进行标准化,默认Falseself.norm_pix_loss = norm_pix_loss# 权重初始化self.initialize_weights()

3.3.1 2d绝对位置编码

下面代码实现了常见的位置编码,包括MAE模型:

  • MAE中使用了基于正弦余弦的2d绝对位置编码,是在 x, y 方向上分别独立进行绝对位置编码
  • Transformer中绝对位置编码公式如下:
    在这里插入图片描述
import torch
import torch.nn as nn# 1、Transformer
def create_1d_absolute_sincos_embeddings(n_pos_vec, dim):# n_pos_vec: torch.arange(n_pos)# 初始化position_embeddingassert dim % 2 == 0, "wrong dimension"position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float)omega = torch.arange(dim // 2, dtype=torch.float)omega = 2. * omega / dimomega = 1.0 / (10000 ** omega)out = n_pos_vec[:, None] @ omega[None, :]  # shape = (n_pos, dim // 2)position_embedding_sin = torch.sin(out)position_embedding_cos = torch.cos(out)# 赋值position_embedding[:, 0::2] = position_embedding_sinposition_embedding[:, 1::2] = position_embedding_cosreturn position_embedding# 2、ViT
def create_1d_absolute_learnable_embeddings(n_pos_vec, dim):position_embedding = nn.Embedding(n_pos_vec.numel(), dim)# 初始化nn.init.constant_(position_embedding.weight, 0.)return position_embedding# 3、MAE
def create_2d_absolute_sincos_embeddings(height, width, dim):assert dim % 4 == 0, "wrong dimension"position_embedding = torch.zeros(height*width, dim, dtype=torch.float)coords = torch.stack(torch.meshgrid(torch.arange(height, dtype=torch.float),torch.arange(width, dtype=torch.float))) # [2, height, width]height_embedding = create_1d_absolute_sincos_embeddings(torch.flatten(coords[0]), dim// 2)width_embedding = create_1d_absolute_sincos_embeddings(torch.flatten(coords[1]), dim// 2)position_embedding[:, :dim // 2] = height_embeddingposition_embedding[:, dim // 2:] = width_embeddingreturn position_embeddingif __name__ == '__main__':n_pos_vec, dim = torch.arange(4, dtype=torch.float), 4create_1d_absolute_sincos_embeddings(n_pos_vec, dim)create_1d_absolute_learnable_embeddings(n_pos_vec, dim)create_2d_absolute_sincos_embeddings(height=2, width=2, dim=dim)

3.4 forward_encoder函数

3.4.1 Patch Embedding

  • Patch Embedding和ViT一样,可以参考:当CV遇上transformer(一)ViT模型
    def forward_encoder(self, x, mask_ratio):# embed patches# 1、先将图像从 (B,C,H,W) reshape 成 (B,N,PxPxC)# N为 patch 数量,N = (img_size/patch_size)^2 = (224 / 16)^2 = 14*14=196# PxPxC = in_chans * patch_size * patch_size = 3*16*16 = 768# 在PatchEmbed源码中,主要是利用卷积Conv2d(3, 768, kernel_size=16, stride=16)完成# 即:x(B, 3, 224, 224)# ->torch.Size([B, 768, 14, 14])【卷积】# ->torch.Size([B, 768, 196])   【宽高flatten】# ->torch.Size([B, 196, 768])   【转换维度】x = self.patch_embed(x)# add pos embed w/o cls token# 2、添加2d的sincos绝对位置编码# ->torch.Size([B, 196, 768])  【添加位置编码,不包含cls_token】x = x + self.pos_embed[:, 1:, :]......

3.4.2 核心代码random_masking

    def forward_encoder(self, x, mask_ratio):......# masking: length -> length * mask_ratio# 3、【核心代码random_masking】  x->torch.Size([B, 49, 768])x, mask, ids_restore = self.random_masking(x, mask_ratio)
  • 我们这里单独建一个py文件,将这段代码摘出来,传入模拟数据,了解这段核心代码。
  • 这里面很巧妙的利用了torch.argsort和torch.gather函数,对于torch.gather函数,可以参考:Pytorch常用的函数(九)torch.gather()用法
import torch
import torch.nn as nntorch.manual_seed(seed=42)def random_masking(x, mask_ratio=0.75):"""Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape  # batch, length, dimlen_keep = int(L * (1 - mask_ratio))  # 计算unmasked的片数# 利用0-1均匀分布进行采样,避免潜在的【中心归纳偏好】noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]# sort noise for each sample【核心代码】ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is removeids_restore = torch.argsort(ids_shuffle, dim=1)# keep the first subsetids_keep = ids_shuffle[:, :len_keep]x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))# generate the binary mask: 0 is keep, 1 is removemask = torch.ones([N, L], device=x.device)mask[:, :len_keep] = 0# unshuffle to get the binary maskmask = torch.gather(mask, dim=1, index=ids_restore)return x_masked, mask, ids_restoredef forward_decoder(x, ids_restore):mask_token = nn.Parameter(torch.ones(1, 1, 4))mask_tokens = mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)# 将unmasked tokens和masked tokens在dim=1维度concat起来x_ = torch.cat([x, mask_tokens], dim=1)  # no cls token# unshufflex_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))return x_if __name__ == '__main__':x = torch.arange(16).reshape(1, 4, 4)x_masked, mask, ids_restore = random_masking(x)forward_decoder(x_masked, ids_restore)
  • 核心代码的解释如下:
原始数据:
tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11],[12, 13, 14, 15]]])我们要从原始数据中采用随机采样25%的作为unmasked tokens(在这个示例中,就只采样一行数据)
具体做法如下:
1、利用0-1均匀分布进行采样,避免潜在的【中心归纳偏好】
noise = torch.rand(N, L, device=x.device)noise:
tensor([[0.8823, 0.9150, 0.3829, 0.9593]])2、对noise从小到大排列,并获取索引
ids_shuffle = torch.argsort(noise, dim=1)ids_shuffle:
tensor([[2, 0, 1, 3]])我们只需要获取前25%作为unmasked tokens
ids_keep = ids_shuffle[:, :len_keep]ids_keep:
tensor([[2]])因为是获取一行数据,因此需要对ids_keep进行复制
index=ids_keep.unsqueeze(-1).repeat(1, 1, D)index:
tensor([[[2, 2, 2, 2]]])3、我们有了index,就可以利用torch.gather函数获取unmasked tokens
x_masked[0, 0, :]在dim=1上,替换为[0, 2, :],即获取x上[0, 2, :]的数据([ 8,  9, 10, 11])
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))x_masked:
tensor([[[ 8,  9, 10, 11]]])4、在预训练时,只计算masked tokens的mse loss,因此需要记录原始图像块中哪一块masked 哪一块unmasked
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0mask before gather:
tensor([[0., 1., 1., 1.]])因为我们是随机采样的,实际上x_masked=tensor([[[ 8,  9, 10, 11]]]),在原始x中为[0, 2, :]
因此mask并不是tensor([[0., 1., 1., 1.]]),而是mask=tensor([[1., 1., 0., 1.]])
那么如何获取真实的mask呢?作者利用torch.gather函数很巧妙的实现了。具体做法如下:
我们对ids_shuffle再次排序,作为index,然后在dim=1上,继续利用torch.gather函数
ids_restore = torch.argsort(ids_shuffle, dim=1)ids_restore:
tensor([[1, 2, 0, 3]])mask = torch.gather(mask, dim=1, index=ids_restore)
new mask[0, 0]= old mask[0, 1]
new mask[0, 1]= old mask[0, 2]
new mask[0, 2]= old mask[0, 0]
new mask[0, 3]= old mask[0, 3]如此一来new mask = tensor([[1., 1., 0., 1.]]),获取了真实的mask在预训练时,只保留这些masked tokens的loss(即值为1的数,可以使用loss[N, L] * mask[N, L]实现)
mask:tensor([[1., 1., 0., 1.]])例如:loss =  torch.tensor([[0.5, 0.6, 0.7, 0.4]]) 
loss * mask = torch.tensor([[0.5, 0.6, 0, 0.4]]) 5、ids_restore也要用在图像的unshuffle中我们知道对于masked的patches,解码器通过同一个向量来表示,这个向量通过学习得到
mask_token = nn.Parameter(torch.ones(1, 1, 4))
因为masked token有多个,显然我们需要复制mask_token,这里我们复制3份
mask_tokens = mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
tensor([[[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.]]], grad_fn=<RepeatBackward0>)我们需要把unmasked tokens([1,1,4])和masked tokens[1,3,4]拼接起来送入到decoder中
x_masked(unmasked tokens)如下:
tensor([[[ 8,  9, 10, 11]]])>>> x_ = torch.cat([x, mask_tokens], dim=1)
tensor([[[ 8.,  9., 10., 11.],[ 1.,  1.,  1.,  1.],[ 1.,  1.,  1.,  1.],[ 1.,  1.,  1.,  1.]]], grad_fn=<CatBackward0>)tensor([[[ 8,  9, 10, 11]]])位置和原始的x不一致
我们继续利用torch.gather进行恢复
>>> index
tensor([[[1, 1, 1, 1],[2, 2, 2, 2],[0, 0, 0, 0],[3, 3, 3, 3]]])x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))	new x_[0, 0, :] = x_[0, 1, :]
new x_[0, 1, :] = x_[0, 2, :]
new x_[0, 2, :] = x_[0, 0, :]
new x_[0, 3, :] = x_[0, 3, :]>>> x_		 
tensor([[[ 1.,  1.,  1.,  1.],[ 1.,  1.,  1.,  1.],[ 8.,  9., 10., 11.],[ 1.,  1.,  1.,  1.]]], grad_fn=<GatherBackward0>)

3.4.3 剩余代码

  • 了解完核心代码后,下面代码就很容易理解了。
  • MAE为了和ViT保持一致,拼接了cls token,但实际上并未使用此信息。
 def forward_encoder(self, x, mask_ratio):......# 4、拼接cls tokencls_token = self.cls_token + self.pos_embed[:, :1, :]cls_tokens = cls_token.expand(x.shape[0], -1, -1)x = torch.cat((cls_tokens, x), dim=1) # x->torch.Size([B, 50, 768])# 5、apply Transformer blocks and normfor blk in self.blocks:x = blk(x)x = self.norm(x)return x, mask, ids_restore

3.5 forward_decoder函数

  • 将 Encoder 编码后的 tokens 与 加入位置信息后的 masked tokens 按照原先在 patch 形态时对应的次序拼在一起,然后喂给 Decoder 。
  • Encoder 编码后的 token 的维度与 Decoder 要求的输入维度不一致,需要先经过 linear projection 将维度映射到符合 Decoder 的要求;
  • Decoder 解码后取出 masked tokens 对应的部分送入到全连接层,对 masked patches 的像素值进行预测.
  • 了解完核心代码后,下面代码就很好理解了。
    def forward_decoder(self, x, ids_restore):# embed tokens# 1、x->torch.Size([B, 50, 768]) 线性映射层将768维的embed_dim 转换为 512维的decoder_embed_dimx = self.decoder_embed(x)# append mask tokens to sequence# 2、复制(masked token的所占的patch数 + 1【cls token】)份mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)# 3、将unmasked tokens和masked tokens在dim=1维度concat起来x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token# 4、unshufflex_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))# 5、append cls tokenx = torch.cat([x[:, :1, :], x_], dim=1)# add pos embedx = x + self.decoder_pos_embed# apply Transformer blocksfor blk in self.decoder_blocks:x = blk(x)x = self.decoder_norm(x)# predictor projectionx = self.decoder_pred(x)# remove cls tokenx = x[:, 1:, :]return x

3.6 forward_loss函数

  • 将预测结果与 masked patches 进行比较,计算 MSE loss

  • 需要注意的是,只计算masked tokens的loss

  • 我们后面再分析MAE的其他代码。

 def forward_loss(self, imgs, pred, mask):"""imgs: [N, 3, H, W]pred: [N, L, p*p*3]mask: [N, L], 0 is keep, 1 is remove, """# 1、将imgs的shape由[N, 3, H, W]变为和pred一致的(N, L, patch_size**2 *3)target = self.patchify(imgs)if self.norm_pix_loss:mean = target.mean(dim=-1, keepdim=True)var = target.var(dim=-1, keepdim=True)target = (target - mean) / (var + 1.e-6)**.5# 计算mse lossloss = (pred - target) ** 2# 2、计算每一个patch的mean lossloss = loss.mean(dim=-1)  # [N, L], mean loss per patch# 3、unmasked tokens的mask=0,masked tokens的mask=1# loss * mask后,只有masked tokens的loss保留下来,这里只计算masked tokens的lossloss = (loss * mask).sum() / mask.sum()  # mean loss on removed patchesreturn loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/323438.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

单单单单单の刁队列

在数据结构的学习中&#xff0c;队列是一种常用的线性数据结构&#xff0c;它遵循先进先出&#xff08;FIFO&#xff09;的原则。而单调队列是队列的一种变体&#xff0c;它在特定条件下保证了队列中的元素具有某种单调性质&#xff0c;例如单调递增或单调递减。单调队列在处理…

AWS Lambda 第一个例子Hello (JAVA)

什么是Serverless&#xff08;无服务器计算&#xff09; 行业通常所说的Serverless&#xff0c;主要是指“无服务器计算&#xff08;Serverless Computing&#xff09;”。无服务器计算&#xff0c;并不是真的不需要服务器&#xff0c;而是说&#xff0c;对于用户&#xff0c;…

C语言指针相关知识(第一篇章)(非常详细版)

文章目录 前言一、指针概念的引入与指针的基本介绍&#xff08;一&#xff09;、内存与地址&#xff08;二&#xff09;、指针变量和地址&#xff08;三&#xff09;、指针变量类型的意义&#xff08;四&#xff09;、const修饰指针 二、指针的运算&#xff08;一&#xff09;、…

锤子蜡烛如何交易?Anzo Capital这样交易10倍收益结束

很多投资者发现以下的情况&#xff0c;就认为反转到来了&#xff0c;颓势即将结束牛市即将来临。什么情况呢&#xff1f;就是在交易中发现这种情况&#xff1a;会在局部低点形成&#xff0c;上影线很小或几乎没有上阴影&#xff0c;收盘价高出 1/4 &#xff0c;烛台总有长长的下…

【数据结构(邓俊辉)学习笔记】栈与队列01——栈应用(栈混洗、前缀后缀表达式、括号匹配)

文章目录 0. 概述1. 操作与接口2. 操作实例3. 实现4. 栈与递归5. 应用5.1 逆序输出5.1.1 进制转换5.1.1.1 思路5.1.1.2 算法实现 5.2 递归嵌套5.2.1 栈混洗5.2.1.1 混洗5.2.1.2 计数5.2.1.3 甄别 5.2.2 括号匹配5.2.2.1 构思5.2.2.2 实现5.2.2.3 实例 5.3 延迟缓冲5.3.1 中缀表…

Gitee 码云与Git 交互

优质博文&#xff1a;IT-BLOG-CN 一、进入码云官方网站&#xff0c;注册用户 码云(Gitee.com)是一个类似于GitHub的在线代码托管平台。 码云提供了包括版本控制、代码托管、协作开发和代码分享等功能&#xff0c;基于Git开发&#xff0c;支持代码在线查看、历史版本查看、Fo…

基于vs和C#的WPF应用之动画3

注&#xff1a;1、在内部和外部使用缓动函数 <Grid.Resources> <PowerEase x:Key"powerease" Power"3" EasingMode"EaseInOut"/> </Grid.Resources> <DoubleAnimation EasingFunction"{StaticResource powerease}&quo…

linux开发笔记(buildroot 增加自己的开发板支持文件)

1、该笔记参考了mangopi r3的buildroot。某宝上卖的LC-PI-200S提供的buildroot就是这个。已经上传到我的资源中&#xff0c;可以下载看看。 2、首先在buildroot目录输入make menuconfig打开buildroot配置。 进入build options查看 可以看到第二行就是buildroot配置的保存位置…

KaiwuDB 解析器之语义解析

KaiwuDB 解析器介绍 解析器是数据库系统的重要组成部分之一&#xff0c;主要的功能是将客户端输入的 SQL 语句分解为语法单元&#xff0c;然后将这些语法单元转化成数据库内部可识别的数据结构&#xff0c;最终生成数据库可以执行的计划。 KaiwuDB 的一条 SQL 执行的整个生命…

达梦数据刷盘测试

达梦数据库为了保证数据故障恢复的一致性&#xff0c;REDO 日志的刷盘必须在数据页刷盘之前进行。 下面我们通过测试来验证是不是这样 执行我们事先准备的SHELL脚本 可以看到第一次strings文件没有输出&#xff0c;说明刚写的数据在数据库的BUFFER缓冲区内&#xff0c;还没有刷…

什么样的人能上百度词条

百度百科是一个向所有互联网用户开放的平台&#xff0c;任何人都可以创建或编辑词条。然而&#xff0c;并不是所有的人物或事物都能被收录到百度百科中&#xff0c;它有一定的收录标准和审结的关于哪些人或事物能上百度百科的条件和流程。 百度百科的收录标准 知名度和影响力&…

太牛了!360大佬编写的《应急响应指导手册》火了!(PDF限时3天领取)

免责声明&#xff1a; 请使用者遵守《中华人民共和国网络安全法》&#xff0c;由于传播、利用本账号所提供的信息而造成的任何直接或者间接的后果及损失&#xff0c;均由使用者本人负责&#xff0c;公众号及作者不为此承担任何责任。 简介 这份《应急响应指导手册》&#xf…

OpenNJet评测,探寻云原生之美

在信息时代的大海上&#xff0c;云原生应用引擎如一艘航行于波涛之间的帆船&#xff0c;承载着创新的梦想和数字化的未来。本文将带领您登上这艘船&#xff0c;聚焦其中之一的OpenNJet&#xff0c;一同探寻其中的奥秘和精妙&#xff0c;领略其独特之美。 OpenNJet 内容浅析 O…

每日Attention学习3——Cross-level Feature Fusion

模块出处 [link] [code] [PR 23] Cross-level Feature Aggregation Network for Polyp Segmentation 模块名称 Cross-level Feature Fusion (CFF) 模块作用 双级特征融合 模块结构 模块代码 import torch import torch.nn as nnclass BasicConv2d(nn.Module):def __init__(…

Python批量备份华为设备配置到FTP服务器

Excel表格存放交换机信息&#xff1a; 备份文件夹效果图&#xff1a; Windows系统配置计划任务定时执行python脚本&#xff1a; Program/script&#xff1a;C:\Python\python.exe Add arguments (optional)&#xff1a; D:\Python_PycharmProjects\JunLan_pythonProje…

AWS Cli Windows安装配置

1. 安装 下载地址&#xff1a;AWS 命令行界面(CLI)_管理AWS服务的统一工具-AWS云服务 检验安装&#xff1a; > aws --version aws-cli/2.15.44 Python/3.11.8 Windows/10 exe/AMD64 prompt/off 2. 创建IAM用户 1) 创建组 选择IAM 点击创建组 填写用户组名&#xff0c;…

c++——类和对象(中)

1.类的六个默认成员函数 在一个空类中真的什么都没有吗&#xff0c;错&#xff01;在创建类的时候&#xff0c;编译器自动生成六个函数&#xff0c;这六个函数叫默认成员函数。但是&#xff0c;如果我们自己实现六个同名函数&#xff08;依旧有默认成员函数的特性&#xff0c;…

Django项目之电商购物商城 -- 创建收货地址

Django项目之电商购物商城 – 创建收货地址 一. 在users中创建新的视图与路由用于创建收货地址 # 设置收货地址 class AddressView(View):def get(self , request):return render(request , "user_center_site.html")# 设置收货地址path(user_center_site/, views.…

金和OAC6 FileDownLoad 任意文件读取漏洞

文章目录 免责声明漏洞描述漏洞原理影响版本漏洞复现修复建议 免责声明 没有网络安全就没有国家安全&#xff0c;该文章只为学习和交流&#xff0c;利用做违法乱纪的事&#xff0c;与本人无关 漏洞描述 金和网络是专业信息化服务商,为城市监管部门提供了互联网监管解决方案,…

AI视频教程下载:零代码创建AI智能体、AI Agents和ChatGPT的Gpts

这门课程专注于提示工程的掌握&#xff0c;教你以精确的方式引导GPT&#xff0c;利用它们的生成能力产生卓越的AI驱动结果。一步一步地&#xff0c;你将学会创建多样化的GPT军团——每个都设计来满足特定的专业需求。 从提供个性化职业变更指导的职业教练AI&#xff0c;到以惊…