论文题目:Learning Spatial Fusion for Single-Shot Object Detection
论文地址:Paper - ASFF
官方源码:GitHub - GOATmessi8/ASFF
简 介
多尺度特征融合是解决多尺度目标检测问题的关键技术,其中 FPN(特征金字塔网络)通过自顶向下的特征融合机制,将高层语义特征与低层细节特征进行简单结合,提升了检测效果。然而,FPN 的融合方法由于未充分考虑不同层级的特征图之间存在表征不一致性,可能引入冲突信息,限制了融合效果的进一步提升。ASFF(自适应空间特征融合)通过动态加权机制,在不同尺度和空间位置上自适应地融合特征,有效抑制了层级特征间的冲突信息,提高了多尺度目标检测的效果。这种优化方式体现了特征融合理论中对层次差异和空间适应性的关注。
核 心 代 码
(1)融合相邻层与非相邻层:
import torch
import torch.nn as nn
from ultralytics.utils.tal import dist2bbox, make_anchors
import math
import torch.nn.functional as F__all__ = ['ASFF_Detect']def autopad(k, p=None, d=1): # kernel, padding, dilation"""Pad to 'same' shape outputs."""if d > 1:k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-sizeif p is None:p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-padreturn pclass Conv(nn.Module):"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""default_act = nn.SiLU() # default activationdef __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):"""Initialize Conv layer with given arguments including activation."""super().__init__()self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)self.bn = nn.BatchNorm2d(c2)self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()def forward(self, x):"""Apply convolution, batch normalization and activation to input tensor."""return self.act(self.bn(self.conv(x)))def forward_fuse(self, x):"""Perform transposed convolution of 2D data."""return self.act(self.conv(x))class DFL(nn.Module):"""Integral module of Distribution Focal Loss (DFL).Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391"""def __init__(self, c1=16):"""Initialize a convolutional layer with a given number of input channels."""super().__init__()self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)x = torch.arange(c1, dtype=torch.float)self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))self.c1 = c1def forward(self, x):"""Applies a transformer layer on input tensor 'x' and returns a tensor."""b, c, a = x.shape # batch, channels, anchorsreturn self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)# return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)class ASFFV5(nn.Module):def __init__(self, level, ch, multiplier=1, rfb=False, vis=False, act_cfg=True):"""ASFF version for YoloV5 .different than YoloV3multiplier should be 1, 0.5 which means, the channel of ASFF can be512, 256, 128 -> multiplier=1256, 128, 64 -> multiplier=0.5For even smaller, you need change code manually."""super(ASFFV5, self).__init__()self.level = levelself.dim = [int(ch[2] * multiplier), int(ch[1] * multiplier),int(ch[0] * multiplier)]# print(self.dim)self.inter_dim = self.dim[self.level]if level == 0:self.stride_level_1 = Conv(int(ch[1] * multiplier), self.inter_dim, 3, 2)self.stride_level_2 = Conv(int(ch[0] * multiplier), self.inter_dim, 3, 2)self.expand = Conv(self.inter_dim, int(ch[2] * multiplier), 3, 1)elif level == 1:self.compress_level_0 = Conv(int(ch[2] * multiplier), self.inter_dim, 1, 1)self.stride_level_2 = Conv(int(ch[0] * multiplier), self.inter_dim, 3, 2)self.expand = Conv(self.inter_dim, int(ch[1] * multiplier), 3, 1)elif level == 2:self.compress_level_0 = Conv(int(ch[2] * multiplier), self.inter_dim, 1, 1)self.compress_level_1 = Conv(int(ch[1] * multiplier), self.inter_dim, 1, 1)self.expand = Conv(self.inter_dim, int(ch[0] * multiplier), 3, 1)# when adding rfb, we use half number of channels to save memorycompress_c = 8 if rfb else 16self.weight_level_0 = Conv(self.inter_dim, compress_c, 1, 1)self.weight_level_1 = Conv(self.inter_dim, compress_c, 1, 1)self.weight_level_2 = Conv(self.inter_dim, compress_c, 1, 1)self.weight_levels = Conv(compress_c * 3, 3, 1, 1)self.vis = visdef forward(self, x): # l,m,s"""# 128, 256, 512512, 256, 128from small -> large"""x_level_0 = x[2] # lx_level_1 = x[1] # mx_level_2 = x[0] # s# print('x_level_0: ', x_level_0.shape)# print('x_level_1: ', x_level_1.shape)# print('x_level_2: ', x_level_2.shape)if self.level == 0:level_0_resized = x_level_0level_1_resized = self.stride_level_1(x_level_1)level_2_downsampled_inter = F.max_pool2d(x_level_2, 3, stride=2, padding=1)level_2_resized = self.stride_level_2(level_2_downsampled_inter)elif self.level == 1:level_0_compressed = self.compress_level_0(x_level_0)level_0_resized = F.interpolate(level_0_compressed, scale_factor=2, mode='nearest')level_1_resized = x_level_1level_2_resized = self.stride_level_2(x_level_2)elif self.level == 2:level_0_compressed = self.compress_level_0(x_level_0)level_0_resized = F.interpolate(level_0_compressed, scale_factor=4, mode='nearest')x_level_1_compressed = self.compress_level_1(x_level_1)level_1_resized = F.interpolate(x_level_1_compressed, scale_factor=2, mode='nearest')level_2_resized = x_level_2# print('level: {}, l1_resized: {}, l2_resized: {}'.format(self.level,# level_1_resized.shape, level_2_resized.shape))level_0_weight_v = self.weight_level_0(level_0_resized)level_1_weight_v = self.weight_level_1(level_1_resized)level_2_weight_v = self.weight_level_2(level_2_resized)# print('level_0_weight_v: ', level_0_weight_v.shape)# print('level_1_weight_v: ', level_1_weight_v.shape)# print('level_2_weight_v: ', level_2_weight_v.shape)levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)levels_weight = self.weight_levels(levels_weight_v)levels_weight = F.softmax(levels_weight, dim=1)fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] + \level_1_resized * levels_weight[:, 1:2, :, :] + \level_2_resized * levels_weight[:, 2:, :, :]out = self.expand(fused_out_reduced)if self.vis:return out, levels_weight, fused_out_reduced.sum(dim=1)else:return outclass ASFF_Detect(nn.Module):"""YOLOv8 Detect head for detection models."""dynamic = False # force grid reconstructionexport = False # export modeshape = Noneanchors = torch.empty(0) # initstrides = torch.empty(0) # initdef __init__(self, nc=80, ch=(), multiplier=1, rfb=False):"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""super().__init__()self.nc = nc # number of classesself.nl = len(ch) # number of detection layersself.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)self.no = nc + self.reg_max * 4 # number of outputs per anchorself.stride = torch.zeros(self.nl) # strides computed during buildc2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channelsself.cv2 = nn.ModuleList(nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()self.l0_fusion = ASFFV5(level=0, ch=ch, multiplier=multiplier, rfb=rfb)self.l1_fusion = ASFFV5(level=1, ch=ch, multiplier=multiplier, rfb=rfb)self.l2_fusion = ASFFV5(level=2, ch=ch, multiplier=multiplier, rfb=rfb)def forward(self, x):"""Concatenates and returns predicted bounding boxes and class probabilities."""x1 = self.l0_fusion(x)x2 = self.l1_fusion(x)x3 = self.l2_fusion(x)x = [x3, x2, x1]shape = x[0].shape # BCHWfor i in range(self.nl):x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)if self.training:return xelif self.dynamic or self.shape != shape:self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))self.shape = shapex_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV opsbox = x_cat[:, :self.reg_max * 4]cls = x_cat[:, self.reg_max * 4:]else:box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.stridesif self.export and self.format in ('tflite', 'edgetpu'):# Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:# https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695img_h = shape[2] * self.stride[0]img_w = shape[3] * self.stride[0]img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)dbox /= img_sizey = torch.cat((dbox, cls.sigmoid()), 1)return y if self.export else (y, x)def bias_init(self):"""Initialize Detect() biases, WARNING: requires stride availability."""m = self # self.model[-1] # Detect() module# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequencyfor a, b, s in zip(m.cv2, m.cv3, m.stride): # froma[-1].bias.data[:] = 1.0 # boxb[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)if __name__ == "__main__":image1 = torch.rand(1, 128, 160, 160)image2 = torch.rand(1, 256, 80, 80)image3 = torch.rand(1, 512, 40, 40)image = [image1, image2, image3]channel = (128, 256, 512)model = ASFF_Detect(nc=80, ch=channel)out = model(image)print(out[1].shape)
(2)仅融合相邻层:
import torch
import torch.nn as nn
from ultralytics.utils.tal import dist2bbox, make_anchors
import math
import torch.nn.functional as F__all__ = ['ASFF_Detect']def autopad(k, p=None, d=1): # kernel, padding, dilation"""Pad to 'same' shape outputs."""if d > 1:k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-sizeif p is None:p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-padreturn pclass Conv(nn.Module):"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""default_act = nn.SiLU() # default activationdef __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):"""Initialize Conv layer with given arguments including activation."""super().__init__()self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)self.bn = nn.BatchNorm2d(c2)self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()def forward(self, x):"""Apply convolution, batch normalization and activation to input tensor."""return self.act(self.bn(self.conv(x)))def forward_fuse(self, x):"""Perform transposed convolution of 2D data."""return self.act(self.conv(x))class DFL(nn.Module):"""Integral module of Distribution Focal Loss (DFL).Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391"""def __init__(self, c1=16):"""Initialize a convolutional layer with a given number of input channels."""super().__init__()self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)x = torch.arange(c1, dtype=torch.float)self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))self.c1 = c1def forward(self, x):"""Applies a transformer layer on input tensor 'x' and returns a tensor."""b, c, a = x.shape # batch, channels, anchorsreturn self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)# return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)class ASFFV5(nn.Module):def __init__(self, level, ch, multiplier=1, rfb=False, vis=False, act_cfg=True):super(ASFFV5, self).__init__()self.level = levelself.dim = [int(ch[2] * multiplier), int(ch[1] * multiplier), int(ch[0] * multiplier)]self.inter_dim = self.dim[self.level]if level == 0:self.stride_level_1 = Conv(int(ch[1] * multiplier), self.inter_dim, 3, 2)self.expand = Conv(self.inter_dim, int(ch[2] * multiplier), 3, 1)elif level == 1:self.compress_level_0 = Conv(int(ch[2] * multiplier), self.inter_dim, 1, 1)self.stride_level_2 = Conv(int(ch[0] * multiplier), self.inter_dim, 3, 2)self.expand = Conv(self.inter_dim, int(ch[1] * multiplier), 3, 1)elif level == 2:self.compress_level_1 = Conv(int(ch[1] * multiplier), self.inter_dim, 1, 1)self.expand = Conv(self.inter_dim, int(ch[0] * multiplier), 3, 1)compress_c = 8 if rfb else 16self.weight_level_0 = Conv(self.inter_dim, compress_c, 1, 1)self.weight_level_1 = Conv(self.inter_dim, compress_c, 1, 1)self.weight_level_2 = Conv(self.inter_dim, compress_c, 1, 1)if level == 1:self.weight_levels = Conv(compress_c * 3, 3, 1, 1)else:self.weight_levels = Conv(compress_c * 2, 2, 1, 1)self.vis = visdef forward(self, x): # l,m,sx_level_0 = x[2] # l (1,256,8,8)x_level_1 = x[1] # m (1,128,16,16)x_level_2 = x[0] # s (1,64,32,32)if self.level == 0:level_0_resized = x_level_0level_1_resized = self.stride_level_1(x_level_1)level_0_weight_v = self.weight_level_0(level_0_resized)level_1_weight_v = self.weight_level_1(level_1_resized)levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v), 1)levels_weight = self.weight_levels(levels_weight_v)levels_weight = F.softmax(levels_weight, dim=1)fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] + level_1_resized * levels_weight[:, 1:, :, :]elif self.level == 1:level_0_resized = self.compress_level_0(x_level_0)level_0_resized = F.interpolate(level_0_resized, scale_factor=2, mode='nearest')level_1_resized = x_level_1level_2_resized = self.stride_level_2(x_level_2)level_0_weight_v = self.weight_level_0(level_0_resized)level_1_weight_v = self.weight_level_1(level_1_resized)level_2_weight_v = self.weight_level_2(level_2_resized)levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)levels_weight = self.weight_levels(levels_weight_v)levels_weight = F.softmax(levels_weight, dim=1)fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] + level_1_resized * levels_weight[:, 1:2, :, :] + level_2_resized * levels_weight[:, 2:, :, :]elif self.level == 2:level_1_resized = self.compress_level_1(x_level_1)level_1_resized = F.interpolate(level_1_resized, scale_factor=2, mode='nearest')level_2_resized = x_level_2level_1_weight_v = self.weight_level_1(level_1_resized)level_2_weight_v = self.weight_level_2(level_2_resized)levels_weight_v = torch.cat((level_1_weight_v, level_2_weight_v), 1)levels_weight = self.weight_levels(levels_weight_v)levels_weight = F.softmax(levels_weight, dim=1)fused_out_reduced = level_1_resized * levels_weight[:, 0:1, :, :] + level_2_resized * levels_weight[:, 1:, :, :]out = self.expand(fused_out_reduced)if self.vis:return out, levels_weight, fused_out_reduced.sum(dim=1)else:return outclass ASFF_Detect(nn.Module):"""YOLOv8 Detect head for detection models."""dynamic = False # force grid reconstructionexport = False # export modeshape = Noneanchors = torch.empty(0) # initstrides = torch.empty(0) # initdef __init__(self, nc=80, ch=(), multiplier=1, rfb=False):"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""super().__init__()self.nc = nc # number of classesself.nl = len(ch) # number of detection layersself.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)self.no = nc + self.reg_max * 4 # number of outputs per anchorself.stride = torch.zeros(self.nl) # strides computed during buildc2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channelsself.cv2 = nn.ModuleList(nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()self.l0_fusion = ASFFV5(level=0, ch=ch, multiplier=multiplier, rfb=rfb)self.l1_fusion = ASFFV5(level=1, ch=ch, multiplier=multiplier, rfb=rfb)self.l2_fusion = ASFFV5(level=2, ch=ch, multiplier=multiplier, rfb=rfb)def forward(self, x):"""Concatenates and returns predicted bounding boxes and class probabilities."""x1 = self.l0_fusion(x)x2 = self.l1_fusion(x)x3 = self.l2_fusion(x)x = [x3, x2, x1]shape = x[0].shape # BCHWfor i in range(self.nl):x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)if self.training:return xelif self.dynamic or self.shape != shape:self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))self.shape = shapex_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV opsbox = x_cat[:, :self.reg_max * 4]cls = x_cat[:, self.reg_max * 4:]else:box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.stridesif self.export and self.format in ('tflite', 'edgetpu'):# Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:# https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695img_h = shape[2] * self.stride[0]img_w = shape[3] * self.stride[0]img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)dbox /= img_sizey = torch.cat((dbox, cls.sigmoid()), 1)return y if self.export else (y, x)def bias_init(self):"""Initialize Detect() biases, WARNING: requires stride availability."""m = self # self.model[-1] # Detect() module# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequencyfor a, b, s in zip(m.cv2, m.cv3, m.stride): # froma[-1].bias.data[:] = 1.0 # boxb[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)if __name__ == "__main__":image1 = torch.rand(1, 128, 160, 160)image2 = torch.rand(1, 256, 80, 80)image3 = torch.rand(1, 512, 40, 40)image = [image1, image2, image3]channel = (128, 256, 512)model = ASFF_Detect(nc=80, ch=channel)out = model(image)print(out[1].shape)