项目快过:知识蒸馏 | 目标检测 |FGD | Focal and Global Knowledge Distillation for Detectors

公开时间:2022年3月9号
项目地址:https://github.com/yzd-v/FGD
论文地址:https://arxiv.org/pdf/2111.11837
在这里插入图片描述
知识蒸馏已成功地应用于图像分类。然而,目标检测要复杂得多,大多数知识蒸馏方法都失败了。本文指出,在目标检测中,教师和学生的特征在不同的领域有很大的差异,特别是在前景和背景上。如果我们平均地提取它们,特征图之间的不均匀差异将会对蒸馏产生负面影响。因此,我们提出了聚焦蒸馏和全局蒸馏(FGD)。聚焦蒸馏将前景和背景分开,迫使学生专注于教师的临界像素和通道全局蒸馏重建了不同像素之间的关系,并将其从教师转移到学生身上,补偿了聚焦蒸馏中全局信息的缺失

由于我们的方法只需要计算特征图上的损失,因此FGD可以应用于各种检测器。 我们在不同骨架的各种检测器上进行了实验,结果表明,该学生检测器取得了良好的mAP改进,为2~3个点。

1、核心观点

1.1 区分FG与BG的蒸馏差异

教师和学生的特征在不同的领域有很大的差异,特别是在前景和背景上。
在这里插入图片描述
作者通过实验表明,对fg与bg不做取得的蒸馏,还不如单独对fg或bg进行蒸馏。这里fb是是bbox对应的特征图区域,bg是背景对应的特征图区域。
在这里插入图片描述

1.2 具体实现

1、对backbone的输出进行Global Distillation操作,使教师模型与学生模型的输出解决
2、在neck的输出上,根据bbox区分前景与背景,分别进行蒸馏,然后loss加权
在这里插入图片描述
总体loss实现:在这里插入图片描述

Focal Distillation
对前景与背景分别设定loss权重进行蒸馏,同时附加spatial和chanel的attention蒸馏结构,使学生模型模拟教师模型
在这里插入图片描述

Global Distillation
在这里插入图片描述

1.3 有益效果

基于表3可以发现FGD的蒸馏方式,对于各类任务(目标检测、实力分割、关键点检测)均有提升效果,基本能提升3个点左右。
在这里插入图片描述
与其他目标检测蒸馏策略相比,FGD方法能提升02~0.7个点的精度,同时蒸馏后的S模型精度比T模型要略高。
在这里插入图片描述

蒸馏后的特征图变化
在这里插入图片描述

2、消融实验

2.1 focal and global distillation

基于这里的对比可以发现,仅蒸馏backbone或对neck进行有区别蒸馏,均能取得良好效果。但
两个一起蒸馏能额外取得0.2个点的提升。
在这里插入图片描述

2.2 Spatial attention 与 Channel attention

这里的蒸馏效果差异如下,同样是结合2个维度蒸馏,能提升0.1~0.2个点。同时表明spatial蒸馏更有效
在这里插入图片描述

2.3 GcBlock作用

通常蒸馏是直接对比教师模型与学生模型的差异,而本文中提到基于GcBlock对二者进行高维度映射后在计算loss。这里可以发现GcBlock是蒸馏有效的基本条件,否则涨点幅度较小。
在这里插入图片描述

2.4 蒸馏温度

在neck中进行蒸馏时,考虑了教师输出的spatial与chanel的分布特征,具体如下所示
在这里插入图片描述
这里通过消融实验,表明蒸馏温度对效果的影响。0.5或0.8为最佳值,这表明需要对教师的输出进行加热,体现出显著的分布特征,学生模型才能学习好。
在这里插入图片描述

3、实现代码

基于mmdet进行实现

3.1 配置文件

https://github.com/yzd-v/FGD/blob/master/configs/distillers/fgd/fgd_faster_rcnn_r101_fpn_2x_distill_faster_rcnn_r50_fpn_2x_coco.py

