【Block总结】DynamicFilter,动态滤波器降低计算复杂度,替换传统的MHSA|即插即用

论文信息

标题: FFT-based Dynamic Token Mixer for Vision

论文链接: https://arxiv.org/pdf/2303.03932

关键词: 深度学习、计算机视觉、对象检测、分割

GitHub链接: https://github.com/okojoalg/dfformer

在这里插入图片描述

创新点

本论文提出了一种新的标记混合器(token mixer),称为动态滤波器(Dynamic Filter),旨在解决多头自注意力(MHSA)模型在处理高分辨率图像时的计算复杂度问题。传统的MHSA模型在输入特征图中像素数量的平方上具有计算复杂度,导致处理速度缓慢。通过引入基于快速傅里叶变换(FFT)的动态滤波器,论文展示了在保持性能的同时显著降低计算复杂度的可能性。

方法

论文中提出的动态滤波器结合了全局操作的优点,类似于MHSA,但在计算效率上更具优势。具体方法包括:

  • FFT-based Token Mixer: 通过FFT实现全局操作,降低计算复杂度。
  • DFFormer和CDFFormer模型: 这两种新型图像识别模型利用动态滤波器进行图像分类和其他下游任务。
    在这里插入图片描述

动态滤波器如何具体降低MHSA模型的计算复杂度?

动态滤波器通过引入基于快速傅里叶变换(FFT)的机制,显著降低了多头自注意力(MHSA)模型的计算复杂度。以下是其具体工作原理和优势:

计算复杂度问题

传统的MHSA模型在处理输入特征图时,其计算复杂度与特征图中像素数量的平方成正比。这意味着,当输入图像的分辨率增加时,计算需求会急剧上升,导致处理速度变慢,尤其是在高分辨率图像的情况下。

动态滤波器的工作原理

  1. 频域转换: 动态滤波器首先利用FFT将输入特征图转换到频域。FFT是一种高效的算法,可以将计算复杂度降低到 O ( N log ⁡ N ) O(N \log N) O(NlogN),其中 N N N是数据的长度。这一转换使得后续的操作可以在频域中进行,从而减少了计算量。

  2. 动态生成滤波器: 在频域中,动态滤波器通过一个多层感知机(MLP)动态生成每个特征通道的滤波器。这些滤波器是根据输入特征图的内容进行调整的,能够更好地捕捉到图像中的重要信息。

  3. 频域操作: 生成的滤波器在频域中应用于特征图,进行全局信息的捕捉。通过这种方式,动态滤波器能够有效地进行全局操作,同时避免了MHSA中计算复杂度的急剧增加。

  4. 逆FFT转换: 最后,经过滤波的频域特征图通过逆FFT转换回空间域,得到最终的输出结果。

优势

  • 降低计算复杂度: 通过在频域中进行操作,动态滤波器显著降低了MHSA模型的计算复杂度,使得处理高分辨率图像时的速度得以提升。

  • 提高内存效率: 动态滤波器的设计使得模型在处理时占用更少的内存,适合在资源有限的环境中运行。

  • 保持性能: 尽管计算复杂度降低,动态滤波器仍然能够保持与MHSA相似的性能,尤其是在图像分类和其他视觉任务中表现出色。

效果

实验结果表明,DFFormer和CDFFormer在高分辨率图像识别任务中表现出色,具有显著的吞吐量和内存效率。具体而言,这些模型在处理高分辨率图像时的性能优于传统的MHSA模型,显示出动态滤波器在实际应用中的潜力。

实验结果

论文通过一系列实验验证了提出模型的有效性,包括:

  • 图像分类: DFFormer和CDFFormer在标准数据集上的表现接近或超过了现有的最先进模型。
  • 下游任务分析: 通过对比实验,展示了动态滤波器在不同视觉任务中的适用性和优势。

总结

本论文的研究表明,基于FFT的动态滤波器是一种值得认真考虑的标记混合器选项,尤其是在处理高分辨率图像时。通过降低计算复杂度,动态滤波器不仅提高了模型的处理速度,还保持了良好的性能,推动了计算机视觉领域的进一步发展。研究结果为未来的视觉模型设计提供了新的思路和方向。

