简单谈谈 EMP-SSL:自监督对比学习的一种极简主义风

论文链接:https://arxiv.org/pdf/2304.03977.pdf

代码:https://github.com/tsb0601/EMP-SSL

其他学习链接:突破自监督学习效率极限!马毅、LeCun联合发布EMP-SSL:无需花哨trick,30个epoch即可实现SOTA


主要思想

如图,一张图片裁剪成不同的 patch,对不同的 patch 做数据增强,分别输入 encoder,得到多个 embedding,对它们求均值,得到 \bar z 作为这张图片的 embedding。最后,拉近每个 patch 的 embedding 和图片的 embedding(\bar z)之间的余弦距离;再用 Total Coding Rate(TCR) 防止坍塌(即 encoder 对所有输入都输出相同的 embedding)

图片

图片

Total Coding Rate(TCR)

公式如下:

图片

其中,det 表示求矩阵的行列式,d 是 feature vector 的 dimension,b 是 batch size

查了查该公式的含义:expand all features of Z as large as possible,即尽可能拉远矩阵中特征之间的距离。

源自 PPT 第 24 页:

https://s3.amazonaws.com/sf-web-assets-prod/wp-content/uploads/2021/06/15175515/Deep_Networks_from_First_Principles.pdf

至于为什么最大化该公式的值就可以拉远矩阵中特征之间的距离,这背后的数学原理真难啃啊 /(ㄒoㄒ)/~~


核心代码解读

数据处理

https://github.com/tsb0601/EMP-SSL/blob/main/dataset/aug.py#L116C1-L138C27

class ContrastiveLearningViewGenerator(object):def __init__(self, num_patch = 4):self.num_patch = num_patchdef __call__(self, x):normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])aug_transform = transforms.Compose([transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),transforms.RandomGrayscale(p=0.2),GBlur(p=0.1),transforms.RandomApply([Solarization()], p=0.1),transforms.ToTensor(),  normalize])augmented_x = [aug_transform(x) for i in range(self.num_patch)]return augmented_x

由此看出返回的 数据 为:长度为 num_patches 个 tensor 的列表。其中,每个 tensor 的 shape 为 (B, C, H, W)。

主函数

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L148C9-L162C63

for step, (data, label) in tqdm(enumerate(dataloader)):net.zero_grad()opt.zero_grad()data = torch.cat(data, dim=0) data = data.cuda()z_proj = net(data)z_list = z_proj.chunk(num_patches, dim=0)z_avg = chunk_avg(z_proj, num_patches)# Contractive Lossloss_contract, _ = contractive_loss(z_list, z_avg)loss_TCR = cal_TCR(z_proj, criterion, num_patches)

这里要稍微注意一下几个变量的 shape:

  • data 被 cat 完后:(num_patches * B,C,H,W)
  • z_proj:(num_patches * B,C)
  • z_list:(num_patches,B,C)
  • z_avg:(B,C)

其中,chunk_avg 就是对来自同一张图片的不同 patch 的 embedding 求均值(\bar z):

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L67

def chunk_avg(x,n_chunks=2,normalize=False):x_list = x.chunk(n_chunks,dim=0)x = torch.stack(x_list,dim=0)if not normalize:return x.mean(0)else:return F.normalize(x.mean(0),dim=1)

loss

contractive_loss 就是计算每个 patch 的 embedding 和均值(\bar z)的余弦距离:

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L76

class Similarity_Loss(nn.Module):def __init__(self, ):super().__init__()passdef forward(self, z_list, z_avg):z_sim = 0num_patch = len(z_list)z_list = torch.stack(list(z_list), dim=0)z_avg = z_list.mean(dim=0)z_sim = 0for i in range(num_patch):z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean()z_sim = z_sim/num_patchz_sim_out = z_sim.clone().detach()return -z_sim, z_sim_out

TCR loss:最大化矩阵之间特征的距离,即拉远负样本(不是来自同一个样本的 patches)之间的距离

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L96

def cal_TCR(z, criterion, num_patches):z_list = z.chunk(num_patches,dim=0)loss = 0for i in range(num_patches):loss += criterion(z_list[i])loss = loss/num_patchesreturn loss

需要注意:函数输入的 z 是 z_proj,形状为(num_patches * B,C)。

