在自定义数据集上实现OpenAI CLIP

在2021年1月,OpenAI宣布了两个新模型:DALL-E和CLIP,它们都是以某种方式连接文本和图像的多模态模型。CLIP全称是Contrastive Language–Image Pre-training,一种基于对比文本-图像对的预训练方法。为什么要介绍CLIP呢?因为现在大火得Stable Diffusion 并不是单一模型,而是多个模型组成。其中会用到一个 Text encoder 将用户的文本输入进行编码,这个 text encoder 就是 CLIP 模型中 text encoder

CLIP模型在训练时,可以给它一个输入句子,并提取最相关的图像来配合它。CLIP学习了一个完整的句子和它所描述的图像之间的关系。也就是说它是在完整的句子上训练的,而不是像“汽车”、“狗”等离散的分类,这一点对于应用至关重要。当训练完整的短语时,模型可以学习更多的东西,并识别照片和文本之间的模式。他们还证明,当在相当大的照片和与之相对应的句子数据集上进行训练时,该模型是可以作为分类器的。CLIP在发布的时候能在无任何微调的情况下(zero-shot ),在 ImageNet 数据集上的分类表现超 ResNets-50 微调后的效果,也就是说他是非常有用的。

所以在本文中,我们将使用PyTorch中从头开始实现CLIP模型,以便我们对CLIP有一个更好的理解

这里就需要用到2个库:timm和transformers,我们先导入代码

 import osimport cv2import gcimport numpy as npimport pandas as pdimport itertoolsfrom tqdm.autonotebook import tqdmimport albumentations as Aimport matplotlib.pyplot as pltimport torchfrom torch import nnimport torch.nn.functional as Fimport timmfrom transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer

下一步就是预处理数据和通用配置config。config是一个普通的python文件,我们将所有的超参数放在里面,如果使用Jupyter Notebook的情况下,它是一个在Notebook开头定义的类。

 class CFG:debug = Falseimage_path = "../input/flickr-image-dataset/flickr30k_images/flickr30k_images"captions_path = "."batch_size = 32num_workers = 4head_lr = 1e-3image_encoder_lr = 1e-4text_encoder_lr = 1e-5weight_decay = 1e-3patience = 1factor = 0.8epochs = 2device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model_name = 'resnet50'image_embedding = 2048text_encoder_model = "distilbert-base-uncased"text_embedding = 768text_tokenizer = "distilbert-base-uncased"max_length = 200pretrained = True # for both image encoder and text encodertrainable = True # for both image encoder and text encodertemperature = 1.0# image sizesize = 224# for projection head; used for both image and text encodersnum_projection_layers = 1projection_dim = 256 dropout = 0.1

还有一些我们自定义指标的辅助类

 class AvgMeter:def __init__(self, name="Metric"):self.name = nameself.reset()def reset(self):self.avg, self.sum, self.count = [0] * 3def update(self, val, count=1):self.count += countself.sum += val * countself.avg = self.sum / self.countdef __repr__(self):text = f"{self.name}: {self.avg:.4f}"return textdef get_lr(optimizer):for param_group in optimizer.param_groups:return param_group["lr"]

我们的目标是描述图像和句子。所以数据集必须同时返回句子和图像。所以需要使用DistilBERT标记器对句子(标题)进行标记,然后将标记id (input_ids)和注意掩码提供给DistilBERT。DistilBERT比BERT 模型要小,但是模型的结果都差不多,所以我们选择使用它。

