CLIP微调方法总结

文章目录

  • 前言
  • 1️⃣ Tip-Adapter
    • 论文和源码
    • 原理介绍
  • 2️⃣Cross-modal Adaptation(跨模态适应)
    • 论文和源码
    • 原理介绍
  • 3️⃣ FD-Align(Feature Discrimination Alignment,特征判别对齐)
    • 论文和源码
    • 原理介绍
  • 总结


前言

在这里插入图片描述

本文主要介绍和总结了三种不错的 C L I P CLIP CLIP微调方法,包括原理和思想,并且按照自己的理解给出了相应的代码实现,相当于是一个简化版的code实现。
所有代码使用 j i t t o r jittor jittor框架实现,具体代码请请参考👇

Gitlink-Code 或者 Github-Code


1️⃣ Tip-Adapter

论文和源码

🔥 论文地址
🚀 代码地址

原理介绍

  • 本质上就是在 C L I P CLIP CLIP的预测结果 X X X上又加上了一个预测结果 Y Y Y,我们都知道结果 X X X是测试图像和所有分类文本的相似度之间的关系,而 Y Y Y就是测试图像和训练 C L I P CLIP CLIP时的训练图像之间的相似度关系,最终将 X X X Y Y Y加权求和便得到最终的预测结果,所以可以发现他的优势在于: Z e r o − s h o t t r a n s f e r (无需额外训练) Zero-shot\ transfer(无需额外训练 ) Zeroshot transfer(无需额外训练)

  • 下面结合论文给的框架图就能很好理解这个方法(每个变量后面标出了 s h a p e shape shape大小,方便理解):

    T i p − A d a p t e r Tip-Adapter TipAdapter添加之前:假设分类类别数目是 N N N W c T W_{c}^{T} WcT N N N个文本标签经过 C L I P CLIP CLIP T e s t E n c o d e r Test\ Encoder Test Encoder得到的文本特征,大小 N × 512 N×512 N×512
    输入一张测试图像 I t e s t I_{test} Itest → 经过 C L I P 模型的 V i s u a l E n c o d e r 之后 \xrightarrow{经过CLIP模型的Visual\ Encoder之后} 经过CLIP模型的Visual Encoder之后 得到 f t e s t : 1 × 512 f_{test} :1×512 ftest1×512 → 和 C L I P 的 T e s t F e a t u r e s 作相似度,也就是图中的 f t e s t ∗ W c T \xrightarrow{和CLIP的Test\ Features作相似度,也就是图中的f_{test}*W_{c}^{T}} CLIPTest Features作相似度,也就是图中的ftestWcT 得到分类结果(实际上就是和所有文本标签的相似度) X : 1 × N X:1×N X:1×N

    T i p − A d a p t e r Tip-Adapter TipAdapter添加之后:
    上面步骤同样完全相同,得到 X X X
    首先将所有的训练图像 I K I_{K} IK(假设共有 M M M张, M = C × N M=C×N M=C×N C C C是一个系数,因为训练时一般每个类别的图像会有多张) → 同样经过 C L I P 模型的 V i s u a l E n c o d e r \xrightarrow{同样经过CLIP模型的Visual\ Encoder} 同样经过CLIP模型的Visual Encoder 得到 F t r a i n : M × 512 F_{train}:M×512 FtrainM×512 ,并作为缓存模型( c a c h e m o d e l cache\ model cache model)的 k e y key key
    然后将所有训练图像的文本标签经过 O n e H o t One\ Hot One Hot处理,得到 L t r a i n : M × N L_{train}:M×N LtrainM×N,并作为缓存模型的 v a l u e value value;到此便构建了一个缓存模型,相当于多了一份存储有训练样本特征的先验信息。
    接着将之前得到的 f t e s t : 1 × 512 f_{test} :1×512 ftest1×512 → 送入 c a c h e m o d e l , 计算和训练图像之间的特征余弦相似度 \xrightarrow{送入cache\ model,计算和训练图像之间的特征余弦相似度} 送入cache model,计算和训练图像之间的特征余弦相似度 得到 A = e x p ( − β ( 1 − f t e s t F t r a i n T ) ) : 1 × M A=exp(-\beta(1-f_{test}F_{train}^{T})):1×M A=exp(β(1ftestFtrainT)):1×M → 和 c a c h e m o d e l 的 v a l u e s 相乘,得到预测结果 Y \xrightarrow{和cache\ model的values相乘,得到预测结果Y} cache modelvalues相乘,得到预测结果Y Y = A L t r a i n : 1 × N Y=AL_{train}:1×N Y=ALtrain1×N
    最后将 T i p − A d a p t e r Tip-Adapter TipAdapter的预测结果 Y Y Y和原始 C L I P CLIP CLIP预测结果 X X X进行加权求和:
    logits = α A L train + f test W c T = α φ ( f t e s t F t r a i n T ) L t r a i n + f t e s t W c T , \begin{aligned} \begin{aligned} \text{logits}& =\alpha A\mathbf{L}_\text{train}+f_\text{test}W_c^T \\ &=\alpha\varphi(f_{\mathrm{test}}\mathbf{F}_{\mathrm{train}}^T)\mathbf{L}_{\mathrm{train}}+f_{\mathrm{test}}W_c^T, \end{aligned} \end{aligned} logits=αALtrain+ftestWcT=αφ(ftestFtrainT)Ltrain+ftestWcT,

