ViT的极简pytorch实现及其即插即用

先放一张ViT的网络图
在这里插入图片描述
可以看到是把图像分割成小块,像NLP的句子那样按顺序进入transformer,经过MLP后,输出类别。每个小块是16x16,进入Linear Projection of Flattened Patches, 在每个的开头加上cls token和位置信息,也就是position embedding。
去掉数据读取部分,直接上一个极简的ViT代码:

import torch
from torch import nnfrom einops import rearrange, repeat
from einops.layers.torch import Rearrange# helpersdef pair(t):return t if isinstance(t, tuple) else (t, t)# classesclass PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout = 0.):super().__init__()self.net = nn.Sequential(nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)class Attention(nn.Module):def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):super().__init__()inner_dim = dim_head *  headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim = -1)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):qkv = self.to_qkv(x).chunk(3, dim = -1)## 对tensor张量分块 x :1 197 1024   qkv 最后是一个元祖,tuple,长度是3,每个元素形状:1 197 1024q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)out = torch.matmul(attn, v)out = rearrange(out, 'b h n d -> b n (h d)')return self.to_out(out)class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):super().__init__()self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn xclass ViT(nn.Module):def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):super().__init__()image_height, image_width = pair(image_size)   # 224*224patch_height, patch_width = pair(patch_size)   # 16 * 16assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'num_patches = (image_height // patch_height) * (image_width // patch_width)patch_dim = channels * patch_height * patch_widthassert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'self.to_patch_embedding = nn.Sequential(# (b,3,224,224) -> (b,196,768)    14*14=196  16*16*3=768Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),nn.Linear(patch_dim, dim),    # (b,196,1024))self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))self.cls_token = nn.Parameter(torch.randn(1, 1, dim))self.dropout = nn.Dropout(emb_dropout)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)self.pool = poolself.to_latent = nn.Identity()self.mlp_head = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, num_classes))def forward(self, img):x = self.to_patch_embedding(img)        # img 1 3 224 224  输出形状x : 1 196 1024b, n, _ = x.shape                       # 1 196cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)    # (1,1,1024)x = torch.cat((cls_tokens, x), dim=1)   # (1,197,1024)x += self.pos_embedding[:, :(n + 1)]    # (1,197,1024)x = self.dropout(x)                     # (1,197,1024)x = self.transformer(x)                 # (1,197,1024)x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]     # (1,1024)x = self.to_latent(x)      # (1,1024)return self.mlp_head(x)    # (1,1000)if __name__ == '__main__':v = ViT(image_size = 224,patch_size = 16,num_classes = 1000,dim = 1024,depth = 6,heads = 16,mlp_dim = 2048,dropout = 0.1,emb_dropout = 0.1)img = torch.randn(1, 3, 224, 224)preds = v(img)        # (1, 1000)print(preds.shape)

去掉cls和最后的全连接分类头,变成即插即用的模块:

import torch
from torch import nnfrom einops import rearrange
from einops.layers.torch import Rearrange# helpersdef pair(t):return t if isinstance(t, tuple) else (t, t)# classesclass PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout = 0.):super().__init__()self.net = nn.Sequential(nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)class Attention(nn.Module):def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):super().__init__()inner_dim = dim_head *  headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim = -1)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):qkv = self.to_qkv(x).chunk(3, dim = -1)## 对tensor张量分块 x :1 197 1024   qkv 最后是一个元祖,tuple,长度是3,每个元素形状:1 197 1024q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)out = torch.matmul(attn, v)out = rearrange(out, 'b h n d -> b n (h d)')return self.to_out(out)class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):super().__init__()self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn xclass ViT(nn.Module):def __init__(self, *, image_size, patch_size, dim = 1024, depth = 3, heads = 16, mlp_dim = 2048, dim_head = 64, dropout = 0.1, emb_dropout = 0.1):super().__init__()channels, image_height, image_width = image_size   # 256,64,80patch_height, patch_width = pair(patch_size)       # 4*4assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'num_patches = (image_height // patch_height) * (image_width // patch_width)     # 16*20patch_dim = 64 * patch_height * patch_width    # 64*8*10self.conv1 = nn.Conv2d(256, 64, 1)self.to_patch_embedding = nn.Sequential(# (b,64,64,80) -> (b,320,1024)    16*20=320  4*4*64=1024Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),nn.Linear(patch_dim, dim),    # (b,320,1024))self.to_img = nn.Sequential(# b c (h p1) (w p2) -> (b,64,64,80)      16*20=320  4*4*64=1024Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', \p1 = patch_height, p2 = patch_width, h = image_height // patch_height, w = image_width // patch_width),nn.Conv2d(64, 256, 1),      # (b,64,64,80) -> (b,256,64,80))# 位置编码self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))self.dropout = nn.Dropout(emb_dropout)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)def forward(self, img):x = self.conv1(img)                     # img 1 256 64 80 -> 1 64 64 80x = self.to_patch_embedding(x)          # 1 320 1024b, n, _ = x.shape                       # 1 320x += self.pos_embedding[:, :(n + 1)]    # (1,320,1024)x = self.dropout(x)                     # (1,320,1024)x = self.transformer(x)                 # (1,320,1024)x = self.to_img(x)return x                                # (1 256 64 80)if __name__ == '__main__':v = ViT(image_size = (256,64,80), patch_size = 4)img = torch.randn(1, 256, 64, 80)preds = v(img)         # (1, 256, 64, 80)print(preds.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/226999.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

自检服务器,无需服务器、不用编程。

自检服务器,无需服务器、不用编程。 大家好,我是JavaPub. 这几年自媒体原来热,很多人都知道了个人 IP 的重要性。连一个搞中医的朋友都要要做一个自己的网站,而且不想学编程、还不想花 RMB 租云服务。 老读者都知道&#xff0c…

索引的使用

