前言
2021年2月份,CLIP模型被提出,想法很简单,性能高效,而且具备很好的泛化性。我在这里简单谈论下我对CLIP模型的理解,以及发现的一些问题。
我是在沐神的视频中了解的CLIP, 里面提到CLIP最大的贡献在于打破了固定类别标签范式。我对这句话是这样理解的:就拿一般的分类任务来说,每一张图片对应一个类别,类别数量都是固定的,当模型训练好后,在实际使用过程中,一但出现一个从未出现的类别,模型是无法识别出来的。但是CLIP模型不一样,CLIP在训练的过程中,是将句子和图片匹配,然后在推理过程中找到与之最接近的模板句子。举个例子:CLIP模型在训练过程中,用到了4亿组图像文本对,可以说是涵盖了自然界中的大部分场景,在迁移学习时,即使从未见过三轮车这个类别,但一定见过与三轮车描述相关的图像文本对,从而在推理过程中将其识别为三轮车类。
CLIP模型的训练以及推理过程
数据集是若干的图像文本对,CLIP用了近4亿组。在训练过程中,取一个batch_size的图像文本对,图像经过Image Encode, 文本经过Text Encoder,然后在向量之间计算余弦相似度,结果就如图像所示,对象线上的元素分别是一一对应的,那么文本编码和图像编码之间的相似度的也该是最高的,即在对比学习中,对角线上的元素即为正样本,其余非对角线元素为负样本。因此这个模型经过训练后,能实现的最终理想目标就是一组图像文本对,图像经过Image Encoder编码和文本经过Text Encoder的编码应该是一摸一样的(显然,并不可能,但是可以保证两个编码的相似度尽可能的高)。
接下来就是推理过程了,可以看出,CLIP训练好的模型并不具备分类头,得到的最终结果就是两个Encoder,同一组图像文本对经这两个Encoder的编码相似度会很高。 推理过,我们需要先给出类别模型,即将一个类别标签变成一个句子
这些类别标签的句子讲过Text Encoder后会生成对应的文本编码,在推理过程中,给出一张图片,经过Image Encoder后得到图像编码,我们只需要比较图像编码和哪个类别文本编码的相似度最高,图像即为对应类别。
CLIP模型伪代码
CLIP论文中并未给出训练过程,仅给出了伪代码,将在下面展示,以及较为权威的huggingface团队实现的CLIP源码。
然后是huggingface团队在CLIPModel中的损失函数实现:
image_embeds = vision_outputs[1]image_embeds = self.visual_projection(image_embeds)text_embeds = text_outputs[1]text_embeds = self.text_projection(text_embeds)# normalized featuresimage_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)# cosine similarity as logitslogit_scale = self.logit_scale.exp()logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scalelogits_per_image = logits_per_text.t()loss = Noneif return_loss:loss = clip_loss(logits_per_text)# contrastive loss function, adapted from# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.htmldef contrastive_loss(logits: torch.Tensor) -> torch.Tensor:return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))def clip_loss(similarity: torch.Tensor) -> torch.Tensor:caption_loss = contrastive_loss(similarity)image_loss = contrastive_loss(similarity.t())return (caption_loss + image_loss) / 2.0
下面是自己的理解:
- 首先是图像文本编码器,编码结果维度并不一致,无法计算相似度,因此一个learn prob将维度统一
- 编码结果归一化
- 对编码计算相似度矩阵
- 计算对比损失
对我来说这个对比损失是最难理解的部分,为什么通过交叉熵损失即实现了对角线全为正样本,其余均为负样本的效果。下面来看交叉熵损失的原理。
从伪代码可以看出,对于相似度矩阵,沿行这个维度来看,可以看成是每张图片与各个文本的相似度,这个一个多分类问题,与之对应的label恰好是第i行这个数字i。
这里可以看出CLIP模型所用的对比损失函数,只考虑了如何拉近正样本对之间的距离,并未考虑负样本之间的关系。即它只关心对于正样本对之间相似性,忽略了负样本至之间的差异性。这在CLIP模型中并无太大影响,因为CLIP模型的训练数据太多,同一个Batch Size中很难出现重复数据,自然所有负样本的差异性没有区别。但是我自己在训练过程中涉及到的负样本十分接近,这时候如果不考虑负样本之间的差异性,模型很难拟合。