(表征学习论文阅读)A Simple Framework for Contrastive Learning of Visual Representations

Chen T, Kornblith S, Norouzi M, et al. A simple framework for contrastive learning of visual representations[C]//International conference on machine learning. PMLR, 2020: 1597-1607.

1. 前言

本文作者为了了解对比学习是如何学习到有效的表征,对本文所提出的三大组件进行了全面的研究:

  1. 各种数据增强手段的组合在表征学习中起到了重要作用;
  2. 在表征和对比损失之间引入非线性变换能够有效提高表征质量;
  3. 对比学习相较于监督学习需要更大的batch size和更多的训练步数。

在没有人类标注或者监督的情况下学习数据的有效表征是一个长期存在的难题,目前的主要工作可以分为两类:

  1. 基于生成模型的方法
    例如VQ-VAE,MAE,BERT
  2. 基于判别模型的方法
    例如MoCo,CLIP

2. 方法

本文提出了一个框架SimCLR,通过最大化同一数据的不同数据增强处理后的两个视角之间的相似度来学习有效表征。
在这里插入图片描述

  1. 如图所示,本文首先将数据 x x x进行两个不同的增强,这里作者使用了三种简单的数据增强方法:随机裁剪后再调整到原始大小、随机颜色失真、高斯模糊。
  2. f ( ∙ ) f(\bullet) f()代表编码器,这里作者使用的是同一个编码器来对两个视角数据进行编码
  3. 最后编码器输出的结果通过非线性变换 g ( ∙ ) g(\bullet) g()得到 z i z_i zi z j z_j zj,两个向量构成了一组正例,进行相似度计算,也就是简单的单位向量内积计算出余弦相似度。目标就是最大化两者的余弦相似度。同时,一个batch中其他的数据构成了负例,最小化与负例的相似度。注意最终训练完成的编码器我们是需要舍弃掉非线性变换的。
    本文使用的损失函数就是最基本的InfoNCE损失,具体可以参考我的另一篇讲解InfoNCE的博文。
    在这里插入图片描述
    在这里插入图片描述

3. 代码

这里仅提供文章提到的两个点的代码:

  1. 数据增强
    高斯模糊
import numpy as np
import torch
from torch import nn
from torchvision.transforms import transformsnp.random.seed(0)class GaussianBlur(object):"""blur a single image on CPU"""def __init__(self, kernel_size):radias = kernel_size // 2kernel_size = radias * 2 + 1self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),stride=1, padding=0, bias=False, groups=3)self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),stride=1, padding=0, bias=False, groups=3)self.k = kernel_sizeself.r = radiasself.blur = nn.Sequential(nn.ReflectionPad2d(radias),self.blur_h,self.blur_v)self.pil_to_tensor = transforms.ToTensor()self.tensor_to_pil = transforms.ToPILImage()def __call__(self, img):img = self.pil_to_tensor(img).unsqueeze(0)sigma = np.random.uniform(0.1, 2.0)x = np.arange(-self.r, self.r + 1)x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))x = x / x.sum()x = torch.from_numpy(x).view(1, -1).repeat(3, 1)self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))with torch.no_grad():img = self.blur(img)img = img.squeeze()img = self.tensor_to_pil(img)return img

组合各类增强手段

class ContrastiveLearningDataset:def __init__(self, root_folder=r"D:\pyproject\representation_learning\data"):self.root_folder = root_folder@staticmethoddef get_simclr_pipeline_transform(size, s=1):"""Return a set of data augmentation transformations as described in the SimCLR paper."""color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),transforms.RandomHorizontalFlip(),transforms.RandomApply([color_jitter], p=0.8),transforms.RandomGrayscale(p=0.2),GaussianBlur(kernel_size=int(0.1 * size)),transforms.ToTensor()])return data_transformsdef get_dataset(self, name, n_views):valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True,transform=ContrastiveLearningViewGenerator(self.get_simclr_pipeline_transform(32),n_views),download=True),'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled',transform=ContrastiveLearningViewGenerator(self.get_simclr_pipeline_transform(96),n_views),download=True)}try:dataset_fn = valid_datasets[name]except KeyError:raise InvalidDatasetSelection()else:return dataset_fn()
  1. 非线性变换