所以,函数内部 z_list 的形状为(num_patches,B,C),即将数据分为了 num_patches 个组,每个组包含了来自不同图片里 patch 的 embedding。再分别对每个组求 TCR loss,最大化组内(不同图片的 patch)特征的距离。

所以,公式中的 Z 指的是一组来自不同图片里 patch 的 embedding,形状为(B,C)。

每个组内求 TCR loss 的代码按照公式计算,如下: 

图片

https://github.com/tsb0601/EMP-SSL/blob/main/loss.py#L76

class TotalCodingRate(nn.Module):def __init__(self, eps=0.01):super(TotalCodingRate, self).__init__()self.eps = epsdef compute_discrimn_loss(self, W):"""Discriminative Loss."""p, m = W.shape  #[d, B]I = torch.eye(p,device=W.device)scalar = p / (m * self.eps)logdet = torch.logdet(I + scalar * W.matmul(W.T))return logdet / 2.def forward(self,X):return - self.compute_discrimn_loss(X.T)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/89832.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue3 setup tsx 子组件向父组件传值 emit

需求:Vue3 setup 父组件向子组件传值,子组件接收父组件传入的值;子组件向父组件传值,父组件接收的子组件传递的值。 父组件:parent.tsx: import { defineComponent, ref, reactive } from vue; import To…

【STM32】利用CubeMX对FreeRTOS用按键控制任务

对于FreeRTOS中的操作,最常用的就是创建、删除、暂停和恢复任务。 此次实验目标: 1.创建任务一:LED1每间隔1秒闪烁一次,并通过串口打印 2.创建任务二:LED2每间隔0.5秒闪烁一次,并通过串口打印 3.创建任…

【工作记录】mysql中实现分组统计的三种方式

前言 实际工作中对范围分组统计的需求还是相对普遍的,本文记录下在mysql中通过函数和sql完成分组统计的实现过程。 数据及期望 比如我们获取到了豆瓣电影top250,现在想知道各个分数段的电影总数. 表数据如下: 期望结果: 实现方案 主要思路是根据s…

SpringMVC拦截器

1.拦截器简介 拦截器(Interceptor)是一种动态拦截方法调用的机制,在SpringMVC中动态拦截控制器方法的执行 作用: 在指定的方法调用前后执行预先设定的代码 阻止原始方法的执行 总结:拦截器就是用来做增强 看完以后&#xff0…

【在一个升序数组中插入一个数仍升序输出】

在一个升序数组中插入一个数仍升序输出 题目举例: 有一个升序数组nums,给一个数字data,将data插入数组nums中仍旧保证nums升序,返回数组中有效元素个数。 比如:nums[100] {1, 2, 3, 5, 6, 7, 8, 9} size 8 data 4 …

elementUi表单恢复至初始状态并不触发表单验证

elementUi表单恢复至初始状态并不触发表单验证 1.场景再现2.解决方法 1.场景再现 左侧是树形列表,右侧是显示节点的详情,点击按钮应该就是新增一个规则的意思,表单内容是没有改变的,所以就把需要把表单恢复至初始状态并不触发表单…

正则表达式试炼

序 我希望在这里列出我很多想写的正则表达式,很多我想写,但是不知道怎么写的。分享点滴案例。未来这个文章会越来越长 前言 互联网时代,除了文本还有更好的学习方式,下面是几个不错的练习网站,如果你想系统地学习&a…

深入了解Linux运维的重要性与最佳实践

Linux作为开源操作系统的代表,在企业级环境中的应用越来越广泛。而在保障Linux系统的正常运行和管理方面,Linux运维显得尤为关键。本文将介绍Linux运维的重要性以及一些最佳实践,帮助读者更好地了解和掌握Linux系统的运维技巧。 首先&#xf…

如何更快地执行 Selenium 测试用例?

前言: 当我们谈论自动化时,首先想到的工具之一是 Selenium。我们都知道Selenium WebDriver 是一个出色的 Web 自动化工具。实施Selenium 自动化测试的主要原因是加速 selenium 测试。在大多数情况下,Selenium 的性能比手动的要好得多。但是&…

离线安装vscode插件,导出 Visual Studio Code 的扩展应用,并离线安装

