昇思25天学习打卡营第六天|应用实践/计算机视觉/Vision Transformer图像分类

心得

运行模型似乎有点靠天意?每次跑模型之前先来个焚香沐浴?总之今天是机器视觉的最后一课了,尽管课程里强调模型跑得慢,可是我的这次运行,居然很快的就看到结果了。

如果一直看我这个系列文章的小伙伴,应该可以看出,我每次课程都是成功跑完整的。每次的文章中,不仅有课程的初始代码,还有跑成功后的显示结果,一起学习的小伙伴遇到问题时可以对着看看。

"

实际上,pos_embedding总共有4种方案。但是经过作者的论证,只有加上pos_embedding和不加pos_embedding有明显影响,至于pos_embedding是一维还是二维对分类结果影响不大,所以,在我们的代码中,也是采用了一维的pos_embedding,由于class_embedding是加在pos_embedding之前,所以pos_embedding的维度会比patch拉伸后的维度加1。

总的而言,ViT模型还是利用了Transformer模型在处理上下文语义时的优势,将图像转换为一种“变种词向量”然后进行处理,而这样转换的意义在于,多个patch之间本身具有空间联系,这类似于一种“空间语义”,从而获得了比较好的处理效果。

"

也许有小伙伴愿意尝试一下二维呢?加油哦!

打卡截图

Vision Transformer图像分类

感谢ZOMI酱对本文的贡献。

Vision Transformer(ViT)简介

近些年,随着基于自注意(Self-Attention)结构的模型的发展,特别是Transformer模型的提出,极大地促进了自然语言处理模型的发展。由于Transformers的计算效率和可扩展性,它已经能够训练具有超过100B参数的空前规模的模型。

ViT则是自然语言处理和计算机视觉两个领域的融合结晶。在不依赖卷积操作的情况下,依然可以在图像分类任务上达到很好的效果。

模型结构

ViT模型的主体结构是基于Transformer模型的Encoder部分(部分结构顺序有调整,如:Normalization的位置与标准Transformer不同),其结构图[1]如下:

vit-architecture

模型特点

ViT模型主要应用于图像分类领域。因此,其模型结构相较于传统的Transformer有以下几个特点:

  1. 数据集的原图像被划分为多个patch(图像块)后,将二维patch(不考虑channel)转换为一维向量,再加上类别向量与位置向量作为模型输入。
  2. 模型主体的Block结构是基于Transformer的Encoder结构,但是调整了Normalization的位置,其中,最主要的结构依然是Multi-head Attention结构。
  3. 模型在Blocks堆叠后接全连接层,接受类别向量的输出作为输入并用于分类。通常情况下,我们将最后的全连接层称为Head,Transformer Encoder部分为backbone。

下面将通过代码实例来详细解释基于ViT实现ImageNet分类任务。

注意,本教程在CPU上运行时间过长,不建议使用CPU运行。

环境准备与数据读取

开始实验之前,请确保本地已经安装了Python环境并安装了MindSpore。

首先我们需要下载本案例的数据集,可通过http://image-net.org下载完整的ImageNet数据集,本案例应用的数据集是从ImageNet中筛选出来的子集。

运行第一段代码时会自动下载并解压,请确保你的数据集路径如以下结构。

.dataset/├── ILSVRC2012_devkit_t12.tar.gz├── train/├── infer/└── val/

[1]:

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14

[2]:

# 查看当前 mindspore 版本
!pip show mindspore
Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 

[3]:

