模块出处
[link] [code] [ACM MM 23] Frequency Perception Network for Camouflaged Object Detection
模块名称
Frequency-Perception Module (FPM)
模块作用
获取频域信息,更好识别伪装对象
模块结构
模块代码
import torch
import torch.nn as nn
import torch.nn.functional as Fclass FirstOctaveConv(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1,groups=1, bias=False):super(FirstOctaveConv, self).__init__()self.stride = stridekernel_size = kernel_size[0]self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)self.h2l = torch.nn.Conv2d(in_channels, int(alpha * in_channels),kernel_size, 1, padding, dilation, groups, bias)self.h2h = torch.nn.Conv2d(in_channels, in_channels - int(alpha * in_channels),kernel_size, 1, padding, dilation, groups, bias)def forward(self, x):if self.stride ==2:x = self.h2g_pool(x)X_h2l = self.h2g_pool(x)X_h = xX_h = self.h2h(X_h)X_l = self.h2l(X_h2l)return X_h, X_lclass OctaveConv(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1,groups=1, bias=False):super(OctaveConv, self).__init__()kernel_size = kernel_size[0]self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)self.upsample = torch.nn.Upsample(scale_factor=2, mode='nearest')self.stride = strideself.l2l = torch.nn.Conv2d(int(alpha * in_channels), int(alpha * out_channels),kernel_size, 1, padding, dilation, groups, bias)self.l2h = torch.nn.Conv2d(int(alpha * in_channels), out_channels - int(alpha * out_channels),kernel_size, 1, padding, dilation, groups, bias)self.h2l = torch.nn.Conv2d(in_channels - int(alpha * in_channels), int(alpha * out_channels),kernel_size, 1, padding, dilation, groups, bias)self.h2h = torch.nn.Conv2d(in_channels - int(alpha * in_channels),out_channels - int(alpha * out_channels),kernel_size, 1, padding, dilation, groups, bias)def forward(self, x):X_h, X_l = xif self.stride == 2:X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)X_h2l = self.h2g_pool(X_h)X_h2h = self.h2h(X_h)X_l2h = self.l2h(X_l)X_l2l = self.l2l(X_l)X_h2l = self.h2l(X_h2l)X_l2h = F.interpolate(X_l2h, (int(X_h2h.size()[2]),int(X_h2h.size()[3])), mode='bilinear')X_h = X_l2h + X_h2hX_l = X_h2l + X_l2lreturn X_h, X_lclass LastOctaveConv(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1,groups=1, bias=False):super(LastOctaveConv, self).__init__()self.stride = stridekernel_size = kernel_size[0]self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)self.l2h = torch.nn.Conv2d(int(alpha * out_channels), out_channels,kernel_size, 1, padding, dilation, groups, bias)self.h2h = torch.nn.Conv2d(out_channels - int(alpha * out_channels),out_channels,kernel_size, 1, padding, dilation, groups, bias)self.upsample = torch.nn.Upsample(scale_factor=2, mode='nearest')def forward(self, x):X_h, X_l = xif self.stride == 2:X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)X_h2h = self.h2h(X_h) X_l2h = self.l2h(X_l) X_l2h = F.interpolate(X_l2h, (int(X_h2h.size()[2]), int(X_h2h.size()[3])), mode='bilinear')X_h = X_h2h + X_l2h return X_hclass FPM(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=(3, 3)):super(FPM, self).__init__()self.fir = FirstOctaveConv(in_channels, out_channels, kernel_size)self.mid1 = OctaveConv(in_channels, in_channels, kernel_size)self.mid2 = OctaveConv(in_channels, out_channels, kernel_size)self.lst = LastOctaveConv(in_channels, out_channels, kernel_size)def forward(self, x):x_h, x_l = self.fir(x) x_h_1, x_l_1 = self.mid1((x_h, x_l)) x_h_2, x_l_2 = self.mid1((x_h_1, x_l_1)) x_h_5, x_l_5 = self.mid2((x_h_2, x_l_2)) x_ret = self.lst((x_h_5, x_l_5))return x_retif __name__ == '__main__':x = torch.randn([3, 256, 16, 16])fpm = FPM(in_channels=256, out_channels=64)out = fpm(x)print(out.shape) # 3, 64, 16, 16
原文表述
具体来说,我们采用八度卷积以端到端的方式自动感知高频和低频信息,从而实现伪装物体检测的在线学习。八度卷积可以有效避免DCT 引起的块状效应,并利用GPU的计算速度优势。此外,它可以轻松插入任意网络。