昇思25天学习打卡营第16天 | Vision Transformer图像分类

昇思25天学习打卡营第16天 | Vision Transformer图像分类

文章目录

  • 昇思25天学习打卡营第16天 | Vision Transformer图像分类
    • Vision Transform(ViT)模型
      • Transformer
        • Attention模块
        • Encoder模块
      • ViT模型输入
    • 模型构建
      • Multi-Head Attention模块
      • Encoder模块
      • Patch Embedding模块
      • ViT网络
    • 总结
    • 打卡

Vision Transform(ViT)模型

ViT是NLP和CV领域的融合,可以在不依赖于卷积操作的情况下在图像分类任务上达到很好的效果。

ViT模型的主体结构是基于Transformer的Encoder部分。

Transformer

Transformer由很多Encoder和Decoder模块构成,包括多头注意力(Multi-Head Attention)层,Feed Forward层,Normalization层和残差连接(Residual Connection)。
encoder-decoder
多头注意力结构基于自注意力机制(Self-Attention),是多个Self-Attention的并行组成。

Attention模块

Attention的核心在于为输入向量的每个单词学习一个权重。

  1. 最初的输入向量首先经过Embedding层映射为Q(Query),K(Key),V(Value)三个向量。
  2. 通过将Q和所有K进行点乘初一维度平方根,得到向量间的相似度,通过softmax获取每词向量之间的关系权重。
  3. 利用关系权重对词向量的V加权求和,得到自注意力值。
    self-attention
    多头注意力机制只是对self-attention的并行化:
    multi-head-attention
Encoder模块

ViT中的Encoder相对于标准Transformer,主要在于将Normolization放在self-attention和Feed Forward之前,其他结构与标准Transformer相同。
vit-encoder

ViT模型输入

传统Transformer主要应用于自然语言处理的一维词向量,而图像时二维矩阵的堆叠。
在ViT中:

  1. 通过卷积将输入图像在每个channel上划分为 16 × 16 16\times 16 16×16个patch。如果输入 224 × 224 224\times224 224×224的图像,则每一个patch的大小为 14 × 14 14\times 14 14×14
  2. 将每一个patch拉伸为一个一维向量,得到近似词向量堆叠的效果。如将 14 × 14 14\times14 14×14展开为 196 196 196的向量。
    这一部分Patch Embedding用来替换Transformer中Word Embedding,用作网络中的图像输入。

模型构建

Multi-Head Attention模块

from mindspore import nn, opsclass Attention(nn.Cell):def __init__(self,dim: int,num_heads: int = 8,keep_prob: float = 1.0,attention_keep_prob: float = 1.0):super(Attention, self).__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = ms.Tensor(head_dim ** -0.5)self.qkv = nn.Dense(dim, dim * 3)self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)self.out = nn.Dense(dim, dim)self.out_drop = nn.Dropout(p=1.0-keep_prob)self.attn_matmul_v = ops.BatchMatMul()self.q_matmul_k = ops.BatchMatMul(transpose_b=True)self.softmax = nn.Softmax(axis=-1)def construct(self, x):"""Attention construct."""b, n, c = x.shapeqkv = self.qkv(x)qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))q, k, v = ops.unstack(qkv, axis=0)attn = self.q_matmul_k(q, k)attn = ops.mul(attn, self.scale)attn = self.softmax(attn)attn = self.attn_drop(attn)out = self.attn_matmul_v(attn, v)out = ops.transpose(out, (0, 2, 1, 3))out = ops.reshape(out, (b, n, c))out = self.out(out)out = self.out_drop(out)return out

Encoder模块

from typing import Optional, Dictclass FeedForward(nn.Cell):def __init__(self,in_features: int,hidden_features: Optional[int] = None,out_features: Optional[int] = None,activation: nn.Cell = nn.GELU,keep_prob: float = 1.0):super(FeedForward, self).__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.dense1 = nn.Dense(in_features, hidden_features)self.activation = activation()self.dense2 = nn.Dense(hidden_features, out_features)self.dropout = nn.Dropout(p=1.0-keep_prob)def construct(self, x):"""Feed Forward construct."""x = self.dense1(x)x = self.activation(x)x = self.dropout(x)x = self.dense2(x)x = self.dropout(x)return xclass ResidualCell(nn.Cell):def __init__(self, cell):super(ResidualCell, self).__init__()self.cell = celldef construct(self, x):"""ResidualCell construct."""return self.cell(x) + xclass TransformerEncoder(nn.Cell):def __init__(self,dim: int,num_layers: int,num_heads: int,mlp_dim: int,keep_prob: float = 1.,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: nn.Cell = nn.LayerNorm):super(TransformerEncoder, self).__init__()layers = []for _ in range(num_layers):normalization1 = norm((dim,))normalization2 = norm((dim,))attention = Attention(dim=dim,num_heads=num_heads,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob)feedforward = FeedForward(in_features=dim,hidden_features=mlp_dim,activation=activation,keep_prob=keep_prob)layers.append(nn.SequentialCell([ResidualCell(nn.SequentialCell([normalization1, attention])),ResidualCell(nn.SequentialCell([normalization2, feedforward]))]))self.layers = nn.SequentialCell(layers)def construct(self, x):"""Transformer construct."""return self.layers(x)