from download import download
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"
path = download(dataset_url, path, kind="zip", replace=True)
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip (489.1 MB)file_sizes: 100%|█████████████████████████████| 513M/513M [00:02<00:00, 220MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./

[4]:

 
import os
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms
data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)
trans_train = [
    transforms.RandomCropDecodeResize(size=224,
                                      scale=(0.08, 1.0),
                                      ratio=(0.75, 1.333)),
    transforms.RandomHorizontalFlip(prob=0.5),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

模型解析

下面将通过代码来细致剖析ViT模型的内部结构。

Transformer基本原理

Transformer模型源于2017年的一篇文章[2]。在这篇文章中提出的基于Attention机制的编码器-解码器型结构在自然语言处理领域获得了巨大的成功。模型结构如下图所示:

transformer-architecture

其主要结构为多个Encoder和Decoder模块所组成,其中Encoder和Decoder的详细结构如下图[2]所示:

encoder-decoder

Encoder与Decoder由许多结构组成,如:多头注意力(Multi-Head Attention)层,Feed Forward层,Normaliztion层,甚至残差连接(Residual Connection,图中的“Add”)。不过,其中最重要的结构是多头注意力(Multi-Head Attention)结构,该结构基于自注意力(Self-Attention)机制,是多个Self-Attention的并行组成。

所以,理解了Self-Attention就抓住了Transformer的核心。

Attention模块

以下是Self-Attention的解释,其核心内容是为输入向量的每个单词学习一个权重。通过给定一个任务相关的查询向量Query向量,计算Query和各个Key的相似性或者相关性得到注意力分布,即得到每个Key对应Value的权重系数,然后对Value进行加权求和得到最终的Attention数值。

在Self-Attention中:

  1. 最初的输入向量首先会经过Embedding层映射成Q(Query),K(Key),V(Value)三个向量,由于是并行操作,所以代码中是映射成为dim x 3的向量然后进行分割,换言之,如果你的输入向量为一个向量序列(𝑥1𝑥1,𝑥2𝑥2,𝑥3𝑥3),其中的𝑥1𝑥1,𝑥2𝑥2,𝑥3𝑥3都是一维向量,那么每一个一维向量都会经过Embedding层映射出Q,K,V三个向量,只是Embedding矩阵不同,矩阵参数也是通过学习得到的。这里大家可以认为,Q,K,V三个矩阵是发现向量之间关联信息的一种手段,需要经过学习得到,至于为什么是Q,K,V三个,主要是因为需要两个向量点乘以获得权重,又需要另一个向量来承载权重向加的结果,所以,最少需要3个矩阵。

𝑞𝑖=𝑊𝑞⋅𝑥𝑖𝑘𝑖=𝑊𝑘⋅𝑥𝑖,𝑣𝑖=𝑊𝑣⋅𝑥𝑖𝑖=1,2,3…(1)(1){𝑞𝑖=𝑊𝑞⋅𝑥𝑖𝑘𝑖=𝑊𝑘⋅𝑥𝑖,𝑖=1,2,3…𝑣𝑖=𝑊𝑣⋅𝑥𝑖

self-attention1

  1. 自注意力机制的自注意主要体现在它的Q,K,V都来源于其自身,也就是该过程是在提取输入的不同顺序的向量的联系与特征,最终通过不同顺序向量之间的联系紧密性(Q与K乘积经过Softmax的结果)来表现出来。Q,K,V得到后就需要获取向量间权重,需要对Q和K进行点乘并除以维度的平方根,对所有向量的结果进行Softmax处理,通过公式(2)的操作,我们获得了向量之间的关系权重。

𝑎1,1=𝑞1⋅𝑘1/𝑑⎯⎯√𝑎1,2=𝑞1⋅𝑘2/𝑑⎯⎯√𝑎1,3=𝑞1⋅𝑘3/𝑑⎯⎯√(2)(2){𝑎1,1=𝑞1⋅𝑘1/𝑑𝑎1,2=𝑞1⋅𝑘2/𝑑𝑎1,3=𝑞1⋅𝑘3/𝑑

self-attention3

𝑆𝑜𝑓𝑡𝑚𝑎𝑥:𝑎̂ 1,𝑖=𝑒𝑥𝑝(𝑎1,𝑖)/∑𝑗𝑒𝑥𝑝(𝑎1,𝑗),𝑗=1,2,3…(3)(3)𝑆𝑜𝑓𝑡𝑚𝑎𝑥:𝑎^1,𝑖=𝑒𝑥𝑝(𝑎1,𝑖)/∑𝑗𝑒𝑥𝑝(𝑎1,𝑗),𝑗=1,2,3…

self-attention2

  1. 其最终输出则是通过V这个映射后的向量与Q,K经过Softmax结果进行weight sum获得,这个过程可以理解为在全局上进行自注意表示。每一组Q,K,V最后都有一个V输出,这是Self-Attention得到的最终结果,是当前向量在结合了它与其他向量关联权重后得到的结果。

𝑏1=∑𝑖𝑎̂ 1,𝑖𝑣𝑖,𝑖=1,2,3...(4)(4)𝑏1=∑𝑖𝑎^1,𝑖𝑣𝑖,𝑖=1,2,3...

通过下图可以整体把握Self-Attention的全部过程。

self-attention

多头注意力机制就是将原本self-Attention处理的向量分割为多个Head进行处理,这一点也可以从代码中体现,这也是attention结构可以进行并行加速的一个方面。

总结来说,多头注意力机制在保持参数总量不变的情况下,将同样的query, key和value映射到原来的高维空间(Q,K,V)的不同子空间(Q_0,K_0,V_0)中进行自注意力的计算,最后再合并不同子空间中的注意力信息。

所以,对于同一个输入向量,多个注意力机制可以同时对其进行处理,即利用并行计算加速处理过程,又在处理的时候更充分的分析和利用了向量特征。下图展示了多头注意力机制,其并行能力的主要体现在下图中的𝑎1𝑎1和𝑎2𝑎2是同一个向量进行分割获得的。

multi-head-attention

以下是Multi-Head Attention代码,结合上文的解释,代码清晰的展现了这一过程。

[5]:

 
from mindspore import nn, ops
class Attention(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_heads: int = 8,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = ms.Tensor(head_dim ** -0.5)
        self.qkv = nn.Dense(dim, dim * 3)
        self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)
        self.out = nn.Dense(dim, dim)
        self.out_drop = nn.Dropout(p=1.0-keep_prob)
        self.attn_matmul_v = ops.BatchMatMul()
        self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
        self.softmax = nn.Softmax(axis=-1)
    def construct(self, x):
        """Attention construct."""
        b, n, c = x.shape
        qkv = self.qkv(x)
        qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
        qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = ops.unstack(qkv, axis=0)
        attn = self.q_matmul_k(q, k)
        attn = ops.mul(attn, self.scale)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        out = self.attn_matmul_v(attn, v)
        out = ops.transpose(out, (0, 2, 1, 3))
        out = ops.reshape(out, (b, n, c))
        out = self.out(out)
        out = self.out_drop(out)
        return out

Transformer Encoder

在了解了Self-Attention结构之后,通过与Feed Forward,Residual Connection等结构的拼接就可以形成Transformer的基础结构,下面代码实现了Feed Forward,Residual Connection结构。

[6]:

 
from typing import Optional, Dict
class FeedForward(nn.Cell):
    def __init__(self,
                 in_features: int,
                 hidden_features: Optional[int] = None,
                 out_features: Optional[int] = None,
                 activation: nn.Cell = nn.GELU,
                 keep_prob: float = 1.0):
        super(FeedForward, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dense1 = nn.Dense(in_features, hidden_features)
        self.activation = activation()
        self.dense2 = nn.Dense(hidden_features, out_features)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
    def construct(self, x):
        """Feed Forward construct."""
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.dropout(x)
        return x
class ResidualCell(nn.Cell):
    def __init__(self, cell):
        super(ResidualCell, self).__init__()
        self.cell = cell
    def construct(self, x):
        """ResidualCell construct."""
        return self.cell(x) + x

接下来就利用Self-Attention来构建ViT模型中的TransformerEncoder部分,类似于构建了一个Transformer的编码器部分,如下图[1]所示:

vit-encoder

  1. ViT模型中的基础结构与标准Transformer有所不同,主要在于Normalization的位置是放在Self-Attention和Feed Forward之前,其他结构如Residual Connection,Feed Forward,Normalization都如Transformer中所设计。

  2. 从Transformer结构的图片可以发现,多个子encoder的堆叠就完成了模型编码器的构建,在ViT模型中,依然沿用这个思路,通过配置超参数num_layers,就可以确定堆叠层数。

  3. Residual Connection,Normalization的结构可以保证模型有很强的扩展性(保证信息经过深层处理不会出现退化的现象,这是Residual Connection的作用),Normalization和dropout的应用可以增强模型泛化能力。

从以下源码中就可以清晰看到Transformer的结构。将TransformerEncoder结构和一个多层感知器(MLP)结合,就构成了ViT模型的backbone部分。

[7]:

 
class TransformerEncoder(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_layers: int,
                 num_heads: int,
                 mlp_dim: int,
                 keep_prob: float = 1.,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: nn.Cell = nn.LayerNorm):
        super(TransformerEncoder, self).__init__()
        layers = []
        for _ in range(num_layers):
            normalization1 = norm((dim,))
            normalization2 = norm((dim,))
            attention = Attention(dim=dim,
                                  num_heads=num_heads,
                                  keep_prob=keep_prob,
                                  attention_keep_prob=attention_keep_prob)
            feedforward = FeedForward(in_features=dim,
                                      hidden_features=mlp_dim,
                                      activation=activation,
                                      keep_prob=keep_prob)
            layers.append(
                nn.SequentialCell([
                    ResidualCell(nn.SequentialCell([normalization1, attention])),
                    ResidualCell(nn.SequentialCell([normalization2, feedforward]))
                ])
            )
        self.layers = nn.SequentialCell(layers)
    def construct(self, x):
        """Transformer construct."""
        return self.layers(x)

ViT模型的输入

传统的Transformer结构主要用于处理自然语言领域的词向量(Word Embedding or Word Vector),词向量与传统图像数据的主要区别在于,词向量通常是一维向量进行堆叠,而图片则是二维矩阵的堆叠,多头注意力机制在处理一维词向量的堆叠时会提取词向量之间的联系也就是上下文语义,这使得Transformer在自然语言处理领域非常好用,而二维图片矩阵如何与一维词向量进行转化就成为了Transformer进军图像处理领域的一个小门槛。

在ViT模型中:

  1. 通过将输入图像在每个channel上划分为1616个patch,这一步是通过卷积操作来完成的,当然也可以人工进行划分,但卷积操作也可以达到目的同时还可以进行一次而外的数据处理;*例如一幅输入224 x 224的图像,首先经过卷积处理得到16 x 16个patch,那么每一个patch的大小就是14 x 14。

  2. 再将每一个patch的矩阵拉伸成为一个一维向量,从而获得了近似词向量堆叠的效果。上一步得到的14 x 14的patch就转换为长度为196的向量。

这是图像输入网络经过的第一步处理。具体Patch Embedding的代码如下所示:

[8]:

 
class PatchEmbedding(nn.Cell):
    MIN_NUM_PATCHES = 4
    def __init__(self,
                 image_size: int = 224,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 input_channels: int = 3):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)
    def construct(self, x):
        """Path Embedding construct."""
        x = self.conv(x)
        b, c, h, w = x.shape
        x = ops.reshape(x, (b, c, h * w))
        x = ops.transpose(x, (0, 2, 1))
        return x

输入图像在划分为patch之后,会经过pos_embedding 和 class_embedding两个过程。

  1. class_embedding主要借鉴了BERT模型的用于文本分类时的思想,在每一个word vector之前增加一个类别值,通常是加在向量的第一位,上一步得到的196维的向量加上class_embedding后变为197维。

  2. 增加的class_embedding是一个可以学习的参数,经过网络的不断训练,最终以输出向量的第一个维度的输出来决定最后的输出类别;由于输入是16 x 16个patch,所以输出进行分类时是取 16 x 16个class_embedding进行分类。

  3. pos_embedding也是一组可以学习的参数,会被加入到经过处理的patch矩阵中。

  4. 由于pos_embedding也是可以学习的参数,所以它的加入类似于全链接网络和卷积的bias。这一步就是创造一个长度维197的可训练向量加入到经过class_embedding的向量中。

实际上,pos_embedding总共有4种方案。但是经过作者的论证,只有加上pos_embedding和不加pos_embedding有明显影响,至于pos_embedding是一维还是二维对分类结果影响不大,所以,在我们的代码中,也是采用了一维的pos_embedding,由于class_embedding是加在pos_embedding之前,所以pos_embedding的维度会比patch拉伸后的维度加1。

总的而言,ViT模型还是利用了Transformer模型在处理上下文语义时的优势,将图像转换为一种“变种词向量”然后进行处理,而这样转换的意义在于,多个patch之间本身具有空间联系,这类似于一种“空间语义”,从而获得了比较好的处理效果。

整体构建ViT

以下代码构建了一个完整的ViT模型。

[9]:

 
from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameter
def init(init_type, shape, dtype, name, requires_grad):
    """Init."""
    initial = initializer(init_type, shape, dtype).init_data()
    return Parameter(initial, name=name, requires_grad=requires_grad)
class ViT(nn.Cell):
    def __init__(self,
                 image_size: int = 224,
                 input_channels: int = 3,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 num_layers: int = 12,
                 num_heads: int = 12,
                 mlp_dim: int = 3072,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: Optional[nn.Cell] = nn.LayerNorm,
                 pool: str = 'cls') -> None:
        super(ViT, self).__init__()
        self.patch_embedding = PatchEmbedding(image_size=image_size,
                                              patch_size=patch_size,
                                              embed_dim=embed_dim,
                                              input_channels=input_channels)
        num_patches = self.patch_embedding.num_patches
        self.cls_token = init(init_type=Normal(sigma=1.0),
                              shape=(1, 1, embed_dim),
                              dtype=ms.float32,
                              name='cls',
                              requires_grad=True)
        self.pos_embedding = init(init_type=Normal(sigma=1.0),
                                  shape=(1, num_patches + 1, embed_dim),
                                  dtype=ms.float32,
                                  name='pos_embedding',
                                  requires_grad=True)
        self.pool = pool
        self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
        self.norm = norm((embed_dim,))
        self.transformer = TransformerEncoder(dim=embed_dim,
                                              num_layers=num_layers,
                                              num_heads=num_heads,
                                              mlp_dim=mlp_dim,
                                              keep_prob=keep_prob,
                                              attention_keep_prob=attention_keep_prob,
                                              drop_path_keep_prob=drop_path_keep_prob,
                                              activation=activation,
                                              norm=norm)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
        self.dense = nn.Dense(embed_dim, num_classes)
    def construct(self, x):
        """ViT construct."""
        x = self.patch_embedding(x)
        cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
        x = ops.concat((cls_tokens, x), axis=1)
        x += self.pos_embedding
        x = self.pos_dropout(x)
        x = self.transformer(x)
        x = self.norm(x)
        x = x[:, 0]
        if self.training:
            x = self.dropout(x)
        x = self.dense(x)
        return x

整体流程图如下所示:

data-process

模型训练与推理

模型训练

模型开始训练前,需要设定损失函数,优化器,回调函数等。

完整训练ViT模型需要很长的时间,实际应用时建议根据项目需要调整epoch_size,当正常输出每个Epoch的step信息时,意味着训练正在进行,通过模型输出可以查看当前训练的loss值和时间等指标。

[10]:

 
from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train
# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()
# construct model
network = ViT()
# load ckpt
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
# define learning rate
lr = nn.cosine_decay_lr(min_lr=float(0),
                        max_lr=0.00005,
                        total_step=epoch_size * step_size,
                        step_per_epoch=step_size,
                        decay_epoch=10)
# define optimizer
network_opt = nn.Adam(network.trainable_params(), lr, momentum)
# define loss function
class CrossEntropySmooth(LossBase):
    """CrossEntropy."""
    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
        super(CrossEntropySmooth, self).__init__()
        self.onehot = ops.OneHot()
        self.sparse = sparse
        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
    def construct(self, logit, label):
        if self.sparse:
            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
        loss = self.ce(logit, label)
        return loss
network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)
# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)
# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")
# train model
model.train(epoch_size,
            dataset_train,
            callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
            dataset_sink_mode=False,)
Downloading data from https://download-mindspore.osinfra.cn/vision/classification/vit_b_16_224.ckpt (330.2 MB)file_sizes: 100%|████████████████████████████| 346M/346M [00:18<00:00, 18.7MB/s]
Successfully downloaded file to ./ckpt/vit_b_16_224.ckpt
epoch: 1 step: 125, loss is 2.4269836
Train epoch time: 254275.431 ms, per step time: 2034.203 ms
epoch: 2 step: 125, loss is 1.6255778
Train epoch time: 27142.445 ms, per step time: 217.140 ms
epoch: 3 step: 125, loss is 1.3483672
Train epoch time: 25470.780 ms, per step time: 203.766 ms
epoch: 4 step: 125, loss is 1.1531123
Train epoch time: 25066.077 ms, per step time: 200.529 ms
epoch: 5 step: 125, loss is 1.5591685
Train epoch time: 24959.266 ms, per step time: 199.674 ms
epoch: 6 step: 125, loss is 1.164649
Train epoch time: 24743.208 ms, per step time: 197.946 ms
epoch: 7 step: 125, loss is 1.2841825
Train epoch time: 25585.884 ms, per step time: 204.687 ms
epoch: 8 step: 125, loss is 1.3261327
Train epoch time: 24364.958 ms, per step time: 194.920 ms
epoch: 9 step: 125, loss is 1.2463571
Train epoch time: 24698.548 ms, per step time: 197.588 ms
epoch: 10 step: 125, loss is 1.1417197
Train epoch time: 25268.490 ms, per step time: 202.148 ms

模型验证

模型验证过程主要应用了ImageFolderDataset,CrossEntropySmooth和Model等接口。

ImageFolderDataset主要用于读取数据集。

CrossEntropySmooth是损失函数实例化接口。

Model主要用于编译模型。

与训练过程相似,首先进行数据增强,然后定义ViT网络结构,加载预训练模型参数。随后设置损失函数,评价指标等,编译模型后进行验证。本案例采用了业界通用的评价标准Top_1_Accuracy和Top_5_Accuracy评价指标来评价模型表现。

在本案例中,这两个指标代表了在输出的1000维向量中,以最大值或前5的输出值所代表的类别为预测结果时,模型预测的准确率。这两个指标的值越大,代表模型准确率越高。

[11]:

 
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)
trans_val = [
    transforms.Decode(),
    transforms.Resize(224 + 32),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]
dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)
# construct model
network = ViT()
# load ckpt
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)
# define metric
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),
                'Top_5_Accuracy': train.Top5CategoricalAccuracy()}
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")
# evaluate model
result = model.eval(dataset_val)
print(result)
{'Top_1_Accuracy': 0.7495, 'Top_5_Accuracy': 0.928}

