昇思25天学习打卡营第16天|Vision Transformer图像分类

昇思25天学习打卡营第16天|Vision Transformer图像分类

  • 前言
  • Vision Transformer图像分类
    • Vision Transformer(ViT)简介
      • 模型结构
      • 模型特点
    • 环境准备与数据读取
    • 模型解析
      • Transformer基本原理
        • Attention模块
      • Transformer Encoder
      • ViT模型的输入
      • 整体构建ViT
    • 模型训练与推理
      • 模型训练
      • 模型验证
      • 模型推理
    • 总结
  • 个人任务打卡(读者请忽略)
  • 个人思考

前言

  非常感谢华为昇思大模型平台和CSDN邀请体验昇思大模型!从今天起,笔者将以打卡的方式,将原文搬运和个人思考结合,分享25天的学习内容与成果。为了提升文章质量和阅读体验,笔者会将思考部分放在最后,供大家探索讨论。同时也欢迎各位领取算力,免费体验昇思大模型!

Vision Transformer图像分类

感谢ZOMI酱对本文的贡献。

Vision Transformer(ViT)简介

近些年,随着基于自注意(Self-Attention)结构的模型的发展,特别是Transformer模型的提出,极大地促进了自然语言处理模型的发展。由于Transformers的计算效率和可扩展性,它已经能够训练具有超过100B参数的空前规模的模型。

ViT则是自然语言处理和计算机视觉两个领域的融合结晶。在不依赖卷积操作的情况下,依然可以在图像分类任务上达到很好的效果。

模型结构

ViT模型的主体结构是基于Transformer模型的Encoder部分(部分结构顺序有调整,如:Normalization的位置与标准Transformer不同),其结构图[1]如下:

vit-architecture

模型特点

ViT模型主要应用于图像分类领域。因此,其模型结构相较于传统的Transformer有以下几个特点:

  1. 数据集的原图像被划分为多个patch(图像块)后,将二维patch(不考虑channel)转换为一维向量,再加上类别向量与位置向量作为模型输入。
  2. 模型主体的Block结构是基于Transformer的Encoder结构,但是调整了Normalization的位置,其中,最主要的结构依然是Multi-head Attention结构。
  3. 模型在Blocks堆叠后接全连接层,接受类别向量的输出作为输入并用于分类。通常情况下,我们将最后的全连接层称为Head,Transformer Encoder部分为backbone。

下面将通过代码实例来详细解释基于ViT实现ImageNet分类任务。

注意,本教程在CPU上运行时间过长,不建议使用CPU运行。

环境准备与数据读取

开始实验之前,请确保本地已经安装了Python环境并安装了MindSpore。

首先我们需要下载本案例的数据集,可通过http://image-net.org下载完整的ImageNet数据集,本案例应用的数据集是从ImageNet中筛选出来的子集。

运行第一段代码时会自动下载并解压,请确保你的数据集路径如以下结构。

.dataset/├── ILSVRC2012_devkit_t12.tar.gz├── train/├── infer/└── val/
%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore

在这里插入图片描述

from download import downloaddataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"path = download(dataset_url, path, kind="zip", replace=True)

在这里插入图片描述

