MoCo v1(CVPR 2020)原理与代码解读

paper:Momentum Contrast for Unsupervised Visual Representation Learning

official implementation:https://github.com/facebookresearch/moco

背景

最近的一些研究提出使用对比损失相关的方法进行无监督视觉表征学习并取得了不错的结果。尽管是受到不同motivation的启发,这些方法都可以看做是在构建一个动态字典。字典中的"keys"(tokens)从数据(图片或图片的patch)中采样并用一个编码器encoder网络来表示。无监督学习训练encoder来执行字典查找:一个encoded "query"应该与它匹配的key相似,而与其它的key不同。学习过程表述为最小化对比损失的过程。

存在的问题

从构建动态字典的角度来看,作者假设构建的字典应该具备两个特点:

  1. large即字典要足够大
  2. 在训练期间字典要保持一致性

从直觉上来说,一个更大的字典可以更好地对连续的、高维的视觉空间进行采样。而字典中的键应该由相同或相似的编码器表示,以便它们与query的比较是一致的。然而,一些使用对比损失的现有方法受限于这两个方面中的一个(具体将在后续的方法介绍中讨论)。

本文的创新点

本文提出了动量对比(Momentum Contrast,MoCo)作为一种构建大型和一致的字典的方法,用于对比损失的无监督学习,如图1所示。

作者维护了一个数据样本的队列作为字典,当前mini-batch的encoded representation进队,队列中最老的表示出队。队列将字典大小和batch size进行解耦从而使得字典可以非常大。此外由于字典的key来源于之前若干个mini-batch,作者提出了一个缓慢变化的key encoder,具体实现为query encoder的基于动量的移动平均值,从而保持一致性。 

无监督学习的一个主要目的是得到一个预训练表示,通过微调可以tranfer到下游任务中。作者通过实验表明,在7个与检测和分割相关的下游任务中,MoCo无监督预训练可以超过ImageNet有监督预训练。

方法介绍

Contrastive Learning as Dictionary Look-up

对比学习可以用来为字典查找任务训练一个编码器。对于一个encoded query \(q\) 和一组encoded样本 \(\{k_0,k_1,k_2,...\}\),后者是字典的keys。假设字典中有一个单独的key(表示为 \(k_+\))与 \(q\) 匹配,对比损失作为一个函数,当 \(q\) 和的positive key \(k_+\) 相似并与所有其它的key(被认为是 \(q\) 的negative keys)不相似时对比损失的值很小。用点积来表示相似性,本文采用了对比损失的一种形式,InfoNCE,如下

其中 \(\tau\) 是是温度超参,结果对一个正样本和 \(K\) 个负样本求和。从直觉上来说,这个损失是一个 \((K+1)\) 类基于softmax分类器的log损失,这个分类器试图将 \(q\) 分为 \(k_+\) 类。对比损失还有其它形式,比如margin-based loss和NCE loss的一些变种。

对比损失作为无监督的目标函数用来训练encoder network来表示queries和keys。一般来说,query representation是 \(q=f_q(x_q)\) 其中 \(f_q\) 是encoder网络,\(x_q\) 是一个query样本(同样,\(k=f_k(x_k)\))。初始化取决于具体的代理任务,输入 \(x_q\) 和 \(x_k\) 可以是图像、patches、或包含一组patches的context。网络 \(f_q\) 和 \(f_k\) 可以是相同的、部分共享的、或不同的。

Momentum Contrast

Dictionary as a queue

本文方法的核心是维护一个数据样本的队列作为字典,这使得我们可以重用前面mini-batch中的encoded keys,队列的引入将字典大小与batch大小进行了解耦,我们的字典可以比普通的batch size大得多,并且可以灵活独立的作为一个超参来设置。

字典中的样本被逐步替换掉,当前mini-batch进入队列,而队列中最老的mini-batch被删除。字典总是表示所有数据的一个采样子集,而维护字典的额外计算是可控的。此外删除最早的mini-batch也是有好处的,因为它的encoded keys是最老的,与最新的编码最不一致。

Momentum update

