【Block总结】EfficientViT中的多尺度线性注意力模块即插即用

论文信息

  • 标题: EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
  • 作者: Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han(MIT/浙江大学/清华大学/MIT-IBM Watson AI Lab)[3][7]
  • GitHub: mit-han-lab/efficientvit
  • 研究背景: 高分辨率密集预测(如语义分割、超分辨率)在自动驾驶、计算摄影等领域应用广泛,但现有模型存在计算成本高、硬件部署效率低的问题[3][7]。
    在这里插入图片描述

核心创新点

  1. 多尺度线性注意力(Multi-Scale Linear Attention)

    • 替代传统Softmax注意力,通过ReLU线性注意力降低计算复杂度,实现全局感受野和多尺度学习[3][5][7]。
    • 使用多尺度卷积核(3×3、5×5)聚合多尺度特征,提升细节捕捉能力[3][7]。
  2. 稀疏注意力优化

    • 通过局部注意力聚焦关键区域,减少冗余计算,提升硬件效率(如移动端GPU)。
  3. 高效网络架构设计

    • 主干网络采用多阶段层级结构,结合深度可分离卷积(DWConv)增强局部特征提取[3][5]。
    • 特征金字塔融合多尺度特征,支持高分辨率输出(如1024×2048分辨率输入)。
  4. 硬件友好性

    • 在移动CPU(如骁龙855)、边缘GPU(Jetson AGX Orin)和云端GPU(A100)上实现显著加速,延迟降低最高达13.9倍[3][7]。

方法详解

模型结构

输入 → 多阶段主干网络(4阶段下采样) → 多尺度线性注意力模块 → FFN+DWConv → 特征金字塔融合 → 上采样输出

在这里插入图片描述

关键设计

组件作用优势
多尺度线性注意力全局上下文建模 + 多尺度特征融合硬件高效,支持高分辨率输入
FFN+DWConv局部细节增强减少计算量,提升移动端推理速度
特征金字塔多阶段特征融合平衡语义与空间信息
  • 多尺度线性注意力模块:该模块旨在通过硬件高效的操作实现全局感受野和多尺度学习。它使用ReLU线性注意力替代传统的softmax注意力,以降低计算复杂度并保持功能性。同时,通过小核卷积聚合附近的tokens生成多尺度tokens,进一步增强了局部信息提取和多尺度学习能力。
  • ReLU线性注意力:在EfficientViT中,ReLU线性注意力用于实现全局感受野。其相似性函数定义为 S i m ( Q , K ) = R e L U ( Q ) R e L U ( K ) T Sim(Q,K)=ReLU(Q)ReLU(K)^T Sim(Q,K)=ReLU(Q)ReLU(K)T,通过矩阵乘法的结合律,可将计算复杂度从二次降为线性,同时避免了softmax等硬件低效操作。
  • 深度卷积增强:在每个前馈神经网络(FFN)层中插入深度卷积,以进一步提高局部特征提取能力。

在这里插入图片描述


总结

  1. 性能与效率平衡

    • 在语义分割、超分辨率等任务中达到SOTA,部分指标超越CNN和传统Transformer模型[3][7]。
    • 支持移动端到云端的多平台实时推理(如Jetson AGX Orin延迟<8ms)[3][7]。
  2. 扩展与应用

    • EfficientViT-SAM: 在医疗图像分割挑战赛(CVPR 2024)中夺冠,支持实时交互式分割[GitHub]。
    • DC-AE压缩自编码器: 后续工作,加速扩散模型生成(如文本到图像生成速度提升128倍)[GitHub]。

关键公式:ReLU线性注意力