从结果可以看出,由于我们加载了预训练模型参数,模型的Top_1_Accuracy和Top_5_Accuracy达到了很高的水平,实际项目中也可以以此准确率为标准。如果未使用预训练模型参数,则需要更多的epoch来训练。

模型推理

在进行模型推理之前,首先要定义一个对推理图片进行数据预处理的方法。该方法可以对我们的推理图片进行resize和normalize处理,这样才能与我们训练时的输入数据匹配。

本案例采用了一张Doberman的图片作为推理图片来测试模型表现,期望模型可以给出正确的预测结果。

[12]:

dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)
trans_infer = [
    transforms.Decode(),
    transforms.Resize([224, 224]),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]
dataset_infer = dataset_infer.map(operations=trans_infer,
                                  input_columns=["image"],
                                  num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)

接下来,我们将调用模型的predict方法进行模型。

在推理过程中,通过index2label就可以获取对应标签,再通过自定义的show_result接口将结果写在对应图片上。

[13]:

 
import os
import pathlib
import cv2
import numpy as np
from PIL import Image
from enum import Enum
from scipy import io
class Color(Enum):
    """dedine enum color."""
    red = (0, 0, 255)
    green = (0, 255, 0)
    blue = (255, 0, 0)
    cyan = (255, 255, 0)
    yellow = (0, 255, 255)
    magenta = (255, 0, 255)
    white = (255, 255, 255)
    black = (0, 0, 0)