使用队列可以使字典更大,但也使得通过反向传播更新key encoder变得困难(梯度应该传播到队列中的所有样本)。一个天真的解决方法是忽略key encoder \(f_k\) 的梯度直接拷贝query encoder \(f_q\),但这种解决方案在实验中得到的结果很差,作者推测这是由于快速变化的encoder减少了key representation的一致性导致的。因此提出了动量更新来解决这个问题。

我们将 \(f_k\) 的参数表示为 \(\theta_k\),\(f_q\) 的参数表示为 \(\theta_q\),然后通过下式更新 \(\theta_k\)

其中 \(m\in[0,1)\) 是动量系数,只有参数 \(\theta_q\) 通过反向传播更新,式(2)中的动量更新使得 \(\theta_k\) 比 \(\theta_q\) 的更新更平滑。因此,尽管队列中的keys是通过不同的encoder编码的(不同的mini-batch),这些encoder之间的差异非常小。后续实验表明,一个更大的动量(例如 \(m=0.999\))比更小的动量(例如 \(m=0.9\))表现得更好,表明一个缓慢更新的key encoder是使用队列的核心。

Relations to previous mechanisms

MoCo是使用对比损失的一种机制,作者将其与其它两种机制进行了对比,如图2所示,它们在字典大小和一致性上表现出不同的属性。

图2(a)是通过反向传播进行end-to-end更新的一种机制,它使用当前mini-batch中的样本作为字典,因此key的编码是一致的(通过相同的一组encoder参数)。但是字典的大小和mini-batch的大小耦合,受限于GPU的内存。同时也受到大mini-batch优化问题的挑战。

另外一种机制是采用memory bank,如图2(b)所示。memory back包含了数据集中所有样本的representation,每个mini-batch的字典是从memory bank中随机采样得到的,且没有反向传播,因此字典的size可以很大。但是,memory bank中一个样本的表示在它最后一次被看到时就更新了,因此采样的keys是过去一个epoch中不同step的encoder得到的,从而缺乏了一致性。

Pretext Task

对比学习可以使用不同的代理任务,由于本文的重点不是设计一个新的代理任务,本文遵循instance discrimination任务使用了一个简单的代理任务。如果一个query和一个key来源于同一张图像,则将它们视为positive pair,否则视为negative pair。我们对同一张图像进行两次随机数据增强得到一个postive pair,queries和keys分别由各自的encoder \(f_q\) 和 \(f_k\) 进行编码,encoder可以是任何的卷积网络。

MoCo的伪代码如下所示,对当前的mini-batch,我们对postive pair分别进行编码得到queries和对应的keys,负样本来源于队列。

Shuffling BN

编码器 \(f_q\) 和 \(f_k\) 中都使用了BN,作者在实验中发现使用BN会阻止模型学习好的表示,模型似乎“欺骗”了代理任务并很容易地找到了一种low-loss的解决方法。这可能是样本之间的batch内的通信(BN引起的)泄露了信息。

作者通过shuffle BN来解决这个问题。具体训练是在多个GPU上进行的,每个GPU独立的对样本执行BN。对于key encoder \(f_k\),在将当前mini-batch分配到不同GPU之前打乱样本顺序(并在编码之后还原顺序),query encoder \(f_q\) 不进行打乱顺序。这保证了用于计算query和对应的positve key的统计信息来自于不同的子集,有效解决了欺骗问题。

代码解析

下面是官方实现,基本上和文章中的伪代码一致,没有什么难以理解的地方。其中encoder_k的参数更新顺序和伪代码不一样,伪代码是f_q和f_k分别forward,然后f_q的loss反向传播,更新f_q的参数,最后f_k进行动量更新。而代码中是f_q先forward,然后f_k更新参数,接着f_k进行forward,最后再根据反向传播更新f_q。

另外,这里包含了MoCo v2的代码,主要的区别就是v2借鉴SimCLR的做法,在encoder的avg pooling层后多加了一层projection layer,即一个MLP。