在这里插入图片描述

2️⃣Cross-modal Adaptation(跨模态适应)

论文和源码

🔥 论文地址
🚀 代码地址

原理介绍

  • 原理图和伪代码在这里插入图片描述
    在这里插入图片描述
  • 该方法的核心思想就是将多种模态的信息融合在一起,并且论文假设 C L I P CLIP CLIP可以将不同模态的样本映射到同一个特征空间。比如对于文本-图像这种模态形式,在训练过程中,就可以引入这里的文本信息(也就是每个类别的标签),将其作为额外的训练样本,其实就是将每张图像的图像特征和文本特征视作同一个特征来进行训练。
  • 同上面一样,根据伪代码的内容,将维度变换显示出来也非常好理解整个实现过程:
    假设输入的 b a t c h _ s i z e batch\_size batch_size大小为 b b b,分类的类别数为 n u m _ c l a s s num\_class num_class

i m a g e _ e n c o d e r 输出的图像特征 i m _ f : b × 512 t e x t _ e n c o d e r 输出的文本特征 t x _ f : b × 512 在行维度上将两个特征拼接起来并归一化 f e a t u r e s : 2 b × 512 对应的标签也进行拼接 l a b e l s : 2 b × 512 将 f e a t u r e s 通过一个分类器得到每个类别的预测概率 l o g i t s : 2 b × n u m _ c l a s s 最后 l o g i t s 和 l a b e l s 之间作交叉熵损失,并更新分类器、图像编码器和文本编码器的参数 \begin{aligned} image\_encoder输出的图像特征 \quad im\_f:b×512\\ text\_encoder输出的文本特征 \quad tx\_f:b×512\\ 在行维度上将两个特征拼接起来并归一化\quad features:2b×512\\ 对应的标签也进行拼接\quad labels:2b×512\\ 将features通过一个分类器得到每个类别的预测概率 \quad logits:2b×num\_class\\ 最后logits和labels之间作交叉熵损失,并更新分类器、图像编码器和文本编码器的参数 \end{aligned} image_encoder输出的图像特征im_fb×512text_encoder输出的文本特征tx_fb×512在行维度上将两个特征拼接起来并归一化features2b×512对应的标签也进行拼接labels2b×512features通过一个分类器得到每个类别的预测概率logits:2b×num_class最后logitslabels之间作交叉熵损失,并更新分类器、图像编码器和文本编码器的参数

注意:在实现该代码进行训练的过程中发现如果按照伪代码中将cross_logits除以一个常量,loss反而会很难下降,相反乘上一个系数loss下降的更好一些。(直接loss=cross_entropy_loss(logits*3.0,labels)即可),否则loss值很难会下降。

在这里插入图片描述
在这里插入图片描述

3️⃣ FD-Align(Feature Discrimination Alignment,特征判别对齐)