def check_file_exist(file_name: str):
    """check_file_exist."""
    if not os.path.isfile(file_name):
        raise FileNotFoundError(f"File `{file_name}` does not exist.")
def color_val(color):
    """color_val."""
    if isinstance(color, str):
        return Color[color].value
    if isinstance(color, Color):
        return color.value
    if isinstance(color, tuple):
        assert len(color) == 3
        for channel in color:
            assert 0 <= channel <= 255
        return color
    if isinstance(color, int):
        assert 0 <= color <= 255
        return color, color, color
    if isinstance(color, np.ndarray):
        assert color.ndim == 1 and color.size == 3
        assert np.all((color >= 0) & (color <= 255))
        color = color.astype(np.uint8)
        return tuple(color)
    raise TypeError(f'Invalid type for color: {type(color)}')
def imread(image, mode=None):
    """imread."""
    if isinstance(image, pathlib.Path):
        image = str(image)
    if isinstance(image, np.ndarray):
        pass
    elif isinstance(image, str):
        check_file_exist(image)
        image = Image.open(image)
        if mode:
            image = np.array(image.convert(mode))
    else:
        raise TypeError("Image must be a `ndarray`, `str` or Path object.")
    return image
def imwrite(image, image_path, auto_mkdir=True):
    """imwrite."""
    if auto_mkdir:
        dir_name = os.path.abspath(os.path.dirname(image_path))
        if dir_name != '':
            dir_name = os.path.expanduser(dir_name)
            os.makedirs(dir_name, mode=777, exist_ok=True)
    image = Image.fromarray(image)
    image.save(image_path)
