【Transformer系列(4)】基于vision transformer(ViT)实现猫狗二分类项目实战

文章目录

  • 一、vision transformer(ViT)结构解释
  • 二、Patch Embedding部分
    • 2.1 图像Patch化
    • 2.2 cls token
    • 2.3 位置编码(positional embedding)
  • 三、Transformer Encoder部分
    • (1) Multi-head Self-Attention
    • (2) encoder block
  • 四、head部分
  • 五、vision transformer(ViT)完整代码
  • 六、基于vision transformer(ViT)实现猫狗二分类项目实战


一、vision transformer(ViT)结构解释

vision transformerViT)结构大致流程如下图

+------------+       +--------------+
|   Input    | ----> |    Patch     |
+------------+       +--------------+|v+-------+|  Embed  |+-------+|v+-------------------+|   Transformer     |+-------------------+|v+-------+|  Pool |+-------+|v+-------+|  MLP  |+-------+|v+-------+|  Class|+-------+|vOutput

Vision TransformrerViT)是一种基于自注意力机制的图像分类模型,它试图将图像分类任务转化为自然语言处理中的序列建模问题。与传统的卷积神经网络不同,ViT使用Transformer作为它的基本结构。

ViT的整体结构可以分为两个部分:Patch EmbeddingTransformer Encoder

Patch Embedding阶段,输入的图像首先被划分为多个小的固定尺寸的图像块,称为patch。每个patch经过一个线性投影层和一个位置编码层得到相应的向量表示。这些向量表示被展平为序列并通过一个可训练的嵌入层得到输入序列。

Transformer Encoder阶段,输入序列通过多个堆叠的Transformer Encoder层进行处理。每个Transformer Encoder层由多个注意力机制和多层感知机组成。注意力机制用于捕捉全局和局部的上下文信息,通过计算输入序列中不同位置的相互关系来获取注意力权重。多层感知机则用于在每个位置上对向量进行非线性转换。

ViT的最后,经过多个Transformer Encoder层处理后的序列经过一个全局平均池化层得到固定长度的表示,再通过一个线性分类层进行分类预测。

总的来说,ViT的结构利用自注意力机制,将输入的图像转化为序列,并通过多个Transformer Encoder层对序列进行处理,最后通过全局平均池化和线性分类层得到分类结果。这种结构在图像分类任务上取得了不错的性能,并且能够处理较大尺寸的图像。

二、Patch Embedding部分

Patch Embedding部分主要由(1)图像Patch化(2)cls token(3)positional embeding构成

2.1 图像Patch化

在这里插入图片描述
代码实现

# 序列组合位置编码
class PatchEmbed(nn.Module):def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_features=768, norm_layer=None, flatten=True):super().__init__()# 196 = (224 // 16) * (224 // 16) Patch化self.num_patches    = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)# Trueself.flatten        = flatten# 注意: kernel_size = stride 才能实现patch之间不相交self.proj = nn.Conv2d(in_chans, num_features, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(num_features) if norm_layer else nn.Identity()def forward(self, x):# Step 1. Patch using Conv2d with 'kernel_size = patch_size'# [1,3,224,224] -> [1,768,14,14]x = self.proj(x)if self.flatten:# x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC# Step 2. H*W -> N 宽高维度平铺,形成序列# BCHW -> BCN   N = H*W# [1,768,14,14] -> [1,768,196]x = x.flatten(2)# BCN -> BNC 交换1维和2维,即CN transpose一次只能对两个维度进行操作# [1,768,196] -> [1,196,768]x = x.transpose(1, 2)x = self.norm(x)return x

在这里插入图片描述
在这里插入图片描述
PatchEmbed之后,维度由[1,3,224,224]变为了[1,196,768]

2.2 cls token

ViT模型中,每个图块都经过一系列的Transformer编码器层,这些编码器层处理图块之间的局部关系。而cls token则是在第一个编码器层的输入中插入的一个特殊令牌。它作为整个图像的表示引入了全局信息。