Patch Embedding模块

class PatchEmbedding(nn.Cell):MIN_NUM_PATCHES = 4def __init__(self,image_size: int = 224,patch_size: int = 16,embed_dim: int = 768,input_channels: int = 3):super(PatchEmbedding, self).__init__()self.image_size = image_sizeself.patch_size = patch_sizeself.num_patches = (image_size // patch_size) ** 2self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)def construct(self, x):"""Path Embedding construct."""x = self.conv(x)b, c, h, w = x.shapex = ops.reshape(x, (b, c, h * w))x = ops.transpose(x, (0, 2, 1))return x

ViT网络

from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameterdef init(init_type, shape, dtype, name, requires_grad):"""Init."""initial = initializer(init_type, shape, dtype).init_data()return Parameter(initial, name=name, requires_grad=requires_grad)class ViT(nn.Cell):def __init__(self,image_size: int = 224,input_channels: int = 3,patch_size: int = 16,embed_dim: int = 768,num_layers: int = 12,num_heads: int = 12,mlp_dim: int = 3072,keep_prob: float = 1.0,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: Optional[nn.Cell] = nn.LayerNorm,pool: str = 'cls') -> None:super(ViT, self).__init__()self.patch_embedding = PatchEmbedding(image_size=image_size,patch_size=patch_size,embed_dim=embed_dim,input_channels=input_channels)num_patches = self.patch_embedding.num_patchesself.cls_token = init(init_type=Normal(sigma=1.0),shape=(1, 1, embed_dim),dtype=ms.float32,name='cls',requires_grad=True)self.pos_embedding = init(init_type=Normal(sigma=1.0),shape=(1, num_patches + 1, embed_dim),dtype=ms.float32,name='pos_embedding',requires_grad=True)self.pool = poolself.pos_dropout = nn.Dropout(p=1.0-keep_prob)self.norm = norm((embed_dim,))self.transformer = TransformerEncoder(dim=embed_dim,num_layers=num_layers,num_heads=num_heads,mlp_dim=mlp_dim,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob,drop_path_keep_prob=drop_path_keep_prob,activation=activation,norm=norm)self.dropout = nn.Dropout(p=1.0-keep_prob)self.dense = nn.Dense(embed_dim, num_classes)def construct(self, x):"""ViT construct."""x = self.patch_embedding(x)cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))x = ops.concat((cls_tokens, x), axis=1)x += self.pos_embeddingx = self.pos_dropout(x)x = self.transformer(x)x = self.norm(x)x = x[:, 0]if self.training:x = self.dropout(x)x = self.dense(x)return x

总结

这一节对Transformer进行介绍,包括Attention机制、并行化的Attention以及Encoder模块。由于传统Transformer主要作用于一维的词向量,因此二维图像需要被转换为类似的一维词向量堆叠,在ViT中通过将Patch Embedding解决这一问题,并用来代替传统Transformer中的Word Embedding作为网络的输入。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/380776.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【中项第三版】系统集成项目管理工程师 | 第 5 章 软件工程① | 5.1 - 5.3

前言 第5章对应的内容选择题和案例分析都会进行考查,这一章节属于技术的内容,学习要以教材为准。 目录 5.1 软件工程定义 5.2 软件需求 5.2.1 需求的层次 5.2.2 质量功能部署 5.2.3 需求获取 5.2.4 需求分析 5.2.5 需求规格说明书 5.2.6 需求变…

工业三防平板助力工厂生产数据实时管理

在当今高度数字化和智能化的工业生产环境中,工业三防平板正逐渐成为工厂实现生产数据实时管理的得力助手。这种创新的技术设备不仅能够在恶劣的工业环境中稳定运行,还为工厂的生产流程优化、效率提升和质量控制带来了前所未有的机遇。 工业生产场景通常充…

保护模式下的分页

4KB页 4KB页的构成 该分页方式下,32 位虚拟地址被分为三个位段:页目录索引、页表索引、页内偏移 只有一级页目录,其中包含 1024 个条目 ,每个条目指向一个页表,每个页表中有 1024 个条目 其中一个条目就指向一个物理…

智能听诊器:宠物健康监测的革新者

宠物健康护理领域迎来了一项激动人心的技术革新——智能听诊器。这款创新设备以其卓越的精确度和用户友好的操作,为宠物主人提供了一种全新的健康监测方法。 使用智能听诊器时,只需将其放置在宠物身上,它便能立即捕捉到宠物胸腔的微小振动。…

9.11和9.9哪个大?

没问题 文心一言 通义千问

ThinkPad改安装Windows7系统的操作步骤

ThinkPad:改安装Windows7系统的操作步骤 一、BIOS设置 1、先重新启动计算机,并按下笔记本键盘上“F1”键进入笔记本的BIOS设置界面。 2、进入BIOS设置界面后,按下键盘上“→”键将菜单移动至“Restart“项目,按下键盘上“↓”按键…

frp反向代理的安装与配置、ftp服务的搭建及应用