基于对配置文件的分析,博主认为只有一个针对neck层的FeatureLoss

_base_ = ['../../_base_/datasets/coco_detection.py','../../_base_/schedules/schedule_2x.py', '../../_base_/default_runtime.py'
]
# model settings
find_unused_parameters=True
temp=0.5
alpha_fgd=0.00005
beta_fgd=0.000025
gamma_fgd=0.00005
lambda_fgd=0.0000005
distiller = dict(type='DetectionDistiller',teacher_pretrained = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_2x_coco/faster_rcnn_r101_fpn_2x_coco_bbox_mAP-0.398_20200504_210455-1d2dac9c.pth',init_student = True,distill_cfg = [ dict(student_module = 'neck.fpn_convs.3.conv',teacher_module = 'neck.fpn_convs.3.conv',output_hook = True,methods=[dict(type='FeatureLoss',name='loss_fgd_fpn_3',student_channels = 256,teacher_channels = 256,temp = temp,alpha_fgd=alpha_fgd,beta_fgd=beta_fgd,gamma_fgd=gamma_fgd,lambda_fgd=lambda_fgd,)]),dict(student_module = 'neck.fpn_convs.2.conv',teacher_module = 'neck.fpn_convs.2.conv',output_hook = True,methods=[dict(type='FeatureLoss',name='loss_fgd_fpn_2',student_channels = 256,teacher_channels = 256,temp = temp,alpha_fgd=alpha_fgd,beta_fgd=beta_fgd,gamma_fgd=gamma_fgd,lambda_fgd=lambda_fgd,)]),dict(student_module = 'neck.fpn_convs.1.conv',teacher_module = 'neck.fpn_convs.1.conv',output_hook = True,methods=[dict(type='FeatureLoss',name='loss_fgd_fpn_1',student_channels = 256,teacher_channels = 256,temp = temp,alpha_fgd=alpha_fgd,beta_fgd=beta_fgd,gamma_fgd=gamma_fgd,lambda_fgd=lambda_fgd,)]),dict(student_module = 'neck.fpn_convs.0.conv',teacher_module = 'neck.fpn_convs.0.conv',output_hook = True,methods=[dict(type='FeatureLoss',name='loss_fgd_fpn_0',student_channels = 256,teacher_channels = 256,temp = temp,alpha_fgd=alpha_fgd,beta_fgd=beta_fgd,gamma_fgd=gamma_fgd,lambda_fgd=lambda_fgd,)]),])student_cfg = 'configs/faster_rcnn/faster_rcnn_r50_fpn_2x_coco.py'
teacher_cfg = 'configs/faster_rcnn/faster_rcnn_r101_fpn_2x_coco.py'
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
data = dict(samples_per_gpu=2,workers_per_gpu=2,)

3.2 forward_train函数

detection_distiller.py 中的forward_train函数定义了模型蒸馏的前向推理流程,可以发现就是针对配置文件中的layer计算FeatureLoss

    def forward_train(self, img, img_metas, **kwargs):"""Args:img (Tensor): Input images of shape (N, C, H, W).Typically these should be mean centered and std scaled.img_metas (list[dict]): A List of image info dict where each dicthas: 'img_shape', 'scale_factor', 'flip', and may also contain'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.For details on the values of these keys see:class:`mmdet.datasets.pipelines.Collect`.Returns:dict[str, Tensor]: A dictionary of loss components(student's losses and distiller's losses)."""with torch.no_grad():self.teacher.eval()feat = self.teacher.extract_feat(img)student_loss = self.student.forward_train(img, img_metas, **kwargs)buffer_dict = dict(self.named_buffers())for item_loc in self.distill_cfg:student_module = 'student_' + item_loc.student_module.replace('.','_')teacher_module = 'teacher_' + item_loc.teacher_module.replace('.','_')student_feat = buffer_dict[student_module]teacher_feat = buffer_dict[teacher_module]for item_loss in item_loc.methods:loss_name = item_loss.namestudent_loss[loss_name] = self.distill_losses[loss_name](student_feat,teacher_feat,kwargs['gt_bboxes'], img_metas)return student_loss

3.3 Focal Global Distillation 代码

代码地址:
https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py

这里的代码实现比较复杂,博主认为是将Focal Distillation部分+Global 部分的GcBlock针对同一layer对象进行实现,并没有像论文示意图中作用于不同的layer

import torch.nn as nn
import torch.nn.functional as F
import torch
from mmcv.cnn import constant_init, kaiming_init
from ..builder import DISTILL_LOSSES@DISTILL_LOSSES.register_module()
class FeatureLoss(nn.Module):"""PyTorch version of `Focal and Global Knowledge Distillation for Detectors`Args:student_channels(int): Number of channels in the student's feature map.teacher_channels(int): Number of channels in the teacher's feature map. temp (float, optional): Temperature coefficient. Defaults to 0.5.name (str): the loss name of the layeralpha_fgd (float, optional): Weight of fg_loss. Defaults to 0.001beta_fgd (float, optional): Weight of bg_loss. Defaults to 0.0005gamma_fgd (float, optional): Weight of mask_loss. Defaults to 0.001lambda_fgd (float, optional): Weight of relation_loss. Defaults to 0.000005"""def __init__(self,student_channels,teacher_channels,name,temp=0.5,alpha_fgd=0.001,beta_fgd=0.0005,gamma_fgd=0.001,lambda_fgd=0.000005,):super(FeatureLoss, self).__init__()self.temp = tempself.alpha_fgd = alpha_fgdself.beta_fgd = beta_fgdself.gamma_fgd = gamma_fgdself.lambda_fgd = lambda_fgdif student_channels != teacher_channels:self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0)else:self.align = Noneself.conv_mask_s = nn.Conv2d(teacher_channels, 1, kernel_size=1)self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1)self.channel_add_conv_s = nn.Sequential(nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),nn.LayerNorm([teacher_channels//2, 1, 1]),nn.ReLU(inplace=True),  # yapf: disablenn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))self.channel_add_conv_t = nn.Sequential(nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),nn.LayerNorm([teacher_channels//2, 1, 1]),nn.ReLU(inplace=True),  # yapf: disablenn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1)) #FcBlockself.reset_parameters()def forward(self,preds_S,preds_T,gt_bboxes,img_metas):"""Forward function.Args:preds_S(Tensor): Bs*C*H*W, student's feature mappreds_T(Tensor): Bs*C*H*W, teacher's feature mapgt_bboxes(tuple): Bs*[nt*4], pixel decimal: (tl_x, tl_y, br_x, br_y)img_metas (list[dict]): Meta information of each image, e.g.,image size, scaling factor, etc."""assert preds_S.shape[-2:] == preds_T.shape[-2:],'the output dim of teacher and student differ'if self.align is not None:preds_S = self.align(preds_S)N,C,H,W = preds_S.shapeS_attention_t, C_attention_t = self.get_attention(preds_T, self.temp)S_attention_s, C_attention_s = self.get_attention(preds_S, self.temp)Mask_fg = torch.zeros_like(S_attention_t)Mask_bg = torch.ones_like(S_attention_t)wmin,wmax,hmin,hmax = [],[],[],[]for i in range(N):new_boxxes = torch.ones_like(gt_bboxes[i])new_boxxes[:, 0] = gt_bboxes[i][:, 0]/img_metas[i]['img_shape'][1]*Wnew_boxxes[:, 2] = gt_bboxes[i][:, 2]/img_metas[i]['img_shape'][1]*Wnew_boxxes[:, 1] = gt_bboxes[i][:, 1]/img_metas[i]['img_shape'][0]*Hnew_boxxes[:, 3] = gt_bboxes[i][:, 3]/img_metas[i]['img_shape'][0]*Hwmin.append(torch.floor(new_boxxes[:, 0]).int())wmax.append(torch.ceil(new_boxxes[:, 2]).int())hmin.append(torch.floor(new_boxxes[:, 1]).int())hmax.append(torch.ceil(new_boxxes[:, 3]).int())area = 1.0/(hmax[i].view(1,-1)+1-hmin[i].view(1,-1))/(wmax[i].view(1,-1)+1-wmin[i].view(1,-1))for j in range(len(gt_bboxes[i])):Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1] = \torch.maximum(Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1], area[0][j])Mask_bg[i] = torch.where(Mask_fg[i]>0, 0, 1)if torch.sum(Mask_bg[i]):Mask_bg[i] /= torch.sum(Mask_bg[i])fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg, C_attention_s, C_attention_t, S_attention_s, S_attention_t)mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, S_attention_s, S_attention_t)rela_loss = self.get_rela_loss(preds_S, preds_T)loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \+ self.gamma_fgd * mask_loss + self.lambda_fgd * rela_lossreturn lossdef get_attention(self, preds, temp):""" preds: Bs*C*W*H """N, C, H, W= preds.shapevalue = torch.abs(preds)# Bs*W*Hfea_map = value.mean(axis=1, keepdim=True)S_attention = (H * W * F.softmax((fea_map/temp).view(N,-1), dim=1)).view(N, H, W)# Bs*Cchannel_map = value.mean(axis=2,keepdim=False).mean(axis=2,keepdim=False)C_attention = C * F.softmax(channel_map/temp, dim=1)return S_attention, C_attentiondef get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, S_t):loss_mse = nn.MSELoss(reduction='sum')Mask_fg = Mask_fg.unsqueeze(dim=1)Mask_bg = Mask_bg.unsqueeze(dim=1)C_t = C_t.unsqueeze(dim=-1)C_t = C_t.unsqueeze(dim=-1)S_t = S_t.unsqueeze(dim=1)fea_t= torch.mul(preds_T, torch.sqrt(S_t))fea_t = torch.mul(fea_t, torch.sqrt(C_t))fg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_fg))bg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_bg))fea_s = torch.mul(preds_S, torch.sqrt(S_t))fea_s = torch.mul(fea_s, torch.sqrt(C_t))fg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_fg))bg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_bg))fg_loss = loss_mse(fg_fea_s, fg_fea_t)/len(Mask_fg)bg_loss = loss_mse(bg_fea_s, bg_fea_t)/len(Mask_bg)return fg_loss, bg_lossdef get_mask_loss(self, C_s, C_t, S_s, S_t):mask_loss = torch.sum(torch.abs((C_s-C_t)))/len(C_s) + torch.sum(torch.abs((S_s-S_t)))/len(S_s)return mask_lossdef spatial_pool(self, x, in_type):batch, channel, width, height = x.size()input_x = x# [N, C, H * W]input_x = input_x.view(batch, channel, height * width)# [N, 1, C, H * W]input_x = input_x.unsqueeze(1)# [N, 1, H, W]if in_type == 0:context_mask = self.conv_mask_s(x)else:context_mask = self.conv_mask_t(x)# [N, 1, H * W]context_mask = context_mask.view(batch, 1, height * width)# [N, 1, H * W]context_mask = F.softmax(context_mask, dim=2)# [N, 1, H * W, 1]context_mask = context_mask.unsqueeze(-1)# [N, 1, C, 1]context = torch.matmul(input_x, context_mask)# [N, C, 1, 1]context = context.view(batch, channel, 1, 1)return contextdef get_rela_loss(self, preds_S, preds_T):loss_mse = nn.MSELoss(reduction='sum')context_s = self.spatial_pool(preds_S, 0)context_t = self.spatial_pool(preds_T, 1)out_s = preds_Sout_t = preds_Tchannel_add_s = self.channel_add_conv_s(context_s)out_s = out_s + channel_add_schannel_add_t = self.channel_add_conv_t(context_t)out_t = out_t + channel_add_trela_loss = loss_mse(out_s, out_t)/len(out_s)return rela_lossdef last_zero_init(self, m):if isinstance(m, nn.Sequential):constant_init(m[-1], val=0)else:constant_init(m, val=0)def reset_parameters(self):kaiming_init(self.conv_mask_s, mode='fan_in')kaiming_init(self.conv_mask_t, mode='fan_in')self.conv_mask_s.inited = Trueself.conv_mask_t.inited = Trueself.last_zero_init(self.channel_add_conv_s)self.last_zero_init(self.channel_add_conv_t)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/482032.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】匿名管道通信场景——进程池

🔥 个人主页:大耳朵土土垚 🔥 所属专栏:Linux系统编程 这里将会不定期更新有关Linux的内容,欢迎大家点赞,收藏,评论🥳🥳🎉🎉🎉 文章目…

Sybase数据恢复—Sybase数据库无法启动,Sybase Central连接报错的处理案例

Sybase数据库数据恢复环境: Sybase数据库版本:SQL Anywhere 8.0。 Sybase数据库故障&分析: Sybase数据库无法启动。 错误提示: 使用Sybase Central连接报错。 数据库数据恢复工程师经过检测,发现Sybase数据库出现…

分布式FastDFS存储的同步方式

目录 一:FatsDFS的结构图 二:FatsDFS文件同步 前言: 1:同步日志所在目录 2:binlog格式 3:同步规则 4:binlog同步过程 1 :获取组内的其他Storage信息 tracker_report_thread_e…

【绘图】数据可视化(python)

对于数据绝对值差异较大(数据离散) 1. 对数坐标直方图(Histogram with Log Scale) import pandas as pd import matplotlib.pyplot as plt import numpy as np# 示例数据 data {count: [10, 20, 55, 90, 15, 5, 45, 80, 1000, …

使用Dify与BGE-M3搭建RAG(检索增强生成)应用-改进一,使用工作流代替Agnet

文章目录 前言Agent vs 工作流编写工作流 前言 在上一篇中,我们实现了一个基本的基于Dify的RAG的示范。 使用Dify与BGE-M3搭建RAG(检索增强生成)应用 这个效果确实很差。 我们一起来看看,该怎么改进。 今天我们就尝试一下&…

【Linux课程学习】:文件第二弹---理解一切皆文件,缓存区

🎁个人主页:我们的五年 🔍系列专栏:Linux课程学习 🌷追光的人,终会万丈光芒 🎉欢迎大家点赞👍评论📝收藏⭐文章 Linux学习笔记: https://blog.csdn.net/d…

【iOS】《Effective Objective-C 2.0》阅读笔记(一)

文章目录 前言了解OC语言的起源在类的头文件中尽量少引入其他头文件多用字面量语法,少用与之等价的方法字面量数值字面量数组字面量字典 多用类型常量,少用#define预处理指令用枚举法表示状态、选项、状态码 总结 前言 最近开始阅读一些iOS开发的相关书籍…

猫狗分类调试过程

一,下载名称为archive数据集 下载方式:机房共享文件夹 二、打开CatDogProject项目 配置环境:选择你所建的环境 三、调试运行 1,报错一:Traceback (most recent call last): File "G:/AI_Project/CatDogPro…

探索Python WebSocket新境界:picows库揭秘

文章目录 探索Python WebSocket新境界:picows库揭秘第一部分:背景介绍第二部分:picows库概述第三部分:安装picows库第四部分:简单库函数使用方法第五部分:场景应用第六部分:常见Bug及解决方案第…

零基础学安全--Burp Suite(4)proxy模块以及漏洞测试理论

目录 学习连接 一些思路 proxy模块 所在位置 功能简介 使用例子 抓包有一个很重要的点,就是我们可以看到一些在浏览器中看不到的传参点,传参点越多就意味着攻击面越广 学习连接 声明! 学习视频来自B站up主 **泷羽sec** 有兴趣的师傅可…

30 基于51单片机的手环设计仿真

目录 一、主要功能 二、硬件资源 三、程序编程 四、实现现象 一、主要功能 基于STC89C52单片机,DHT11温湿度采集温湿度,滑动变阻器连接ADC0832数模转换器模拟水位传感器检测水位,通过LCD1602显示信息,然后在程序里设置好是否…

十一、快速入门go语言之接口和反射

文章目录 接口:one: 接口基础:two: 接口类型断言和空接口:star2: 空接口实现存储不同数据类型的切片/数组:star2: 复制切片到空接口切片:star2: 类型断言 反射 📅 2024年5月9日 📦 使用版本为1.21.5 接口 十、Java类的封装和继承、多态 - 七点半的菜市…

QT6学习第六天 初识QML

QT6学习第六天 创建Qt Quick UI项目使用Qt Quick DesignerQML 语法基础导入语句 import对象 object 和属性 property布局注释表达式和属性绑定QML 编码约定 设置应用程序图标 创建Qt Quick UI项目 如果你有只测试QML相关内容快速显示界面的需求,这时可以创建Qt Qui…

图解RabbitMQ七种工作模式生产者消费者模型的补充

文章目录 1.消费者模型2.生产者-消费者模型注意事项2.1资源释放顺序问题2.2消费者的声明问题2.3虚拟机和用户的权限问题 3.七种工作模式3.1简单模式3.2工作模式3.3发布/订阅模式3.4路由模式3.5通配符模式3.6RPC通信3.7发布确认 1.消费者模型 之前学习的这个消息队列的快速上手…

C-操作符

操作符种类 在C语言中,操作符有以下几种: 算术操作符 移位操作符 位操作符 逻辑操作符 条件操作符 逗号表达式 下标引用,函数调用 拓展:整型提升 我们介绍常用的几个 算术操作符 (加)&#xff…

使用 Spring Boot 和 GraalVM 的原生镜像

🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:历代文学,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编程,高并发设计&#xf…

基于Java Springboot宠物医院微信小程序

一、作品包含 源码数据库设计文档万字PPT全套环境和工具资源部署教程 二、项目技术 前端技术:Html、Css、Js、Vue、Element-ui 数据库:MySQL 后端技术:Java、Spring Boot、MyBatis 三、运行环境 开发工具:IDEA/eclipse 微信…

Tree搜索二叉树、map和set_数据结构

数据结构专栏 如烟花般绚烂却又稍纵即逝的个人主页 本章讲述数据结构中搜索二叉树与HashMap的学习,感谢大家的支持!欢迎大家踊跃评论,感谢大佬们的支持! 目录 搜索二叉树的概念二叉树搜索模拟实现搜索二叉树查找搜索二叉树插入搜索二叉树删除…

C#使用ExcelDataReader读取Xlsx文件为DataTable对象

创建控制台项目 在NuGet中安装ExcelDataReader.DataSet 3.7.0 创建一个xlsx文件 测试代码 读取xlsx文件内容,为一个DataTable对象。 读取xlsx时,xlsx文件不能被其他软件打开,否则会报“进程无法访问此文件”的错。 using ExcelDataRead…

【JavaEE初阶】应是天仙狂醉,乱把白云揉碎 - (重点)线程

本篇博客给大家带来的是线程的知识点, 由于内容较多分几天来写. 🐎文章专栏: JavaEE初阶 🚀若有问题 评论区见 ⭐欢迎大家点赞 评论 收藏 分享 ❤❤❤ 如果你不知道分享给谁,那就分享给薯条. 你们的支持是我不断创作的动力 . 1. 认识线程 1.1 概念 )1 …