论文信息
- 标题: EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
- 作者: Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han(MIT/浙江大学/清华大学/MIT-IBM Watson AI Lab)[3][7]
- GitHub: mit-han-lab/efficientvit
- 研究背景: 高分辨率密集预测(如语义分割、超分辨率)在自动驾驶、计算摄影等领域应用广泛,但现有模型存在计算成本高、硬件部署效率低的问题[3][7]。
核心创新点
-
多尺度线性注意力(Multi-Scale Linear Attention)
- 替代传统Softmax注意力,通过ReLU线性注意力降低计算复杂度,实现全局感受野和多尺度学习[3][5][7]。
- 使用多尺度卷积核(3×3、5×5)聚合多尺度特征,提升细节捕捉能力[3][7]。
-
稀疏注意力优化
- 通过局部注意力聚焦关键区域,减少冗余计算,提升硬件效率(如移动端GPU)。
-
高效网络架构设计
- 主干网络采用多阶段层级结构,结合深度可分离卷积(DWConv)增强局部特征提取[3][5]。
- 特征金字塔融合多尺度特征,支持高分辨率输出(如1024×2048分辨率输入)。
-
硬件友好性
- 在移动CPU(如骁龙855)、边缘GPU(Jetson AGX Orin)和云端GPU(A100)上实现显著加速,延迟降低最高达13.9倍[3][7]。
方法详解
模型结构
输入 → 多阶段主干网络(4阶段下采样) → 多尺度线性注意力模块 → FFN+DWConv → 特征金字塔融合 → 上采样输出
关键设计
组件 | 作用 | 优势 |
---|---|---|
多尺度线性注意力 | 全局上下文建模 + 多尺度特征融合 | 硬件高效,支持高分辨率输入 |
FFN+DWConv | 局部细节增强 | 减少计算量,提升移动端推理速度 |
特征金字塔 | 多阶段特征融合 | 平衡语义与空间信息 |
- 多尺度线性注意力模块:该模块旨在通过硬件高效的操作实现全局感受野和多尺度学习。它使用ReLU线性注意力替代传统的softmax注意力,以降低计算复杂度并保持功能性。同时,通过小核卷积聚合附近的tokens生成多尺度tokens,进一步增强了局部信息提取和多尺度学习能力。
- ReLU线性注意力:在EfficientViT中,ReLU线性注意力用于实现全局感受野。其相似性函数定义为 S i m ( Q , K ) = R e L U ( Q ) R e L U ( K ) T Sim(Q,K)=ReLU(Q)ReLU(K)^T Sim(Q,K)=ReLU(Q)ReLU(K)T,通过矩阵乘法的结合律,可将计算复杂度从二次降为线性,同时避免了softmax等硬件低效操作。
- 深度卷积增强:在每个前馈神经网络(FFN)层中插入深度卷积,以进一步提高局部特征提取能力。
总结
-
性能与效率平衡
- 在语义分割、超分辨率等任务中达到SOTA,部分指标超越CNN和传统Transformer模型[3][7]。
- 支持移动端到云端的多平台实时推理(如Jetson AGX Orin延迟<8ms)[3][7]。
-
扩展与应用
- EfficientViT-SAM: 在医疗图像分割挑战赛(CVPR 2024)中夺冠,支持实时交互式分割[GitHub]。
- DC-AE压缩自编码器: 后续工作,加速扩散模型生成(如文本到图像生成速度提升128倍)[GitHub]。
关键公式:ReLU线性注意力
O i = ∑ j = 1 N [ ReLU ( Q i ) ReLU ( K j ) T ] V j ReLU ( Q i ) ∑ j = 1 N ReLU ( K j ) T = ∑ j = 1 N ReLU ( Q i ) [ ( ReLU ( K j ) T V j ) ] ReLU ( Q i ) ∑ j = 1 N ReLU ( K j ) T = ReLU ( Q i ) ( ∑ j = 1 N ReLU ( K j ) T V j ) ReLU ( Q i ) ( ∑ j = 1 N ReLU ( K j ) T ) . \begin{aligned} O_{i} & =\frac{\sum_{j=1}^{N}\left[\operatorname{ReLU}\left(Q_{i}\right) \operatorname{ReLU}\left(K_{j}\right)^{T}\right] V_{j}}{\operatorname{ReLU}\left(Q_{i}\right) \sum_{j=1}^{N} \operatorname{ReLU}\left(K_{j}\right)^{T}} \\ & =\frac{\sum_{j=1}^{N} \operatorname{ReLU}\left(Q_{i}\right)\left[\left(\operatorname{ReLU}\left(K_{j}\right)^{T} V_{j}\right)\right]}{\operatorname{ReLU}\left(Q_{i}\right) \sum_{j=1}^{N} \operatorname{ReLU}\left(K_{j}\right)^{T}} \\ & =\frac{\operatorname{ReLU}\left(Q_{i}\right)\left(\sum_{j=1}^{N} \operatorname{ReLU}\left(K_{j}\right)^{T} V_{j}\right)}{\operatorname{ReLU}\left(Q_{i}\right)\left(\sum_{j=1}^{N} \operatorname{ReLU}\left(K_{j}\right)^{T}\right)} . \end{aligned} Oi=ReLU(Qi)∑j=1NReLU(Kj)T∑j=1N[ReLU(Qi)ReLU(Kj)T]Vj=ReLU(Qi)∑j=1NReLU(Kj)T∑j=1NReLU(Qi)[(ReLU(Kj)TVj)]=ReLU(Qi)(∑j=1NReLU(Kj)T)ReLU(Qi)(∑j=1NReLU(Kj)TVj).
代码
from functools import partial
from inspect import signature
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Optional,Dictdef val2list(x, repeat_time=1) -> list:if isinstance(x, (list, tuple)):return list(x)return [x for _ in range(repeat_time)]def val2tuple(x, min_len: int = 1, idx_repeat: int = -1) -> tuple:x = val2list(x)# repeat elements if necessaryif len(x) > 0:x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]return tuple(x)def get_same_padding(kernel_size):if isinstance(kernel_size, tuple):return tuple([get_same_padding(ks) for ks in kernel_size])else:assert kernel_size % 2 > 0, "kernel size should be odd number"return kernel_size // 2def build_kwargs_from_config(config: dict, target_func: callable) -> Dict[str, any]:valid_keys = list(signature(target_func).parameters)kwargs = {}for key in config:if key in valid_keys:kwargs[key] = config[key]return kwargs# register activation function here
REGISTERED_ACT_DICT: dict[str, type] = {"relu": nn.ReLU,"relu6": nn.ReLU6,"hswish": nn.Hardswish,"silu": nn.SiLU,"gelu": partial(nn.GELU, approximate="tanh"),
}class LayerNorm2d(nn.LayerNorm):def forward(self, x: torch.Tensor) -> torch.Tensor:out = x - torch.mean(x, dim=1, keepdim=True)out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps)if self.elementwise_affine:out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)return outdef build_act(name: str, **kwargs) -> Optional[nn.Module]:if name in REGISTERED_ACT_DICT:act_cls = REGISTERED_ACT_DICT[name]args = build_kwargs_from_config(kwargs, act_cls)return act_cls(**args)else:return None# register normalization function here
REGISTERED_NORM_DICT: dict[str, type] = {"bn2d": nn.BatchNorm2d,"ln": nn.LayerNorm,"ln2d": LayerNorm2d,
}def build_norm(name="bn2d", num_features=None, **kwargs) -> Optional[nn.Module]:if name in ["ln", "ln2d", "trms2d"]:kwargs["normalized_shape"] = num_featureselse:kwargs["num_features"] = num_featuresif name in REGISTERED_NORM_DICT:norm_cls = REGISTERED_NORM_DICT[name]args = build_kwargs_from_config(kwargs, norm_cls)return norm_cls(**args)else:return Noneclass ConvLayer(nn.Module):def __init__(self,in_channels: int,out_channels: int,kernel_size=3,stride=1,dilation=1,groups=1,use_bias=False,dropout=0,norm="bn2d",act_func="relu",):super(ConvLayer, self).__init__()padding = get_same_padding(kernel_size)padding *= dilationself.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else Noneself.conv = nn.Conv2d(in_channels,out_channels,kernel_size=(kernel_size, kernel_size),stride=(stride, stride),padding=padding,dilation=(dilation, dilation),groups=groups,bias=use_bias,)self.norm = build_norm(norm, num_features=out_channels)self.act = build_act(act_func)def forward(self, x: torch.Tensor) -> torch.Tensor:if self.dropout is not None:x = self.dropout(x)x = self.conv(x)if self.norm:x = self.norm(x)if self.act:x = self.act(x)return xclass LiteMLA(nn.Module):r"""Lightweight multi-scale linear attention"""def __init__(self,in_channels: int,out_channels: int,heads: Optional[int] = None,heads_ratio: float = 1.0,dim=8,use_bias=False,norm=(None, "bn2d"),act_func=(None, None),kernel_func="relu",scales: tuple[int, ...] = (5,),eps=1.0e-15,):super(LiteMLA, self).__init__()self.eps = epsheads = int(in_channels // dim * heads_ratio) if heads is None else headstotal_dim = heads * dimuse_bias = val2tuple(use_bias, 2)norm = val2tuple(norm, 2)act_func = val2tuple(act_func, 2)self.dim = dimself.qkv = ConvLayer(in_channels,3 * total_dim,1,use_bias=use_bias[0],norm=norm[0],act_func=act_func[0],)self.aggreg = nn.ModuleList([nn.Sequential(nn.Conv2d(3 * total_dim,3 * total_dim,scale,padding=get_same_padding(scale),groups=3 * total_dim,bias=use_bias[0],),nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),)for scale in scales])self.kernel_func = build_act(kernel_func, inplace=False)self.proj = ConvLayer(total_dim * (1 + len(scales)),out_channels,1,use_bias=use_bias[1],norm=norm[1],act_func=act_func[1],)@torch.autocast(device_type="cuda", enabled=False)def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:B, _, H, W = list(qkv.size())if qkv.dtype == torch.float16:qkv = qkv.float()qkv = torch.reshape(qkv,(B,-1,3 * self.dim,H * W,),)q, k, v = (qkv[:, :, 0 : self.dim],qkv[:, :, self.dim : 2 * self.dim],qkv[:, :, 2 * self.dim :],)# lightweight linear attentionq = self.kernel_func(q)k = self.kernel_func(k)# linear matmultrans_k = k.transpose(-1, -2)v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1)vk = torch.matmul(v, trans_k)out = torch.matmul(vk, q)if out.dtype == torch.bfloat16:out = out.float()out = out[:, :, :-1] / (out[:, :, -1:] + self.eps)out = torch.reshape(out, (B, -1, H, W))return out@torch.autocast(device_type="cuda", enabled=False)def relu_quadratic_att(self, qkv: torch.Tensor) -> torch.Tensor:B, _, H, W = list(qkv.size())qkv = torch.reshape(qkv,(B,-1,3 * self.dim,H * W,),)q, k, v = (qkv[:, :, 0 : self.dim],qkv[:, :, self.dim : 2 * self.dim],qkv[:, :, 2 * self.dim :],)q = self.kernel_func(q)k = self.kernel_func(k)att_map = torch.matmul(k.transpose(-1, -2), q) # b h n noriginal_dtype = att_map.dtypeif original_dtype in [torch.float16, torch.bfloat16]:att_map = att_map.float()att_map = att_map / (torch.sum(att_map, dim=2, keepdim=True) + self.eps) # b h n natt_map = att_map.to(original_dtype)out = torch.matmul(v, att_map) # b h d nout = torch.reshape(out, (B, -1, H, W))return outdef forward(self, x: torch.Tensor) -> torch.Tensor:# generate multi-scale q, k, vqkv = self.qkv(x)multi_scale_qkv = [qkv]for op in self.aggreg:multi_scale_qkv.append(op(qkv))qkv = torch.cat(multi_scale_qkv, dim=1)H, W = list(qkv.size())[-2:]if H * W > self.dim:out = self.relu_linear_att(qkv).to(qkv.dtype)else:out = self.relu_quadratic_att(qkv)out = self.proj(out)return outif __name__ == "__main__":if __name__ == '__main__':# 定义输入张量大小(Batch、Channel、Height、Wight)B, C, H, W = 2, 64, 40, 40input_tensor = torch.randn(B, C, H, W) # 随机生成输入张量# 初始化 SAFMdim = C # 输入和输出通道数# 创建 SAFM 实例block = LiteMLA(in_channels=dim,out_channels=dim,scales=(5,))# 如果GPU可用将模块移动到 GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")sablock = block.to(device)print(sablock)input_tensor = input_tensor.to(device)# 执行前向传播output = sablock(input_tensor)# 打印输入和输出的形状print(f"Input: {input_tensor.shape}")print(f"Output: {output.shape}")
代码的详细解释:
类定义和初始化
LiteMLA
类继承自nn.Module
,是一个自定义的神经网络层。- 在初始化方法
__init__
中,定义了多个参数来控制模块的行为:in_channels
和out_channels
分别表示输入和输出通道数。heads
和heads_ratio
用于控制注意力头的数量和比例。dim
是每个注意力头的维度。use_bias
指定是否在卷积层中使用偏置项。norm
指定归一化层的类型,可以是None
或"bn2d"
(二维批量归一化)。act_func
指定激活函数的类型。kernel_func
是用于 q 和 k 的激活函数,默认为"relu"
。scales
是一个元组,定义了多尺度聚合中使用的卷积核大小。eps
是一个小的正数,用于防止除以零的错误。
主要组件
self.qkv
是一个卷积层,用于生成查询(q)、键(k)和值(v)。self.aggreg
是一个nn.ModuleList
,包含多个卷积序列,用于多尺度聚合。self.kernel_func
是根据kernel_func
参数构建的激活函数。self.proj
是一个卷积层,用于将多尺度注意力机制的输出投影到所需的输出通道数。
方法
relu_linear_att
和relu_quadratic_att
是两个实现不同注意力机制的方法。它们都接受一个形状为(B, C, H, W)
的张量qkv
作为输入,并返回一个形状相同的张量作为输出。relu_linear_att
使用线性注意力机制,其中 q 和 k 通过self.kernel_func
激活后,进行线性矩阵乘法来计算注意力权重。relu_quadratic_att
使用传统的二次(softmax)注意力机制,其中 q 和 k 的点积结果通过 softmax 归一化后,与 v 相乘得到输出。
forward
方法是模块的前向传播逻辑。它首先通过self.qkv
生成 qkv 张量,然后通过self.aggreg
中的多尺度聚合操作生成多尺度 qkv 张量。根据输入特征图的高度和宽度与dim
的比较,选择使用relu_linear_att
或relu_quadratic_att
方法计算注意力输出。最后,通过self.proj
将输出投影到所需的通道数。