import osimport mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transformsdata_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)trans_train = [transforms.RandomCropDecodeResize(size=224,scale=(0.08, 1.0),ratio=(0.75, 1.333)),transforms.RandomHorizontalFlip(prob=0.5),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

模型解析

下面将通过代码来细致剖析ViT模型的内部结构。

Transformer基本原理

Transformer模型源于2017年的一篇文章[2]。在这篇文章中提出的基于Attention机制的编码器-解码器型结构在自然语言处理领域获得了巨大的成功。模型结构如下图所示:

transformer-architecture

其主要结构为多个Encoder和Decoder模块所组成,其中Encoder和Decoder的详细结构如下图[2]所示:

encoder-decoder

Encoder与Decoder由许多结构组成,如:多头注意力(Multi-Head Attention)层,Feed Forward层,Normaliztion层,甚至残差连接(Residual Connection,图中的“Add”)。不过,其中最重要的结构是多头注意力(Multi-Head Attention)结构,该结构基于自注意力(Self-Attention)机制,是多个Self-Attention的并行组成。

所以,理解了Self-Attention就抓住了Transformer的核心。

Attention模块

以下是Self-Attention的解释,其核心内容是为输入向量的每个单词学习一个权重。通过给定一个任务相关的查询向量Query向量,计算Query和各个Key的相似性或者相关性得到注意力分布,即得到每个Key对应Value的权重系数,然后对Value进行加权求和得到最终的Attention数值。

在Self-Attention中:

  1. 最初的输入向量首先会经过Embedding层映射成Q(Query),K(Key),V(Value)三个向量,由于是并行操作,所以代码中是映射成为dim x 3的向量然后进行分割,换言之,如果你的输入向量为一个向量序列( x 1 x_1 x1 x 2 x_2 x2 x 3 x_3 x3),其中的 x 1 x_1 x1 x 2 x_2 x2 x 3 x_3 x3都是一维向量,那么每一个一维向量都会经过Embedding层映射出Q,K,V三个向量,只是Embedding矩阵不同,矩阵参数也是通过学习得到的。这里大家可以认为,Q,K,V三个矩阵是发现向量之间关联信息的一种手段,需要经过学习得到,至于为什么是Q,K,V三个,主要是因为需要两个向量点乘以获得权重,又需要另一个向量来承载权重向加的结果,所以,最少需要3个矩阵。

{ q i = W q ⋅ x i k i = W k ⋅ x i , i = 1 , 2 , 3 … v i = W v ⋅ x i (1) \begin{cases} q_i = W_q \cdot x_i & \\ k_i = W_k \cdot x_i,\hspace{1em} &i = 1,2,3 \ldots \\ v_i = W_v \cdot x_i & \end{cases} \tag{1} qi=Wqxiki=Wkxi,vi=Wvxii=1,2,3(1)

self-attention1

  1. 自注意力机制的自注意主要体现在它的Q,K,V都来源于其自身,也就是该过程是在提取输入的不同顺序的向量的联系与特征,最终通过不同顺序向量之间的联系紧密性(Q与K乘积经过Softmax的结果)来表现出来。Q,K,V得到后就需要获取向量间权重,需要对Q和K进行点乘并除以维度的平方根,对所有向量的结果进行Softmax处理,通过公式(2)的操作,我们获得了向量之间的关系权重。

{ a 1 , 1 = q 1 ⋅ k 1 / d a 1 , 2 = q 1 ⋅ k 2 / d a 1 , 3 = q 1 ⋅ k 3 / d (2) \begin{cases} a_{1,1} = q_1 \cdot k_1 / \sqrt d \\ a_{1,2} = q_1 \cdot k_2 / \sqrt d \\ a_{1,3} = q_1 \cdot k_3 / \sqrt d \end{cases} \tag{2} a1,1=q1k1/d a1,2=q1k2/d a1,3=q1k3/d (2)

self-attention3

S o f t m a x : a ^ 1 , i = e x p ( a 1 , i ) / ∑ j e x p ( a 1 , j ) , j = 1 , 2 , 3 … (3) Softmax: \hat a_{1,i} = exp(a_{1,i}) / \sum_j exp(a_{1,j}),\hspace{1em} j = 1,2,3 \ldots \tag{3} Softmax:a^1,i=exp(a1,i)/jexp(a1,j),j=1,2,3(3)

self-attention2

  1. 其最终输出则是通过V这个映射后的向量与Q,K经过Softmax结果进行weight sum获得,这个过程可以理解为在全局上进行自注意表示。每一组Q,K,V最后都有一个V输出,这是Self-Attention得到的最终结果,是当前向量在结合了它与其他向量关联权重后得到的结果。

b 1 = ∑ i a ^ 1 , i v i , i = 1 , 2 , 3... (4) b_1 = \sum_i \hat a_{1,i}v_i,\hspace{1em} i = 1,2,3... \tag{4} b1=ia^1,ivi,i=1,2,3...(4)

通过下图可以整体把握Self-Attention的全部过程。

self-attention

多头注意力机制就是将原本self-Attention处理的向量分割为多个Head进行处理,这一点也可以从代码中体现,这也是attention结构可以进行并行加速的一个方面。

总结来说,多头注意力机制在保持参数总量不变的情况下,将同样的query, key和value映射到原来的高维空间(Q,K,V)的不同子空间(Q_0,K_0,V_0)中进行自注意力的计算,最后再合并不同子空间中的注意力信息。

所以,对于同一个输入向量,多个注意力机制可以同时对其进行处理,即利用并行计算加速处理过程,又在处理的时候更充分的分析和利用了向量特征。下图展示了多头注意力机制,其并行能力的主要体现在下图中的 a 1 a_1 a1 a 2 a_2 a2是同一个向量进行分割获得的。

multi-head-attention

以下是Multi-Head Attention代码,结合上文的解释,代码清晰的展现了这一过程。

from mindspore import nn, opsclass Attention(nn.Cell):def __init__(self,dim: int,num_heads: int = 8,keep_prob: float = 1.0,attention_keep_prob: float = 1.0):super(Attention, self).__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = ms.Tensor(head_dim ** -0.5)self.qkv = nn.Dense(dim, dim * 3)self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)self.out = nn.Dense(dim, dim)self.out_drop = nn.Dropout(p=1.0-keep_prob)self.attn_matmul_v = ops.BatchMatMul()self.q_matmul_k = ops.BatchMatMul(transpose_b=True)self.softmax = nn.Softmax(axis=-1)def construct(self, x):"""Attention construct."""b, n, c = x.shapeqkv = self.qkv(x)qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))q, k, v = ops.unstack(qkv, axis=0)attn = self.q_matmul_k(q, k)attn = ops.mul(attn, self.scale)attn = self.softmax(attn)attn = self.attn_drop(attn)out = self.attn_matmul_v(attn, v)out = ops.transpose(out, (0, 2, 1, 3))out = ops.reshape(out, (b, n, c))out = self.out(out)out = self.out_drop(out)return out

