论文解读:Masked Generative Distillation

文章汇总

话题

知识蒸馏

创新点

带掩盖的生成式蒸馏

方法旨在通过学生的遮罩特征来生成老师的特征(通过遮盖学生部分的特征来生成老师的特征),来帮助学生获得更好的表现

输入:老师:T,学生:S,输入:x,标签:y,超参数:\alpha,\lambda

1:使用S得到输入x的特征fea^S和输出\hat{y}

2:使用T得到输入x的特征fea^T

3:计算模型的原始损失:L_{original}(\hat{y},y)

4:计算公式5中的蒸馏损失:

其中:

G表示投影层,包括两个卷积层:W_{l1}W_{l2},一个激活层ReLU。在本文中,我们采用1×1卷积层为适配层f_{align}, 3×3为投影层W_{l1}W_{l2}的卷积层。

5:使用L_{all}=L_{original}+\alpha*L_{dis}更新S

输出:S

想改进的地方

摘要

知识蒸馏已成功地应用于各种任务中。目前的蒸馏算法通常通过模仿老师的输出来提高学生的表现。本文表明,教师还可以通过引导学生特征恢复来提高学生的代表性。从这个角度来看,我们提出了掩膜生成蒸馏(mask Generative Distillation, MGD),它很简单:我们掩膜学生特征的随机像素,并通过一个简单的块强制其生成教师的完整特征。MGD是一种真正通用的基于特征的蒸馏方法,可用于各种任务,包括图像分类、目标检测、语义分割和实例分割。我们用大量的数据集对不同的模型进行了实验,结果表明所有的学生都取得了很好的进步。值得注意的是,我们将ResNet-18的ImageNet顶级1精度从69.90%提高到71.69%,将ResNet-50主干的RetinaNet从37.4提高到41.0 Boundingbox mAP,基于ResNet-50的SOLO从33.1提高到36.2 Mask mAP,以及基于ResNet-18的DeepLabV3从73.20提高到76.02 mIoU。我们的代码可在https://github.com/yzd-v/MGD上获得。

关键词:知识蒸馏,图像分类,目标检测,语义分割,实例分割

介绍

深度卷积神经网络(cnn)已广泛应用于各种计算机视觉任务中。一般来说,较大的模型具有较好的性能,但推理速度较低,难以在有限的源下部署。为了克服这一问题,知识蒸馏被提出[18]。按蒸馏类型可分为两种。第一种是专门为不同的任务而设计的,例如用于分类的基于logit的蒸馏[18,40]和用于检测的基于head的蒸馏[10,39]。第二种是基于特征的蒸馏[28,17,4]。由于各种网络之间只有头部或投影仪后的特征是不同的,从理论上讲,基于特征的蒸馏方法可以可用于各种任务。然而,为特定任务设计的蒸馏方法通常不适用于其他任务。例如,OFD[17]和KR[4]对探测器的改进有限。FKD[37]和FGD[35]是专门为探测器设计的,由于缺乏颈部,无法用于其他任务。

以往基于特征的提炼方法,由于教师的特征具有更强的表征能力,通常会让学生尽可能地模仿教师的输出。然而,我们认为没有必要直接模仿老师来提高学生特征的表征能力。用于蒸馏的特征通常是通过深度网络获得的高阶语义信息。特征像素已经在一定程度上包含了相邻像素的信息。因此,如果我们可以使用部分像素通过一个简单的块来还原教师的全部特征,那么这些使用的像素的代表性也可以得到提高。从这个角度出发,我们提出了一种简单有效的基于特征的蒸馏方法——掩膜生成蒸馏(mask Generative Distillation, MGD)。

如图2所示,我们首先对学生特征的随机像素进行遮罩,然后通过一个简单的块将遮罩后的特征生成教师的完整特征。由于每次迭代都使用随机像素,因此在整个训练过程中都会使用所有像素,这意味着特征将更加鲁棒,并且其表示能力将得到提高。在我们的方法中,老师只是引导学生还原特征,并不要求学生直接模仿

为了验证我们的假设,即在不直接模仿教师的情况下,掩蔽特征生成可以提高学生的特征表征能力,我们从学生和教师的颈部对特征注意力进行了可视化。如图1所示,学生和教师的特征有很大的不同。