# Copyright (c) Meta Platforms, Inc. and affiliates.# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.import torch
import torch.nn as nnclass MoCo(nn.Module):"""Build a MoCo model with: a query encoder, a key encoder, and a queuehttps://arxiv.org/abs/1911.05722"""def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False):"""dim: feature dimension (default: 128)K: queue size; number of negative keys (default: 65536)m: moco momentum of updating key encoder (default: 0.999)T: softmax temperature (default: 0.07)"""super(MoCo, self).__init__()self.K = Kself.m = mself.T = T# create the encoders# num_classes is the output fc dimensionself.encoder_q = base_encoder(num_classes=dim)self.encoder_k = base_encoder(num_classes=dim)if mlp:  # hack: brute-force replacementdim_mlp = self.encoder_q.fc.weight.shape[1]  # 2048self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):param_k.data.copy_(param_q.data)  # initializeparam_k.requires_grad = False  # not update by gradient# create the queueself.register_buffer("queue", torch.randn(dim, K))# 将张量或缓冲区注册为 nn.Module 的一部分,但不会被视为模型的可学习参数。# 通常情况下,这用于存储模型中的固定参数或状态,例如均值、方差等,这些参数在训练过程中不会被更新。self.queue = nn.functional.normalize(self.queue, dim=0)self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))@torch.no_grad()def _momentum_update_key_encoder(self):"""Momentum update of the key encoder"""for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)@torch.no_grad()def _dequeue_and_enqueue(self, keys):# gather keys before updating queuekeys = concat_all_gather(keys)batch_size = keys.shape[0]ptr = int(self.queue_ptr)assert self.K % batch_size == 0  # for simplicity# replace the keys at ptr (dequeue and enqueue)self.queue[:, ptr: ptr + batch_size] = keys.Tptr = (ptr + batch_size) % self.K  # move pointerself.queue_ptr[0] = ptr@torch.no_grad()def _batch_shuffle_ddp(self, x):"""Batch shuffle, for making use of BatchNorm.*** Only support DistributedDataParallel (DDP) model. ***"""# gather from all gpusbatch_size_this = x.shape[0]x_gather = concat_all_gather(x)batch_size_all = x_gather.shape[0]num_gpus = batch_size_all // batch_size_this# random shuffle indexidx_shuffle = torch.randperm(batch_size_all).cuda()# 打乱索引顺序,比如batch_size_all=8, idx_shuffle=[1,3,5,2,0,4,7,6]# broadcast to all gpustorch.distributed.broadcast(idx_shuffle, src=0)# 将生成的随机索引序列从GPU 0(src=0)广播到所有其他的GPU设备上,以便在分布式训练时,每个GPU都能够获得相同的随机索引序列,以保持数据的同步性。# index for restoringidx_unshuffle = torch.argsort(idx_shuffle)  # tensor([4, 0, 3, 1, 5, 2, 7, 6])# shuffled index for this gpugpu_idx = torch.distributed.get_rank()idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]return x_gather[idx_this], idx_unshuffle@torch.no_grad()def _batch_unshuffle_ddp(self, x, idx_unshuffle):"""Undo batch shuffle.*** Only support DistributedDataParallel (DDP) model. ***"""# gather from all gpusbatch_size_this = x.shape[0]x_gather = concat_all_gather(x)batch_size_all = x_gather.shape[0]num_gpus = batch_size_all // batch_size_this# restored index for this gpugpu_idx = torch.distributed.get_rank()idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]return x_gather[idx_this]def forward(self, im_q, im_k):"""Input:im_q: a batch of query imagesim_k: a batch of key imagesOutput:logits, targets"""# compute query featuresq = self.encoder_q(im_q)  # queries: NxCq = nn.functional.normalize(q, dim=1)# compute key featureswith torch.no_grad():  # no gradient to keysself._momentum_update_key_encoder()  # update the key encoder# 和论文中伪代码的顺序不一样,论文中encoder_k是先forward后更新参数,这里是先更新参数后forward# shuffle for making use of BNim_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)k = self.encoder_k(im_k)  # keys: NxCk = nn.functional.normalize(k, dim=1)# undo shufflek = self._batch_unshuffle_ddp(k, idx_unshuffle)# compute logits# Einstein sum is more intuitive# positive logits: Nx1l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)# negative logits: NxKl_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])# logits: Nx(1+K)logits = torch.cat([l_pos, l_neg], dim=1)# apply temperaturelogits /= self.T# labels: positive key indicatorslabels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()# dequeue and enqueueself._dequeue_and_enqueue(k)return logits, labels# utils
@torch.no_grad()
def concat_all_gather(tensor):"""Performs all_gather operation on the provided tensors.*** Warning ***: torch.distributed.all_gather has no gradient."""tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]torch.distributed.all_gather(tensors_gather, tensor, async_op=False)output = torch.cat(tensors_gather, dim=0)return output