Transformer Encoder

在了解了Self-Attention结构之后,通过与Feed Forward,Residual Connection等结构的拼接就可以形成Transformer的基础结构,下面代码实现了Feed Forward,Residual Connection结构。

from typing import Optional, Dictclass FeedForward(nn.Cell):def __init__(self,in_features: int,hidden_features: Optional[int] = None,out_features: Optional[int] = None,activation: nn.Cell = nn.GELU,keep_prob: float = 1.0):super(FeedForward, self).__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.dense1 = nn.Dense(in_features, hidden_features)self.activation = activation()self.dense2 = nn.Dense(hidden_features, out_features)self.dropout = nn.Dropout(p=1.0-keep_prob)def construct(self, x):"""Feed Forward construct."""x = self.dense1(x)x = self.activation(x)x = self.dropout(x)x = self.dense2(x)x = self.dropout(x)return xclass ResidualCell(nn.Cell):def __init__(self, cell):super(ResidualCell, self).__init__()self.cell = celldef construct(self, x):"""ResidualCell construct."""return self.cell(x) + x

接下来就利用Self-Attention来构建ViT模型中的TransformerEncoder部分,类似于构建了一个Transformer的编码器部分,如下图[1]所示:

vit-encoder

  1. ViT模型中的基础结构与标准Transformer有所不同,主要在于Normalization的位置是放在Self-Attention和Feed Forward之前,其他结构如Residual Connection,Feed Forward,Normalization都如Transformer中所设计。

  2. 从Transformer结构的图片可以发现,多个子encoder的堆叠就完成了模型编码器的构建,在ViT模型中,依然沿用这个思路,通过配置超参数num_layers,就可以确定堆叠层数。

  3. Residual Connection,Normalization的结构可以保证模型有很强的扩展性(保证信息经过深层处理不会出现退化的现象,这是Residual Connection的作用),Normalization和dropout的应用可以增强模型泛化能力。

从以下源码中就可以清晰看到Transformer的结构。将TransformerEncoder结构和一个多层感知器(MLP)结合,就构成了ViT模型的backbone部分。

class TransformerEncoder(nn.Cell):def __init__(self,dim: int,num_layers: int,num_heads: int,mlp_dim: int,keep_prob: float = 1.,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: nn.Cell = nn.LayerNorm):super(TransformerEncoder, self).__init__()layers = []for _ in range(num_layers):normalization1 = norm((dim,))normalization2 = norm((dim,))attention = Attention(dim=dim,num_heads=num_heads,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob)feedforward = FeedForward(in_features=dim,hidden_features=mlp_dim,activation=activation,keep_prob=keep_prob)layers.append(nn.SequentialCell([ResidualCell(nn.SequentialCell([normalization1, attention])),ResidualCell(nn.SequentialCell([normalization2, feedforward]))]))self.layers = nn.SequentialCell(layers)def construct(self, x):"""Transformer construct."""return self.layers(x)