代码

import torch
import torch.nn as nn
from timm.models.layers import to_2tupleclass StarReLU(nn.Module):"""StarReLU: s * relu(x) ** 2 + b"""def __init__(self, scale_value=1.0, bias_value=0.0,scale_learnable=True, bias_learnable=True,mode=None, inplace=False):super().__init__()self.inplace = inplaceself.relu = nn.ReLU(inplace=inplace)self.scale = nn.Parameter(scale_value * torch.ones(1),requires_grad=scale_learnable)self.bias = nn.Parameter(bias_value * torch.ones(1),requires_grad=bias_learnable)def forward(self, x):return self.scale * self.relu(x) ** 2 + self.biasclass Mlp(nn.Module):""" MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.Mostly copied from timm."""def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0.,bias=False, **kwargs):super().__init__()in_features = dimout_features = out_features or in_featureshidden_features = int(mlp_ratio * in_features)drop_probs = to_2tuple(drop)self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)self.act = act_layer()self.drop1 = nn.Dropout(drop_probs[0])self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)self.drop2 = nn.Dropout(drop_probs[1])def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop1(x)x = self.fc2(x)x = self.drop2(x)return xclass DynamicFilter(nn.Module):def __init__(self, dim, expansion_ratio=2, reweight_expansion_ratio=.25,act1_layer=StarReLU, act2_layer=nn.Identity,bias=False, num_filters=4, size=14, weight_resize=False,**kwargs):super().__init__()size = to_2tuple(size)self.size = size[0]self.filter_size = size[1] // 2 + 1self.num_filters = num_filtersself.dim = dimself.med_channels = int(expansion_ratio * dim)self.weight_resize = weight_resizeself.pwconv1 = nn.Linear(dim, self.med_channels, bias=bias)self.act1 = act1_layer()self.reweight = Mlp(dim, reweight_expansion_ratio, num_filters * self.med_channels)self.complex_weights = nn.Parameter(torch.randn(self.size, self.filter_size, num_filters, 2,dtype=torch.float32) * 0.02)self.act2 = act2_layer()self.pwconv2 = nn.Linear(self.med_channels, dim, bias=bias)def forward(self, x):B, H, W, _ = x.shaperouteing = self.reweight(x.mean(dim=(1, 2))).view(B, self.num_filters,-1).softmax(dim=1)x = self.pwconv1(x)x = self.act1(x)x = x.to(torch.float32)x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')if self.weight_resize:complex_weights = resize_complex_weight(self.complex_weights, x.shape[1],x.shape[2])complex_weights = torch.view_as_complex(complex_weights.contiguous())else:complex_weights = torch.view_as_complex(self.complex_weights)routeing = routeing.to(torch.complex64)weight = torch.einsum('bfc,hwf->bhwc', routeing, complex_weights)if self.weight_resize:weight = weight.view(-1, x.shape[1], x.shape[2], self.med_channels)else:weight = weight.view(-1, self.size, self.filter_size, self.med_channels)x = x * weightx = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')x = self.act2(x)x = self.pwconv2(x)return x
def resize_complex_weight(origin_weight, new_h, new_w):h, w, num_heads = origin_weight.shape[0:3]  # size, w, c, 2origin_weight = origin_weight.reshape(1, h, w, num_heads * 2).permute(0, 3, 1, 2)new_weight = torch.nn.functional.interpolate(origin_weight,size=(new_h, new_w),mode='bicubic',align_corners=True).permute(0, 2, 3, 1).reshape(new_h, new_w, num_heads, 2)return new_weightif __name__ == "__main__":# 如果GPU可用,将模块移动到 GPUinput_size=20device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 输入张量 (batch_size, height, width,channels)x = torch.randn(1, input_size , input_size, 32).to(device)# 初始化 pconv 模块dim = 32block = DynamicFilter(dim=dim,size=input_size)print(block)block = block.to(device)# 前向传播output = block(x)print("输入:", x.shape)print("输出:", output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/10194.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