cls token的计算方式与其他图块的计算方式相同,它经过自注意力机制和前馈神经网络进行特征转换。然后,将经过编码器层处理后的cls token的输出连接到分类器中,用于图像分类任务的最终预测。

cls token的作用是捕捉整个图像的全局特征。因为Transformer模型是一种自注意力模型,并没有显式的全局信息概念,cls token的引入可以将整个图像的特征聚合成一个向量,使得模型具备对整个图像的全局理解能力。这样,模型就可以利用cls token的特征进行分类任务的预测。
代码实现

# [1,1,768]
cls_token = self.cls_token.expand(batch_size, -1, -1)
# H*W+1
# [1,196,768] -> [1,197,768]
x = torch.cat((cls_token, x), dim=1)

在这里插入图片描述

cls token之后,维度由[1,196,768]变为了[1,197,768]

2.3 位置编码(positional embedding)

由于ViT是基于自注意力机制(self-attention mechanism)构建的,它无法直接处理序列中项目的顺序信息。
位置编码通常是通过将位置信息转换为向量形式,然后将其添加到输入图像的嵌入表示中来实现的。这样,每个嵌入向量就会包含图像中的位置信息。位置编码的加入可以帮助模型在处理图像时更好地理解不同位置之间的关系,它使得模型能够关注图像中的全局和局部结构,从而更好地对图像进行建模。

总结起来,位置编码在ViT中的作用是引入图像中不同位置的位置关系,以帮助模型理解和处理图像中的全局和局部结构。
代码实现

# [1,196+1,768] -> [1,196,768]
img_token_pe = self.pos_embed[:, 1:, :]# old_feature_shape: [1,196,768] -> [1,14,14,768] -> [1,768,14,14]
img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
# new_feature_shape: [1,768,14,14] -> [1,768,14,14]
img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
# [1,768,14,14] -> [1,14,14,768] -> [1,196,768]
img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
# [1,1,768] cat [1,196,768] -> [1,197,768]
pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)# Step 4. residual connection + droppath
# [1,197,768] + [1,197,768] -> [1,197,768]
x = self.pos_drop(x + pos_embed)

位置编码(positional embedding)之后,维度还是[1,197,768],没有改变。

三、Transformer Encoder部分

Encoder部分主要由下面流程构成
(1)Multi-head Self-Attention:在Encoder block中,每个patch embedding都会与其他所有patch embeddings进行注意力计算。这种注意力计算将每个patch embedding与其他patch embeddings进行交互,从而使每个patch能够“看到”其他patch的信息。这种注意力计算可以通过独立的多头注意力机制实现,其中每个注意力头都可以学习不同的关注模式。
(2)Layer Normalization:在注意力计算之后,对每个patch embedding进行层归一化操作,以减少信息波动。
(3)Feed-Forward Network:在层归一化之后,通过一个全连接前馈网络,对每个patch的特征进行非线性转换。这个前馈网络可以是多层感知机(MLP),可以通过两个线性变换和一个激活函数来实现。
(4)Residual Connection:Encoder block中的每个操作都有一个残差连接,将输入与输出相加,以保留输入的信息。
(5)Layer Normalization:在前馈网络之后,再次对每个patch embedding进行层归一化操作。
(6)Dropout:为了防止过拟合,可以在Encoder block中应用dropout操作,以随机丢弃一部分特征。

(1) Multi-head Self-Attention

原理参考:【Transformer系列(2)】Multi-head self-attention 多头自注意力
代码实现

# multi-head self-attention
class Attention(nn.Module):def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):super().__init__()# multi-head,这里有点类似分组卷积self.num_heads  = num_heads# 尺度self.scale      = (dim // num_heads) ** -0.5# qkv通过Linear生成self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop  = nn.Dropout(attn_drop)self.proj       = nn.Linear(dim, dim)self.proj_drop  = nn.Dropout(proj_drop)def forward(self, x):# Step 1. get qkv# N=W*HB, N, C     = x.shape# [B,N,3,num_heads,C//num_heads] -> [3,B,num_heads,N,//num_heads]qkv         = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v     = qkv[0], qkv[1], qkv[2]# Step 2. get attention# q*k的转置,再除以根号维度attn = (q @ k.transpose(-2, -1)) * self.scale# softmax就是attentionattn = attn.softmax(dim=-1)# dropout,随机失活attn = self.attn_drop(attn)# Step 3. use attention on v# 注意力乘vx = (attn @ v).transpose(1, 2).reshape(B, N, C)# Linearx = self.proj(x)# dropout,随机失活x = self.proj_drop(x)return x

