解决Vision Transformer在任意尺寸图像上微调的问题:使用timm库

解决Vision Transformer在任意尺寸图像上微调的问题:使用timm库

文章目录

          • 一、ViT的微调问题的本质
          • 二、Positional Embedding如何处理
            • 1,绝对位置编码
            • 2,相对位置编码
            • 3,对位置编码进行插值
          • 三、Patch Embedding Layer如何处理
          • 四、使用timm库来对任意尺寸进行微调

一、ViT的微调问题的本质

自从ViT被提出以来,在CV领域引起了新的研究热潮。理论上来说,Transformer的输入是一个序列,并且其参数主要来自于Transformer Block中的Linear层,因此Transformer可以处理任意长度的输入序列。但是在Vision Transformer中,由于需要将二维的图像通过Patch Embedding Layer映射为一个一维的序列,并且需要添加pos_embedding来保留位置信息。因此当patch_size和img_size发生改变时,会造成pos_embbeding的长度和Patch Embedding Layer的参数发生改变,从而导致预训练权重无法直接加载。更多有关ViT的实现细节和原理,可以参考Vision Transformer , 通用 Vision Backbone 超详细解读 (原理分析+代码解读)。

二、Positional Embedding如何处理

在Vision Transformer中有两种主流的编码方式:相对位置编码和绝对位置编码。

1,绝对位置编码

绝对位置编码依据token每个的绝对位置分配一个固定的值,其本质上是一组一维向量,有两种实现方式:

# 可学习的位置编码,ViT中使用, +1是因为有cls_tokenself.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))# 根据正余弦获取位置编码, Transformer中使用
def get_positional_embeddings(sequence_length,dim):result = torch.ones(sequence_length,dim)for i in range(sequence_length):for j in range(dim):result[i][j] = np.sin(i/(10000**(j/dim))) if j %2==0 else np.cos(i/(10000**((j-1)/dim)))return result

在forward过程中,绝对位置编码会在最开始直接和token相加:

	tokens += self.pos_embedding[:, :(n + 1)]
2,相对位置编码

相对位置编码,依据每个token的query相对于key的位置来分配位置编码,典型例子就是swin transformer,其本质是构建一个可学习的二维table,然后依据相对位置索引(x,y)来从table中取值,具体可以参考:有关swin transformer相对位置编码的理解

不过,在swin transformer中,query和key都是来自于同一个window,因此query和key的数量相同,构建位置编码的方式相对来说比较简单。如果query和key的数量不同,例如Focal Transformer中多层次的self-attention,其位置编码的方式可以参考:Focal Transformer。

对于相对位置编码的构造,还有一种方式是CrossFormer中提出的Dynamic Position Bias。其核心思想为构建一个MLP,其输入是二维的相对位置索引,输出是指定dim的位置偏置。这个和根据正余弦获取位置编码有点类似,只不过一个是依据一维的绝对坐标来生成位置编码,一个是依据二维的相对坐标来生成位置编码。

image-20231122172434141

在forward过程中,相对位置编码不会在一开始与token相加,而是在Attention Layer中以Bias的形式参与self-attention计算,核心代码如下:

        attn = (q @ k.transpose(-2, -1))relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nHrelative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Wwattn = attn + relative_position_bias.unsqueeze(0)
3,对位置编码进行插值

综上,我们可以依据实现方式将位置编码分为两大类:可学习的位置编码(例如,ViT、Swin Transformer、Focal Transformer等)和生成式的位置编码(例如,正余弦位置编码和CrossFormer中的DPB)。更多有关位置编码的内容,可以参考论文:Rethinking and Improving Relative Position Encoding for Vision Transformer。对于生成式的位置编码而言,其编码方式与序列长度无关,因此当patch_size和img_size改变而造成num_patches改变时,仍然可以加载与位置编码有关的预训练权重。

但是,对于可学习的位置编码而言,num_patches改变时,无法直接加载与位置编码的预训练权重。以ViT为例,其参数一般是一个shape为[N+1, C]的tensor。与cls_token有关的位置编码不用改变,我们只需要关心与img patch相关的位置编码即可,其shape为[N, C]。当num_patches变为n时,所需要位置编码shape为[n, C]。这显然无法直接加载预训练权重。

Pytorch官方提供了一种思路,通过插值算法,来获取新的权重。我们不妨将原始的位置编码想象为一个shape为[ N , N , C \sqrt{N}, \sqrt{N}, C N ,N ,C]的tensor,将所需要的位置编码想象为一个shape为[ n , n , C \sqrt{n}, \sqrt{n}, C n ,n ,C]。这样我们就可以通过插值算法,将原始的权重映射到所需要的权重上。核心代码如下:

# (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)new_seq_length_1d = image_size // patch_size# Perform interpolation.# (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)new_pos_embedding_img = nn.functional.interpolate(pos_embedding_img,size=new_seq_length_1d,mode=interpolation_mode,align_corners=True,)# (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)# (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)

不过,Pytorch官方的这个代码,只能适配当num_patches是一个完全平方数的情况,因为需要开根号操作。实际上,num_patches一般是通过如下方式计算获得,理论上来说通过插值算法是可以适配到任意尺寸的num_patches的。