论文和源码

🔥 论文地址
🚀 代码地址

原理介绍

  • 原理图:
    在这里插入图片描述
  • 论文中提出了一个概念:虚假关联性的鲁棒性,它指的是模型是否具有区分出样本中和类别相关信息(因果信息)以及(背景、风格等)类别无关信息(虚假信息)的能力。同时注意到全微调的CLIP的OOD性能会下降,因此提出了一种不影响模型对虚假特征识别能力的微调方法来保证微调后的模型对虚假关联性的鲁棒性。从模型框架图中看,实际上就是在微调的过程中通过约束微调后的CLIP模型和原始的CLIP模型对虚假特征的分布保持一致,从而在一定程度上避免微调过程中CLIP的OOD性能下降。
  • 该方法相对于前两个方法稍显复杂,先熟悉它定义的几个符号意义,再来结合框架图看一下它的整个模型原理:

首先假设存在一个小样本数据集 D ⊂ X × Y ,( X 表示图像, Y 表示标签) 有 M 个提示模板 ( P 1 , … , P M ) , C L I P 模型的 t e x t − e n c o d e r 和 i m a g e − e n c o d e r 分别表示为 g 0 和 f 0 ; 假设任意的一个类别 y ,那么 y 的原型表示为: μ y class  ,也被称为类的原型 首先假设存在一个小样本数据集D\subset X\times Y,(X表示图像,Y表示标签)\\ 有M个提示模板(P_1,\ldots,P_M),CLIP模型的text-encoder和image-encoder分别表示为g_{0}和f_{0};\\ 假设任意的一个类别y,那么y的原型表示为:\mu_y^\text{class },也被称为类的原型 首先假设存在一个小样本数据集DX×Y,(X表示图像,Y表示标签)M个提示模板(P1,,PM)CLIP模型的textencoderimageencoder分别表示为g0f0;假设任意的一个类别y,那么y的原型表示为:μyclass ,也被称为类的原型
μ y class  : = 1 M ∑ j = 1 M g 0 ( [ P j , y ] ) . \begin{aligned} \mu_y^\text{class }:=\frac{1}{M}\sum_{j=1}^Mg_0([P_j,y]). \end{aligned} μyclass :=M1j=1Mg0([Pj,y]).
因此第一个损失函数 L c l a s s \mathcal{L}_{\mathrm{class}} Lclass和clip模型中的损失函数本质上相同的,约束图像-文本之间的相似度,只不过这里的文本不在是单个的prompt,而是多个prompt取平均值得到的。
L class = − 1 ∣ D ∣ ∑ ( x i , y i ) ∈ D log ⁡ exp ⁡ ( s ( f t ( x i ) , μ y i class ) ) ∑ y ∈ Y exp ⁡ ( s ( f t ( x i ) , μ y class ) ) 其中, s ( : ) 表示余弦相似度 \begin{aligned} \mathcal{L}_{\text{class}}=-\frac{1}{|\mathcal{D}|}\sum_{(x_i,y_i)\in\mathcal{D}}\log\frac{\exp(s(f_t(x_i),\mu_{y_i}^{\text{class}}))}{\sum_{y\in\mathcal{Y}}\exp(s(f_t(x_i),\mu_y^{\text{class}}))}\\ 其中,s(:)表示余弦相似度 \end{aligned} Lclass=D1(xi,yi)DlogyYexp(s(ft(xi),μyclass))exp(s(ft(xi),μyiclass))其中,s(:)表示余弦相似度
紧接着,定义提示模板( p r o m p t )的原型:每个 P j 在所有类中的特征平均值,公式为: 紧接着,定义提示模板(prompt)的原型:每个P_{j}在所有类中的特征平均值,公式为: 紧接着,定义提示模板(prompt)的原型:每个Pj在所有类中的特征平均值,公式为:
μ P j spurious : = 1 ∣ Y ∣ ∑ y ∈ Y g 0 ( [ P j , y ] ) \begin{aligned} \mu_{P_j}^\text{spurious}:=\frac{1}{|\mathcal{Y}|}\sum_{y\in\mathcal{Y}}g_0([P_j,y]) \end{aligned} μPjspurious:=Y1yYg0([Pj,y]) 现在希望的是在微调过程中保持模型对虚假相关性的鲁棒性 , 即保持模型在微调前后提取的虚假特征不变。 所以需要知道模型在虚假特征上的分布——即将微调模型提取的特征与虚假原型之间的相似度定义为模型虚假特征的分布。 现在希望的是在微调过程中保持模型对虚假相关性的鲁棒性,即保持模型在微调前后提取的虚假特征不变。\\所以需要知道模型在虚假特征上的分布——即将微调模型提取的特征与虚假原型之间的相似度定义为模型虚假特征的分布。 现在希望的是在微调过程中保持模型对虚假相关性的鲁棒性,即保持模型在微调前后提取的虚假特征不变。所以需要知道模型在虚假特征上的分布——即将微调模型提取的特征与虚假原型之间的相似度定义为模型虚假特征的分布。