(2) encoder block

代码实现

# 多个block组成
self.blocks = nn.Sequential(*[Block(dim=num_features,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,drop=drop_rate,attn_drop=attn_drop_rate,drop_path=dpr[i],norm_layer=norm_layer,act_layer=act_layer) for i in range(depth)]
)

Transformer Encoder之后,维度还是[1,197,768],没有改变。

四、head部分

head部分,(1)会对encoder输出的[1,197,768]向量进行归一化,(2)然后再取出cls token,(3)再将cls token送入Linear层。
(1)归一化

# [1,197,768]
x = self.norm(x)

(2)取出cls token

#  get cls_token 768类似channel
# [1,197,768] -> [1,768]
x= x[:, 0]

维度变化:1,197,768] -> [1,768]
(3)送入Linar层

self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
# [1,768] -> [1,2] 2分类问题
x = self.head(x)

维度变化:[1,768] -> [1,2] 2分类问题

五、vision transformer(ViT)完整代码

class VisionTransformer(nn.Module):def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU):super().__init__()# [224, 224, 3] -> [196, 768]# 196 = (224 // 16) * (224 // 16) Patch化后,再平铺# 768 = 16 * 16 * 3   input_channel = 3 HW分别缩放16倍 output_channel拓宽16*16倍self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans,num_features=num_features)# kernel_size = stride# 196 = (224 // 16) * (224 // 16)num_patches = (224 // patch_size) * (224 // patch_size)self.num_features = num_features# new feature shape: [14,14]self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]# old feature shape: [14,14]self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]# --------------------------------------------------------------------------------------------------------------------##   classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。##   在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。#   此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。#   在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。# --------------------------------------------------------------------------------------------------------------------##   196, 768 -> 197, 768# [1,1,768]self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))# --------------------------------------------------------------------------------------------------------------------##   为网络提取到的特征添加上位置信息。#   以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768#   196 = (224 // 16) * (224 //16) 768 = 16 * 16 * 3#   此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。# --------------------------------------------------------------------------------------------------------------------##   197, 768 -> 197, 768# [1,196,768] -> [1,196+1,768]self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))self.pos_drop = nn.Dropout(p=drop_rate)# -----------------------------------------------##   197, 768 -> 197, 768  12次# -----------------------------------------------## 0~drop_path_rate的等差数列,12位dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]# 多个block组成self.blocks = nn.Sequential(*[Block(dim=num_features,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,drop=drop_rate,attn_drop=attn_drop_rate,drop_path=dpr[i],norm_layer=norm_layer,act_layer=act_layer) for i in range(depth)])self.norm = norm_layer(num_features)self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()def forward_features(self, x):# Step 1. 序列化:先Patch,再HW平铺# [B,C,H,W] -> BNC N=H*Wx = self.patch_embed(x)# Step 2. 增加cls token# -1:当前维度不拓展# 1batch_size = x.shape[0]# [1,1,768]cls_token = self.cls_token.expand(batch_size, -1, -1)# H*W+1# [1,196,768] -> [1,197,768]x = torch.cat((cls_token, x), dim=1)# Step 3. 位置编码# [1,196+1,768] -> [1,1,768]cls_token_pe = self.pos_embed[:, 0:1, :]# [1,196+1,768] -> [1,196,768]img_token_pe = self.pos_embed[:, 1:, :]# old_feature_shape: [1,196,768] -> [1,14,14,768] -> [1,768,14,14]img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)# new_feature_shape: [1,768,14,14] -> [1,768,14,14]img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)# [1,768,14,14] -> [1,14,14,768] -> [1,196,768]img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)# [1,1,768] cat [1,196,768] -> [1,197,768]pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)# Step 4. residual connection + droppath# [1,197,768] + [1,197,768] -> [1,197,768]x = self.pos_drop(x + pos_embed)# Step 5. multi-head self_attention# [1, 197, 768]x = self.blocks(x)# Step 6. layers_norm# [1,197,768]x = self.norm(x)# Step 7. get cls_token 768类似channel# [1,197,768] -> [1,768]x= x[:, 0]return xdef forward(self, x):# Step 1~6.# [1,3,224,224] -> [1,768]x = self.forward_features(x)# Step 7. Linear# [1,768] -> [1,2] 2分类问题x = self.head(x)return x