n u m _ p a t c h e s = i m g _ s i z e h p a t c h _ s i z e h i m g _ s i z e w p a t c h _ s i z e w (1) num\_patches=\frac{img\_size_h}{patch\_size_h}\frac{img\_size_w}{patch\_size_w} \tag{1} num_patches=patch_sizehimg_sizehpatch_sizewimg_sizew(1)

从上式可以看出,pos_embedding主要与img_size/patch_size有关,因此当把img_size和patch_size等比例缩放时,是不需要调整pos_embedding的。

在timm库中,提供了resample_abs_pos_embed函数,并将其集成到了VisionTransformer类中,所以我们在使用时无需自己考虑对位置编码进行插值处理。

三、Patch Embedding Layer如何处理

Patch Embedding Layer用于将二维的图像转为一维的输入序列,其实现方式通常有两种,如下所示:

### 基于MLP的实现方式patch_dim = in_channels * patch_height * patch_widthself.patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), # 使用einops库nn.LayerNorm(patch_dim),nn.Linear(patch_dim, dim),nn.LayerNorm(dim),)### 基于Conv2d的实现方式self.patch_embedding = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)

从这两种实现可以看出,Patch Embedding Layer的参数主要与patch_size和in_channels有关,而与img_size无关。Pytorch官方和Timm库都采用基于Conv2d的方式来实现,当patch_size和in_channels改变时,无法直接加载预训练权重。

Pytorch官方并未给出解决方案,timm库通过resample_patch_embed来解决这一问题,并且也集成到了VisionTransformer类中。在使用时,我们也不需要考虑手动对Patch Embedding Layer的权重进行调整。

四、使用timm库来对任意尺寸进行微调

首先需要安装timm库

pip install timm
# 如果安装的Pytorch2.0及以上版本,无需考虑一下步骤
# 如果是其他版本的Pytorch,需要下载functorch库
pip install functorch==版本号
# 具体版本号,需要依据自己环境中的pytorch版本来
# 例如:0.20.0对应Pytorch1.12.0,0.2.1对应Pytorch1.12.1
# 对应关系可以去github上查看:https://github.com/pytorch/functorch/releases

代码示例如下:

import timm
from timm.models.registry import register_model@register_model # 注册模型
def vit_tiny_patch4_64(pretrained: bool = False, **kwargs) -> VisionTransformer:""" ViT-Tiny (Vit-Ti/16)"""# 在model_args中对需要部分参数进行修改,此处调整了img_size, patch_size和in_chansmodel_args = dict(img_size = 64, patch_size=4, in_chans=1, embed_dim=192, depth=12, num_heads=3) # vit_tiny_patch16_224是想要加载的预训练权重对应的模型model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model# 注册模型之后,就可以通过create_model来创建模型了
vit = timm.create_model('vit_tiny_patch4_64', pretrained = True) 

不过,由于预训练权重在线下载一般比较慢,可以通过pretrained_cfg来实现加载本地模型,代码如下:

    vit = timm.create_model('vit_tiny_patch4_64')cfg = vit.default_cfgprint(cfg['url']) # 查看下载的url来手动下载cfg['file'] = 'vit-tiny.npz' # 这里修改为你下载的模型vit = timm.create_model('vit_tiny_patch4_64', pretrained=True, pretrained_cfg=cfg).cuda()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/201213.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【0到1学习Unity脚本编程】第一人称视角的角色控制器

👨‍💻个人主页:元宇宙-秩沅 👨‍💻 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍💻 本文由 秩沅 原创 👨‍💻 收录于专栏:【0…

7.Gin 路由详解 - 路由分组 - 路由文件抽离

7.Gin 路由详解 - 路由分组 - 路由文件抽离 前言 在前面的示例中,我们直接将路由的定义全部写在 main.go 文件中,如果后面 路由越来越多,那将会越来越不好管理。 所以,下一步我们应该考虑将路由进行分组管理,并且将其抽…

腾讯云代金券怎么领取(腾讯云代金券在哪领取)

腾讯云代金券是可抵扣费用的优惠券,领券之后新购、续费、升级腾讯云相关云产品可以直接抵扣订单金额,节省购买腾讯云的费用,本文将详细介绍腾讯云代金券的领取方法和使用教程。 一、腾讯云代金券领取 1、新用户代金券【点此领取】 2、老用户…

SVN创建分支

一 从本地创建方式可指定版本号进行分支创建。 1、在本地目录右击 -----> 点击branch/tag(分支/标签) From: 源,可指定具体的版本号, To path: 可通过"..."选择分支路径 最后点击确定,交由服务器执行创建。 二 通过SVN客…

3D 纹理渲染如何帮助设计师有效、清晰地表达设计理念

在线工具推荐: 三维数字孪生场景工具 - GLTF/GLB在线编辑器 - Three.js AI自动纹理化开发 - YOLO 虚幻合成数据生成器 - 3D模型在线转换 - 3D模型预览图生成服务 定义 3D 渲染可视化及其用途 3D 可视化是一种艺术形式。这是一个机会。这是进步。借助 3D 纹理…

