浅谈 EMP-SSL + 代码解读:自监督对比学习的一种极简主义风

论文链接:https://arxiv.org/pdf/2304.03977.pdf

代码:https://github.com/tsb0601/EMP-SSL

其他学习链接:突破自监督学习效率极限!马毅、LeCun联合发布EMP-SSL:无需花哨trick,30个epoch即可实现SOTA


主要思想

如图,一张图片裁剪成不同的 patch,对不同的 patch 做数据增强,分别输入 encoder,得到多个 embedding,对它们求均值,得到 \bar z 作为这张图片的 embedding。最后,拉近每个 patch 的 embedding 和图片的 embedding(\bar z)之间的余弦距离;再用 Total Coding Rate(TCR) 防止坍塌(即 encoder 对所有输入都输出相同的 embedding)

图片

图片

Total Coding Rate(TCR)

公式如下:

图片

其中,det 表示求矩阵的行列式,d 是 feature vector 的 dimension,b 是 batch size

查了查该公式的含义:expand all features of Z as large as possible,即尽可能拉远矩阵中特征之间的距离。

源自 PPT 第 24 页:

https://s3.amazonaws.com/sf-web-assets-prod/wp-content/uploads/2021/06/15175515/Deep_Networks_from_First_Principles.pdf

至于为什么最大化该公式的值就可以拉远矩阵中特征之间的距离,这背后的数学原理真难啃啊 /(ㄒoㄒ)/~~


核心代码解读

数据处理

https://github.com/tsb0601/EMP-SSL/blob/main/dataset/aug.py#L116C1-L138C27

class ContrastiveLearningViewGenerator(object):def __init__(self, num_patch = 4):self.num_patch = num_patchdef __call__(self, x):normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])aug_transform = transforms.Compose([transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),transforms.RandomGrayscale(p=0.2),GBlur(p=0.1),transforms.RandomApply([Solarization()], p=0.1),transforms.ToTensor(),  normalize])augmented_x = [aug_transform(x) for i in range(self.num_patch)]return augmented_x

由此看出返回的 数据 为:长度为 num_patches 个 tensor 的列表。其中,每个 tensor 的 shape 为 (B, C, H, W)。

主函数

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L148C9-L162C63

for step, (data, label) in tqdm(enumerate(dataloader)):net.zero_grad()opt.zero_grad()data = torch.cat(data, dim=0) data = data.cuda()z_proj = net(data)z_list = z_proj.chunk(num_patches, dim=0)z_avg = chunk_avg(z_proj, num_patches)# Contractive Lossloss_contract, _ = contractive_loss(z_list, z_avg)loss_TCR = cal_TCR(z_proj, criterion, num_patches)

这里要稍微注意一下几个变量的 shape:

  • data 被 cat 完后:(num_patches * B,C,H,W)
  • z_proj:(num_patches * B,C)
  • z_list:(num_patches,B,C)
  • z_avg:(B,C)

其中,chunk_avg 就是对来自同一张图片的不同 patch 的 embedding 求均值(\bar z):

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L67

def chunk_avg(x,n_chunks=2,normalize=False):x_list = x.chunk(n_chunks,dim=0)x = torch.stack(x_list,dim=0)if not normalize:return x.mean(0)else:return F.normalize(x.mean(0),dim=1)

loss

contractive_loss 就是计算每个 patch 的 embedding 和均值(\bar z)的余弦距离:

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L76

class Similarity_Loss(nn.Module):def __init__(self, ):super().__init__()passdef forward(self, z_list, z_avg):z_sim = 0num_patch = len(z_list)z_list = torch.stack(list(z_list), dim=0)z_avg = z_list.mean(dim=0)z_sim = 0for i in range(num_patch):z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean()z_sim = z_sim/num_patchz_sim_out = z_sim.clone().detach()return -z_sim, z_sim_out

TCR loss:最大化矩阵之间特征的距离,即拉远负样本(不是来自同一个样本的 patches)之间的距离

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L96

def cal_TCR(z, criterion, num_patches):z_list = z.chunk(num_patches,dim=0)loss = 0for i in range(num_patches):loss += criterion(z_list[i])loss = loss/num_patchesreturn loss