六、基于vision transformer(ViT)实现猫狗二分类项目实战

项目链接:https://download.csdn.net/download/m0_51579041/89255878
数据集链接:https://download.csdn.net/download/m0_51579041/89255922

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/318796.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

小程序账号设置以及request请求的封装

一般开发在小程序时,都会有测试版和正式版,这样在开发时会比较方便。 在开发时。产品经理都会给到测试账号和正式账号,后端给的接口也都会有测试环境用到的接口和正式环境用到的接口。 这里讲一讲我这边如何去做的。 1.在更目录随便命名一…

langchain+qwen1.5-7b-chat搭建本地RAG系统

概念 检索增强生成(Retrieval Augmented Generation, RAG)是一种结合语言模型和信息检索的技术,用于生成更准确且与上下文相关的输出。 通用模型遇到的问题,也是RAG所擅长的: 知识的局限性: RAG 通过从知识库、数据库、企业内部数据等外部数据源中检索相关信息,将其注…

物联网实战--平台篇之(二)基础搭建

目录 一、Qt工程创建 二、数据库知识 三、通信协议 四、名词定义 本项目的交流QQ群:701889554 物联网实战--入门篇https://blog.csdn.net/ypp240124016/category_12609773.html 物联网实战--驱动篇https://blog.csdn.net/ypp240124016/category_12631333.html 一、Qt工程…

nginx--压缩https证书favicon.iconginx隐藏版本号 去掉nginxopenSSL

压缩功能 简介 Nginx⽀持对指定类型的⽂件进行压缩然后再传输给客户端,而且压缩还可以设置压缩比例,压缩后的文件大小将比源文件显著变小,这样有助于降低出口带宽的利用率,降低企业的IT支出,不过会占用相应的CPU资源…

VTK —— 二、教程六 - 为模型加入3D微件(按下i键隐藏或显示)(附完整源码)

代码效果 本代码编译运行均在如下链接文章生成的库执行成功,若无VTK库则请先参考如下链接编译vtk源码: VTK —— 一、Windows10下编译VTK源码,并用Vs2017代码测试(附编译流程、附编译好的库、vtk测试源码) 教程描述 本…

运维笔记:基于阿里云跨地域服务器通信(上)

运维笔记 阿里云:跨地域服务器通信(上) - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite:http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this a…

算法打卡day40

今日任务: 1)139.单词拆分 2)多重背包理论基础(卡码网56携带矿石资源) 3)背包问题总结 4)复习day15 139单词拆分 题目链接:139. 单词拆分 - 力扣(LeetCode) …

【Node.js工程师养成计划】之express框架

一、Express 官网:http://www.expressjs.com.cn express 是一个基于内置核心 http 模块的,一个第三方的包,专注于 web 服务器的构建。 Express 是一个简洁而灵活的 node.js Web应用框架, 提供了一系列强大特性帮助你创建各种 Web 应用&…

网络安全知识点