与教师相比,学生特征的背景有更高的反应。教师的mAP也显著高于学生,为41.0比37.4。采用最先进的蒸馏方法FGD蒸馏后[35],迫使学生用心模仿老师的特征,学生的特征与老师的特征更加相似,mAP很大提高到40.7。而经过MGD培训后,学生与教师的特征仍有显著差异,但学生对背景的反应却大大降低。令我们惊讶的是,该学生的成绩超过了FGD,甚至达到了与老师相同的mAP。这也说明用MGD训练可以提高学生特征的表征能力。此外,我们还在图像分类和密集预测任务上做了大量的实验。结果表明,MGD对图像分类、目标检测、语义分割和实例分割等任务都有较大的改善。MGD还可以与其他基于logit或基于head的蒸馏方法相结合,以获得更大的性能收益。综上所述,本文的贡献有:

1. 我们提出了一种新的基于特征的知识提炼方法,使学生利用被掩盖的特征来生成教师的特征,而不是直接模仿教师的特征。

2. 本文提出了一种新的基于特征的蒸馏方法——掩膜生成蒸馏,该方法简单易用,只需要两个超参数。

3. 我们通过在不同数据集上的大量实验验证了我们的方法在各种模型上的有效性。对于图像分类和密集预测任务,学生在MGD的帮助下都取得了显著的进步。

相关工作

面向分类的知识蒸馏

知识蒸馏最早是由Hinton等人[18]提出的,其中学生受到来自教师最后一个线性层的标签和软标签的监督。

然而,除了logit之外,更多的蒸馏方法是基于特征映射的。FitNet[28]从中间层提取语义信息。AT[36]总结了跨渠道维度的价值,并将注意力知识转移给学生。OFD[17]提出了余量ReLU,并设计了一个测量蒸馏距离的新函数。CRD[30]利用对比学习将知识传递给学生。最近,KR[4]建立了一个审查机制,并利用多层次信息进行蒸馏。SRRL[33]将表示学习和分类解耦,利用老师的分类器来训练学生的倒数第二层特征。WSLD[40]从偏方差权衡的角度提出了加权的蒸馏软标签。

面向语义分割的知识蒸馏

Liu等人[23]提出了成对和整体蒸馏,在学生和教师的输出之间执行成对和高阶一致性。他等人[16]将教师网络的输出重新解释为一个重新表示的潜在域,并从教师网络中捕获长期依赖关系。CWD[29]最小化了概率图之间的Kullback-Leibler (KL)散度,该散度是通过对每个通道的激活图进行归一化计算得到的。

方法

对于不同的任务,模型的体系结构差别很大。此外,大多数蒸馏方法都是为特定任务而设计的。然而,基于特征的精馏可以同时应用于分类和密集预测。特征蒸馏的基本方法可表述为:

式中,F^TF^S分别表示教师和学生的特征,f_{align}是将学生的特征F^S与教师的特征F^T

对齐的适应层。C, H, W表示特征映射的形状。

这种方法有助于学生直接模仿老师的特征。然而,我们提出了掩蔽生成蒸馏(MGD),其目的是迫使学生产生教师的特征,而不是模仿它,给学生带来了分类和密集预测方面的显着改善。MGD的架构如图2所示,我们将在本节中专门介绍它。

带掩盖特征的生成

对于基于cnn的模型,更深层的特征具有更大的接受域和更好的原始输入图像表征。换句话说,特性的图像素在一定程度上已经包含了相邻像素的信息。

因此,我们可以使用部分像素来恢复完整的特征映射。我们的方法旨在通过学生的遮罩特征来生成老师的特征(通过遮盖学生部分的特征来生成老师的特征),这样可以帮助学生获得更好的表现。
我们用T^l \in R^{C \times H \times W},S^l \in R^{C \times H \times W}(l=1,...,L)表示分别为教师和学生的第l个特征图。首先我们设置第l个随机掩码来覆盖学生的第l个特征,可以表示为:

其中R_{i,j}^l为(0,1)中的随机数,i,j分别为特征图的横坐标和纵坐标。λ是表示掩码比的超参数。第
l个特征映射被第l个随机掩码覆盖。