「AI学习笔记」深度学习的起源与发展:从神经网络到大数据(二)

深度学习(DL)是现代人工智能(AI)的核心之一,但它并不是一夜之间出现的技术。从最初的理论提出到如今的广泛应用,深度学习经历了几乎一个世纪的不断探索与发展。今天,我们一起回顾深度学习的历史…

Axure PR 9 旋转效果 设计交互

大家好,我是大明同学。 这期内容,我们将学习Axure中的旋转效果设计与交互技巧。 旋转 创建旋转效果所需的元件 1.打开一个新的 RP 文件并在画布上打开 Page 1。 2.在元件库中拖出一个按钮元件。 创建交互 创建按钮交互状态 1.选中按钮元件&#xf…

【外文原版书阅读】《机器学习前置知识》2.用看电影推荐的例子带你深入了解向量点积在机器学习的作用

目录 3.3 Where Are You Looking, Vector? The Dot Product 个人主页:Icomi 大家好,我是Icomi,本专栏是我阅读外文原版书《Before Machine Learning》对于文章中我认为能够增进线性代数与机器学习之间的理解的内容的一个输出,希望…

论文阅读(八):结构方程模型用于研究数量遗传学中的因果表型网络

1.论文链接:Structural Equation Models for Studying Causal Phenotype Networks in Quantitative Genetics 摘要: 表型性状可能在它们之间发挥因果作用。例如,农业物种的高产可能会增加某些疾病的易感性,相反,疾病的…

每日一题——序列化二叉树

序列化二叉树 BM39 序列化二叉树题目描述序列化反序列化 示例示例1示例2 解题思路序列化过程反序列化过程 代码实现代码说明复杂度分析总结 BM39 序列化二叉树 题目描述 请实现两个函数,分别用来序列化和反序列化二叉树。二叉树的序列化是将二叉树按照某种遍历方式…

JVM_程序计数器的作用、特点、线程私有、本地方法的概述

①. 程序计数器 ①. 作用 (是用来存储指向下一条指令的地址,也即将要执行的指令代码。由执行引擎读取下一条指令) ②. 特点(是线程私有的 、不会存在内存溢出) ③. 注意:在物理上实现程序计数器是在寄存器实现的,整个cpu中最快的一个执行单元 ④. 它是唯一一个在java虚拟机规…

Attention--人工智能领域的核心技术

1. Attention 的全称与基本概念 在人工智能(Artificial Intelligence,AI)领域,Attention 机制的全称是 Attention Mechanism(注意力机制)。它是一种能够动态分配计算资源,使模型在处理输入数据…

机器学习2 (笔记)(朴素贝叶斯,集成学习,KNN和matlab运用)

朴素贝叶斯模型 贝叶斯定理: 常见类型 算法流程 优缺点 集成学习算法 基本原理 常见方法 KNN(聚类模型) 算法性质: 核心原理: 算法流程 优缺点 matlab中的运用 朴素贝叶斯模型 朴素贝叶斯模型是基于贝叶斯…

智慧园区系统助力企业智能化升级实现管理效率与安全性全方位提升

内容概要 在当今数字化转型的浪潮中,企业面临着前所未有的挑战和机遇。智慧园区系统作为一种创新性解决方案,正在快速崛起,帮助企业实现全面的智能化升级。这套系统不仅仅是一个简单的软件工具,而是一个强大的综合管理平台&#…

【视频+图文详解】HTML基础4-html标签的基本使用

图文教程 html标签的基本使用 无序列表 作用&#xff1a;定义一个没有顺序的列表结构 由两个标签组成&#xff1a;<ul>以及<li>&#xff08;两个标签都属于容器级标签&#xff0c;其中ul只能嵌套li标签&#xff0c;但li标签能嵌套任何标签&#xff0c;甚至ul标…

电子电气架构 --- 在智能座舱基础上定义人机交互