一、索引是什么 索引是一种排序的表,它记录着索引字段的值以及对应行记录的数据所在的物理位置; ●索引是一个排序的列表,在这个列表中存储着索引的值和包含这个值的数据所在行的物理地址(类似于C语言的链表通过指针指向数据记录…

天擎终端安全管理系统clientinfobymid存在SQL注入漏洞

产品简介 奇安信天擎终端安全管理系统是面向政企单位推出的一体化终端安全产品解决方案。该产品集防病毒、终端安全管控、终端准入、终端审计、外设管控、EDR等功能于一体,兼容不同操作系统和计算平台,帮助客户实现平台一体化、功能一体化、数据一体化的…

SAP缓存 表缓存( Table Buffering)

本文主要介绍SAP中的表缓存在查询数据,更新数据时的工作情况以及对应概念。 SAP表缓存的工作 查询数据 更新数据 删除数据 表缓存的概念 表缓存技术设置属性 不允许缓冲: 允许缓冲,但已关闭: 缓冲已激活: 已…

Flask笔记

一:模板渲染 一般的话都序列化成字符串 二:项目拆分 2.1 项目拆分 app.py init.py views.py models.py 模型数据 2.2 蓝图 三:路由参数 3.1 String 重点 3.2 int 3.3 path 3.4 UUID 3.5 any 四:请求方式 五:Requ…

【经典算法】有趣的算法之---蚁群算法梳理

every blog every motto: You can do more than you think. 0. 前言 蚁群算法记录 1. 简介 蚁群算法(Ant Clony Optimization, ACO)是一种群智能算法,它是由一群无智能或有轻微智能的个体(Agent)通过相互协作而表现出智能行为,从而为求解复杂问题提供了一个新的可能性…

WPF 漂亮长方体、正文体简单实现方法 Path实现长方体 正方体方案 WPF快速实现长方体、正方体的方法源代码

这段XAML代码在WPF中实现了一个类似长方体视觉效果的图形 声明式绘制:通过Path、PathGeometry和PathFigure等元素组合,能够以声明方式精确描述长方体每个面的位置和形状,无需编写复杂的绘图逻辑,清晰直观。 层次结构与ZIndex控制…

机器学习距离度量方法

1. 机器学习中为什么要度量距离? 机器学习算法中,经常需要 判断两个样本之间是否相似 ,比如KNN,K-means,推荐算法中的协同过滤等等,常用的套路是 将相似的判断转换成距离的计算 ,距离近的样本相…

MySQL入门教程-触发器

9.触发器 什么是触发器 触发器(trigger):监视某种情况,并进行某种操作,它的执行并不是程序调用,也不是手工启动,而是由事件来触发,例如:对一张表进行操作(插入,更新&…

初识Java并发,一问读懂Java并发知识文集(3)

🏆作者简介,普修罗双战士,一直追求不断学习和成长,在技术的道路上持续探索和实践。 🏆多年互联网行业从业经验,历任核心研发工程师,项目技术负责人。 🎉欢迎 👍点赞✍评论…

全院级PACS系统源码,集成放射科管理RIS系统,支持多种图像处理及三维重建功能

PACS系统是医院影像科室中应用的一种系统,主要用于获取、传输、存档和处理医学影像。它通过各种接口,如模拟、DICOM和网络,以数字化的方式将各种医学影像,如核磁共振、CT扫描、超声波等保存起来,并在需要时能够快速调取…

【项目管理】CMMI-项目总结报告模版

1、文档目录结构 2、计划与实际情况对比 3、开放工作评价

【中小型企业网络实战案例 五】配置可靠性和负载分担

【中小型企业网络实战案例 三】配置DHCP动态分配地址-CSDN博客 【中小型企业网络实战案例 四】配置OSPF动态路由协议 【中小型企业网络实战案例 二】配置网络互连互通-CSDN博客 【中小型企业网络实战案例 一】规划、需求和基本配置_大小企业网络配置实例-CSDN博客 配置VRRP联…

Android MVVM 写法

前言 Model:负责数据逻辑 View:负责视图逻辑 ViewModel:负责业务逻辑 持有关系: 1、ViewModel 持有 View 2、ViewModel 持有 Model 3、Model 持有 ViewModel 辅助工具:DataBinding 执行流程:View &g…

共享单车之数据分析

文章目录 第1关:统计共享单车每天的平均使用时间第2关:统计共享单车在指定地点的每天平均次数第3关:统计共享单车指定车辆每次使用的空闲平均时间第4关:统计指定时间共享单车使用次数第5关:统计共享单车线路流量 第1关…

关于log4j的那些坑

背景:工程中同时存在log4j.xml&log4j2.xml maven依赖如下: 此时工程实际使用的日志文件为log4j.xml 1、当同时设置log4j和log4j2的桥接依赖时 maven依赖如下: 此时启动会有警告日志: 点击告警日志链接:https://…

基于OpenAI的Whisper构建的高效语音识别模型:faster-whisper

1 faster-whisper介绍 faster-whisper是基于OpenAI的Whisper模型的高效实现,它利用CTranslate2,一个专为Transformer模型设计的快速推理引擎。这种实现不仅提高了语音识别的速度,还优化了内存使用效率。faster-whisper的核心优势在于其能够在…

YOLOv8改进 | 主干篇 | EfficientNetV1均衡缩放网络改进特征提取层

一、本文介绍 这次给大家带来的改进机制是EfficientNetV1主干,用其替换我们YOLOv8的特征提取网络,其主要思想是通过均衡地缩放网络的深度、宽度和分辨率,以提高卷积神经网络的性能。这种方法采用了一个简单但有效的复合系数,统一…

mac安装k8s环境

安装kubectl brew install kubectl 确认一下安装的版本 kubectl version --client 如果想在本地运行kubernetes 需要安装minikube brew install minikube 需要注意安装minikube需要本地的docker服务是启动的 启动 默认连接的是google的仓库 minikube start 指定阿…

HTML实战演练之贪吃蛇美食大作战

导入: 一 :粉丝要求 今天一位小伙伴私信我说,想玩HTML贪吃蛇美食大作战,自己也是学HTML的,希望我能安排一下,那么好它来了 需知: 一:别着急先看需要知道的 要用HTML开发贪吃蛇美食…