下一步就是使用HuggingFace tokenizer进行标记化。在__init__中获得的tokenizer对象,将在模型运行时加载。标题被填充并截断到预定的最大长度。在加载相关图像之前,我们将在**getitem**中加载一个编码的标题,这是一个带有键input_ids和attention_mask的字典,并对其进行转换和扩充(如果有的话)。然后把它变成一个张量,并以“image”作为键存储在字典中。最后我们将标题的原始文本与关键字“标题”一起输入字典。

 class CLIPDataset(torch.utils.data.Dataset):def __init__(self, image_filenames, captions, tokenizer, transforms):"""image_filenames and cpations must have the same length; so, if there aremultiple captions for each image, the image_filenames must have repetitivefile names """self.image_filenames = image_filenamesself.captions = list(captions)self.encoded_captions = tokenizer(list(captions), padding=True, truncation=True, max_length=CFG.max_length)self.transforms = transformsdef __getitem__(self, idx):item = {key: torch.tensor(values[idx])for key, values in self.encoded_captions.items()}image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)image = self.transforms(image=image)['image']item['image'] = torch.tensor(image).permute(2, 0, 1).float()item['caption'] = self.captions[idx]return itemdef __len__(self):return len(self.captions)def get_transforms(mode="train"):if mode == "train":return A.Compose([A.Resize(CFG.size, CFG.size, always_apply=True),A.Normalize(max_pixel_value=255.0, always_apply=True),])else:return A.Compose([A.Resize(CFG.size, CFG.size, always_apply=True),A.Normalize(max_pixel_value=255.0, always_apply=True),])

图像和文本编码器:我们将使用ResNet50作为图像编码器。

 class ImageEncoder(nn.Module):"""Encode images to a fixed size vector"""def __init__(self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable):super().__init__()self.model = timm.create_model(model_name, pretrained, num_classes=0, global_pool="avg")for p in self.model.parameters():p.requires_grad = trainabledef forward(self, x):return self.model(x)

使用DistilBERT作为文本编码器。使用CLS令牌的最终表示来获得句子的整个表示。

 class TextEncoder(nn.Module):def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):super().__init__()if pretrained:self.model = DistilBertModel.from_pretrained(model_name)else:self.model = DistilBertModel(config=DistilBertConfig())for p in self.model.parameters():p.requires_grad = trainable# we are using the CLS token hidden representation as the sentence's embeddingself.target_token_idx = 0def forward(self, input_ids, attention_mask):output = self.model(input_ids=input_ids, attention_mask=attention_mask)last_hidden_state = output.last_hidden_statereturn last_hidden_state[:, self.target_token_idx, :]

上面的代码已经将图像和文本编码为固定大小的向量(图像2048,文本768),我们需要图像和文本具有相似的尺寸,以便能够比较它们,所以我们把2048维和768维向量投影到256维(projection_dim),只有维度相同我们才能比较它们。

 class ProjectionHead(nn.Module):def __init__(self,embedding_dim,projection_dim=CFG.projection_dim,dropout=CFG.dropout):super().__init__()self.projection = nn.Linear(embedding_dim, projection_dim)self.gelu = nn.GELU()self.fc = nn.Linear(projection_dim, projection_dim)self.dropout = nn.Dropout(dropout)self.layer_norm = nn.LayerNorm(projection_dim)def forward(self, x):projected = self.projection(x)x = self.gelu(projected)x = self.fc(x)x = self.dropout(x)x = x + projectedx = self.layer_norm(x)return x

所以最后我们的CLIP模型就是这样:

 class CLIPModel(nn.Module):def __init__(self,temperature=CFG.temperature,image_embedding=CFG.image_embedding,text_embedding=CFG.text_embedding,):super().__init__()self.image_encoder = ImageEncoder()self.text_encoder = TextEncoder()self.image_projection = ProjectionHead(embedding_dim=image_embedding)self.text_projection = ProjectionHead(embedding_dim=text_embedding)self.temperature = temperaturedef forward(self, batch):# Getting Image and Text Featuresimage_features = self.image_encoder(batch["image"])text_features = self.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])# Getting Image and Text Embeddings (with same dimension)image_embeddings = self.image_projection(image_features)text_embeddings = self.text_projection(text_features)# Calculating the Losslogits = (text_embeddings @ image_embeddings.T) / self.temperatureimages_similarity = image_embeddings @ image_embeddings.Ttexts_similarity = text_embeddings @ text_embeddings.Ttargets = F.softmax((images_similarity + texts_similarity) / 2 * self.temperature, dim=-1)texts_loss = cross_entropy(logits, targets, reduction='none')images_loss = cross_entropy(logits.T, targets.T, reduction='none')loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)return loss.mean()#这里还加了一个交叉熵函数def cross_entropy(preds, targets, reduction='none'):log_softmax = nn.LogSoftmax(dim=-1)loss = (-targets * log_softmax(preds)).sum(1)if reduction == "none":return losselif reduction == "mean":return loss.mean()