def imshow(img, win_name='', wait_time=0):
    """imshow"""
    cv2.imshow(win_name, imread(img))
    if wait_time == 0:  # prevent from hanging if windows was closed
        while True:
            ret = cv2.waitKey(1)
            closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
            # if user closed window or if some key pressed
            if closed or ret != -1:
                break
    else:
        ret = cv2.waitKey(wait_time)
def show_result(img: str,
                result: Dict[int, float],
                text_color: str = 'green',
                font_scale: float = 0.5,
                row_width: int = 20,
                show: bool = False,
                win_name: str = '',
                wait_time: int = 0,
                out_file: Optional[str] = None) -> None:
    """Mark the prediction results on the picture."""
    img = imread(img, mode="RGB")
    img = img.copy()
    x, y = 0, row_width
    text_color = color_val(text_color)
    for k, v in result.items():
        if isinstance(v, float):
            v = f'{v:.2f}'
        label_text = f'{k}: {v}'
        cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
                    font_scale, text_color)
        y += row_width
    if out_file:
        show = False
        imwrite(img, out_file)
    if show:
        imshow(img, win_name, wait_time)
def index2label():
    """Dictionary output for image numbers and categories of the ImageNet dataset."""
    metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
    meta = io.loadmat(metafile, squeeze_me=True)['synsets']
    nums_children = list(zip(*meta))[4]
    meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
    _, wnids, classes = list(zip(*meta))[:3]
    clssname = [tuple(clss.split(', ')) for clss in classes]
    wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
    wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
    mapping = {}
    for index, (_, class_name) in enumerate(wind2class_name):
        mapping[index] = class_name[0]
    return mapping