实验结果

无监督模型的常见评估方法是将训练好的encoder的权重freeze,后面接一层全连接层和softmax,然后在目标数据上只训练全连接层,最后在测试集上评估得到的模型效果。下面是MoCo和之前的无监督模型的结果对比,可以看到MoCo取得了最优的结果。

无监督模型的另一个作用是当做下游任务的预训练权重。在VOC目标检测任务上和监督预训练的对比如下,可以看到MoCo比监督预训练权重的效果更好。

 

下面是在COCO数据的目标检测任务和实例分割任务上与随机初始化权重、监督预训练权重的结果对比

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/307865.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

.net 6 集成NLog

.net 6 webapi项目集成NLog 上代码step 1 添加nugetstep 2 添加支持step 3 添加配置文件 结束 上代码 step 1 添加nuget 添加nuget 包 Roc step 2 添加支持 修改program.cs var builder WebApplication.CreateBuilder(args); // 添加NLog日志支持 builder.AddRocNLog();ste…

mysql8.0高可用集群架构实战

MySQL :: MySQL Shell 8.0 :: 7 MySQL InnoDB Cluster 基本概述 InnoDB Cluster是MySQL官方实现高可用读写分离的架构方案,其中包含以下组件 MySQL Group Replication,简称MGR,是MySQL的主从同步高可用方案,包括数据同步及角色选举Mysql Shell 是InnoDB Cluster的管理工具,用…

单链表详解(无哨兵位),实现增删改查

1.顺序表对比单链表的缺点 中间或头部插入时,需要移动数据再插入,如果数据庞大会导致效率降低每次增容就需要申请空间,而且需要拷贝数据,释放旧空间增容造成浪费,因为一般都是以2倍增容 2.链表的基础知识 链表也是线…

【随笔】Git 高级篇 -- 快速定位分支 ^|~(二十三)

💌 所属专栏:【Git】 😀 作  者:我是夜阑的狗🐶 🚀 个人简介:一个正在努力学技术的CV工程师,专注基础和实战分享 ,欢迎咨询! 💖 欢迎大…

[面向对象] 单例模式与工厂模式