需要注意:函数输入的 z 是 z_proj,形状为(num_patches * B,C)。

所以,函数内部 z_list 的形状为(num_patches,B,C),即将数据分为了 num_patches 个组,每个组包含了来自不同图片里 patch 的 embedding。再分别对每个组求 TCR loss,最大化组内(不同图片的 patch)特征的距离。

所以,公式中的 Z 指的是一组来自不同图片里 patch 的 embedding,形状为(B,C)。

每个组内求 TCR loss 的代码按照公式计算,如下: 

图片

https://github.com/tsb0601/EMP-SSL/blob/main/loss.py#L76

class TotalCodingRate(nn.Module):def __init__(self, eps=0.01):super(TotalCodingRate, self).__init__()self.eps = epsdef compute_discrimn_loss(self, W):"""Discriminative Loss."""p, m = W.shape  #[d, B]I = torch.eye(p,device=W.device)scalar = p / (m * self.eps)logdet = torch.logdet(I + scalar * W.matmul(W.T))return logdet / 2.def forward(self,X):return - self.compute_discrimn_loss(X.T)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/92018.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【需求输出】流程图输出

文章目录 1、什么是流程图2、绘制流程图的工具和基本要素3、流程图的分类和应用场景4、如何根据具体场景输出流程图 1、什么是流程图 2、绘制流程图的工具和基本要素 3、流程图的分类和应用场景 4、如何根据具体场景输出流程图

如何能够写出带货的爆文?

网络推广这个领域,公司众多价格差别很大,就拿软文文案这块来讲,有人报价几十块,也有人报价几千块。作为企业的营销负责人往往会被价格吸引,比价择优选用,结果写出来的文案不满意,修改也无从入手…

LVS简介及LVS-DR搭建