然后我们使用相应的掩码覆盖学生的特征图,并尝试用左边的像素生成教师的特征图,可以表示为:

G表示投影层,包括两个卷积层:W_{l1}W_{l2},一个激活层ReLU。在本文中,我们采用1×1卷积层为适配层f_{align}, 3×3为投影层W_{l1}W_{l2}的卷积层。

根据该方法,我们设计了MGD的蒸馏损失L_{dis}:

其中L为蒸馏层数和,C、H、W为特征映射的形状。S和T分别表示学生和教师的特征。

总体损失

利用提出的MGD蒸馏损失L_{dis},我们用总损失训练所有模型如下:

其中L_{original}为所有任务中模型的原始损失,α为平衡损失的超参数。

MGD是一种简单有效的蒸馏方法,可方便地应用于各种任务。算法1总结了我们的方法的过程。

方法过程汇总

算法1:带掩盖的生成式蒸馏

输入:老师:T,学生:S,输入:x,标签:y,超参数:\alpha,\lambda

1:使用S得到输入x的特征fea^S和输出\hat{y}

2:使用T得到输入x的特征fea^T

3:计算模型的原始损失:L_{original}(\hat{y},y)

4:计算公式5中的蒸馏损失:L_{dis}(fea^S, fea^T)

5:使用L_{all}=L_{original}+\alpha*L_{dis}更新S

输出:S

主要实验

MGD是一种基于特征的蒸馏,可以很容易地应用于各种任务的不同模型。在本文中,我们对分类、目标检测、语义分割和实例分割等任务进行了实验。我们针对不同的任务使用不同的模型和数据集进行实验,所有模型都通过MGD获得了出色的改进。

分类

数据集

对于分类任务,我们在包含1000个对象类别的ImageNet[11]上评估我们的知识蒸馏方法。我们用120万张图片进行训练,用5万张图片进行测试,完成所有的分类实验。我们用准确性来评价模型。

实现细节

对于分类任务,我们计算来自主干的最后一个特征映射的蒸馏损失。有关消融的研究见5.5节。MGD使用超参数α来平衡方程6中的蒸馏损失。另一个超参数λ用于调整公式2中的屏蔽比。所有分类实验均采用超参数{α = 7 × 10^(−5),λ = 0.5}。我们使用SGD优化器训练所有模型100个epoch,其中动量为0.9,权重衰减为0.0001。我们初始化学习率为0.1,并每30次衰减一次。此设置基于8个gpu。实验采用基于Pytorch[26]的MMClassification[6]和MMRazor[7]进行。

分类结果

我们用两种常用的蒸馏设置进行实验,包括均相蒸馏和非均相蒸馏。

第一个蒸馏设置是从ResNet-34[15]到ResNet-18,另一个设置是从ResNet-50到MobileNet[19]。如表1所示,我们比较了各种知识蒸馏方法[18、36、17、25、30、4、40、33],包括基于特征的方法、基于逻辑的方法和结合的方法。使用我们的方法,学生ResNet-18和MobileNet的Top-1准确率分别提高了1.68和3.14。此外,如上所述,MGD只需要计算特征图上的蒸馏损失,并且可以与其他基于逻辑的图像分类方法相结合。因此,我们尝试在WSLD中加入基于logit的蒸馏损失[40]。这样,两位同学的Top-1准确率分别达到了71.80和72.59,分别提高了0.22和0.24。

表1。不同蒸馏方法在ImageNet数据集上的结果。T和S分别表示老师和学生。

目标检测和实例分割

数据集

我们在COCO2017数据集[22]上进行实验,该数据集包含80个对象类别。我们使用120k的训练图像进行训练,5k的val图像进行测试。用平均精度对模型的性能进行了评价。

表2。不同蒸馏方法在COCO上的目标检测结果。

实现细节

我们从颈部计算所有特征映射的蒸馏损失。我们采用超参数{α = 2 × 10^(−5),λ = 0.65}对所有的单阶段模型,{α = 5 × 10^(−7),λ = 0.45}对所有的两阶段模型。我们使用SGD优化器训练所有模型,其中动量为0.9,权重衰减为0.0001。除非特别说明,否则我们训练模型为24个epoch。我们使用继承策略[20,35],用教师的颈部和头部参数初始化学生,在头部结构相同的情况下训练学生。实验采用MMDetection进行[2]。