这里需要说明下,CLIP使用 symmetric cross entropy 作为损失函数,可以降低噪音影响,提高模型鲁棒性,我们这里为了简单只是用cross entropy 。

我们可以进行测试:

 # A simple Examplebatch_size = 4dim = 256embeddings = torch.randn(batch_size, dim)out = embeddings @ embeddings.Tprint(F.softmax(out, dim=-1))

下一步就是训练了,有一些函数可以帮助我们加载训练和验证的dataloader

 def make_train_valid_dfs():dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv")max_id = dataframe["id"].max() + 1 if not CFG.debug else 100image_ids = np.arange(0, max_id)np.random.seed(42)valid_ids = np.random.choice(image_ids, size=int(0.2 * len(image_ids)), replace=False)train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)return train_dataframe, valid_dataframedef build_loaders(dataframe, tokenizer, mode):transforms = get_transforms(mode=mode)dataset = CLIPDataset(dataframe["image"].values,dataframe["caption"].values,tokenizer=tokenizer,transforms=transforms,)dataloader = torch.utils.data.DataLoader(dataset,batch_size=CFG.batch_size,num_workers=CFG.num_workers,shuffle=True if mode == "train" else False,)return dataloader

然后就是训练和评估

 def train_epoch(model, train_loader, optimizer, lr_scheduler, step):loss_meter = AvgMeter()tqdm_object = tqdm(train_loader, total=len(train_loader))for batch in tqdm_object:batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}loss = model(batch)optimizer.zero_grad()loss.backward()optimizer.step()if step == "batch":lr_scheduler.step()count = batch["image"].size(0)loss_meter.update(loss.item(), count)tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))return loss_meterdef valid_epoch(model, valid_loader):loss_meter = AvgMeter()tqdm_object = tqdm(valid_loader, total=len(valid_loader))for batch in tqdm_object:batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}loss = model(batch)count = batch["image"].size(0)loss_meter.update(loss.item(), count)tqdm_object.set_postfix(valid_loss=loss_meter.avg)return loss_meter

最后整合起来就是全部流程

 def main():train_df, valid_df = make_train_valid_dfs()tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)train_loader = build_loaders(train_df, tokenizer, mode="train")valid_loader = build_loaders(valid_df, tokenizer, mode="valid")model = CLIPModel().to(CFG.device)params = [{"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr},{"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr},{"params": itertools.chain(model.image_projection.parameters(), model.text_projection.parameters()), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}]optimizer = torch.optim.AdamW(params, weight_decay=0.)lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=CFG.patience, factor=CFG.factor)step = "epoch"best_loss = float('inf')for epoch in range(CFG.epochs):print(f"Epoch: {epoch + 1}")model.train()train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)model.eval()with torch.no_grad():valid_loss = valid_epoch(model, valid_loader)if valid_loss.avg < best_loss:best_loss = valid_loss.avgtorch.save(model.state_dict(), "best.pt")print("Saved Best Model!")lr_scheduler.step(valid_loss.avg)

应用:获取图像嵌入并找到匹配。

我们训练完成后如何实际应用呢?我们需要编写一个函数加载训练后的模型,为其提供验证集中的图像,并返回形状(valid_set_size, 256)和模型本身的image_embeddings。

 def get_image_embeddings(valid_df, model_path):tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)valid_loader = build_loaders(valid_df, tokenizer, mode="valid")model = CLIPModel().to(CFG.device)model.load_state_dict(torch.load(model_path, map_location=CFG.device))model.eval()valid_image_embeddings = []with torch.no_grad():for batch in tqdm(valid_loader):image_features = model.image_encoder(batch["image"].to(CFG.device))image_embeddings = model.image_projection(image_features)valid_image_embeddings.append(image_embeddings)return model, torch.cat(valid_image_embeddings)_, valid_df = make_train_valid_dfs()model, image_embeddings = get_image_embeddings(valid_df, "best.pt")def find_matches(model, image_embeddings, query, image_filenames, n=9):tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)encoded_query = tokenizer([query])batch = {key: torch.tensor(values).to(CFG.device)for key, values in encoded_query.items()}with torch.no_grad():text_features = model.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])text_embeddings = model.text_projection(text_features)image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)dot_similarity = text_embeddings_n @ image_embeddings_n.Tvalues, indices = torch.topk(dot_similarity.squeeze(0), n * 5)matches = [image_filenames[idx] for idx in indices[::5]]_, axes = plt.subplots(3, 3, figsize=(10, 10))for match, ax in zip(matches, axes.flatten()):image = cv2.imread(f"{CFG.image_path}/{match}")image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)ax.imshow(image)ax.axis("off")plt.show()

