模块出处
[SPL 25] [link] [code] KAN See In the Dark
模块名称
Kolmogorov-Arnold Network Block (KAN-Block)
模块作用
用于vision的KAN结构
模块结构
模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import mathclass Swish(nn.Module):def forward(self, x):return x * torch.sigmoid(x)class KANLinear(torch.nn.Module):def __init__(self,in_features,out_features,grid_size=5,spline_order=3,scale_noise=0.1,scale_base=1.0,scale_spline=1.0,enable_standalone_scale_spline=True,base_activation=torch.nn.SiLU,grid_eps=0.02,grid_range=[-1, 1],):super(KANLinear, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.grid_size = grid_sizeself.spline_order = spline_orderself.weight = nn.Parameter(torch.Tensor(out_features, in_features))self.bias = nn.Parameter(torch.Tensor(out_features))h = (grid_range[1] - grid_range[0]) / grid_sizegrid = ((torch.arange(-spline_order, grid_size + spline_order + 1) * h+ grid_range[0]).expand(in_features, -1).contiguous())self.register_buffer("grid", grid)self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))self.spline_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order))if enable_standalone_scale_spline:self.spline_scaler = torch.nn.Parameter(torch.Tensor(out_features, in_features))self.scale_noise = scale_noiseself.scale_base = scale_baseself.scale_spline = scale_splineself.enable_standalone_scale_spline = enable_standalone_scale_splineself.base_activation = base_activation()self.grid_eps = grid_epsself.reset_parameters()def reset_parameters(self):torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)with torch.no_grad():noise = ((torch.rand(self.grid_size + 1, self.in_features, self.out_features)- 1 / 2)* self.scale_noise/ self.grid_size)self.spline_weight.data.copy_((self.scale_spline if not self.enable_standalone_scale_spline else 1.0)* self.curve2coeff(self.grid.T[self.spline_order : -self.spline_order],noise,))if self.enable_standalone_scale_spline:# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)def b_splines(self, x: torch.Tensor):"""Compute the B-spline bases for the given input tensor.Args:x (torch.Tensor): Input tensor of shape (batch_size, in_features).Returns:torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order)."""assert x.dim() == 2 and x.size(1) == self.in_featuresgrid: torch.Tensor = (self.grid) # (in_features, grid_size + 2 * spline_order + 1)x = x.unsqueeze(-1)bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)for k in range(1, self.spline_order + 1):bases = ((x - grid[:, : -(k + 1)])/ (grid[:, k:-1] - grid[:, : -(k + 1)])* bases[:, :, :-1]) + ((grid[:, k + 1 :] - x)/ (grid[:, k + 1 :] - grid[:, 1:(-k)])* bases[:, :, 1:])assert bases.size() == (x.size(0),self.in_features,self.grid_size + self.spline_order,)return bases.contiguous()def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):"""Compute the coefficients of the curve that interpolates the given points.Args:x (torch.Tensor): Input tensor of shape (batch_size, in_features).y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).Returns:torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order)."""assert x.dim() == 2 and x.size(1) == self.in_featuresassert y.size() == (x.size(0), self.in_features, self.out_features)A = self.b_splines(x).transpose(0, 1) # (in_features, batch_size, grid_size + spline_order)B = y.transpose(0, 1) # (in_features, batch_size, out_features)solution = torch.linalg.lstsq(A, B).solution # (in_features, grid_size + spline_order, out_features)result = solution.permute(2, 0, 1) # (out_features, in_features, grid_size + spline_order)assert result.size() == (self.out_features,self.in_features,self.grid_size + self.spline_order,)return result.contiguous()@propertydef scaled_spline_weight(self):return self.spline_weight * (self.spline_scaler.unsqueeze(-1)if self.enable_standalone_scale_splineelse 1.0)def forward(self, x: torch.Tensor):assert x.dim() == 2 and x.size(1) == self.in_featuresbase_output = F.linear(self.base_activation(x), self.base_weight)spline_output = F.linear(self.b_splines(x).view(x.size(0), -1),self.scaled_spline_weight.view(self.out_features, -1),)return base_output + spline_output@torch.no_grad()def update_grid(self, x: torch.Tensor, margin=0.01):assert x.dim() == 2 and x.size(1) == self.in_featuresbatch = x.size(0)splines = self.b_splines(x) # (batch, in, coeff)splines = splines.permute(1, 0, 2) # (in, batch, coeff)orig_coeff = self.scaled_spline_weight # (out, in, coeff)orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)unreduced_spline_output = unreduced_spline_output.permute(1, 0, 2) # (batch, in, out)# sort each channel individually to collect data distributionx_sorted = torch.sort(x, dim=0)[0]grid_adaptive = x_sorted[torch.linspace(0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device)]uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_sizegrid_uniform = (torch.arange(self.grid_size + 1, dtype=torch.float32, device=x.device).unsqueeze(1)* uniform_step+ x_sorted[0]- margin)grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptivegrid = torch.concatenate([grid[:1]- uniform_step* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),grid,grid[-1:]+ uniform_step* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),],dim=0,)self.grid.copy_(grid.T)self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):"""Compute the regularization loss.This is a dumb simulation of the original L1 regularization as stated in thepaper, since the original one requires computing absolutes and entropy from theexpanded (batch, in_features, out_features) intermediate tensor, which is hiddenbehind the F.linear function if we want an memory efficient implementation.The L1 regularization is now computed as mean absolute value of the splineweights. The authors implementation also includes this term in addition to thesample-based regularization."""l1_fake = self.spline_weight.abs().mean(-1)regularization_loss_activation = l1_fake.sum()p = l1_fake / regularization_loss_activationregularization_loss_entropy = -torch.sum(p * p.log())return (regularize_activation * regularization_loss_activation+ regularize_entropy * regularization_loss_entropy)class DW_bn_relu(nn.Module):def __init__(self, dim=768):super(DW_bn_relu, self).__init__()self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)self.bn = nn.BatchNorm2d(dim)self.relu = nn.ReLU()def forward(self, x, H, W):B, N, C = x.shapex = x.transpose(1, 2).view(B, C, H, W)x = self.dwconv(x)x = self.bn(x)x = self.relu(x)x = x.flatten(2).transpose(1, 2)return xclass KANBlock(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5, version=4):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.dim = in_featuresgrid_size=5spline_order=3scale_noise=0.1scale_base=1.0scale_spline=1.0base_activation=torch.nn.SiLUgrid_eps=0.02grid_range=[-1, 1]self.fc1 = KANLinear(in_features,hidden_features,grid_size=grid_size,spline_order=spline_order,scale_noise=scale_noise,scale_base=scale_base,scale_spline=scale_spline,base_activation=base_activation,grid_eps=grid_eps,grid_range=grid_range,)self.fc2 = KANLinear(hidden_features,out_features,grid_size=grid_size,spline_order=spline_order,scale_noise=scale_noise,scale_base=scale_base,scale_spline=scale_spline,base_activation=base_activation,grid_eps=grid_eps,grid_range=grid_range,)self.fc3 = KANLinear(hidden_features,out_features,grid_size=grid_size,spline_order=spline_order,scale_noise=scale_noise,scale_base=scale_base,scale_spline=scale_spline,base_activation=base_activation,grid_eps=grid_eps,grid_range=grid_range,) self.dwconv_1 = DW_bn_relu(hidden_features)self.dwconv_2 = DW_bn_relu(hidden_features)self.dwconv_3 = DW_bn_relu(hidden_features)self.drop = nn.Dropout(drop)self.shift_size = shift_sizeself.pad = shift_size // 2def forward(self, x, H, W):B, N, C = x.shapex = self.fc1(x.reshape(B*N,C))x = x.reshape(B,N,C).contiguous()x = self.dwconv_1(x, H, W)x = self.fc2(x.reshape(B*N,C))x = x.reshape(B,N,C).contiguous()x = self.dwconv_2(x, H, W)x = self.fc3(x.reshape(B*N,C))x = x.reshape(B,N,C).contiguous()x = self.dwconv_3(x, H, W)return xif __name__ == '__main__':x = torch.randn([1, 22*22, 128])kan = KANBlock(in_features=128)out = kan(x, H=22, W=22)print(out.shape) # [1, 22*22, 128]