ViT模型的输入

传统的Transformer结构主要用于处理自然语言领域的词向量(Word Embedding or Word Vector),词向量与传统图像数据的主要区别在于,词向量通常是一维向量进行堆叠,而图片则是二维矩阵的堆叠,多头注意力机制在处理一维词向量的堆叠时会提取词向量之间的联系也就是上下文语义,这使得Transformer在自然语言处理领域非常好用,而二维图片矩阵如何与一维词向量进行转化就成为了Transformer进军图像处理领域的一个小门槛。

在ViT模型中:

  1. 通过将输入图像在每个channel上划分为16*16个patch,这一步是通过卷积操作来完成的,当然也可以人工进行划分,但卷积操作也可以达到目的同时还可以进行一次而外的数据处理;例如一幅输入224 x 224的图像,首先经过卷积处理得到16 x 16个patch,那么每一个patch的大小就是14 x 14。

  2. 再将每一个patch的矩阵拉伸成为一个一维向量,从而获得了近似词向量堆叠的效果。上一步得到的14 x 14的patch就转换为长度为196的向量。

这是图像输入网络经过的第一步处理。具体Patch Embedding的代码如下所示:

class PatchEmbedding(nn.Cell):MIN_NUM_PATCHES = 4def __init__(self,image_size: int = 224,patch_size: int = 16,embed_dim: int = 768,input_channels: int = 3):super(PatchEmbedding, self).__init__()self.image_size = image_sizeself.patch_size = patch_sizeself.num_patches = (image_size // patch_size) ** 2self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)def construct(self, x):"""Path Embedding construct."""x = self.conv(x)b, c, h, w = x.shapex = ops.reshape(x, (b, c, h * w))x = ops.transpose(x, (0, 2, 1))return x

输入图像在划分为patch之后,会经过pos_embedding 和 class_embedding两个过程。

  1. class_embedding主要借鉴了BERT模型的用于文本分类时的思想,在每一个word vector之前增加一个类别值,通常是加在向量的第一位,上一步得到的196维的向量加上class_embedding后变为197维。

  2. 增加的class_embedding是一个可以学习的参数,经过网络的不断训练,最终以输出向量的第一个维度的输出来决定最后的输出类别;由于输入是16 x 16个patch,所以输出进行分类时是取 16 x 16个class_embedding进行分类。

  3. pos_embedding也是一组可以学习的参数,会被加入到经过处理的patch矩阵中。

  4. 由于pos_embedding也是可以学习的参数,所以它的加入类似于全链接网络和卷积的bias。这一步就是创造一个长度维197的可训练向量加入到经过class_embedding的向量中。

实际上,pos_embedding总共有4种方案。但是经过作者的论证,只有加上pos_embedding和不加pos_embedding有明显影响,至于pos_embedding是一维还是二维对分类结果影响不大,所以,在我们的代码中,也是采用了一维的pos_embedding,由于class_embedding是加在pos_embedding之前,所以pos_embedding的维度会比patch拉伸后的维度加1。

总的而言,ViT模型还是利用了Transformer模型在处理上下文语义时的优势,将图像转换为一种“变种词向量”然后进行处理,而这样转换的意义在于,多个patch之间本身具有空间联系,这类似于一种“空间语义”,从而获得了比较好的处理效果。

整体构建ViT

以下代码构建了一个完整的ViT模型。

from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameterdef init(init_type, shape, dtype, name, requires_grad):"""Init."""initial = initializer(init_type, shape, dtype).init_data()return Parameter(initial, name=name, requires_grad=requires_grad)class ViT(nn.Cell):def __init__(self,image_size: int = 224,input_channels: int = 3,patch_size: int = 16,embed_dim: int = 768,num_layers: int = 12,num_heads: int = 12,mlp_dim: int = 3072,keep_prob: float = 1.0,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: Optional[nn.Cell] = nn.LayerNorm,pool: str = 'cls') -> None:super(ViT, self).__init__()self.patch_embedding = PatchEmbedding(image_size=image_size,patch_size=patch_size,embed_dim=embed_dim,input_channels=input_channels)num_patches = self.patch_embedding.num_patchesself.cls_token = init(init_type=Normal(sigma=1.0),shape=(1, 1, embed_dim),dtype=ms.float32,name='cls',requires_grad=True)self.pos_embedding = init(init_type=Normal(sigma=1.0),shape=(1, num_patches + 1, embed_dim),dtype=ms.float32,name='pos_embedding',requires_grad=True)self.pool = poolself.pos_dropout = nn.Dropout(p=1.0-keep_prob)self.norm = norm((embed_dim,))self.transformer = TransformerEncoder(dim=embed_dim,num_layers=num_layers,num_heads=num_heads,mlp_dim=mlp_dim,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob,drop_path_keep_prob=drop_path_keep_prob,activation=activation,norm=norm)self.dropout = nn.Dropout(p=1.0-keep_prob)self.dense = nn.Dense(embed_dim, num_classes)def construct(self, x):"""ViT construct."""x = self.patch_embedding(x)cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))x = ops.concat((cls_tokens, x), axis=1)x += self.pos_embeddingx = self.pos_dropout(x)x = self.transformer(x)x = self.norm(x)x = x[:, 0]if self.training:x = self.dropout(x)x = self.dense(x)return x