class ResNetSimCLR(nn.Module):def __init__(self, base_model, out_dim):super(ResNetSimCLR, self).__init__()self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),"resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}self.backbone = self._get_basemodel(base_model)dim_mlp = self.backbone.fc.in_features# add mlp projection head# 修改resnet最后一层的全连接层即可self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)def _get_basemodel(self, model_name):try:model = self.resnet_dict[model_name]except KeyError:raise InvalidBackboneError("Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")else:return modeldef forward(self, x):return self.backbone(x)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/301932.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

3. Django 初探路由

3. 初探路由 一个完整的路由包含: 路由地址, 视图函数(或者视图类), 可选变量和路由命名. 本章讲述Django的路由编写规则与使用方法, 内容分为: 路由定义规则, 命名空间与路由命名, 路由的使用方式.3.1 路由定义规则 路由称为URL (Uniform Resource Locator, 统一资源定位符)…

蓝桥杯简单模板

目录 最大公约数 两个数的最大公约数 多个数的最大公约数 最小公倍数 两个数的最小公倍数 多个数的最小公倍数 素数 ​编辑 位数分离 正写 ​编辑 反写 闰年 最大公约数 两个数的最大公约数 之前看见的是辗转相除法,例如现在让算一个49,21…

该主机与 Cloudera Manager Server 失去联系的时间过长。 该主机未与 Host Monitor 建立联系

该主机与 Cloudera Manager Server 失去联系的时间过长。 该主机未与 Host Monitor 建立联系 这个去集群主机cm界面上看会出现这个错误 排查思路: 一般比较常见的原因可能是出问题的主机和集群主节点的时间对应不上了。还有就是cm agent服务出现问题了 去该主机的…

React - 你使用过高阶组件吗

难度级别:初级及以上 提问概率:55% 高阶组件并不能单纯的说它是一个函数,或是一个组件,在React中,函数也可以做为一种组件。而高阶组件就是将一个组件做为入参,被传入一个函数或者组件中,经过一定的加工处理,最终再返回一个组件的组合…

不使用 Docker 构建 Triton 服务器并在 Google Colab 平台上部署 HuggingFace 模型

Build Triton server without docker and deploy HuggingFace models on Google Colab platform EnvironmentBuilding Triton serverDeploying HuggingFace models客户端推荐阅读参考 Environment 根据Triton 环境对应表 ,Colab 环境缺少 tensorrt-8.6.1&#xff0…

IP地址到底有什么用

IP地址在计算机网络中的作用至关重要,它不仅是设备在网络中的唯一标识,更是实现网络通信、网络管理和安全的关键要素。下面,我们将从多个方面详细阐述IP地址的作用。 首先,IP地址作为设备的唯一标识,为网络通信提供了…

再探Java为面试赋能(二)Java基础知识(二)反射机制、Lambda表达式、多态

文章目录 前言1.4 反射机制1.4.1 Class对象的获取1.4.2 Class类的方法1.4.3 通过反射机制修改只读类的属性 1.5 Lambda表达式1.5.1 函数式接口1.5.2 Lambda表达式的使用 1.6 多态1.6.1 多态的概念1.6.2 多态的实现条件1.6.3 重载(Overload)和重写&#x…

用Python+OpenCV截取视频中所有含有字幕的画面

1、需求背景 有的视频文件的字幕已经压制到了视频的图像中,不能单独提取出字幕文件。网上的 “提取视频字幕” 网站多为提取视频中的字幕文件,而非识别视频图像中的字幕。少数通过OCR技术识别画面中字幕的工具需要在线运行、运行速度较慢,或…

力扣2- 两数相加

给你两个 非空 的链表,表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的,并且每个节点只能存储 一位 数字。 请你将两个数相加,并以相同形式返回一个表示和的链表。 你可以假设除了数字 0 之外,这两个数都不会以 0 …