因此,计算由微调模型提取的特征和虚假原型之间的相似性,并且如下产生虚假特征的分布: 因此,计算由微调模型提取的特征和虚假原型之间的相似性,并且如下产生虚假特征的分布: 因此,计算由微调模型提取的特征和虚假原型之间的相似性,并且如下产生虚假特征的分布:
P spurious ( x ; f t ) = SoftMax [ s ( f t ( x ) , μ P 1 spurious ) , … , s ( f t ( x ) , μ P M spurious ) ] \begin{aligned} \mathcal{P}_\text{spurious}(x;f_t)=\text{SoftMax}\left[s\left(f_t(x),\mu_{P_1}^\text{spurious}\right),\ldots,s\left(f_t(x),\mu_{P_M}^\text{spurious}\right)\right] \end{aligned} Pspurious(x;ft)=SoftMax[s(ft(x),μP1spurious),,s(ft(x),μPMspurious)]
类似地,将 f t 换成 f 0 ,可以得到微调前模型的虚假特征分布: 类似地,将f_{t}换成f_{0},可以得到微调前模型的虚假特征分布: 类似地,将ft换成f0,可以得到微调前模型的虚假特征分布:
P spurious ( x ; f 0 ) = SoftMax [ s ( f 0 ( x ) , μ P 1 spurious ) , … , s ( f 0 ( x ) , μ P M spurious ) ] \begin{aligned} \mathcal{P}_{\text{spurious}}(x;f_0)=\text{SoftMax}\left[s\left(f_0(x),\mu_{P_1}^{\text{spurious}}\right),\ldots,s\left(f_0(x),\mu_{P_M}^{\text{spurious}}\right)\right] \end{aligned} Pspurious(x;f0)=SoftMax[s(f0(x),μP1spurious),,s(f0(x),μPMspurious)]

因此第二个损失函数的作用就是保持微调前后模型对虚假特征概率分布保持一致:
L spurious = 1 ∣ D ∣ ∑ ( x i , y i ) ∈ D KL ( P spurious ( x i ; f t ) ∣ ∣ P spurious ( x i ; f 0 ) ) \begin{aligned} \mathcal{L}_{\text{spurious}}=\frac{1}{|\mathcal{D}|}\sum_{(x_i,y_i)\in\mathcal{D}}\text{KL}\left(\mathcal{P}_{\text{spurious}}(x_i;f_t)\mid\mid\mathcal{P}_{\text{spurious}}(x_i;f_0)\right) \end{aligned} Lspurious=D1(xi,yi)DKL(Pspurious(xi;ft)∣∣Pspurious(xi;f0))
综上,最终的损失函数为:
L t o t a l = α ⋅ L c l a s s + β ⋅ L s p u r i o u s 论文中取 α = 1 , β = 20 \begin{aligned} \mathcal{L}_{\mathrm{total}}=\alpha\cdot\mathcal{L}_{\mathrm{class}}+\beta\cdot\mathcal{L}_{\mathrm{spurious}} \end{aligned}\\ 论文中取\alpha=1,\beta=20 Ltotal=αLclass+βLspurious论文中取α=1,β=20