整体流程图如下所示:

data-process

模型训练与推理

模型训练

模型开始训练前,需要设定损失函数,优化器,回调函数等。

完整训练ViT模型需要很长的时间,实际应用时建议根据项目需要调整epoch_size,当正常输出每个Epoch的step信息时,意味着训练正在进行,通过模型输出可以查看当前训练的loss值和时间等指标。

from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()# construct model
network = ViT()# load ckpt
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)# define learning rate
lr = nn.cosine_decay_lr(min_lr=float(0),max_lr=0.00005,total_step=epoch_size * step_size,step_per_epoch=step_size,decay_epoch=10)# define optimizer
network_opt = nn.Adam(network.trainable_params(), lr, momentum)# define loss function
class CrossEntropySmooth(LossBase):"""CrossEntropy."""def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):super(CrossEntropySmooth, self).__init__()self.onehot = ops.OneHot()self.sparse = sparseself.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)def construct(self, logit, label):if self.sparse:label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)loss = self.ce(logit, label)return lossnetwork_loss = CrossEntropySmooth(sparse=True,reduction="mean",smooth_factor=0.1,num_classes=num_classes)# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")# train model
model.train(epoch_size,dataset_train,callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],dataset_sink_mode=False,)

在这里插入图片描述

模型验证

模型验证过程主要应用了ImageFolderDataset,CrossEntropySmooth和Model等接口。

ImageFolderDataset主要用于读取数据集。

CrossEntropySmooth是损失函数实例化接口。

Model主要用于编译模型。

与训练过程相似,首先进行数据增强,然后定义ViT网络结构,加载预训练模型参数。随后设置损失函数,评价指标等,编译模型后进行验证。本案例采用了业界通用的评价标准Top_1_Accuracy和Top_5_Accuracy评价指标来评价模型表现。

在本案例中,这两个指标代表了在输出的1000维向量中,以最大值或前5的输出值所代表的类别为预测结果时,模型预测的准确率。这两个指标的值越大,代表模型准确率越高。

dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)trans_val = [transforms.Decode(),transforms.Resize(224 + 32),transforms.CenterCrop(224),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)# construct model
network = ViT()# load ckpt
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)network_loss = CrossEntropySmooth(sparse=True,reduction="mean",smooth_factor=0.1,num_classes=num_classes)# define metric
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),'Top_5_Accuracy': train.Top5CategoricalAccuracy()}if ascend_target:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")# evaluate model
result = model.eval(dataset_val)
print(result)

从结果可以看出,由于我们加载了预训练模型参数,模型的Top_1_Accuracy和Top_5_Accuracy达到了很高的水平,实际项目中也可以以此准确率为标准。如果未使用预训练模型参数,则需要更多的epoch来训练。
在这里插入图片描述

模型推理

在进行模型推理之前,首先要定义一个对推理图片进行数据预处理的方法。该方法可以对我们的推理图片进行resize和normalize处理,这样才能与我们训练时的输入数据匹配。

本案例采用了一张Doberman的图片作为推理图片来测试模型表现,期望模型可以给出正确的预测结果。

dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)trans_infer = [transforms.Decode(),transforms.Resize([224, 224]),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]dataset_infer = dataset_infer.map(operations=trans_infer,input_columns=["image"],num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)

接下来,我们将调用模型的predict方法进行模型。