# Read data for inference
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = image["image"]
    image = ms.Tensor(image)
    prob = model.predict(image)
    label = np.argmax(prob.asnumpy(), axis=1)
    mapping = index2label()
    output = {int(label): mapping[int(label)]}
    print(output)
    show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG",
                result=output,
                out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")
{236: 'Doberman'}

推理过程完成后,在推理文件夹下可以找到图片的推理结果,可以看出预测结果是Doberman,与期望结果相同,验证了模型的准确性。

infer-result

总结

本案例完成了一个ViT模型在ImageNet数据上进行训练,验证和推理的过程,其中,对关键的ViT模型结构和原理作了讲解。通过学习本案例,理解源码可以帮助用户掌握Multi-Head Attention,TransformerEncoder,pos_embedding等关键概念,如果要详细理解ViT的模型原理,建议基于源码更深层次的详细阅读。

[14]:

 
import time
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),'guojun0718')
2024-07-16 08:19:33 guojun0718

[ ]:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/378139.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《ElementUI/Plus 基础知识》el-tree 之修改可拖拽节点的高亮背景和线

前言 收到需求&#xff0c;PM 觉得可拖拽节点的高亮背景和线样式不明显&#xff01;CSS 样式得改&#xff01; 注意&#xff1a;下述方式适用于ElementUI el-tree 和 ElementPlus el-tree&#xff01; 修改 拖拽被叠加节点的背景色和文字 关键类名 is-drop-inner .el-tree…