调用方法如下:

 find_matches(model, image_embeddings,query="one dog sitting on the grass",image_filenames=valid_df['image'].values,n=9)

可以看到我们自定义效果还是不错的(但是图里面有个猫,哈)。也就是说CLIP这种方法在小数据集上自定义也是可行的。

以下是本文的代码和数据集:

https://avoid.overfit.cn/post/25295aa8daee45fc8336b2e86a29106a

作者:Jyoti Dabass, Ph.D

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/129725.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C#】关于Array.Copy 和 GC

关于Array.Copy 和 GC //一个简单的 数组copy 什么情况下会触发GC呢[ReliabilityContract(Consistency.MayCorruptInstance, Cer.MayFail)]public static void Copy(Array sourceArray,long sourceIndex,Array destinationArray,long destinationIndex,long length);当源和目…

浅析Open vSwitch数据结构:哈希表hmap/smap/shash

文章目录 概述hmaphmap数据结构初始化hmap插入节点扩展hmap空间resize函数 删除节点遍历所有节点辅助函数hmap_first辅助函数hmap_next smapsmap数据结构插入节点删除节点查找节点遍历所有节点 shashshash数据结构插入节点删除节点查找节点遍历所有节点 概述 在OVS软件中&…

【网络安全带你练爬虫-100练】第23练:文件内容的删除+写入

目录 0x00 前言&#xff1a; 0x02 解决&#xff1a; 0x00 前言&#xff1a; 本篇博文可能会有一点点的超级呆 0x02 解决&#xff1a; 你是不是也会想&#xff1a; 使用pyrhon将指定文件夹位置里面的1.txt中数据全部删除以后---->然后再将参数req_text的值写入到1.txt …

rsa加密解密java和C#互通

前言 因为第三方项目是java的案例&#xff0c;但是原来的项目使用的是java&#xff0c;故需要将java代码转化为C#代码&#xff0c;其中核心代码就是RSA加密以及加签和验签&#xff0c;其他的都是api接口请求难度不大。 遇到的问题 java和c#密钥格式不一致&#xff0c;java使…

LeetCode(力扣)491. 递增子序列Python

LeetCode491. 递增子序列 题目链接代码 题目链接 https://leetcode.cn/problems/non-decreasing-subsequences/ 代码 class Solution:def backtracking(self, nums, index, result, path):if len(path) > 1:result.append(path[:])uset set()for i in range(index, len…

分类预测 | Matlab实现基于LFDA-SVM局部费歇尔判别数据降维结合支持向量机的多输入分类预测

分类预测 | Matlab实现基于LFDA-SVM局部费歇尔判别数据降维结合支持向量机的多输入分类预测 目录 分类预测 | Matlab实现基于LFDA-SVM局部费歇尔判别数据降维结合支持向量机的多输入分类预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 基于局部费歇尔判别数据降维的L…

日常开发小汇总(5)元素跟随鼠标移动(在视口下移动)

<div class"mark"><h1>title</h1><div><p>title 鼠标移动 盒子整体移动</p><p>test</p><p>Lorem ipsum dolor sit amet, consectetur adipisicing elit. Modi, porro.</p></div></div>cons…

springboot MongoDB 主从 多数据源

上一篇&#xff0c;我写了关于用一个map管理mongodb多个数据源&#xff08;每个数据源&#xff0c;只有单例&#xff09;的内容。 springboot mongodb 配置多数据源 临到部署到阿里云的测试环境&#xff0c;发现还需要考虑一下主从的问题&#xff0c;阿里云买的数据库&#x…

入门力扣自学笔记279 C++ (题目编号:1123)