我是穿拖鞋的汉子&#xff0c;魔都中坚持长期主义的汽车电子工程师。 老规矩&#xff0c;分享一段喜欢的文字&#xff0c;避免自己成为高知识低文化的工程师&#xff1a; 简单&#xff0c;单纯&#xff0c;喜欢独处&#xff0c;独来独往&#xff0c;不易合同频过着接地气的生活…

SAP SD学习笔记27 - 请求计划(开票计划)之1 - 定期请求

上两章讲了贩卖契约&#xff08;框架协议&#xff09;的概要&#xff0c;以及贩卖契约中最为常用的 基本契约 - 数量契约和金额契约。 SAP SD学习笔记26 - 贩卖契约(框架协议)的概要&#xff0c;基本契约 - 数量契约_sap 框架协议-CSDN博客 SAP SD学习笔记27 - 贩卖契约(框架…

Ansible自动化运维实战--fetch、cron和group模块(5/8)

文章目录 一、fetch 模块1.1、功能1.2、常用参数1.3、测试1.4、注意事项 二、cron 模块2.1、功能2.2、常用参数2.3、注意事项 三、group模块3.1、功能3.2、常用参数3.3、例子3.4、注意事项 一、fetch 模块 1.1、功能 fetch 模块的主要功能是将远程主机上的文件复制到本地控制…

C++中常用的十大排序方法之1——冒泡排序

成长路上不孤单&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a; 【&#x1f60a;///计算机爱好者&#x1f60a;///持续分享所学&#x1f60a;///如有需要欢迎收藏转发///&#x1f60a;】 今日分享关于C中常用的排序方法之——冒泡排序的相关…

商密测评题库详解:商用密码应用安全性评估从业人员考核题库详细解析(8)

1. 重要领域网络和信息系统的范畴 题目 根据《商用密码应用安全性评估管理办法(试行)》,下列哪些属于重要领域网络和信息系统( )。 A. 基础信息网络 B. 面向社会服务的政务信息系统 C. 重要工业控制系统 D. 以上都是 答案 D 答案解析 依据《商用密码应用安全性评…

openssl 生成证书 windows导入证书

初级代码游戏的专栏介绍与文章目录-CSDN博客 我的github&#xff1a;codetoys&#xff0c;所有代码都将会位于ctfc库中。已经放入库中我会指出在库中的位置。 这些代码大部分以Linux为目标但部分代码是纯C的&#xff0c;可以在任何平台上使用。 源码指引&#xff1a;github源…

SpringBoot整合Swagger UI 用于提供接口可视化界面

目录 一、引入相关依赖 二、添加配置文件 三、测试 四、Swagger 相关注解 一、引入相关依赖 图像化依赖 Swagger UI 用于提供可视化界面&#xff1a; <dependency><groupId>io.springfox</groupId><artifactId>springfox-swagger-ui</artifactI…

Nuxt:利用public-ip这个npm包来获取公网IP

目录 一、安装public-ip包1.在Vue组件中使用2.在Nuxt.js插件中使用public-ip 一、安装public-ip包 npm install public-ip1.在Vue组件中使用 你可以在Nuxt.js的任意组件或者插件中使用public-ip来获取公网IP。下面是在一个Vue组件中如何使用它的例子&#xff1a; <template…

QT串口通信,实现单个温湿度传感器数据的采集

1、硬件设备 RS485中继器(一进二出),usb转485模块、电源等等 => 累计115元左右。 2、核心代码 #include "MainWindow.h" #include "ui_MainWindow.h"MainWindow::

【深度分析】DeepSeek 遭暴力破解,攻击 IP 均来自美国,造成影响有多大?有哪些好的防御措施?

技术铁幕下的暗战&#xff1a;当算力博弈演变为代码战争 一场针对中国AI独角兽的全球首例国家级密码爆破&#xff0c;揭开了数字时代技术博弈的残酷真相。DeepSeek服务器日志中持续跳动的美国IP地址&#xff0c;不仅是网络攻击的地理坐标&#xff0c;更是技术霸权对新兴挑战者的…