目标检测和实例分割结果

对于目标检测,我们在三种不同类型的检测器上进行了实验,包括两级检测器(Faster RCNN[27])、基于锚点的一级检测器(RetinaNet[21])和无锚点的一级检测器(RepPoints[34])。

我们将MGD与最近三种最先进的检测器蒸馏方法进行比较[37,29,35]。以分割为例,我们在SOLO[32]和Mask RCNN[14]两个模型上进行了实验。如表2和表3所示,我们的方法在两种目标检测和实例分割方面都优于其他最先进的方法。学生在MGD的帮助下获得了显著的AP改善,例如基于ResNet-50的retanet和SOLO在COCO数据集上分别获得了3.6个Boundingbox mAP和3.1个Mask mAP的改善。

表3。不同蒸馏方法对实例分割的结果。MS的意思是多尺度训练。这里的AP指掩模AP。

参考资料

论文下载(ECCV 2区 2022)

https://arxiv.org/pdf/2205.01529.pdf

📎Masked Generative Distillation.pdf

代码地址

GitHub - yzd-v/MGD: Masked Generative Distillation (ECCV 2022)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/259412.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

水质监测站工作原理!

TH-LSZ06】水质监测站的工作原理基于现代化学和生物学技术,主要通过化学分析和生物检测两种方法来检测水中有害物质。化学分析技术包括酸碱度、氧化还原电位、重金属离子、有机物、氮和磷等,而生物检测技术则主要关注病毒、细菌、真菌等微生物。 在水质…

Mac M1芯片编译openjdk报错问题解决

使用命令: sudo sh configure --with-target-bits64 用mac m1芯片编译openjdk一直报错: configure: The tested number of bits in the target (64) differs from the number of bits expected to be found in the target (32) configure: error: Cann…

【前端工程化面试题目】webpack 的热更新原理

可以在顺便学习一下 vite 的热更新原理,请参考这篇文章。 首先有几个知识点需要明确 热更新是针对开发过程中的开发服务器的,也就是 webpack-dev-serverwebpack 的热更新不需要额外的插件,但是需要在配置文件中 devServer 属性中配置 hot&a…

云原生之容器编排实践-基于CentOS7搭建三个节点的Kubernetes集群

背景 前面采用 minikube 作为 Kubernetes 环境来体验学习 Kubernetes 基本概念与操作,这样避免了初学者在裸金属主机上搭建 Kubernetes 集群的复杂度,但是随着产品功能的逐渐完善,我们需要过渡到生产环境中的 K8S 集群模式;而在实…

【代码移植】UNIX/Linux/POSIX代码程序移植到Windows系统平台技术汇总与经验分享

​ 图片来源 UNIX (Linux) to Windows代码移植技术路线 MinGW MinGW/MinGW-W64是用Windows原生系统API实现的,在Windows上运行的GCC编译工具链,可以编译出Windows原生应用程序。 MinGW编译工具链的生态位和微软官方的MSVC类似。 优点 MinGW编译出…

计算机网络-数据通信基础

目录 前言 一、数据通信基本概念 二、数据通信相关知识1 总结 前言 正在学习计算机网络体系,把每日所学的知识梳理出来,既能够当作读书笔记,又能分享出来和大家一同学习讨论。 一、数据通信基本概念 基本概念:信源、信道、信宿&…

第二篇【传奇开心果系列】Python的文本和语音相互转换库技术点案例示例:深度解读pyttsx3支持多种语音引擎

传奇开心果短博文系列 系列短博文目录Python的文本和语音相互转换库技术点案例示例系列 短博文目录前言一、三种语音引擎支持介绍和示例代码二、SAPI5引擎适用场景介绍和示例代码三、nsss引擎适用场景介绍和示例代码四、eSpeak适用场景介绍和示例代码五、归纳总结 系列短博文目…

红队学习笔记Day6 --->干货分享

今天看到这样的一个东西,好好好,有点恐怖😓😓😱😱😱😱 我就想网安是不是也有这种东西? 我来试试 icmp,RDP,arp,dhcp,nat&a…

Eclipse - 查看工程或者文件的磁盘路径

