CLIP模型原理

CLIP模型

CLIP(Contrastive Language-Image Pre-Training) 模型是 OpenAI 在 2021 年初发布的用于匹配图像文本的预训练神经网络模型,是近年来在多模态研究领域的经典之作。OpenAI 收集了 4 亿对图像 - 文本对(一张图像和它对应的文本描述),分别将图像和文本进行编码,使用 metric learning进行训练。希望通过对比学习,模型能够学习到图像 - 文本对的匹配关系。

CLIP的论文地址

CLIP模型共有3个阶段:1阶段用作训练,2、3阶段用作推理。

  1. Contrastive pre-training:预训练阶段,使用图片 - 文本对进行对比学习训练;
  2. Create dataset classifier from label text:提取预测类别文本特征;
  3. Use for zero-shot predictiion:进行 zero-shot 推理预测;

在这里插入图片描述

1、训练阶段

通过计算目标图像和对应文本描述的余弦相似度从而获取预测值。CLIP第一阶段主要包含以下两个子模型;

  • Image Encoder:用来提取图像的特征,可以采用常用CNN模型或者vision transformer模型;(视觉模型)
  • Text Encoder:用来提取文本的特征,可以采用NLP中常用的text transformer模型;(文本模型)

在这里插入图片描述
这里举例一个包含N个文本-图像对的训练batch,对提取的文本特征和图像特征进行训练的过程:

  1. 输入图片 —> 图像编码器 —> 图片特征向量;输入文字 —> 文字编码器 —> 文字特征向量;并进行线性投射,得到相同维度;
  2. N个文本特征和N个图像特征两两组合,形成一个具有N2个元素的矩阵;
  3. CLIP模型会预测计算出这N2个文本-图像对的相似度(文本特征和图像特征的余弦相似性即为相似度);
  4. 对角线上的N个元素因为图像-标签对应正确被作为训练的正样本,剩下的N2-N个元素作为负样本;
  5. CLIP的训练目标就是最大化N个正样本的相似度,同时最小化N2-N个负样本的相似度;

2、推理过程

CLIP的预测推理过程主要有以下两步:

  1. 提取预测类别的文本特征:由于CLIP 预训练文本端的输出输入都是句子,因此需要将任务的分类标签按照提示模板 (prompt template)构造成描述文本(由单词构造成句子):A photo of {object}.,然后再送入Text Encoder得到对应的文本特征。如果预测类别的数目为N,那么将得到N个文本特征。
  2. 进行 zero-shot 推理预测:将要预测的图像送入Image Encoder得到图像特征,然后与上述的N个文本特征计算余弦相似度(和训练过程一致),然后选择相似度最大的文本对应的类别作为图像分类预测结果。进一步地,可以将这些相似度看成输入,送入softmax后可以得到每个类别的预测概率。

在这里插入图片描述

3、补充:zero-shot 零样本学习

zero-shot :零样本学习,域外泛化问题。利用训练集数据训练模型,使得模型能够对测试集的对象进行分类,但是训练集类别和测试集类别之间没有交集,期间需要借助类别的描述,来建立训练集和测试集之间的联系,从而使得模型有效。

在计算机视觉中,即便想迁移VGGMobileNet这种预训练模型,也需要新数据经过预训练、微调等手段,才能学习新数据集所持有的数据特征,CLIP可以直接实现zero-shot的图像分类,即:不需要训练任何新数据,就能在某个具体下游任务上实现分类,这也是CLIP亮点和强大之处。

我的猜测:CLIPzero-shot能力是依赖于它预训练的4亿对图像-文本对,样本空间非常大,下游任务的类别也不过是CLIP样本空间的子集,并不是真正的零样本学习,和解决域外泛化问题。和人脸比对的原理相似,依靠大量样本来学习分类对象的特征空间,区别在于人脸比对是image-to-imageCLIPimage-to-text

4、代码: CLIP实现zero-shot分类

OpenAI有关CLIP的代码链接地址

环境:

pip install ftfy regex tqdm
pip install git+https://github.com/openai/CLIP.gitTorch version: 1.9.0+cu102

4.1、模型加载