更多细节的推导和更准确的表述请参考作者的原论文😀

总结

  • 本文介绍了三种CLIP微调方法的原理以及给出了对应的更加简化版代码实现,如果有问题的地方,欢迎评论区指正。
  • 三种方法相比较而言,Tip-Adapter最通用,无论是免训练版本还是训练版本,使用之后均有一定的提升效果;Cross-modal Adaptation思路最简单,但是要想有效果,尝试后发现需要针对自己的数据集不断调节参数大小;FD-Align方法在保持CLIP的zero-shot能力方面是几个方法当中最好的;
  • 觉得有帮助的话,给个赞吧👋👋👋

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/410553.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

USB3.2 摘录(11)

系列文章目录 USB3.2 摘录(一) USB3.2 摘录(二) USB3.2 摘录(三) USB3.2 摘录(四) USB3.2 摘录(五) USB3.2 摘录(六) USB3.2 摘录&…

IO进程day01(标准IO、缓存区)

目录 【1】标准IO 1》概念: 2》特点 【2】缓存区 1》全缓存:和文件相关 2》行缓存:和终端有关 3》不缓存:也就是没有缓存区,标准错误。 【1】标准IO 1》概念: 标准IO: 是在C库中定义的一…

C++ | Leetcode C++题解之第355题设计推特

题目&#xff1a; 题解&#xff1a; class Twitter {struct Node {// 哈希表存储关注人的 Idunordered_set<int> followee;// 用链表存储 tweetIdlist<int> tweet;};// getNewsFeed 检索的推文的上限以及 tweetId 的时间戳int recentMax, time;// tweetId 对应发送…

香港站群服务器优势

香港站群服务器因其独特的地理位置和网络连接优势&#xff0c;在SEO优化、网站群管理和网络营销等方面受到广泛关注。其优势主要体现在以下几个方面&#xff0c;rak小编为您整理发布。 地理位置优越 连接亚洲国际市场&#xff1a;香港作为亚太地区的重要经济中心&#xff0c;具…

代码随想录 刷题记录-18 动态规划(2)01背包问题、习题

一、01背包理论基础 例题&#xff1a;46. 携带研究材料 01 背包 有n件物品和一个最多能背重量为w 的背包。第i件物品的重量是weight[i]&#xff0c;得到的价值是value[i] 。每件物品只能用一次&#xff0c;求解将哪些物品装入背包里物品价值总和最大。 暴力解法&#xff1a…

SpringBoot实现Word转PDF/TXT

背景 研发工作中难免会遇到一些奇奇怪怪的需求&#xff0c;就比如最近&#xff0c;客户提了个新需求&#xff1a;上传一个WORD文档&#xff0c;要求通过系统把该文档转换成PDF和TXT。客户的需求是没得商量的&#xff0c;必须实现&#xff01;承载着客户的期望&#xff0c;我开始…

【计算机网络】应用层HTTP协议

我们已经实现过应用层协议&#xff0c;但也要看一看成熟的应用层协议 目录 1 HTTP协议11 URL12 urlencode 和 urldecode13 HTTP 协议请求与响应格式请求格式响应格式 14 界面的基本处理显示基本主页显示图片页面跳转 15 常见header16 状态码161 404举例162 关于3开头的状态码 1…

yd云手机登录算法分析

yd云手机登录算法分析 yd云手机登录算法分析第一步&#xff1a;抓包-登录第二步&#xff1a;定位加密入口第三步&#xff1a;分析加密算法第四步&#xff1a;算法实现 yd云手机登录算法分析 在这篇文章中&#xff0c;我们将详细解析yd云手机的登录算法&#xff0c;涵盖从抓包到…

96.SAP MII功能详解(09)Workbench-Transaction Debugging

目录 1.About Transaction Debugging Use Features Activities 2.How to Debug Start Debugging Create Breakpoint Watch Variables Debugging logs 1.About Transaction Debugging Use You use this function to monitor and manipulate a transaction while it …

微深节能 堆取料机回转俯仰角度检测系统 格雷母线定位系统