O i = ∑ j = 1 N [ ReLU ⁡ ( Q i ) ReLU ⁡ ( K j ) T ] V j ReLU ⁡ ( Q i ) ∑ j = 1 N ReLU ⁡ ( K j ) T = ∑ j = 1 N ReLU ⁡ ( Q i ) [ ( ReLU ⁡ ( K j ) T V j ) ] ReLU ⁡ ( Q i ) ∑ j = 1 N ReLU ⁡ ( K j ) T = ReLU ⁡ ( Q i ) ( ∑ j = 1 N ReLU ⁡ ( K j ) T V j ) ReLU ⁡ ( Q i ) ( ∑ j = 1 N ReLU ⁡ ( K j ) T ) . \begin{aligned} O_{i} & =\frac{\sum_{j=1}^{N}\left[\operatorname{ReLU}\left(Q_{i}\right) \operatorname{ReLU}\left(K_{j}\right)^{T}\right] V_{j}}{\operatorname{ReLU}\left(Q_{i}\right) \sum_{j=1}^{N} \operatorname{ReLU}\left(K_{j}\right)^{T}} \\ & =\frac{\sum_{j=1}^{N} \operatorname{ReLU}\left(Q_{i}\right)\left[\left(\operatorname{ReLU}\left(K_{j}\right)^{T} V_{j}\right)\right]}{\operatorname{ReLU}\left(Q_{i}\right) \sum_{j=1}^{N} \operatorname{ReLU}\left(K_{j}\right)^{T}} \\ & =\frac{\operatorname{ReLU}\left(Q_{i}\right)\left(\sum_{j=1}^{N} \operatorname{ReLU}\left(K_{j}\right)^{T} V_{j}\right)}{\operatorname{ReLU}\left(Q_{i}\right)\left(\sum_{j=1}^{N} \operatorname{ReLU}\left(K_{j}\right)^{T}\right)} . \end{aligned} Oi=ReLU(Qi)j=1NReLU(Kj)Tj=1N[ReLU(Qi)ReLU(Kj)T]Vj=ReLU(Qi)j=1NReLU(Kj)Tj=1NReLU(Qi)[(ReLU(Kj)TVj)]=ReLU(Qi)(j=1NReLU(Kj)T)ReLU(Qi)(j=1NReLU(Kj)TVj).

代码