在推理过程中,通过index2label就可以获取对应标签,再通过自定义的show_result接口将结果写在对应图片上。

import os
import pathlib
import cv2
import numpy as np
from PIL import Image
from enum import Enum
from scipy import ioclass Color(Enum):"""dedine enum color."""red = (0, 0, 255)green = (0, 255, 0)blue = (255, 0, 0)cyan = (255, 255, 0)yellow = (0, 255, 255)magenta = (255, 0, 255)white = (255, 255, 255)black = (0, 0, 0)def check_file_exist(file_name: str):"""check_file_exist."""if not os.path.isfile(file_name):raise FileNotFoundError(f"File `{file_name}` does not exist.")def color_val(color):"""color_val."""if isinstance(color, str):return Color[color].valueif isinstance(color, Color):return color.valueif isinstance(color, tuple):assert len(color) == 3for channel in color:assert 0 <= channel <= 255return colorif isinstance(color, int):assert 0 <= color <= 255return color, color, colorif isinstance(color, np.ndarray):assert color.ndim == 1 and color.size == 3assert np.all((color >= 0) & (color <= 255))color = color.astype(np.uint8)return tuple(color)raise TypeError(f'Invalid type for color: {type(color)}')def imread(image, mode=None):"""imread."""if isinstance(image, pathlib.Path):image = str(image)if isinstance(image, np.ndarray):passelif isinstance(image, str):check_file_exist(image)image = Image.open(image)if mode:image = np.array(image.convert(mode))else:raise TypeError("Image must be a `ndarray`, `str` or Path object.")return imagedef imwrite(image, image_path, auto_mkdir=True):"""imwrite."""if auto_mkdir:dir_name = os.path.abspath(os.path.dirname(image_path))if dir_name != '':dir_name = os.path.expanduser(dir_name)os.makedirs(dir_name, mode=777, exist_ok=True)image = Image.fromarray(image)image.save(image_path)def imshow(img, win_name='', wait_time=0):"""imshow"""cv2.imshow(win_name, imread(img))if wait_time == 0:  # prevent from hanging if windows was closedwhile True:ret = cv2.waitKey(1)closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1# if user closed window or if some key pressedif closed or ret != -1:breakelse:ret = cv2.waitKey(wait_time)def show_result(img: str,result: Dict[int, float],text_color: str = 'green',font_scale: float = 0.5,row_width: int = 20,show: bool = False,win_name: str = '',wait_time: int = 0,out_file: Optional[str] = None) -> None:"""Mark the prediction results on the picture."""img = imread(img, mode="RGB")img = img.copy()x, y = 0, row_widthtext_color = color_val(text_color)for k, v in result.items():if isinstance(v, float):v = f'{v:.2f}'label_text = f'{k}: {v}'cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,font_scale, text_color)y += row_widthif out_file:show = Falseimwrite(img, out_file)if show:imshow(img, win_name, wait_time)def index2label():"""Dictionary output for image numbers and categories of the ImageNet dataset."""metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")meta = io.loadmat(metafile, squeeze_me=True)['synsets']nums_children = list(zip(*meta))[4]meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]_, wnids, classes = list(zip(*meta))[:3]clssname = [tuple(clss.split(', ')) for clss in classes]wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])mapping = {}for index, (_, class_name) in enumerate(wind2class_name):mapping[index] = class_name[0]return mapping# Read data for inference
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):image = image["image"]image = ms.Tensor(image)prob = model.predict(image)label = np.argmax(prob.asnumpy(), axis=1)mapping = index2label()output = {int(label): mapping[int(label)]}print(output)show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG",result=output,out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")

在这里插入图片描述

推理过程完成后,在推理文件夹下可以找到图片的推理结果,可以看出预测结果是Doberman,与期望结果相同,验证了模型的准确性。

infer-result

总结

本案例完成了一个ViT模型在ImageNet数据上进行训练,验证和推理的过程,其中,对关键的ViT模型结构和原理作了讲解。通过学习本案例,理解源码可以帮助用户掌握Multi-Head Attention,TransformerEncoder,pos_embedding等关键概念,如果要详细理解ViT的模型原理,建议基于源码更深层次的详细阅读。

个人任务打卡(读者请忽略)

在这里插入图片描述

个人思考