在没有网络的情况下,如何安装vscode插件 1.使用之前电脑安装过的插件包 Visual Studio Code 的扩展应用安装位置在文件夹 .vscode/extensions 下。不同平台,它位于: Windows %USERPROFILE%\.vscode\extensions Mac ~/.vscode/extensions L…

C字符串练习题(6.3.1)

编写一个程序&#xff0c;从键盘上读入一个小于1000的正整数&#xff0c;然后创建并输出一个字符串&#xff0c;说明该整数的值。例如&#xff0c;输入941&#xff0c;程序产生的字符串是“Nine hundred and forty one”。 #include<stdlib.h> #include<string.h>…

【JAVA】我们常常谈到的方法是指什么?

个人主页&#xff1a;【&#x1f60a;个人主页】 系列专栏&#xff1a;【❤️初识JAVA】 文章目录 前言方法方法的分类方法的定义方法调用方法重载 前言 在之前的文章中我们总是会介绍到类中的各式各样的方法&#xff0c;也许在应用中我们对它已经有了初步的了解&#xff0c;今…

# ⛳ Docker 安装、配置和详细使用教程-Win10专业版

目录 ⛳ Docker 安装、配置和详细使用教程-Win10专业版&#x1f69c; 一、win10 系统配置&#x1f3a8; 二、Docker下载和安装&#x1f3ed; 三、Docker配置&#x1f389; 四、Docker入门使用 ⛳ Docker 安装、配置和详细使用教程-Win10专业版 &#x1f69c; 一、win10 系统配…

使用docker快速搭建wordpress服务,并指定域名访问

文章目录 引入使用docker快速跑起服务创建数据库安装wordpress服务配置域名 引入 wordpress是一个基于PHP语言编写的开源的内容管理系统&#xff08;CMS&#xff09;&#xff0c;它有丰富的插件和主题&#xff0c;可以非常简单的创建各种类型的网站&#xff0c;包括企业网站、…

vuejs 设计与实现 - 渲染器 - 挂载与更新

渲染器的核心功能:挂载与更新 1.挂载子节点和元素的属性 1.2挂载子节点 (vnode.children) vnode.children可以是字符串类型的&#xff0c;也可以是数组类型的&#xff0c;如下&#xff1a; const vnode {type: div,children: [{type: p,children: hello}] } 可以看到&#…

【前端|Javascript第4篇】详解Javascript的事件模型:小白也能轻松搞懂!

前言 在当今数字时代&#xff0c;前端技术正日益成为塑造用户体验的关键。而其中一个不可或缺的核心概念就是JavaScript的事件模型。或许你是刚踏入前端领域的小白&#xff0c;或者是希望深入了解事件模型的开发者&#xff0c;不论你的经验如何&#xff0c;本篇博客都将带你揭开…

聚类与回归

聚类 聚类属于非监督式学习&#xff08;无监督学习&#xff09;&#xff0c;往往不知道因变量。 通过观察学习&#xff0c;将数据分割成多个簇。 回归 回归属于监督式学习&#xff08;有监督学习&#xff09;&#xff0c;知道因变量。 通过有标签样本的学习分类器 聚类和…

SpringCloud中 Sentinel 限流的使用

引入依赖 <dependency><groupId>com.alibaba.cloud</groupId><artifactId>spring-cloud-starter-alibaba-sentinel</artifactId> </dependency>手动编写限流规则&#xff0c;缺点是不够灵活&#xff0c;如果需要改变限流规则需要修改源码…

打破传统直播,最新数字化升级3DVR全景直播

导语&#xff1a; 近年来&#xff0c;随着科技的不断创新和发展&#xff0c;传媒领域也正经历着一场前所未有的变革。在这个数字化时代&#xff0c;直播已经不再仅仅是在屏幕上看到一些人的视频&#xff0c;而是将观众带入一个真实世界的全新体验。其中&#xff0c;3DVR全景直…

数据结构:力扣OJ题(每日一练)

题一&#xff1a;有效的括号 给定一个只包括 (&#xff0c;)&#xff0c;{&#xff0c;}&#xff0c;[&#xff0c;] 的字符串 s &#xff0c;判断字符串是否有效。 有效字符串需满足&#xff1a; 左括号必须用相同类型的右括号闭合。左括号必须以正确的顺序闭合。每个右括号…