1、frp简介 frp 是⼀个开源、简洁易⽤、⾼性能的内⽹穿透和反向代理软件,⽀持 tcp, udp, http, https等 协议。frp 项⽬官⽹是 https://github.com/fatedier/frp 2、frp⼯作原理 服务端运⾏,监听⼀个主端⼝,等待客户端的连接; …

【Godot4.2】SVGParser - SVG解析器函数库

概述 这是一个基于GDScript内置XMLParser编写的简易SVG文件解析函数库。 目的就是可以将SVG文件解析为GDSCript可以处理的字典或DOM形式,方便SVG渲染和编辑。 目前还只是一个简易实现版本。还需要一些改进。 函数库源码 # # 名称:SVGParser # 类型…

【开源库学习】libodb库学习(三)

4 查询数据库 如果我们不知道我们正在寻找的对象的标识符,我们可以使用查询在数据库中搜索符合特定条件的对象。ODB查询功能是可选的,我们需要使用--generate-query ODB编译器选项显式请求生成必要的数据库支持代码。 ODB提供了一个灵活的查询API&#x…

【机器学习实战】Datawhale夏令营2:深度学习回顾

#DataWhale夏令营 #ai夏令营 文章目录 1. 深度学习的定义1.1 深度学习&图神经网络1.2 机器学习和深度学习的关系 2. 深度学习的训练流程2.1 数学基础2.1.1 梯度下降法基本原理数学表达步骤学习率 α梯度下降的变体 2.1.2 神经网络与矩阵网络结构表示前向传播激活函数…

GESP CCF 图形化编程四级认证真题 2024年6月

一、单选题(共 10 题,每题 2 分,共 30 分) 题号 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 答案 C B C D C D A B D C C D A A B 1、小…

整合Tess4J实现OCR图片识别技术

文章目录 1. 什么是OCR2. 什么是Tess4J库?3. 引入依赖4. 下载默认的训练数据5. 配置训练数据的目录路径6. 测试代码6.1 TesseractOcrConfig6.2 OcrController6.3 OcrService6.4 OcrServiceImpl 7. 功能测试7.1 调试请求接口7.2 测试结果 1. 什么是OCR **OCR (Optic…

Spark的动态资源分配算法

文章目录 前言基于任务需求进行资源请求的整体过程资源申请的生成过程详解资源申请的生成过程的简单例子资源调度算法的代码解析 申请资源以后的处理:Executor的启动或者结束对于新启动的Container的处理对于结束的Container的处理 基于资源分配结果进行任务调度Pen…

AI算法22-决策树算法Decision Tree | DT

目录 决策树算法概述 决策树算法背景 决策树算法简介 决策树算法核心思想 决策树算法的工作过程 特征选择 信息增益 信息增益比 决策树的生成 ID3算法 C4.5的生成算法 决策树的修剪 算法步骤 决策树算法的代码实现 决策树算法的优缺点 优点 缺点 决策树算法的…

Unity游戏开发入门:从安装到创建你的第一个3D场景

目录 引言 一、Unity的安装 1. 访问Unity官网 2. 下载Unity Hub 3. 安装Unity Hub并安装Unity编辑器 二、创建你的第一个项目 1. 启动Unity Hub并创建新项目 2. 熟悉Unity编辑器界面 3. 添加基本对象 4. 调整对象属性 5. 添加光源 三、运行与预览 引言 Unity&…

html 单页面引用vue3和element-plus

引入方式: element-plus基于vue3.0,所以必须导入vue3.0的js文件,然后再导入element-plus自身所需的js以及css文件,导入文件有两种方法:外部引用、下载本地使用 通过外部引用ElementPlus的css和js文件 以及Vue3.0文件 …

冒泡,选择,插入,希尔排序

目录 一. 冒泡排序 1. 算法思想 2. 时间复杂度与空间复杂度 3. 代码实现 二. 选择排序 1. 算法思想 2. 时间复杂度与空间复杂度 3. 代码实现 三.插入排序 1. 直接插入排序 (1). 算法思想 (2). 时间复杂度与空间复杂度 (3). 代码实现 2. 希尔排序 (1). 算法思想 …

昇思25天学习打卡营第23天 | 基于MindSpore的红酒分类实验

学习心得:基于MindSpore的红酒分类实验 在机器学习的学习路径中,理解和实践经典算法是非常重要的一步。最近我进行了一个有趣的实验,使用MindSpore框架实现了K近邻(KNN)算法进行红酒分类。这个实验不仅加深了我对KNN算…

云手机结合自主ADB命令接口 提升海外营销效率

现在,跨境电商直播已经成为在线零售的重要渠道,在大环境下,确保直播应用的稳定性和用户体验至关重要。 云手机支持自主ADB命令接口,为电商直播营销提供了技术支持,使得应用开发、测试、优化和运维更加高效。 什么是A…

卷积神经网络学习问题总结

问题一: 深度学习中的损失函数和应用场景 回归任务: 均方误差函数(MSE)适用于回归任务,如预测房价、预测股票价格等。 import torch.nn as nn loss_fn nn.MSELoss() 分类任务: 交叉熵损失函数&…