网络安全 1. 网络安全的定义,网络安全的属性。 定义:针对各种网络安全威胁研究其安全策略和机制,通过防护、检测和响应,确保网络系统及数据的安全性。 属性:机密性 认证(可鉴别性&#xff09…

【Leetcode每日一题】 分治 - 排序数组(难度⭐⭐)(69)

1. 题目解析 题目链接:912. 排序数组 这个问题的理解其实相当简单,只需看一下示例,基本就能明白其含义了。 2.算法原理 归并排序(Merge Sort)是一种采用“分而治之”(Divide and Conquer)策略…

解决RTC内核驱动的问题bm8563

常用pcf-8563 , 国产平替BM8563(驱动管脚一致); 实时时钟是很常用的一个外设,通过实时时钟我们就可以知道年、月、日和时间等信息。 因此在需要记录时间的场合就需要实时时钟,可以使用专用的实时时钟芯片来完成此功能 RTC 设备驱动是一个标准…

【webrtc】MessageHandler 4: 基于线程的消息处理:以Fake 收发包模拟为例

G:\CDN\rtcCli\m98\src\media\base\fake_network_interface.h// Fake NetworkInterface that sends/receives RTP/RTCP packets.虚假的网络接口,用于模拟发送包、接收包单纯仅是处理一个ST_RTP包 消息的id就是ST_RTP 类型,– 然后给到目的地:mediachannel处理: 最后消息消…

rust前端web开发框架yew使用

构建完整基于 rust 的 web 应用,使用yew框架 trunk 构建、打包、发布 wasm web 应用 安装后会作为一个系统命令,默认有两个特性开启 rustls - 客户端与服务端通信的 tls 库update_check - 用于应用启动时启动更新检查,应用有更新时提示用户更新。nati…

【LeetCode刷题】410. 分割数组的最大值

1. 题目链接2. 题目描述3. 解题方法4. 代码 1. 题目链接 410. 分割数组的最大值 2. 题目描述 3. 解题方法 题目中提到的是某个和的最大值是最小的,这种题目是可以用二分来解决的。 确定区间,根据题目的数据范围,可以确定区间就是[0, 1e9]…

【华为 ICT HCIA eNSP 习题汇总】——题目集20

1、(多选)若两个虚拟机能够互相ping通,则通讯过程中会使用()。 A、虚拟网卡 B、物理网卡 C、物理交换机 D、分布式虚拟交换机 考点:数据通信 解析:(AD) 物理网卡是硬件设…

基于SSM SpringBoot vue宾馆网上预订综合业务服务系统

基于SSM SpringBoot vue宾馆网上预订综合业务服务系统 系统功能 首页 图片轮播 宾馆信息 饮食美食 休闲娱乐 新闻资讯 论坛 留言板 登录注册 个人中心 后台管理 登录注册 个人中心 用户管理 客房登记管理 客房调整管理 休闲娱乐管理 类型信息管理 论坛管理 系统管理 新闻资讯…

记录一下安装cv2的过程

python安装cv2库(命令行安装法,每一步都可复制命令,非常贴心!),手把手安装-CSDN博客 主要是参考的这篇文章 pip install opencv-python关键命令就是这一行,会比较慢 加上清华源吧

Mybatis.net + Mysql

项目文件结构 NuGet下载Mybatis.net相关包:IBatisNet 安装完成后,会显示在,在已安装页面。同时,在管理器中的引用列表中,会多出来两个引用文件 IBatisNet.CommonIBatisNet.DataMapper 安装 Mysql.data。 注意&#xff…

AJ-Report开源数据大屏 verification;swagger-ui RCE漏洞复现

0x01 产品简介 AJ-Report是一个完全开源的BI平台,酷炫大屏展示,能随时随地掌控业务动态,让每个决策都有数据支撑。多数据源支持,内置mysql、elasticsearch、kudu等多种驱动,支持自定义数据集省去数据接口开发,支持17+种大屏组件,不会开发,照着设计稿也可以制作大屏。三…

亚马逊云科技AWS免费证书-EC2服务器设计(含题库)

亚马逊云AWS官方程序员专属免费证书又来了!这次证书是关于AWS EC2实例的设计和搭建,EC2作为AWS服务的核心,是学好AWS的第一步。强推没有任何AWS背景和转码的小伙伴去学!学完也能变成AWS开发大神! 证书名字叫Getting St…