微深节能在堆取料机回转俯仰角度检测系统中引入的格雷母线定位系统&#xff0c;是一项重要的技术创新&#xff0c;显著提升了堆取料作业的自动化水平和精确性。以下是对该系统的详细介绍&#xff1a; 一、系统概述 格雷母线定位系统作为高精度、无磨损的非接触式位置检测系统&a…

07 - procfs

---- 整理自 王利涛老师 课程 实验环境&#xff1a;宅学部落 www.zhaixue.cc 文章目录 1. procfs 快速入门2. procfs 文件创建的回调机制3. 在 proc 目录下创建子目录4. 通过 proc 接口修改内核变量5. 通过 proc 接口访问数组6. 序列文件&#xff1a;seq_file 编程接口7. seq_f…

OpenCV绘图函数(1)绘制带箭头的直线函数arrowedLine()的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 绘制一个从第一个点指向第二个点的箭头线段。 cv::arrowedLine 函数在图像中绘制一个从 pt1 到 pt2 的箭头。另见 line 函数。 函数原型 void c…

基于单片机的无线空气质量检测系统设计

本设计以STC89C52单片机为核心&#xff0c;其中包含了温湿度检测模块、光照检测模块、PM2.5检测模块、报警电路、LCD显示屏显示电路、按键输入模块和无线传输模块来完成工作。首先&#xff0c;系统可以通过按键输入模块设置当前的时间和报警值&#xff1b;使用检测模块检测当前…

在Ubuntu 部署 Grafana且监控MySQL数据

一、安装 打开终端按顺序执行以下命令 1.添加 Grafana 的 APT 仓库&#xff1a; sudo apt-get install -y software-properties-common sudo add-apt-repository "deb https://packages.grafana.com/oss/deb stable main" 2.导入Grafana GPG key&#xff1a; wge…

吴光明铸就鱼跃辉煌,科技创新开辟医疗新篇章

在鱼跃集团的发展历程中&#xff0c;创始人吴光明为其树立了最鲜明的品牌标签——创新。吴光明始终坚信&#xff0c;“研发实力代表一个医疗器械企业的核心竞争力”。他很早就认识到&#xff0c;只有从用户需求出发进行创新&#xff0c;才能提升医疗产品的使用体验&#xff0c;…

软件设计原则之接口隔离原则

接口隔离原则&#xff08;Interface Segregation Principle, ISP&#xff09;是面向对象设计中的一个重要原则&#xff0c;它属于SOLID原则之一。这个原则强调客户端&#xff08;即接口的调用者&#xff09;不应该被迫依赖于它们不使用的方法。换句话说&#xff0c;一个类对另一…

SOA通信中间件介绍(一)

一、通信中间件 在软件定义汽车中&#xff0c;应用程序之间的跨进程或跨核通信是一个需要解决的问题。模块化架构为开发人员提供了便利&#xff0c;但也引入了对通信中间件的需求。 在没有使用通信中间件的情况下&#xff0c;开发人员需要自己定义数据的格式、发送方和接收方…

趣味呈现高效农业管理:智慧农场可视化

运用图扑自主研发的 HT 产品&#xff0c;全程零代码搭建 3D 轻量化 Low Poly 风格的智慧农场可视化&#xff0c;通过生动有趣的图形展示农场运作细节&#xff0c;使农业管理更直观易懂&#xff0c;提升管理效率和用户体验。

C++ 基础学习

提示并输入一个字符串&#xff0c;统计该字符串中字母个数、数字个数、空格个数、其他字符的个数 #include <iostream>using namespace std;int main() {cout<<"请输入字符串:";string str;getline(cin,str);int num0;int alp0;int spa0;int other0;int …

网络安全面试经验分享:蘑菇街/网络安全

《网安面试指南》http://mp.weixin.qq.com/s?__bizMzkwNjY1Mzc0Nw&mid2247484339&idx1&sn356300f169de74e7a778b04bfbbbd0ab&chksmc0e47aeff793f3f9a5f7abcfa57695e8944e52bca2de2c7a3eb1aecb3c1e6b9cb6abe509d51f&scene21#wechat_redirect 蘑菇街 介绍…