本章节主要描述了使用昇思大模型完成Vision Transformer图像分类的主要功能。文章首先介绍了ViT模型的架构和特点,经下载数据集后搭建了多头自注意力模块及其解码器,在完成ViT网络的彻底搭建后进行了10代模型的训练,最终实现了分类结果的验证。ViT以其Non-Local的特点,用全局感受野推理模型,准确率更高,但是对训练资源的需求也不容小觑,因此本文不适用于CPU环境下的训练(需要的时间太长了)。综上所述,本文证明了昇思大模型基于Transformer框架对图像分类任务的有效性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/373216.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【自学网络安全】:安全策略与用户认证综合实验

实验拓扑图&#xff1a; 实验任务&#xff1a; 1、DMZ区内的服务器&#xff0c;办公区仅能在办公时间内(9:00-18:00)可以访问&#xff0c;生产区的设备全天可以访问 2、生产区不允许访问互联网&#xff0c;办公区和游客区允许访问互联网 3、办公区设备10.0.2.10不允许访问Dmz区…

代理详解之静态代理、动态代理、SpringAOP实现

1、代理介绍 代理是指一个对象A通过持有另一个对象B&#xff0c;可以具有B同样的行为的模式。为了对外开放协议&#xff0c;B往往实现了一个接口&#xff0c;A也会去实现接口。但是B是“真正”实现类&#xff0c;A则比较“虚”&#xff0c;他借用了B的方法去实现接口的方法。A…

未羽研发测试管理平台

突然有一些觉悟&#xff0c;程序猿不能只会吭哧吭哧的低头做事&#xff0c;应该学会怎么去展示自己&#xff0c;怎么去宣传自己&#xff0c;怎么把自己想做的事表述清楚。 于是&#xff0c;这两天一直在整理自己的作品&#xff0c;也为接下来的找工作多做点准备。接下来…

wordpress外贸建站公司案例英文模板

Indirect Trade WP外贸网站模板 WordPress Indirect Trade外贸网站模板&#xff0c;建外贸独立站用wordpress模板&#xff0c;快速搭建十分便捷。 衣物清洁wordpress独立站模板 洗衣粉、洗衣液、衣物柔顺剂、干洗剂、衣领净、洗衣皂等衣物清洁wordpress独立站模板。 家具wordpr…

Apache部署与配置

概述 介绍 Apache HTTP Server(简称Apache)是Apache的一个开源的网页服务器&#xff0c;它源自NCSAhttpd服务器&#xff0c;并经过多次修改和发展&#xff0c;如今已经成为全球范围内广泛使用的Web服务器软件之一 特点 跨平台&#xff1a;可以运行在几乎所有广泛使用的计算机平…

哪个充电宝口碑比较好?怎么选充电宝?2024年口碑优秀充电宝推荐

在如今快节奏的生活中&#xff0c;充电宝已然成为我们日常生活中的必备品。然而&#xff0c;市场上充电宝品牌众多&#xff0c;质量参差不齐&#xff0c;如何选择一款安全、可靠且口碑优秀的充电宝成为了消费者关注的焦点。安全性能不仅关系到充电宝的使用寿命&#xff0c;更关…

【正点原子i.MX93开发板试用连载体验】项目计划和开箱体验

本文最早发表于电子发烧友&#xff1a;【   】【正点原子i.MX93开发板试用连载体验】基于深度学习的语音本地控制 - 正点原子学习小组 - 电子技术论坛 - 广受欢迎的专业电子论坛! (elecfans.com)https://bbs.elecfans.com/jishu_2438354_1_1.html 有一段时间没有参加电子发…

clickhouse-jdbc-bridge rce

clickhouse-jdbc-bridge 是什么 JDBC bridge for ClickHouse. It acts as a stateless proxy passing queries from ClickHouse to external datasources. With this extension, you can run distributed query on ClickHouse across multiple datasources in real time, whic…

支持向量机 (support vector machine,SVM)

支持向量机 &#xff08;support vector machine&#xff0c;SVM&#xff09; flyfish 支持向量机是一种用于分类和回归的机器学习模型。在分类任务中&#xff0c;SVM试图找到一个最佳的分隔超平面&#xff0c;使得不同类别的数据点在空间中被尽可能宽的间隔分开。 超平面方…

MMGPL: 多模态医学数据分析与图提示学习| 文献速递-基于深度学习的多模态数据分析与生存分析