1123. 最深叶节点的最近公共祖先 题目&#xff1a; 给你一个有根节点 root 的二叉树&#xff0c;返回它 最深的叶节点的最近公共祖先 。 回想一下&#xff1a; 叶节点 是二叉树中没有子节点的节点树的根节点的 深度 为 0&#xff0c;如果某一节点的深度为 d&#xff0c;那它…

【Flutter】引入网络图片时,提示:Failed host lookup: ‘[图片host]‘

在使用 NetworkImage 组件加载外部图片时&#xff0c;提示 Failed host lookup: [图片host] 错误。 排查方向 1、清理缓存 解决方案&#xff1a; 尝试flutter clean清空缓存后重新安装依赖flutter pub get重新构建项目flutter create . 走完上述三个步骤后&#xff0c;再次…

MySql学习笔记11——DBA命令介绍

DBA命令 数据导入 要进入Mysql 创建数据库 create database database_name;使用数据库 use database_name;初始化数据库 source .sql文件地址&#xff0c;不能加双引号&#xff1b;数据导出 要在windows的dos环境下进行 导出数据库 mysqldump database_name > 存放…

Android 文字转语音播放实现

1&#xff0c;TextToSpeech类是android自带的&#xff0c;但是部分设备需要支持TTS需要增加语音库&#xff0c;我使用的是讯飞语音&#xff08;离线的哦&#xff09;。请自行下载并安装讯飞语音APK&#xff0c;然后到系统设置中设置TTS功能默认使用该选项。有自带TTS库的可以省…

《存储IO路径》专题:块设备层多队列blk-mq架构

我们想象一下&#xff0c;你是一个餐厅的厨师&#xff0c;你要准备很多不同的菜肴&#xff0c;而每种菜肴需要不同的食材和烹饪时间。如果每道菜都按照需要的顺序来准备&#xff0c;那么你的工作效率一定会非常低。为了提高效率&#xff0c;你会怎么做呢&#xff1f; 在linux架…

基于SSM的家政服务网站

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用JSP技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…

windows 下载安装 mysql

windows 下载安装 mysql 官网地址&#xff1a;https://dev.mysql.com/ 下载地址&#xff1a;https://cdn.mysql.com//Downloads/MySQLInstaller/mysql-installer-community-8.0.34.0.msi 点击 Downloads 点击 MySQL Community (GPL) Downloads 点击 MySQL Installer for Window…

usb学习笔记

框架 usb 驱动是基于usb core 的&#xff0c;设备插上之后&#xff0c;host 层自然会进行识别&#xff0c;设备驱动通过core层的接口操作设备&#xff0c;而不用直接面对usb硬件。对于应用层需要封装成一个usb 的设备。 驱动是基于urb 数据进行操作的。 49 static void usb_mo…

Windows环境下Springboot3+Graalvm+Idea 打包成原生镜像 踩坑

https://github.com/oracle/graal/https://github.com/graalvm/graalvm-ce-builds/releases/对应关系graalvm-ce-java17-windows-amd64-X.X.X.zipnative-image-installable-svm-java17-windows-amd64-X.X.X.jar本人使用:graalvm-ce-java17-windows-amd64-23.0.1.zipnative-imag…

第4章_瑞萨MCU零基础入门系列教程之瑞萨 MCU 源码设计规范

本教程基于韦东山百问网出的 DShanMCU-RA6M5开发板 进行编写&#xff0c;需要的同学可以在这里获取&#xff1a; https://item.taobao.com/item.htm?id728461040949 配套资料获取&#xff1a;https://renesas-docs.100ask.net 瑞萨MCU零基础入门系列教程汇总&#xff1a; ht…

【C++11】{}初始化、std::initializer_list、decltype、STL新增容器

文章目录 1. C11简介2. 统一的列表初始化2.1 &#xff5b;&#xff5d;初始化2.2 std::initializer_list 3. 声明3.1 auto3.2 decltype 4. nullptr5. 范围for循环6. 智能指针7. C11STL中的一些变化8. 演示代码 1. C11简介 在2003年C标准委员会曾经提交了一份技术勘误表(简称TC1…

英语单词(1)

1.void:空的 2.main:主要的 3.class:类 4.system:系统 5.out: 输出 6.print:打印 7.public:公共的,公用的 8.static:静态的,静止的 9.oracle:甲骨文公司 10.eclipse: java编程语言