汽车零部件制造企业MES系统主要功能介绍

随着汽车工业的不断发展&#xff0c;汽车零部件制造企业面临着越来越高的生产效率、质量控制和成本管理要求。MES系统作为一种综合信息系统&#xff0c;能够帮助企业实现从订单接收到产品交付的全流程数字化管理&#xff0c;优化资源配置&#xff0c;提高生产效率&#xff0c;确…

Java中消耗掉换行符

scanner.nextLine(); // 消耗掉换行符 这行代码的作用是读取并丢弃输入流中的换行符。这是因为在使用 Scanner 对象读取用户输入时&#xff0c;有时候会在输入流中留下未处理的换行符&#xff0c;这可能会导致后续的输入读取出现问题。 具体来说&#xff0c;当你使用 Scanner …

vue2学习笔记7 - Vue中的MVVM模型

MVVM Model-View-viewModel是一种软件架构模式&#xff0c;用于将用户界面&#xff08;View&#xff09;与业务逻辑&#xff08;Model&#xff09;分离&#xff0c;并通过ViewModel进行连接和协调。MVVM模式的目标是实现视图与模型的解耦&#xff0c;提高代码的可读性、可维护…

django报错(一):python manage.py makemigrations,显示“No changes detected”

执行python manage.py makemigrations命令无任何文件生成&#xff0c;结果显示“No changes detected”。 解决方案一&#xff1a; 1、执行命令&#xff1a;python manage.py makemigrations –empty appname 2、删除其中的0001_initial.py文件&#xff08;因为这个文件内容是…

深度解析:disableHostCheck: true引发的安全迷局与解决之道

在Web开发的浩瀚星空中&#xff0c;开发者们时常会遇到各种配置与调优的挑战&#xff0c;其中disableHostCheck: true这一选项&#xff0c;在提升开发效率的同时&#xff0c;也悄然埋下了安全隐患的伏笔。本文将深入探讨这一配置背后的原理、为何会引发报错&#xff0c;以及如何…

Maven学习—如何在IDEA中配置Maven?又如何创建Maven工程?(详细攻略)

目录 前言 1.在IDEA中配置Maven 2.创建Maven项目 &#xff08;1&#xff09;Maven&#xff1a;创建普通Maven工程 &#xff08;2&#xff09;Maven Archetype&#xff1a;创建Maven模板工程 前言 本篇博客将详细的介绍在IDEA中如何配置Maven&#xff0c;以及如何创建一个Ma…

鸿蒙特色物联网实训室

一、 引言 在当今这个万物皆可连网的时代&#xff0c;物联网&#xff08;IoT&#xff09;正以前所未有的速度改变着我们的生活和工作方式。它如同一座桥梁&#xff0c;将实体世界与虚拟空间紧密相连&#xff0c;让数据成为驱动决策和创新的关键力量。随着物联网技术的不断成熟…

关于 Docker Registry (镜像仓库)

什么是镜像仓库 概念 镜像仓库&#xff08;Docker Registry&#xff09;负责存储、管理和分发镜像&#xff0c;并提供了登录认证能力&#xff0c;建立了仓库的索引。 镜像仓库管理多个 Repository&#xff0c;Repository 通过命名来区分。每个 Repository 包含一个或多个镜像…