CNVD-2023-12632:泛微E-cology9 browserjsp SQL注入漏洞复现 [附POC]

文章目录 泛微E-cology9 browserjsp SQL注入漏洞(CNVD-2023-12632)漏洞复现 [附POC]0x01 前言0x02 漏洞描述0x03 影响版本0x04 漏洞环境0x05 漏洞复现1.访问漏洞环境2.构造POC3.复现 0x06 修复建议 泛微E-cology9 browserjsp SQL注入漏洞(CNVD-2023-12632)漏洞复现 [附POC] 0x…

visionOS空间计算实战开发教程Day 1:环境安装和编写第一个程序

安装 截至目前visionOS还未在Xcode稳定版中开放,所以需要下载​​Xcode Beta版​​。比如我们可以下载Xcode 15.1 beta 2,注意Xcode 15要求系统的版本是macOS Ventura 13.5或更新,也就是说2017年的MacBook Pro基本可以勉强一战,基…

海外IP代理:数据中心代理IP是什么?好用吗?

数据中心代理是代理IP中最常见的类型,也被称为机房IP。这些代理服务器为用户分配不属于 ISP(互联网服务提供商)而来自第三方云服务提供商的 IP 地址。数据中心代理的最大优势——它们允许在访问网络时完全匿名。 如果你正在寻找海外代理IP&am…

【JavaEE】Servlet实战案例:表白墙网页实现

一、功能展示 输入信息: 点击提交: 二、设计要点 2.1 明确前后端交互接口 🚓接口一:当用户打开页面的时候需要从服务器加载已经提交过的表白数据 🚓接口二:当用户新增一个表白的时候,…

MEMS制造的基本工艺——晶圆键合工艺

晶圆键合是一种晶圆级封装技术,用于制造微机电系统 (MEMS)、纳米机电系统 (NEMS)、微电子学和光电子学,确保机械稳定和气密密封。用于 MEMS/NEMS 的晶圆直径范围为 100 毫米至 200 毫米(4 英寸至 8 英寸),用于生产微电…

单链表(数据结构与算法)

✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅ ✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨ 🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿&#x1…

PostgreSQL导出表结构带注释

我们在平时开发过程中,经常会在字段的注释中,加上中文,解释字段的相关含义,也可以避免时间太久忘记这个字段代表什么,毕竟英文水平不好。我们可能要经常整理数据库表结构,提供他人去收集数据,但…

【SpringBoot】通过profiles设置环境

效果图&#xff0c;通过profiles设置环境 在父级pom.xml中添加配置 <profiles><profile><id>dev</id><properties><application.environment>dev</application.environment></properties><activation><activeByDefau…

前端数组方法汇总集锦

前言 数组主要使用场景有&#xff1a; 存储和处理数据&#xff1a;数组是一种有序的数据结构&#xff0c;可以用来存储和处理多个相关的数据。在前端开发中&#xff0c;我们经常使用数组来存储和处理列表、表格、选项等数据。 循环和遍历&#xff1a;数组提供了循环和遍历的功能…

DependencyProperty.Register:wpf 向别的xaml传递参数

一.使用背景&#xff1a;在A.xaml中嵌入B.xaml&#xff0c;并且向B.xaml传递参数。 函数介绍&#xff1a; public static DependencyProperty Register(string name, Type propertyType, Type ownerType );name&#xff08;string&#xff09;&#xff1a; 依赖属性的名称。在…

CmakeLists编译的动态库.so移动到其他位置后,提示找不到该库的依赖库解决办法

主要问题&#xff1a; 最近在搞海康SDK调用相机&#xff0c;发现在linux下一直调用不起来相机&#xff0c;总是提示error code&#xff1a;29&#xff0c;注册失败&#xff0c;重新编译优惠存在找不到依赖库的问题。 1.异常 CmakeLists编译的动态库.so移动到其他位置后&#…

【擎标】CCID信息系统服务商交付能力等级认证标准

为顺应信息技术服务业发展趋势及市场需求&#xff0c;维护市场秩序&#xff0c;加强行业自律&#xff0c;促进信息系统服务商交付能力的不断提高&#xff0c;增强信息系统服务商创新能力和国际竞争力&#xff0c;支撑信息系统服务商转型提升&#xff0c;中国软件行业协会、企业…

Python (十二) 文件

程序员的公众号&#xff1a;源1024&#xff0c;获取更多资料&#xff0c;无加密无套路&#xff01; 最近整理了一份大厂面试资料《史上最全大厂面试题》&#xff0c;Springboot、微服务、算法、数据结构、Zookeeper、Mybatis、Dubbo、linux、Kafka、Elasticsearch、数据库等等 …

CCFCSP试题编号:201912-2试题名称:回收站选址

这题只要比较坐标的四周&#xff0c;然后计数就可以了。 #include <iostream> using namespace std;int main() {int n;cin >> n;int arr[1005][2] { 0 };int res[5] { 0 };int up 0;int down 0;int left 0;int right 0;int score 0;for (int i 0; i <…