Title 题目 MMGPL: Multimodal Medical Data Analysis with Graph Prompt Learning MMGPL: 多模态医学数据分析与图提示学习 01 文献速递介绍 神经学障碍&#xff0c;包括自闭症谱系障碍&#xff08;ASD&#xff09;&#xff08;Lord等&#xff0c;2018年&#xff09;和阿…

kafka的副本replica

指定topic的分区和副本 通过kafka命令行工具 kafka-topics.sh --create --topic myTopic --partitions 3 --replication-factor 1 --bootstrap-server localhost:9092 执行代码时指定分区个数

基于Spring Boot框架的EAM系统设计与实现

摘 要&#xff1a;文章设计并实现一个基于Spring Boot框架的EAM系统&#xff0c;以应对传统人工管理模式存在的低效与信息管理难题。系统利用Java语言、JSP技术、MySQL数据库等技术栈&#xff0c;构建了一个B/S架构的高效管理平台&#xff0c;提升了资产管理的信息化水平。该系…

大小端详解

引例 我们知道整形(int)是4个字节&#xff0c;例如随便举个例子&#xff1a;0x01020304&#xff0c;它一共占了四个地址位&#xff0c;01,02,03,04分别占了一个字节&#xff08;一个字节就对应了一个地址&#xff09;。 那么就会有个问题&#xff1a;我们的01到底是存储在高地…

STM32的 DMA(直接存储器访问) 详解

STM32的DMA&#xff08;Direct Memory Access&#xff0c;直接存储器存取&#xff09;是一种在单片机中用于高效实现数据传输的技术。它允许外设设备直接访问RAM&#xff0c;不需要CPU的干预&#xff0c;从而释放CPU资源&#xff0c;提高CPU工作效率&#xff0c;本文基于STM32F…

C++基础(1)

目录 C的输入输出&#xff1a; 命名空间域&#xff1a; 缺省&#xff08;默认&#xff09;参数&#xff1a; 函数重载&#xff1a; 引用&#xff1a; 内联函数inline&#xff1a; 指针空值nullptr&#xff1a; C的输入输出&#xff1a; 输入&#xff1a; int a; char …

社交论坛圈子系统APP开发社交圈子小程序系统源码开源,带语音派对聊天室/圈子社交论坛及时聊天

功能// 首页左右滑动切换分类 使用资讯类app常见的滑动切换分类&#xff0c;让用户使用更方便。 2信息卡片流展示 每条信息都是一个卡片&#xff0c;头像展示会员标签&#xff0c;单图自动宽度&#xff0c;多图九宫格展示&#xff0c;底部展示信息发布地址&#xff0c;阅读量、…

采用3种稀疏降噪模型对心电信号进行降噪(Matlab R2021B)

心电信号采集自病人体表&#xff0c;是一种无创性的检测手段。因此&#xff0c;心电信号采集过程中&#xff0c;本身也已经包含了机体内部其他生命活动带来的噪声。同时&#xff0c;由于采集设备和环境中存在电流的变化&#xff0c;产生电磁发射等物理现象&#xff0c;会对心电…

Java项目:基于SSM框架实现的中小型企业财务管理系统【ssm+B/S架构+源码+数据库+答辩PPT+开题报告+毕业论文】

一、项目简介 本项目是一套基于SSM框架实现的中小型企业财务管理系统 包含&#xff1a;项目源码、数据库脚本等&#xff0c;该项目附带全部源码可作为毕设使用。 项目都经过严格调试&#xff0c;eclipse或者idea 确保可以运行&#xff01; 该系统功能完善、界面美观、操作简单…

浅谈VPS主机上的数据库性能优化

如何提高网站性能&#xff1f;一个显而易见的解决方案是升级托管账户。您的网站将拥有更多硬件资源&#xff0c;因此可以同时处理更多请求并更快地传递数据。 无论如何&#xff0c;人们都是这么认为的。但事实总是不一样。 现代网站是一个复杂的系统&#xff0c;包含许多必须…

效果惊人!LivePortrait开源数字人技术,让静态照片生动起来

不得了了,快手已经不是众人所知的那个短视频娱乐平台了。 可灵AI视频的风口尚未过去,又推出了LivePortrait--开源的数字人项目。LivePortrait让你的照片动起来,合成逼真的动态人像视频,阿里通义EMO不再是唯一选择。 让图像动起来 LivePortrait 主要提供了对眼睛和嘴唇动作的…