购物网站建设思路,中国建设网站红黑榜名单,销售系统软件,宽城区建设局网站前言
2021年2月份#xff0c;CLIP模型被提出#xff0c;想法很简单#xff0c;性能高效#xff0c;而且具备很好的泛化性。我在这里简单谈论下我对CLIP模型的理解#xff0c;以及发现的一些问题。
我是在沐神的视频中了解的CLIP, 里面提到CLIP最大的贡献在于打破了固定类… 前言
2021年2月份CLIP模型被提出想法很简单性能高效而且具备很好的泛化性。我在这里简单谈论下我对CLIP模型的理解以及发现的一些问题。
我是在沐神的视频中了解的CLIP, 里面提到CLIP最大的贡献在于打破了固定类别标签范式。我对这句话是这样理解的就拿一般的分类任务来说每一张图片对应一个类别类别数量都是固定的当模型训练好后在实际使用过程中一但出现一个从未出现的类别模型是无法识别出来的。但是CLIP模型不一样CLIP在训练的过程中是将句子和图片匹配然后在推理过程中找到与之最接近的模板句子。举个例子CLIP模型在训练过程中用到了4亿组图像文本对可以说是涵盖了自然界中的大部分场景在迁移学习时即使从未见过三轮车这个类别但一定见过与三轮车描述相关的图像文本对从而在推理过程中将其识别为三轮车类。
CLIP模型的训练以及推理过程
数据集是若干的图像文本对CLIP用了近4亿组。在训练过程中取一个batch_size的图像文本对图像经过Image Encode, 文本经过Text Encoder然后在向量之间计算余弦相似度结果就如图像所示对象线上的元素分别是一一对应的那么文本编码和图像编码之间的相似度的也该是最高的即在对比学习中对角线上的元素即为正样本其余非对角线元素为负样本。因此这个模型经过训练后能实现的最终理想目标就是一组图像文本对图像经过Image Encoder编码和文本经过Text Encoder的编码应该是一摸一样的显然并不可能但是可以保证两个编码的相似度尽可能的高。
接下来就是推理过程了可以看出CLIP训练好的模型并不具备分类头得到的最终结果就是两个Encoder同一组图像文本对经这两个Encoder的编码相似度会很高。 推理过我们需要先给出类别模型即将一个类别标签变成一个句子 这些类别标签的句子讲过Text Encoder后会生成对应的文本编码在推理过程中给出一张图片经过Image Encoder后得到图像编码我们只需要比较图像编码和哪个类别文本编码的相似度最高图像即为对应类别。
CLIP模型伪代码
CLIP论文中并未给出训练过程仅给出了伪代码将在下面展示以及较为权威的huggingface团队实现的CLIP源码。 然后是huggingface团队在CLIPModel中的损失函数实现
image_embeds vision_outputs[1]image_embeds self.visual_projection(image_embeds)text_embeds text_outputs[1]text_embeds self.text_projection(text_embeds)# normalized featuresimage_embeds image_embeds / image_embeds.norm(p2, dim-1, keepdimTrue)text_embeds text_embeds / text_embeds.norm(p2, dim-1, keepdimTrue)# cosine similarity as logitslogit_scale self.logit_scale.exp()logits_per_text torch.matmul(text_embeds, image_embeds.t()) * logit_scalelogits_per_image logits_per_text.t()loss Noneif return_loss:loss clip_loss(logits_per_text)# contrastive loss function, adapted from# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.htmldef contrastive_loss(logits: torch.Tensor) - torch.Tensor:return nn.functional.cross_entropy(logits, torch.arange(len(logits), devicelogits.device))def clip_loss(similarity: torch.Tensor) - torch.Tensor:caption_loss contrastive_loss(similarity)image_loss contrastive_loss(similarity.t())return (caption_loss image_loss) / 2.0
下面是自己的理解
首先是图像文本编码器编码结果维度并不一致无法计算相似度因此一个learn prob将维度统一编码结果归一化对编码计算相似度矩阵计算对比损失
对我来说这个对比损失是最难理解的部分为什么通过交叉熵损失即实现了对角线全为正样本其余均为负样本的效果。下面来看交叉熵损失的原理。 从伪代码可以看出对于相似度矩阵沿行这个维度来看可以看成是每张图片与各个文本的相似度这个一个多分类问题与之对应的label恰好是第i行这个数字i。
这里可以看出CLIP模型所用的对比损失函数只考虑了如何拉近正样本对之间的距离并未考虑负样本之间的关系。即它只关心对于正样本对之间相似性忽略了负样本至之间的差异性。这在CLIP模型中并无太大影响因为CLIP模型的训练数据太多同一个Batch Size中很难出现重复数据自然所有负样本的差异性没有区别。但是我自己在训练过程中涉及到的负样本十分接近这时候如果不考虑负样本之间的差异性模型很难拟合。