公开时间:2022年3月9号
项目地址:https://github.com/yzd-v/FGD
论文地址:https://arxiv.org/pdf/2111.11837
知识蒸馏已成功地应用于图像分类。然而,目标检测要复杂得多,大多数知识蒸馏方法都失败了
。本文指出,在目标检测中,教师和学生的特征在不同的领域有很大的差异,特别是在前景和背景上。如果我们平均地提取它们,特征图之间的不均匀差异将会对蒸馏产生负面影响
。因此,我们提出了聚焦蒸馏和全局蒸馏(FGD)。聚焦蒸馏将前景和背景分开,迫使学生专注于教师的临界像素和通道
。全局蒸馏重建了不同像素之间的关系,并将其从教师转移到学生身上,补偿了聚焦蒸馏中全局信息的缺失
。
由于我们的方法只需要计算特征图上的损失,因此FGD可以应用于各种检测器
。 我们在不同骨架的各种检测器上进行了实验,结果表明,该学生检测器取得了良好的mAP改进,为2~3个点。
1、核心观点
1.1 区分FG与BG的蒸馏差异
教师和学生的特征在不同的领域有很大的差异,特别是在前景和背景上。
作者通过实验表明,对fg与bg不做取得的蒸馏,还不如单独对fg或bg进行蒸馏。这里fb是是bbox对应的特征图区域,bg是背景对应的特征图区域。
1.2 具体实现
1、对backbone的输出进行Global Distillation操作,使教师模型与学生模型的输出解决
2、在neck的输出上,根据bbox区分前景与背景,分别进行蒸馏,然后loss加权
总体loss实现:
Focal Distillation
对前景与背景分别设定loss权重进行蒸馏,同时附加spatial和chanel的attention蒸馏结构,使学生模型模拟教师模型
Global Distillation
1.3 有益效果
基于表3可以发现FGD的蒸馏方式,对于各类任务(目标检测、实力分割、关键点检测)均有提升效果,基本能提升3个点左右。
与其他目标检测蒸馏策略相比,FGD方法能提升02~0.7个点的精度,同时蒸馏后的S模型精度比T模型要略高。
蒸馏后的特征图变化
2、消融实验
2.1 focal and global distillation
基于这里的对比可以发现,仅蒸馏backbone或对neck进行有区别蒸馏,均能取得良好效果。但
两个一起蒸馏能额外取得0.2个点的提升。
2.2 Spatial attention 与 Channel attention
这里的蒸馏效果差异如下,同样是结合2个维度蒸馏,能提升0.1~0.2个点。同时表明spatial蒸馏更有效
2.3 GcBlock作用
通常蒸馏是直接对比教师模型与学生模型的差异,而本文中提到基于GcBlock对二者进行高维度映射后在计算loss。这里可以发现GcBlock是蒸馏有效的基本条件,否则涨点幅度较小。
2.4 蒸馏温度
在neck中进行蒸馏时,考虑了教师输出的spatial与chanel的分布特征,具体如下所示
这里通过消融实验,表明蒸馏温度对效果的影响。0.5或0.8为最佳值,这表明需要对教师的输出进行加热,体现出显著的分布特征,学生模型才能学习好。
3、实现代码
基于mmdet进行实现
3.1 配置文件
https://github.com/yzd-v/FGD/blob/master/configs/distillers/fgd/fgd_faster_rcnn_r101_fpn_2x_distill_faster_rcnn_r50_fpn_2x_coco.py
基于对配置文件的分析,博主认为只有一个针对neck层的FeatureLoss
_base_ = ['../../_base_/datasets/coco_detection.py','../../_base_/schedules/schedule_2x.py', '../../_base_/default_runtime.py'
]
# model settings
find_unused_parameters=True
temp=0.5
alpha_fgd=0.00005
beta_fgd=0.000025
gamma_fgd=0.00005
lambda_fgd=0.0000005
distiller = dict(type='DetectionDistiller',teacher_pretrained = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_2x_coco/faster_rcnn_r101_fpn_2x_coco_bbox_mAP-0.398_20200504_210455-1d2dac9c.pth',init_student = True,distill_cfg = [ dict(student_module = 'neck.fpn_convs.3.conv',teacher_module = 'neck.fpn_convs.3.conv',output_hook = True,methods=[dict(type='FeatureLoss',name='loss_fgd_fpn_3',student_channels = 256,teacher_channels = 256,temp = temp,alpha_fgd=alpha_fgd,beta_fgd=beta_fgd,gamma_fgd=gamma_fgd,lambda_fgd=lambda_fgd,)]),dict(student_module = 'neck.fpn_convs.2.conv',teacher_module = 'neck.fpn_convs.2.conv',output_hook = True,methods=[dict(type='FeatureLoss',name='loss_fgd_fpn_2',student_channels = 256,teacher_channels = 256,temp = temp,alpha_fgd=alpha_fgd,beta_fgd=beta_fgd,gamma_fgd=gamma_fgd,lambda_fgd=lambda_fgd,)]),dict(student_module = 'neck.fpn_convs.1.conv',teacher_module = 'neck.fpn_convs.1.conv',output_hook = True,methods=[dict(type='FeatureLoss',name='loss_fgd_fpn_1',student_channels = 256,teacher_channels = 256,temp = temp,alpha_fgd=alpha_fgd,beta_fgd=beta_fgd,gamma_fgd=gamma_fgd,lambda_fgd=lambda_fgd,)]),dict(student_module = 'neck.fpn_convs.0.conv',teacher_module = 'neck.fpn_convs.0.conv',output_hook = True,methods=[dict(type='FeatureLoss',name='loss_fgd_fpn_0',student_channels = 256,teacher_channels = 256,temp = temp,alpha_fgd=alpha_fgd,beta_fgd=beta_fgd,gamma_fgd=gamma_fgd,lambda_fgd=lambda_fgd,)]),])student_cfg = 'configs/faster_rcnn/faster_rcnn_r50_fpn_2x_coco.py'
teacher_cfg = 'configs/faster_rcnn/faster_rcnn_r101_fpn_2x_coco.py'
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
data = dict(samples_per_gpu=2,workers_per_gpu=2,)
3.2 forward_train函数
detection_distiller.py 中的forward_train函数定义了模型蒸馏的前向推理流程,可以发现就是针对配置文件中的layer计算FeatureLoss
def forward_train(self, img, img_metas, **kwargs):"""Args:img (Tensor): Input images of shape (N, C, H, W).Typically these should be mean centered and std scaled.img_metas (list[dict]): A List of image info dict where each dicthas: 'img_shape', 'scale_factor', 'flip', and may also contain'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.For details on the values of these keys see:class:`mmdet.datasets.pipelines.Collect`.Returns:dict[str, Tensor]: A dictionary of loss components(student's losses and distiller's losses)."""with torch.no_grad():self.teacher.eval()feat = self.teacher.extract_feat(img)student_loss = self.student.forward_train(img, img_metas, **kwargs)buffer_dict = dict(self.named_buffers())for item_loc in self.distill_cfg:student_module = 'student_' + item_loc.student_module.replace('.','_')teacher_module = 'teacher_' + item_loc.teacher_module.replace('.','_')student_feat = buffer_dict[student_module]teacher_feat = buffer_dict[teacher_module]for item_loss in item_loc.methods:loss_name = item_loss.namestudent_loss[loss_name] = self.distill_losses[loss_name](student_feat,teacher_feat,kwargs['gt_bboxes'], img_metas)return student_loss
3.3 Focal Global Distillation 代码
代码地址:
https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py
这里的代码实现比较复杂,博主认为是将Focal Distillation部分+Global 部分的GcBlock针对同一layer对象进行实现,并没有像论文示意图中作用于不同的layer
import torch.nn as nn
import torch.nn.functional as F
import torch
from mmcv.cnn import constant_init, kaiming_init
from ..builder import DISTILL_LOSSES@DISTILL_LOSSES.register_module()
class FeatureLoss(nn.Module):"""PyTorch version of `Focal and Global Knowledge Distillation for Detectors`Args:student_channels(int): Number of channels in the student's feature map.teacher_channels(int): Number of channels in the teacher's feature map. temp (float, optional): Temperature coefficient. Defaults to 0.5.name (str): the loss name of the layeralpha_fgd (float, optional): Weight of fg_loss. Defaults to 0.001beta_fgd (float, optional): Weight of bg_loss. Defaults to 0.0005gamma_fgd (float, optional): Weight of mask_loss. Defaults to 0.001lambda_fgd (float, optional): Weight of relation_loss. Defaults to 0.000005"""def __init__(self,student_channels,teacher_channels,name,temp=0.5,alpha_fgd=0.001,beta_fgd=0.0005,gamma_fgd=0.001,lambda_fgd=0.000005,):super(FeatureLoss, self).__init__()self.temp = tempself.alpha_fgd = alpha_fgdself.beta_fgd = beta_fgdself.gamma_fgd = gamma_fgdself.lambda_fgd = lambda_fgdif student_channels != teacher_channels:self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0)else:self.align = Noneself.conv_mask_s = nn.Conv2d(teacher_channels, 1, kernel_size=1)self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1)self.channel_add_conv_s = nn.Sequential(nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),nn.LayerNorm([teacher_channels//2, 1, 1]),nn.ReLU(inplace=True), # yapf: disablenn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))self.channel_add_conv_t = nn.Sequential(nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),nn.LayerNorm([teacher_channels//2, 1, 1]),nn.ReLU(inplace=True), # yapf: disablenn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1)) #FcBlockself.reset_parameters()def forward(self,preds_S,preds_T,gt_bboxes,img_metas):"""Forward function.Args:preds_S(Tensor): Bs*C*H*W, student's feature mappreds_T(Tensor): Bs*C*H*W, teacher's feature mapgt_bboxes(tuple): Bs*[nt*4], pixel decimal: (tl_x, tl_y, br_x, br_y)img_metas (list[dict]): Meta information of each image, e.g.,image size, scaling factor, etc."""assert preds_S.shape[-2:] == preds_T.shape[-2:],'the output dim of teacher and student differ'if self.align is not None:preds_S = self.align(preds_S)N,C,H,W = preds_S.shapeS_attention_t, C_attention_t = self.get_attention(preds_T, self.temp)S_attention_s, C_attention_s = self.get_attention(preds_S, self.temp)Mask_fg = torch.zeros_like(S_attention_t)Mask_bg = torch.ones_like(S_attention_t)wmin,wmax,hmin,hmax = [],[],[],[]for i in range(N):new_boxxes = torch.ones_like(gt_bboxes[i])new_boxxes[:, 0] = gt_bboxes[i][:, 0]/img_metas[i]['img_shape'][1]*Wnew_boxxes[:, 2] = gt_bboxes[i][:, 2]/img_metas[i]['img_shape'][1]*Wnew_boxxes[:, 1] = gt_bboxes[i][:, 1]/img_metas[i]['img_shape'][0]*Hnew_boxxes[:, 3] = gt_bboxes[i][:, 3]/img_metas[i]['img_shape'][0]*Hwmin.append(torch.floor(new_boxxes[:, 0]).int())wmax.append(torch.ceil(new_boxxes[:, 2]).int())hmin.append(torch.floor(new_boxxes[:, 1]).int())hmax.append(torch.ceil(new_boxxes[:, 3]).int())area = 1.0/(hmax[i].view(1,-1)+1-hmin[i].view(1,-1))/(wmax[i].view(1,-1)+1-wmin[i].view(1,-1))for j in range(len(gt_bboxes[i])):Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1] = \torch.maximum(Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1], area[0][j])Mask_bg[i] = torch.where(Mask_fg[i]>0, 0, 1)if torch.sum(Mask_bg[i]):Mask_bg[i] /= torch.sum(Mask_bg[i])fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg, C_attention_s, C_attention_t, S_attention_s, S_attention_t)mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, S_attention_s, S_attention_t)rela_loss = self.get_rela_loss(preds_S, preds_T)loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \+ self.gamma_fgd * mask_loss + self.lambda_fgd * rela_lossreturn lossdef get_attention(self, preds, temp):""" preds: Bs*C*W*H """N, C, H, W= preds.shapevalue = torch.abs(preds)# Bs*W*Hfea_map = value.mean(axis=1, keepdim=True)S_attention = (H * W * F.softmax((fea_map/temp).view(N,-1), dim=1)).view(N, H, W)# Bs*Cchannel_map = value.mean(axis=2,keepdim=False).mean(axis=2,keepdim=False)C_attention = C * F.softmax(channel_map/temp, dim=1)return S_attention, C_attentiondef get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, S_t):loss_mse = nn.MSELoss(reduction='sum')Mask_fg = Mask_fg.unsqueeze(dim=1)Mask_bg = Mask_bg.unsqueeze(dim=1)C_t = C_t.unsqueeze(dim=-1)C_t = C_t.unsqueeze(dim=-1)S_t = S_t.unsqueeze(dim=1)fea_t= torch.mul(preds_T, torch.sqrt(S_t))fea_t = torch.mul(fea_t, torch.sqrt(C_t))fg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_fg))bg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_bg))fea_s = torch.mul(preds_S, torch.sqrt(S_t))fea_s = torch.mul(fea_s, torch.sqrt(C_t))fg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_fg))bg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_bg))fg_loss = loss_mse(fg_fea_s, fg_fea_t)/len(Mask_fg)bg_loss = loss_mse(bg_fea_s, bg_fea_t)/len(Mask_bg)return fg_loss, bg_lossdef get_mask_loss(self, C_s, C_t, S_s, S_t):mask_loss = torch.sum(torch.abs((C_s-C_t)))/len(C_s) + torch.sum(torch.abs((S_s-S_t)))/len(S_s)return mask_lossdef spatial_pool(self, x, in_type):batch, channel, width, height = x.size()input_x = x# [N, C, H * W]input_x = input_x.view(batch, channel, height * width)# [N, 1, C, H * W]input_x = input_x.unsqueeze(1)# [N, 1, H, W]if in_type == 0:context_mask = self.conv_mask_s(x)else:context_mask = self.conv_mask_t(x)# [N, 1, H * W]context_mask = context_mask.view(batch, 1, height * width)# [N, 1, H * W]context_mask = F.softmax(context_mask, dim=2)# [N, 1, H * W, 1]context_mask = context_mask.unsqueeze(-1)# [N, 1, C, 1]context = torch.matmul(input_x, context_mask)# [N, C, 1, 1]context = context.view(batch, channel, 1, 1)return contextdef get_rela_loss(self, preds_S, preds_T):loss_mse = nn.MSELoss(reduction='sum')context_s = self.spatial_pool(preds_S, 0)context_t = self.spatial_pool(preds_T, 1)out_s = preds_Sout_t = preds_Tchannel_add_s = self.channel_add_conv_s(context_s)out_s = out_s + channel_add_schannel_add_t = self.channel_add_conv_t(context_t)out_t = out_t + channel_add_trela_loss = loss_mse(out_s, out_t)/len(out_s)return rela_lossdef last_zero_init(self, m):if isinstance(m, nn.Sequential):constant_init(m[-1], val=0)else:constant_init(m, val=0)def reset_parameters(self):kaiming_init(self.conv_mask_s, mode='fan_in')kaiming_init(self.conv_mask_t, mode='fan_in')self.conv_mask_s.inited = Trueself.conv_mask_t.inited = Trueself.last_zero_init(self.channel_add_conv_s)self.last_zero_init(self.channel_add_conv_t)