import clipclip.available_models()model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_sizeprint("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

4.2、图像、文本数据处理

向模型提供8个示例图像及其文本描述,并比较相应特征之间的相似性

# images in skimage to use and their textual descriptions
descriptions = {"page": "a page of text about segmentation","chelsea": "a facial photo of a tabby cat","astronaut": "a portrait of an astronaut with the American flag","rocket": "a rocket standing on a launchpad","motorcycle_right": "a red motorcycle standing in a garage","camera": "a person looking at a camera on a tripod","horse": "a black-and-white silhouette of a horse", "coffee": "a cup of coffee on a saucer"
}

在这里插入图片描述

4.3、建立图片特征

对图像进行归一化,对每个文本输入进行标记,并运行模型的前向传递以获得图像和文本特征

image_input = torch.tensor(np.stack(images)).cuda()
text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()with torch.no_grad():image_features = model.encode_image(image_input).float()text_features = model.encode_text(text_tokens).float()

4.4、计算余弦相似度

image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().Tcount = len(descriptions)plt.figure(figsize=(20, 14))
plt.imshow(similarity, vmin=0.1, vmax=0.3)
# plt.colorbar()
plt.yticks(range(count), texts, fontsize=18)
plt.xticks([])
for i, image in enumerate(original_images):plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
for x in range(similarity.shape[1]):for y in range(similarity.shape[0]):plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)for side in ["left", "top", "right", "bottom"]:plt.gca().spines[side].set_visible(False)plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])plt.title("Cosine similarity between text and image features", size=20)

在这里插入图片描述

4.5、Zero-Shot图像分类

from torchvision.datasets import CIFAR100cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()with torch.no_grad():text_features = model.encode_text(text_tokens).float()text_features /= text_features.norm(dim=-1, keepdim=True)text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)plt.figure(figsize=(16, 16))for i, image in enumerate(original_images):plt.subplot(4, 4, 2 * i + 1)plt.imshow(image)plt.axis("off")plt.subplot(4, 4, 2 * i + 2)y = np.arange(top_probs.shape[-1])plt.grid()plt.barh(y, top_probs[i])plt.gca().invert_yaxis()plt.gca().set_axisbelow(True)plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])plt.xlabel("probability")plt.subplots_adjust(wspace=0.5)
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/166697.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

shell的for循环与结构化

shell笔记 列表for循环不带列表for循环for循环举例1.例1 所有文件名大写替换为小写2. 例2 读取/etc/passwd文件,依次输出ip段3. 例3 读取/etc/hosts内容for循环,执行ping4. 例4 循环ip列表,输出对应编号5. 例5 批量添加用户 break1. 例1 brea…

FPGA project : IIC_wr_eeprom

简介: 简单双向二线制,同步串行总线。 scl:串行时钟线,用于同步通讯数据。 sda:双向串行数据线。 物理层: 1,支持挂载多设备。 2,二线制。 3,每个设备有其单独的地…

安装visual studio报错“无法安装msodbcsql“

在安装visual studio2022时安装完成后提示无法安装msodbcsql, 查看日志文件详细信息提示:指定账户已存在。 未能安装包“msodbcsql,version17.2.30929.1,chipx64,languagezh-CN”。 搜索 URL https://aka.ms/VSSetupErrorReports?qPackageIdmsodbcsql;PackageActi…

分布式缓存Spring Cache

一、缓存里的数据如何和数据库的数据保持一致? 缓存数据一致性1)、双写模式2)、失效模式1、缓存数据一致性-双写模式 2、 缓存数据一致性-失效模式 我们系统的一致性解决方案: 1、缓存的所有数据都有过期时间,数据过期下一次查询触发主动更新 2、读写数据…

Android 10 中的隐私权变更

Android 10 中的隐私权变更 重大变更外部存储访问权限范围限定为应用文件和媒体在后台运行时访问设备位置信息需要权限以 Android 9 或更低版本为目标平台时自动授予访问权限在设备升级到 Android 10 后访问针对从后台启动 Activity 的限制标识符和数据移除了联系人亲密程度信息…

JIT耗时优化

优质博文:IT-BLOG-CN 一、背景 业务流量突增,机器直接接入大量流量QPS2000,JIT和GC会消耗太多CPU资源,导致1-2分钟时间内的请求超时导致异常,因此采用流量预热的方式,让机器逐步接入流量,需要预…

go语言Array 与 Slice

有的语言会把数组用作常用的基本的数据结构,比如 JavaScript,而 Golang 中的数组(Array),更倾向定位于一种底层的数据结构,记录的是一段连续的内存空间数据。但是在 Go 语言中平时直接用数组的时候不多,大多数场景下我…

【Lua语法】字符串

Lua语言中的字符串是不可变值。不能像在C语言中那样直接改变某个字符串中的某个字符,但是可以通过创建一个新字符串的方式来达到修改的目的 print(add2(1 , 2 ,15,3))a "no one"b string.gsub(a , "no" , "on1111")print(a) print…

微软正式发布开源应用平台 Radius平台

