模块出处
[ACM MM 23] [link] [code] Efficient Parallel Multi-Scale Detail and Semantic Encoding
Network for Lightweight Semantic Segmentation
模块名称
Dynamic Weighted Feature Fusion (DWFF)
模块作用
双级特征融合
模块结构
模块思想
我们提出了 DWFF 策略,选择性地关注特征图中信息量最大的部分,以有效地结合浅层和深层特征,提高分割精度。DWFF 可用于在具有细粒度细节的区域中更重地加权浅层特征,在具有较高语义信息的区域中更重地加权深层特征,从而实现更好的特征组合和准确的分割。
模块代码
import torch
import torch.nn as nn
import torch.nn.functional as Fclass DWFF(nn.Module):def __init__(self,in_channels: int,height: int = 2,reduction: int = 8,bias: bool = False) -> None:super(DWFF, self).__init__()self.height = heightd = max(int(in_channels / reduction), 4)self.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv_du = nn.Sequential(nn.Conv2d(in_channels, d, 1, padding=0, bias=bias),nn.BatchNorm2d(d),nn.LeakyReLU(0.2))self.fcs = nn.ModuleList([])for i in range(self.height):self.fcs.append(nn.Conv2d(d, in_channels, kernel_size=1, stride=1, bias=bias))self.softmax = nn.Softmax(dim=1)def forward(self, inp_feats):batch_size = inp_feats[0].shape[0]n_feats = inp_feats[0].shape[1]inp_feats = torch.cat(inp_feats, dim=1)inp_feats = inp_feats.view(batch_size, self.height, n_feats, inp_feats.shape[2], inp_feats.shape[3])feats_U = torch.sum(inp_feats, dim=1)feats_S = self.avg_pool(feats_U)feats_Z = self.conv_du(feats_S)attention_vectors = [fc(feats_Z) for fc in self.fcs]attention_vectors = torch.cat(attention_vectors, dim=1)attention_vectors = attention_vectors.view(batch_size, self.height, n_feats, 1, 1)attention_vectors = self.softmax(attention_vectors)feats_V = torch.sum(inp_feats * attention_vectors, dim=1)return feats_Vif __name__ == '__main__':dwff = DWFF(in_channels=64)x1 = torch.randn([2, 64, 16, 16])x2 = torch.randn([2, 64, 16, 16])out = dwff([x1, x2])print(out.shape) # 2, 64, 16, 16