前端layui自定义图标的简单使用

iconfont-阿里巴巴矢量图标库 2. 3. 4.追加新图标 5.文件复制追加新图标

TCP/IP协议、HTTP协议和FTP协议等网络协议简介

文章目录 一、常见的网络协议二、TCP/IP协议1、TCP/IP协议模型被划分为四个层次2、TCP/IP五层模型3、TCP/IP七层模型 三、FTP网络协议四、Http网络协议1、Http网络协议简介2、Http网络协议的内容3、HTTP请求协议包组成4、HTTP响应协议包组成 一、常见的网络协议 常见的网络协议…

DIY可视化UniApp表格组件

表格组件在移动端的用处非常广泛,特别是在那些需要展示结构化数据、进行比较分析或提供详细信息的场景中。数据展示与整理:表格是展示结构化数据的理想方式,特别是在需要展示多列和多行数据时。通过表格,用户可以轻松浏览和理解数…

vue 中使 date/time/datetime 类型的 input 支持 placeholder 方法

一般在开发时,设置了 date/time/datetime 等类型的 input 属性 placeholder 提示文本时, 发现实际展示中却并不生效,如图: 处理后效果如图: 处理逻辑 判断表单项未设置值时,则设置其伪类样式,文…

2024-04-08 NO.6 Quest3 自定义交互事件

文章目录 1 交互事件——更改 Cube 颜色2 交互事件——创建 Cube2.1 非代码方式2.2 代码方式 ​ 在开始操作前,我们导入上次操作的场景,相关介绍在 《2024-04-08 NO.5 Quest3 手势追踪进行 UI 交互-CSDN博客》 文章中。 1 交互事件——更改 Cube 颜色 …

知识管理系统|基于Springboot和vue的知识管理系统设计与实现(源码+数据库+文档)

知识管理 目录 基于Springboot和vue的知识管理系统设计与实现 一、前言 二、系统设计 三、系统功能设计 1、前台: 5.2.2 文章信息 5.3.1 论坛交流 2、后台 用户管理 5.1.2 文章分类 5.2.1 资料分类 四、数据库设计 五、核心代码 六、论文参考 七、最…

OpenHarmony实战:Combo解决方案之W800芯片移植案例

本方案基于OpenHarmony LiteOS-M内核,使用联盛德W800芯片的润和软件海王星系列Neptune100开发板,进行开发移植。 移植架构采用Board与SoC分离方案,支持通过Kconfig图形化配置编译选项,增加玄铁ck804ef架构移植,实现了…

老子云、AMRT3D、眸瑞科技

老子云概述 老子云3D可视化快速开发平台,集云压缩、云烘焙、云存储云展示于一体,使3D模型资源自动输出至移动端PC端、Web端,能在多设备、全平台进行展示和交互,是全球领先、自主可控的自动化3D云引擎。 平台架构 平台特性 1、基…

Java | Leetcode Java题解之第18题四数之和

题目&#xff1a; 题解&#xff1a; class Solution {public List<List<Integer>> fourSum(int[] nums, int target) {List<List<Integer>> quadruplets new ArrayList<List<Integer>>();if (nums null || nums.length < 4) {return…

spring cloud gateway openfeign 联合使用产生死锁问题

spring cloud gateway openfeign 联合使用产生死锁问题&#xff0c;应用启动的时候阻塞卡住。 spring.cloud 版本如下 <dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-dependencies</artifactId><vers…

【Leetcode每日一题】 递归 - 二叉树剪枝(难度⭐⭐)(50)

1. 题目解析 题目链接&#xff1a;814. 二叉树剪枝 这个问题的理解其实相当简单&#xff0c;只需看一下示例&#xff0c;基本就能明白其含义了。 2.算法原理 想象一下&#xff0c;你有一堆层层叠叠的积木&#xff0c;你想从底部开始&#xff0c;把那些标记为0的积木拿走。如…