“ 10 月 18 日,微软 Azure 孵化团队正式发布开源应用平台 Radius,该平台将应用程序置于每个开发阶段的中心,重新定义应用程序的构建、管理与理解方式。” 简单的概括就是,它和Kubernetes不一样,Radius将应用程序放在每…

C语言--程序环境和预处理

前言 本章就是c语言的最后一个板块了,学完这章节,我们将知道写出的代码如何变成可执行程序的,这是非常重要的一个章节,那让我们一起进入本章的学习吧。 本章重点: 程序的翻译环境程序的执行环境详解:C语言程…

周立功ZCANPRO简介和使用

ZCANPRO目录 周立功ZCANPRO简介一、软件安装ZCANPRO官网链接:驱动官网链接 二、ZCANPRO使用1.设备管理2.选择CAN、CANFD波特率计算器使用方法(可选) 3.新建视图CAN视图DBC视图 4.发送数据普通发送DBC发送 三、高级功能UDS诊断 周立功ZCANPRO简…

【java爬虫】使用selenium获取某交易所公司半年报数据

引言 上市公司的财报数据一般都会进行公开,我们可以在某交易所的官方网站上查看这些数据,由于数据很多,如果只是手动收集的话可能会比较耗时耗力,我们可以采用爬虫的方法进行数据的获取。 本文就介绍采用selenium框架进行公司财…

HTML选项框的设计以及根据不同选项的值对应不同的事件

文章目录 HTML选项框的设计JS根据不同的选项框对应出不同的事件 HTML选项框的设计 在前端页面的设计中&#xff0c;多选框的设计用select标签完成实现 全部选项都显示的选项框 <form><select multiple"multiple"><option></option><opti…

视频怎么压缩?视频过大这样压缩变小

在日常生活中&#xff0c;我们常常会遇到需要压缩视频的情况&#xff0c;视频压缩不仅可以减小文件大小&#xff0c;方便存储和传输&#xff0c;还可以在保证质量的同时&#xff0c;满足不同的使用需求。那么&#xff0c;如何有效地压缩视频呢&#xff1f; 方法一&#xff1a;嗨…

web APIs——第一天(上)

变量声明的时候建议 const优先&#xff0c;尽量使用const 原因&#xff1a; const语义化更好很多变量我们声明的时候就知道他不会被更改了&#xff0c;那为什么不用const呢&#xff1f;实际开发中也是&#xff0c;比如react框架&#xff0c;基本const如果你有纠结的时候&…

python中的yolov5结合PyQt5,使用QT designer设计界面没正确启动的解决方法

python中的yolov5结合PyQt5&#xff0c;使用QT designer设计界面没正确启动的解决方法 一、窗体设计test: 默认你已经设计好了窗体后&#xff1a; 这时你需要的是保存生成的untitle.ui到某个文件夹下&#xff0c;然后在命令行中奖.ui转换为.py&#xff08;&#xff0c;通过​​…

css之Flex弹性布局

文章目录 &#x1f415;前言&#xff1a;&#x1f3e8;定义flex容器 display:flex&#x1f3e8;在flex容器中子组件进行排列&#x1fa82;行排列 flex-direction: row&#x1fa82;将行排列进行翻转排列 flex-direction: row-reverse&#x1f3c5;按列排列 flex-direction: col…

2020年亚太杯APMCM数学建模大赛B题美国总统的经济影响分析求解全过程文档及程序

2020年亚太杯APMCM数学建模大赛 B题 美国总统的经济影响分析 原题再现&#xff1a; 美国总统选举每四年举行一次。 2020年是美国总统大选年&#xff0c;共和党候选人唐纳德特朗普和民主党对手乔拜登竞选总统。 甲乙双方候选人在金融贸易&#xff0c;经济金融治理&#xff0c;…

离散低通滤波方法

低通滤波器允许低频信号通过&#xff0c;并抑制高频信号。其核心思想是在频率域上通过移除高频成分来平滑信号。这在去噪、平滑和提取基本频率成分时非常有用。 离散低通滤波方法通常采用一阶低通滤波器进行处理。一阶低通滤波器是一种常见的数字滤波器&#xff0c;能够将信号…

电脑出现xinput1_3.dll的错误提示怎么办?有什么办法可以解决

电脑如果缺失了xinput1_3.dll还是一件比较复杂的事情&#xff0c;那么电脑出现xinput1_3.dll的错误提示怎么办&#xff0c;又有什么办法可以解决xinput1_3.dll&#xff1f;今天我们就来聊聊xinput1_3.dll丢失的解决办法&#xff0c;来看看都有哪些办法可以解决吧。 一.常见的问…