WAF基础介绍

WAF 一、WAF是什么&#xff1f;WAF能够做什么 二 waf的部署三、WAF的工作原理 一、WAF是什么&#xff1f; WAF的全称是&#xff08;Web Application Firewall&#xff09;即Web应用防火墙&#xff0c;简称WAF。 国际上公认的一种说法是&#xff1a;Web应用防火墙是通过执行一…

[Labview] 表格单元格外边框 二维图片叠加绘图

最终效果如下所示 转行做Labview都没到三个月&#xff0c;主程居然让我做这么复杂的功能&#xff0c;真是看得起我/(ㄒoㄒ)/~~ 思路大致分为两步 1、确定每个框体的左上/右下单元格位置&#xff0c;转换为表格表格坐标并在二维图片上绘制生成&#xff1b; 2、为二维图片添加…

cuda缓存示意图

一、定义 cuda 缓存示意图gpu 架构示意图gpu 内存访问示意图 二、实现 cuda 缓存示意图 DRAM: 通常指的是GPU的显存&#xff0c;位于GPU芯片外部&#xff0c;通过某种接口&#xff08;如PCIE&#xff09;与GPU芯片相连。它是GPU访问的主要数据存储区域&#xff0c;用于存储大…

MBR40150FCT-ASEMI无人机专用MBR40150FCT

编辑&#xff1a;ll MBR40150FCT-ASEMI无人机专用MBR40150FCT 型号&#xff1a;MBR40150FCT 品牌&#xff1a;ASEMI 封装&#xff1a;TO-220F 批号&#xff1a;最新 最大平均正向电流&#xff08;IF&#xff09;&#xff1a;40A 最大循环峰值反向电压&#xff08;VRRM&a…

shell的变量及赋值

文章目录 一&#xff0c;shell变量是什么二&#xff0c;自定义变量1 . 查看和引用变量的值2 . echo选项3 . unset4 . 特殊符号5 . 括号 三&#xff0c;交互式定义变量1 . read2 . read 选项 四&#xff0c;变量的范围1 . export命令 五&#xff0c;数值变量的运算及特殊变量1 .…

鸿蒙Harmony--文本组件Text属性详解

金樽清酒斗十千&#xff0c;玉盘珍羞直万钱。 停杯投箸不能食&#xff0c;拔剑四顾心茫然。 欲渡黄河冰塞川&#xff0c;将登太行雪满山。 闲来垂钓碧溪上&#xff0c;忽复乘舟梦日边。 行路难&#xff0c;行路难&#xff0c;多歧路&#xff0c;今安在&#xff1f; 长风破浪会有…

浅析stm32启动文件

浅析stm32启动文件 文章目录 浅析stm32启动文件1.什么是启动文件&#xff1f;2.启动文件的命名规则3.stm32芯片的命名规则 1.什么是启动文件&#xff1f; 我们来看gpt给出的答案&#xff1a; STM32的启动文件是一个关键的汇编语言源文件&#xff0c;它负责在微控制器上电或复位…

探索数据结构与算法的奇妙世界 —— Github开源项目推荐《Hello 算法》

在浩瀚的编程与计算机科学领域中&#xff0c;数据结构与算法无疑是每位开发者攀登技术高峰的必经之路。然而&#xff0c;对于初学者而言&#xff0c;这条路往往布满了荆棘与挑战。幸运的是&#xff0c;今天我要向大家推荐一个令人振奋的项目——《Hello Algo》&#xff0c;它正…

在 Vue3 + Electron 中使用预加载脚本(preload)

文章目录 一、什么是预加载脚本(preload)&#xff0c;为什么我们需要它二、通过预加载脚本暴露相关 API 至渲染进程1、实现获取系统默认桌面路径功能2、向剪切板写入内容3、使用系统默认浏览器访问目标 url4、使用文件选择对话框 三、参考资料 一、什么是预加载脚本(preload)&a…

负载测试和功率分析中负载箱的重要作用

在负载测试和功率分析中&#xff0c;负载箱扮演着至关重要的角色。以下是负载箱在这两个方面的重要作用&#xff1a; 一、负载测试中的重要作用 模拟实际负载条件&#xff1a; 负载箱能够模拟各种复杂的负载条件&#xff0c;包括电阻性负载、电感性负载、电容性负载等&#x…

【HarmonyOS开发】弹窗交互(promptAction )

实现效果 点击按钮实现不同方式的弹窗showToast showDialog showActionMenu 代码实现 1.引入’ohos.promptAction’ import promptAction from ohos.promptAction;2.通过promptAction 实现系统既定的弹窗 import promptAction from ohos.promptAction;Entry Component st…