Eclipse - 查看工程或者文件的磁盘路径 1. Help -> Eclipse Marketplace -> Find: Explorer -> Eclipse Explorer 4.1.0 -> Install2. right-click -> Open in ExplorerReferences 1. Help -> Eclipse Marketplace -> Find: Explorer -> Eclipse Explo…

【Spring MVC篇】参数的传递及json数据传参

个人主页:兜里有颗棉花糖 欢迎 点赞👍 收藏✨ 留言✉ 加关注💓本文由 兜里有颗棉花糖 原创 收录于专栏【Spring MVC】 本专栏旨在分享学习Spring MVC的一点学习心得,欢迎大家在评论区交流讨论💌 目录 一、普通参数的传…

【Java多线程】线程中几个常见的属性以及状态

目录 Thread的几个常见属性 1、Id 2、Name名称 3、State状态 4、Priority优先级 5、Daemon后台线程 6、Alive存活 Thread的几个常见属性 1、Id ID 是线程的唯一标识,由系统自动分配,不同线程不会重复。 2、Name名称 用户定义的名称。该名称在各种…

【开源】SpringBoot框架开发服装店库存管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 角色管理模块2.3 服装档案模块2.4 服装入库模块2.5 服装出库模块 三、系统设计3.1 用例设计3.2 数据库设计3.2.1 角色表3.2.2 服装档案表3.2.3 服装入库表3.2.4 服装出库表 四、系统展示五、核心代码5.…

160基于matlab的负熵和峭度信号的盲分离

基于matlab的负熵和峭度信号的盲分离。基于峭度的FastICA算法的收敛速度要快,迭代次数比基于负熵的FastICA算法少四倍以上。SMSE随信噪比增大两种判据下的FastICA算法都逐渐变小,但是基于峭度的算法的SMSE更小,因此基于峭度的FastICA算法性能…

UVa1359/LA3491 Hills

题目链接 本题是2005年ICPC亚洲区域赛杭州欧赛区的H题 题意 平面上有 n(n≤500)条线段,其中每条线段的端点都不会在其他线段上。你的任务是数一数有多少个“没有被其他线段切到”的三角形(即小山)。如下图所示&#x…

VTK Python PyQt 监听键盘 控制 Actor 移动 变色

KeyPressInteractorStyle 在vtk 中有时我们需要监听 键盘或鼠标做一些事; 1. 创建 Actor; Sphere vtk.vtkSphereSource() Sphere.SetRadius(10)mapper vtk.vtkPolyDataMapper() mapper.SetInputConnection(Sphere.GetOutputPort()) actor vtk.vtkAc…

winprop二次开发

winprop二次开发 前言工具1——整合多个天线结果用途代码实现 工具2——wallman辅助工具需求代码实现功能实现参数输入实验 前言 工作需求,对该软件进行简单地二次开发,都是一些挺简单的代码,单纯是为了上传之后将其从本地删除 工具1——整…

嵌入式day24

开课复工啦~ 冲冲冲! 文件IO: read函数和write函数: 📚 write 接口有三个参数: fd:文件描述符buf:要写入的缓冲区的起始地址(如果是字符串,那么就是字符串的起始地址&…

算法学习系列(三十五):贪心(杂)

目录 引言一、合并果子(Huffman树)二、排队打水(排序不等式)三、货仓选址(绝对值不等式)四、耍杂技的牛(推公式) 引言 上一篇文章也说过了这个贪心问题没有一个规范的套路和模板&am…

第三十三回 镇三山大闹青州道 霹雳火夜走瓦砾场-python分割字符串

黄信和刘知寨押解宋江和花荣向青州走,碰到了燕顺等三人来劫囚车,黄信逃走了,刘知寨被抓住,被花荣一刀杀了。 黄信把情况报给青州知府,派来了青州兵马秦统制,人称霹雳火的秦明。秦明与花荣打,花…

UnityShader——06UnityShader介绍

UnityShader介绍 UnityShader的基础ShaderLab UnityShader属性块介绍 Properties {//和public变量一样会显示在Unity的inspector面板上//_MainTex为变量名,在属性里的变量一般会加下划线,来区分参数变量和临时变量//Texture为变量命名//2D为类型&…