from functools import partial
from inspect import signature
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Optional,Dictdef val2list(x, repeat_time=1) -> list:if isinstance(x, (list, tuple)):return list(x)return [x for _ in range(repeat_time)]def val2tuple(x, min_len: int = 1, idx_repeat: int = -1) -> tuple:x = val2list(x)# repeat elements if necessaryif len(x) > 0:x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]return tuple(x)def get_same_padding(kernel_size):if isinstance(kernel_size, tuple):return tuple([get_same_padding(ks) for ks in kernel_size])else:assert kernel_size % 2 > 0, "kernel size should be odd number"return kernel_size // 2def build_kwargs_from_config(config: dict, target_func: callable) -> Dict[str, any]:valid_keys = list(signature(target_func).parameters)kwargs = {}for key in config:if key in valid_keys:kwargs[key] = config[key]return kwargs# register activation function here
REGISTERED_ACT_DICT: dict[str, type] = {"relu": nn.ReLU,"relu6": nn.ReLU6,"hswish": nn.Hardswish,"silu": nn.SiLU,"gelu": partial(nn.GELU, approximate="tanh"),
}class LayerNorm2d(nn.LayerNorm):def forward(self, x: torch.Tensor) -> torch.Tensor:out = x - torch.mean(x, dim=1, keepdim=True)out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps)if self.elementwise_affine:out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)return outdef build_act(name: str, **kwargs) -> Optional[nn.Module]:if name in REGISTERED_ACT_DICT:act_cls = REGISTERED_ACT_DICT[name]args = build_kwargs_from_config(kwargs, act_cls)return act_cls(**args)else:return None# register normalization function here
REGISTERED_NORM_DICT: dict[str, type] = {"bn2d": nn.BatchNorm2d,"ln": nn.LayerNorm,"ln2d": LayerNorm2d,
}def build_norm(name="bn2d", num_features=None, **kwargs) -> Optional[nn.Module]:if name in ["ln", "ln2d", "trms2d"]:kwargs["normalized_shape"] = num_featureselse:kwargs["num_features"] = num_featuresif name in REGISTERED_NORM_DICT:norm_cls = REGISTERED_NORM_DICT[name]args = build_kwargs_from_config(kwargs, norm_cls)return norm_cls(**args)else:return Noneclass ConvLayer(nn.Module):def __init__(self,in_channels: int,out_channels: int,kernel_size=3,stride=1,dilation=1,groups=1,use_bias=False,dropout=0,norm="bn2d",act_func="relu",):super(ConvLayer, self).__init__()padding = get_same_padding(kernel_size)padding *= dilationself.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else Noneself.conv = nn.Conv2d(in_channels,out_channels,kernel_size=(kernel_size, kernel_size),stride=(stride, stride),padding=padding,dilation=(dilation, dilation),groups=groups,bias=use_bias,)self.norm = build_norm(norm, num_features=out_channels)self.act = build_act(act_func)def forward(self, x: torch.Tensor) -> torch.Tensor:if self.dropout is not None:x = self.dropout(x)x = self.conv(x)if self.norm:x = self.norm(x)if self.act:x = self.act(x)return xclass LiteMLA(nn.Module):r"""Lightweight multi-scale linear attention"""def __init__(self,in_channels: int,out_channels: int,heads: Optional[int] = None,heads_ratio: float = 1.0,dim=8,use_bias=False,norm=(None, "bn2d"),act_func=(None, None),kernel_func="relu",scales: tuple[int, ...] = (5,),eps=1.0e-15,):super(LiteMLA, self).__init__()self.eps = epsheads = int(in_channels // dim * heads_ratio) if heads is None else headstotal_dim = heads * dimuse_bias = val2tuple(use_bias, 2)norm = val2tuple(norm, 2)act_func = val2tuple(act_func, 2)self.dim = dimself.qkv = ConvLayer(in_channels,3 * total_dim,1,use_bias=use_bias[0],norm=norm[0],act_func=act_func[0],)self.aggreg = nn.ModuleList([nn.Sequential(nn.Conv2d(3 * total_dim,3 * total_dim,scale,padding=get_same_padding(scale),groups=3 * total_dim,bias=use_bias[0],),nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),)for scale in scales])self.kernel_func = build_act(kernel_func, inplace=False)self.proj = ConvLayer(total_dim * (1 + len(scales)),out_channels,1,use_bias=use_bias[1],norm=norm[1],act_func=act_func[1],)@torch.autocast(device_type="cuda", enabled=False)def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:B, _, H, W = list(qkv.size())if qkv.dtype == torch.float16:qkv = qkv.float()qkv = torch.reshape(qkv,(B,-1,3 * self.dim,H * W,),)q, k, v = (qkv[:, :, 0 : self.dim],qkv[:, :, self.dim : 2 * self.dim],qkv[:, :, 2 * self.dim :],)# lightweight linear attentionq = self.kernel_func(q)k = self.kernel_func(k)# linear matmultrans_k = k.transpose(-1, -2)v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1)vk = torch.matmul(v, trans_k)out = torch.matmul(vk, q)if out.dtype == torch.bfloat16:out = out.float()out = out[:, :, :-1] / (out[:, :, -1:] + self.eps)out = torch.reshape(out, (B, -1, H, W))return out@torch.autocast(device_type="cuda", enabled=False)def relu_quadratic_att(self, qkv: torch.Tensor) -> torch.Tensor:B, _, H, W = list(qkv.size())qkv = torch.reshape(qkv,(B,-1,3 * self.dim,H * W,),)q, k, v = (qkv[:, :, 0 : self.dim],qkv[:, :, self.dim : 2 * self.dim],qkv[:, :, 2 * self.dim :],)q = self.kernel_func(q)k = self.kernel_func(k)att_map = torch.matmul(k.transpose(-1, -2), q)  # b h n noriginal_dtype = att_map.dtypeif original_dtype in [torch.float16, torch.bfloat16]:att_map = att_map.float()att_map = att_map / (torch.sum(att_map, dim=2, keepdim=True) + self.eps)  # b h n natt_map = att_map.to(original_dtype)out = torch.matmul(v, att_map)  # b h d nout = torch.reshape(out, (B, -1, H, W))return outdef forward(self, x: torch.Tensor) -> torch.Tensor:# generate multi-scale q, k, vqkv = self.qkv(x)multi_scale_qkv = [qkv]for op in self.aggreg:multi_scale_qkv.append(op(qkv))qkv = torch.cat(multi_scale_qkv, dim=1)H, W = list(qkv.size())[-2:]if H * W > self.dim:out = self.relu_linear_att(qkv).to(qkv.dtype)else:out = self.relu_quadratic_att(qkv)out = self.proj(out)return outif __name__ == "__main__":if __name__ == '__main__':# 定义输入张量大小(Batch、Channel、Height、Wight)B, C, H, W = 2, 64, 40, 40input_tensor = torch.randn(B, C, H, W)  # 随机生成输入张量# 初始化 SAFMdim = C  # 输入和输出通道数# 创建 SAFM 实例block = LiteMLA(in_channels=dim,out_channels=dim,scales=(5,))# 如果GPU可用将模块移动到 GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")sablock = block.to(device)print(sablock)input_tensor = input_tensor.to(device)# 执行前向传播output = sablock(input_tensor)# 打印输入和输出的形状print(f"Input: {input_tensor.shape}")print(f"Output: {output.shape}")

代码的详细解释:

类定义和初始化

  • LiteMLA 类继承自 nn.Module,是一个自定义的神经网络层。
  • 在初始化方法 __init__ 中,定义了多个参数来控制模块的行为:
    • in_channelsout_channels 分别表示输入和输出通道数。
    • headsheads_ratio 用于控制注意力头的数量和比例。
    • dim 是每个注意力头的维度。
    • use_bias 指定是否在卷积层中使用偏置项。
    • norm 指定归一化层的类型,可以是 None"bn2d"(二维批量归一化)。
    • act_func 指定激活函数的类型。
    • kernel_func 是用于 q 和 k 的激活函数,默认为 "relu"
    • scales 是一个元组,定义了多尺度聚合中使用的卷积核大小。
    • eps 是一个小的正数,用于防止除以零的错误。

主要组件

  • self.qkv 是一个卷积层,用于生成查询(q)、键(k)和值(v)。
  • self.aggreg 是一个 nn.ModuleList,包含多个卷积序列,用于多尺度聚合。
  • self.kernel_func 是根据 kernel_func 参数构建的激活函数。
  • self.proj 是一个卷积层,用于将多尺度注意力机制的输出投影到所需的输出通道数。

方法

  • relu_linear_attrelu_quadratic_att 是两个实现不同注意力机制的方法。它们都接受一个形状为 (B, C, H, W) 的张量 qkv 作为输入,并返回一个形状相同的张量作为输出。
    • relu_linear_att 使用线性注意力机制,其中 q 和 k 通过 self.kernel_func 激活后,进行线性矩阵乘法来计算注意力权重。
    • relu_quadratic_att 使用传统的二次(softmax)注意力机制,其中 q 和 k 的点积结果通过 softmax 归一化后,与 v 相乘得到输出。
  • forward 方法是模块的前向传播逻辑。它首先通过 self.qkv 生成 qkv 张量,然后通过 self.aggreg 中的多尺度聚合操作生成多尺度 qkv 张量。根据输入特征图的高度和宽度与 dim 的比较,选择使用 relu_linear_attrelu_quadratic_att 方法计算注意力输出。最后,通过 self.proj 将输出投影到所需的通道数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/26854.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

unsloth报错FileNotFoundError: [WinError 3] 系统找不到指定的路径。

运行平台 Windows 报错信息 Traceback (most recent call last): File “C:\Python312\Lib\site-packages\IPython\core\interactiveshell.py”, line 3577, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File “”, line 1, in runfile(‘D:\python_pr…

【清华大学】DeepSeek从入门到精通完整版pdf下载

DeepSeek从入门到精通.pdf 一共104页完整版 下载链接: https://pan.baidu.com/s/1-gnkTTD7EF2i_EKS5sx4vg?pwd1234 提取码: 1234 或 链接&#xff1a;https://pan.quark.cn/s/79118f5ab0fd 一、DeepSeek 概述 背景与定位 DeepSeek 的研发背景 核心功能与技术特点&#xff08…

如何使用ArcGIS Pro制作横向图例:详细步骤与实践指南

ArcGIS Pro&#xff0c;作为Esri公司推出的新一代地理信息系统&#xff08;GIS&#xff09;平台&#xff0c;以其强大的功能和灵活的操作界面&#xff0c;在地理数据处理、地图制作和空间分析等领域发挥着重要作用。 在地图制作过程中&#xff0c;图例作为地图的重要组成部分&…

监督学习单模型—线性模型—LASSO回归、Ridge回归

目标变量通常有很多影响因素&#xff0c;通过各类影响因素构建对目标变量的回归模型&#xff0c;能够实现对目标的预测。但根据稀疏性的假设&#xff0c;即使影响一个变量的因素有很多&#xff0c;其关键因素永远只会是少数。在这种情况下&#xff0c;还用传统的线性回归方法来…

【QT】QLinearGradient 线性渐变类简单使用教程

目录 0.简介 1&#xff09;qtDesigner中 2&#xff09;实际执行 1.功能详述 3.举一反三的样式 0.简介 QLinearGradient 是 Qt 框架中的一个类&#xff0c;用于定义线性渐变效果&#xff08;通过样式表设置&#xff09;。它可以用来填充形状、背景或其他图形元素&#xff0…

攻防世界GFSJ1184_welcome_CAT_CTF

题目 附件&#xff1a; 两个文件client和server Get Flag Exeinfo File分析 file client client: ELF 64-bit LSB pie executable, x86-64, version 1 (SYSV), dynamically linked, interpreter /lib64/ld-linux-x86-64.so.2, for GNU/Linux 3.2.0, BuildID[sha1]6045aa1ba5…

EL表达式和JSTL标签

目录 1. EL表达式 1.1. EL表达式概述 1.2. EL表达式运算 1.3. EL表达式操作对象 1.4. EL表达式内置对象 jsp 9个 11个 1.4.1. 参数隐藏对象 1.4.2. 域隐藏对象 1.4.3. PageContext对象 2. JSTL标签 2.1. JSTL概述 2.1.1. 什么是JSTL 2.1.2. 导入标签库 2.2. JSTL核…

PhotoShop学习01

了解Photoshop 这里省略了Photoshop的软件安装&#xff0c;请自行查找资源下载。 1.打开图片 下图为启动photoshop后出现的界面&#xff0c;我们可以通过创建新文件或打开已有文件来启用photoshop的工作界面。 可以通过左边的按钮进行新文件的创建或打开已有文件。 也可以点…

LabVIEW虚拟弗兰克赫兹实验仪

随着信息技术的飞速发展&#xff0c;虚拟仿真技术已经成为教学和研究中不可或缺的工具。开发了一种基于LabVIEW平台开发的虚拟弗兰克赫兹实验仪&#xff0c;该系统不仅能模拟实验操作&#xff0c;还能实时绘制数据图形&#xff0c;极大地丰富了物理实验的教学内容和方式。 ​ …

【TI毫米波雷达】DCA1000的ADC原始数据C语言解析及FMCW的Python解析2D-FFT图像

【TI毫米波雷达】DCA1000的ADC原始数据C语言解析及FMCW的Python解析2D-FFT图像 文章目录 ADC原始数据C语言解析Python的2D-FFT图像附录&#xff1a;结构框架雷达基本原理叙述雷达天线排列位置芯片框架Demo工程功能CCS工程导入工程叙述Software TasksData PathOutput informati…

【数据结构】堆与二叉树

一、树的概念 1.1 什么是树&#xff1f; 树是一种非线性的数据结构&#xff0c;其由 n 个 ( n > 0 ) 有限节点所组成的一个有层次关系的集合。之所以称其为树&#xff0c;是因为其逻辑结构看起来像是一颗倒挂的树。 在树中&#xff0c;有一个特殊的节点称为根节点&#xf…

从零开始开发纯血鸿蒙应用之语音朗读

从零开始开发纯血鸿蒙应用 〇、前言一、API 选型1、基本情况2、认识TextToSpeechEngine 二、功能集成实践1、改造右上角菜单2、实现语音播报功能2.1、语音引擎的获取和关闭2.2、设置待播报文本2.3、speak 目标文本2.4、设置语音回调 三、总结 〇、前言 中华汉字洋洋洒洒何其多…

8 SpringBoot进阶(上):AOP(面向切面编程技术)、AOP案例之统一操作日志

文章目录 前言1. AOP基础1.1 AOP概述: 什么是AOP?1.2 AOP快速入门1.3 Spring AOP核心中的相关术语(面试)2. AOP进阶2.1 通知类型2.1.1 @Around:环绕通知,此注解标注的通知方法在目标方法前、后都被执行(通知的代码在业务方法之前和之后都有)2.1.2 @Before:前置通知,此…

人大金仓国产数据库与PostgreSQL

一、简介 在前面项目中&#xff0c;我们使用若依前后端分离整合人大金仓&#xff0c;在后续开发过程中&#xff0c;我们经常因为各种”不适配“问题&#xff0c;但可以感觉得到大部分问题&#xff0c;将人大金仓视为postgreSQL就能去解决大部分问题。据了解&#xff0c;Kingba…

Deepseek 模型蒸馏

赋范课堂&#xff1a; https://www.bilibili.com/video/BV1qUN8enE4c/

经验分享:用一张表解决并发冲突!数据库事务锁的核心实现逻辑

背景 对于一些内部使用的管理系统来说&#xff0c;可能没有引入Redis&#xff0c;又想基于现有的基础设施处理并发问题&#xff0c;而数据库是每个应用都避不开的基础设施之一&#xff0c;因此分享个我曾经维护过的一个系统中&#xff0c;使用数据库表来实现事务锁的方式。 之…

【前端基础】1、HTML概述(HTML基本结构)

一、网页组成 HTML&#xff1a;网页的内容CSS&#xff1a;网页的样式JavaScript&#xff1a;网页的功能 二、HTML概述 HTML&#xff1a;全称为超文本标记语言&#xff0c;是一种标记语言。 超文本&#xff1a;文本、声音、图片、视频、表格、链接标记&#xff1a;由许许多多…

MongoDB—(一主、一从、一仲裁)副本集搭建

MongoDB集群介绍&#xff1a; MongoDB 副本集是由多个MongoDB实例组成的集群&#xff0c;其中包含一个主节点&#xff08;Primary&#xff09;和多个从节点&#xff08;Secondary&#xff09;&#xff0c;用于提供数据冗余和高可用性。以下是搭建 MongoDB 副本集的详细步骤&am…

Hive-06之函数 聚合Cube、Rollup、窗口函数

1、Hive函数介绍以及内置函数查看 内容较多&#xff0c;见《Hive官方文档》 https://cwiki.apache.org/confluence/display/Hive/LanguageManualUDF 1&#xff09;查看系统自带的函数 hive> show functions; 2&#xff09;显示自带的函数的用法 hive> desc function…

CSS定位详解

1. 相对定位 1.1 如何设置相对定位&#xff1f; 给元素设置 position:relative 即可实现相对定位。 可以使用 left 、 right 、 top 、 bottom 四个属性调整位置。 1.2 相对定位的参考点在哪里&#xff1f; 相对自己原来的位置 1.3 相对定位的特点&#xff1…