目录 一. LVS简介: 1.简介 2. LVS工作模式: 3. LVS调度算法: 4. LVS-DR集群介绍: 二.LVS-DR搭建 1.RS配置 1)两台RS,需要下载好httpd软件并准备好配置文件 2)添加虚拟IP(vip&…

openeuler服务器 ls 和ll 命令报错 command not found...

在openeuler服务器执行 ls 和ll 命令报错 command not found... 大概是系统环境变量导致的问题。 我在安装redis是否没有安装成功后就出现了这样的情况。编辑profile文件没有写正确,导致在命令行下ls 和 ll 等命令不能够识别。 重新设置一下环境变量。 export PAT…

excel快速选择数据、选择性粘贴、冻结单元格

一、如何快速选择数据 在excel中,希望选择全部数据,通常使用鼠标选择数据然后往下拉,当数据很多时,也可单击单元格使用ctrl A选中全部数据,此外,具体介绍另一种方法。 操作:ctrl shift 方向…

【第三阶段】kotlin语言空合并操作符

1.空操作符?: xxx?:“如果是null执行” 如果xxx是null,就执行?:后面的逻辑,如果不是null就执行?:前面的逻辑,后面的不在执行 fun main() {var name:String?"kotlin" namenullvar …

MAC环境,在IDEA执行报错java: -source 1.5 中不支持 diamond 运算符

Error:(41, 51) java: -source 1.5 中不支持 diamond 运算符 (请使用 -source 7 或更高版本以启用 diamond 运算符) 进入设置 修改java版本 pom文件中加入 <plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-compiler-plugin&l…

docker发展历史

docker 一、docker发展历史很久以前2013年2014年2015年2016年2017年2018年2019年及未来 二、 docker概述定义&#xff1a;docker底层运行原理:docker简述核心概念容器特点Docker与虚拟机的区别: 三、容器在内核中支持两种重要技术四、namespace的六项隔离五、虚拟化产品有哪些1…

ChatGPT等人工智能编写文章的内容今后将成为常态

BuzzFeed股价上涨200%可能标志着“转向人工智能”媒体趋势的开始。 周四&#xff0c;一份内部备忘录被华尔街日报透露BuzzFeed正计划使用ChatGPT聊天机器人-风格文本合成技术来自OpenAI&#xff0c;用于创建个性化盘问和将来可能的其他内容。消息传出后&#xff0c;BuzzFeed的…

QGIS3.28的二次开发九:添加矢量要素

对矢量要素的编辑是 GIS 软件很重要的功能点之一&#xff0c;也是最难实现的功能点之一。编辑矢量要素涉及到很多方面的考虑&#xff0c;包括且不限于矢量要素的几何类型&#xff0c;拓扑关系&#xff0c;构成要素的节点的增删改&#xff0c;编辑会话 (session) 的启动、回溯和…

MYSQL 作业三

创建一个student表格&#xff1a; create table student( id int(10) not null unique primary key, name varchar(20) not null, sex varchar(4), birth year, department varchar(20), address varchar(50) ); 创建一个score表格 create table score( id int(10) n…

IPv4分组

4.3.1 IPv4分组 IP协议定义数据传送的基本单元——IP分组及其确切的数据格式 1. IPv4分组的格式 IPv4分组由首部和数据部分&#xff08;TCP、UDP段&#xff09;组成&#xff0c;其中首部分为固定部分&#xff08;20字节&#xff09;和可选字段&#xff08;长度可变&#xff0…

使用MAT分析OOM问题

OOM和内存泄漏在我们的工作中&#xff0c;算是相对比较容易出现的问题&#xff0c;一旦出现了这个问题&#xff0c;我们就需要对堆进行分析。 一般情况下&#xff0c;我们生产应用都会设置这样的JVM参数&#xff0c;以便在出现OOM时&#xff0c;可以dump出堆内存文件&#xff…

Monge矩阵

Monge矩阵 对一个m*n的实数矩阵A&#xff0c;如果对所有i&#xff0c;j&#xff0c;k和l&#xff0c;1≤ i<k ≤ m和1≤ j<l ≤ n&#xff0c;有 A[i,j]A[k,l] ≤ A[i,l]A[k,j] 那么&#xff0c;此矩阵A为Monge矩阵。 换句话说&#xff0c;每当我们从矩阵中挑…

jQuery EasyUI datagrid 无记录时,增加“暂无数据“提示

1、在onLoadSuccess中添加如下代码&#xff1a; if (data.total 0) {var body $(this).data().datagrid.dc.body2;body.find(table tbody).append(<tr><td width" body.width() " style"height: 35px; text-align: center;"><h5>暂…

C++的IO流

目录 C语言的输入与输出 流是什么 CIO流 C标准IO流 C文件IO流 stringstream的简单介绍 在C语言中&#xff0c;如果想要将一个整形变量的数据转化为字符串格式&#xff0c;如何去做&#xff1f; 将数值类型数据格式化为字符串 字符串拼接 序列化和反序列化结构数据 注…

管理类联考——逻辑——综合推理——汇总篇——要点

一、真话假话题 #mermaid-svg-gmlWWCoVLQr21gdi {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-gmlWWCoVLQr21gdi .error-icon{fill:#552222;}#mermaid-svg-gmlWWCoVLQr21gdi .error-text{fill:#552222;stroke:#552…

00 - 环境配置

查看所有文章链接&#xff1a;&#xff08;更新中&#xff09;GIT常用场景- 目录 文章目录 1. 环境说明2. 安装配置2.1 配置user信息2.2 config的三个作用域 3. 建git仓库3.1 把已有的项目代码纳入git管理3.2 新建的项目直接用git管理3.3 配置local的user和email3.4 优先级&…

图像像素梯度

梯度 在高数中&#xff0c;梯度是一个向量&#xff0c;是有方向有大小。假设一二元函数f(x,y)&#xff0c;在某点的梯度有&#xff1a; 结果为&#xff1a; 即方向导数。梯度的方向是函数变化最快的方向&#xff0c;沿着梯度的方向容易找到最大值。 图像梯度 在一幅模糊图…

《电路》基础知识入门学习笔记

文章目录&#xff1a; 一&#xff1a;电路模型和电路规律 1.电路概述 2.电路模型 3.基本电路物理量&#xff1a;电流、电压、电功率和能量 4.电流和电压的参考方向 5.电路元件—电阻 6. 电路元件—电压源和电流源 7.受控电源 8.基尔霍夫&#xff08;后面都要用这个方法…