单例模式 是一种创建模式,保证一个类只有一个实例,且提供访问实例的全局节点。 工厂模式 面向对象其中的三大原则: 单一职责:一个类只有一个职责(Game类负责什么时候创建英雄机,而不需要知道创建英雄机要…

SQL 注入之 Windows/Docker 环境 SQLi-labs 靶场搭建!

在安全测试领域,SQL注入是一种常见的攻击方式,通过应用程序的输入执行恶意SQL查询,从而绕过认证和授权,可以窃取、篡改或破坏数据库中的数据。作为安全测试学习者,如果你要练习SQL注入,在未授权情况下直接去…

DedeCMS 未授权远程命令执行漏洞分析

dedecms介绍 DedeCMS是国内专业的PHP网站内容管理系统-织梦内容管理系统,采用XML名字空间风格核心模板:模板全部使用文件形式保存,对用户设计模板、网站升级转移均提供很大的便利,健壮的模板标签为站长DIY自己的网站提供了强有力…

【AIGC】本地部署通义千问 1.5 (PyTorch)

今天想分享一下 Qwen 1.5 官方用例的二次封装( huggingface 说明页也有提供源码),其实没有太多的技术含量。主要是想记录一下如何从零开始在不使用第三方工具的前提下,以纯代码的方式本地部署一套大模型,相信这对于技术…

Springboot+Vue项目-基于Java+MySQL的旅游网站系统(附源码+演示视频+LW)

大家好!我是程序猿老A,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:Java毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 Python毕业设计 &…

JDX图片识别工具:扫描PDF文件批量改名的应用场景

业务场景 案例1:人事部人员扫描了几百份简历,保存为PDF格式,但是名字是一串数字,不好分辨,有没有能识别里面姓名内容并自动重命名文件的方法 案例2:财务人员有的批量导出很多PDF电子发票,导出…

MATLAB 构建协方差矩阵,解算特征值和特征向量(63)

MATLAB 局部点云构建协方差矩阵,解算特征值和特征向量(63) 一、算法介绍二、算法实现1.代码2.结果一、算法介绍 对于某片有待分析的点云,我们希望构建协方差矩阵,计算特征值和特征向量,这是很多算法必要的分析方法,这里提供完整的计算代码(验证正确) !!! 特别需要注意…

swiftui macOS实现加载本地html文件

import SwiftUI import WebKitstruct ContentView: View {var body: some View {VStack {Text("测试")HTMLView(htmlFileName: "localfile") // 假设你的本地 HTML 文件名为 index.html.frame(minWidth: 100, minHeight: 100) // 设置 HTMLView 的最小尺寸…

分享 GoLand 2024.1 激活的方案,支持JetBrains全家桶

大家好,欢迎来到金榜探云手! GoLand 公司简介 JetBrains 是一家专注于开发工具的软件公司,总部位于捷克。他们以提供强大的集成开发环境(IDE)而闻名,如 IntelliJ IDEA、PyCharm、和 GoLand等。这些工具被广…

华为海思校园招聘-芯片-数字 IC 方向 题目分享——第二套

华为海思校园招聘-芯片-数字 IC 方向 题目分享(共9套,有答案和解析,答案非官方,未仔细校正,仅供参考)——第二套(共九套,每套四十个选择题) 部分题目分享,完整版获取&am…

SQL注入sqli_libs靶场第一题

第一题 联合查询 1)思路: 有回显值 1.判断有无注入点 2.猜解列名数量 3.判断回显点 4.利用注入点进行信息收集 爆用户权限,爆库,爆版本号 爆表,爆列,爆账号密码 2)解题过程&#xff1…

Qt发布可执行exe

文章目录 方法步骤1:在release下编译程序。步骤2:windous R 输入CMD步骤3:切换到你的release目录下。步骤4:打开cmd运行windeployqt xx.exe Qt错误:windeployqt 不是内部或外部命令,也不是可运行的程序 或批…

购物车实现

目录 1.购物车常见的实现方式 2.购物车数据结构介绍 3.实例分析 1.controller层 2.service层 1.购物车常见的实现方式 方式一:存储到数据库 性能存在瓶颈方式二:前端本地存储 localstorage在浏览器中存储 key/value 对,没有过期时间。s…

自动化测试-web

一、自动化测试理论: UI: User Interface (用户接口-用户界面),主要包括:app 和webUI自动化测试:使用工具或代码执行用例的过程什么样的项目适合做自动化: 需要回归测试项目(甲方自…

CH254X 8051芯片手册介绍

1 8051CPU 8051是一种8位元的单芯片微控制器,属于MCS-51单芯片的一种,由英特尔(Intel)公司于1981年制造。Intel公司将MCS51的核心技术授权给了很多其它公司,所以有很多公司在做以8051为核心的单片机,如Atmel、飞利浦、深联华等公…

34-5 CSRF漏洞 - CSRF分类

环境准备:构建完善的安全渗透测试环境:推荐工具、资源和下载链接_渗透测试靶机下载-CSDN博客 1)GET 类型 传参: 参数连接在URL后面 POC构造及执行流程: 构造URL,诱导受害者访问点击利用利用标签进行攻击: 构造虚假URL,在链接上添加payload抓包获取数据包,通过CSRF POC…