gMLP(NeurIPS 2021)原理与代码解析

paper:Pay Attention to MLPs

third-party implementation:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/mlp_mixer.py

方法介绍

gMLP和MLP-Mixer以及ResMLP都是基于MLP的网络结构,非常简单,关于MLP-Mixer和ResMLP的介绍见MLP-Mixer(NeurIPS 2021, Google)论文与源码解读-CSDN博客、ResMLP(NeurIPS 2021,Meta)论文与代码解析-CSDN博客。

在MLP-Mixer中每个block包含两个MLP,每个MLP包含两个线性层(即全连接层),一个MLP用于token间的信息交互,另一个MLP用于通道间的信息交互,每个MLP都用了residual connection,标准化采用LayerNorm。而在ResMLP中,第一个包含两个线性层的token MLP换成了单个线性层,此外在线性层前后包含两个标准化层pre-normalization和post-normalization,pre-normalization采用了简单的仿射变换,post-normalization采用了CaiT中的LayerScale。

gMLP的结构和伪代码如图1所示。可以看到gMLP将token_mlp(即这里的spatial gating unit)和channel_mlp放到了一起,只包含一个skip-connection,而不是像MLP-Mixer和ResMLP中每个mlp都采用一个skip-connection。此外block内的结构和MLP-Mixer以及ResMLP中的先token_mlp后channel_mlp不同,这里采用了channel+token+channel的形式。最后作者专门为token_mlp设计了一个门控机制,将输入split开一分为二,一半经过spatial proj得到的输出再和另一半相乘得到最终输出。

以上就是gMLP和MLP-Mixer以及ResMLP不同之处,总共包括三点,整体结构也非常简单。下面就直接用代码来解释具体的实现细节。

代码解析

一个完整的block的代码如下,forward函数中可以看到只包含一个skip-connection,self.mlp_channels包含了图1中第一个Channel Proj到最后的Channel Proj。

class SpatialGatingBlock(nn.Module):""" Residual Block w/ Spatial GatingBased on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050"""def __init__(self,dim,seq_len,mlp_ratio=4,mlp_layer=GatedMlp,norm_layer=partial(nn.LayerNorm, eps=1e-6),act_layer=nn.GELU,drop=0.,drop_path=0.,):super().__init__()channel_dim = int(dim * mlp_ratio)  # 512x6=3072self.norm = norm_layer(dim)sgu = partial(SpatialGatingUnit, seq_len=seq_len)  # 196self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, gate_layer=sgu, drop=drop)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()def forward(self, x):  # (1,196,512)x = x + self.drop_path(self.mlp_channels(self.norm(x)))return x

上面的mlp_layer的代码如下,self.fc1和self.fc2对应两个Channel Proj。

class GatedMlp(nn.Module):""" MLP as used in gMLP"""def __init__(self,in_features,hidden_features=None,out_features=None,act_layer=nn.GELU,norm_layer=None,gate_layer=None,bias=True,drop=0.,):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresbias = to_2tuple(bias)drop_probs = to_2tuple(drop)self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])self.act = act_layer()self.drop1 = nn.Dropout(drop_probs[0])if gate_layer is not None:assert hidden_features % 2 == 0self.gate = gate_layer(hidden_features)hidden_features = hidden_features // 2  # FIXME base reduction on gate property?else:self.gate = nn.Identity()self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])self.drop2 = nn.Dropout(drop_probs[1])def forward(self, x):  # (1,196,512)# Linear(in_features=512, out_features=3072, bias=True)x = self.fc1(x)  # (1,196,3072)x = self.act(x)x = self.drop1(x)x = self.gate(x)  # (1,196,1536)x = self.norm(x)# Linear(in_features=1536, out_features=512, bias=True)x = self.fc2(x)  # (1,196,512)x = self.drop2(x)return x

gate_layer的代码如下,其中x.chunk(2, dim=-1)表示将x沿最后一个维度均分为2份。

class SpatialGatingUnit(nn.Module):""" Spatial Gating UnitBased on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050"""def __init__(self, dim, seq_len, norm_layer=nn.LayerNorm):super().__init__()gate_dim = dim // 2self.norm = norm_layer(gate_dim)self.proj = nn.Linear(seq_len, seq_len)  # 196,196def init_weights(self):# special init for the projection gate, called as override by base model initnn.init.normal_(self.proj.weight, std=1e-6)nn.init.ones_(self.proj.bias)def forward(self, x):  # (1,196,3072)u, v = x.chunk(2, dim=-1)  # (1,196,1536),(1,196,1536)v = self.norm(v)v = self.proj(v.transpose(-1, -2))  # (1,1536,196)return u * v.transpose(-1, -2)  # (1,196,1536) * (1,196,1536)

实验结果

作者设计了三个不同大小的gMLP,具体参数配置如下

和其它模型在ImageNet上的分类性能对比如下,可以看到和类似大小的MLP-Mixer与ResMLP相比,gMLP用更少的参数得到了更好的性能。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/361018.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CesiumJS加载天地图数据后,可以实现什么效果?

说起地图,大家耳熟能详的百度地图、高德地图、腾讯地图等,由于授权的原因,使用起来心惊胆战的,而天地图就没有这方面的困扰,本文介绍下如何在cesium中时候用天地图数据,已经能够实现哪些交互效果。 一、关…

C# 任务调度 c# TaskScheduler

摘要 在C#中,TaskScheduler是一种非常有用的功能,它允许您在指定的时间或间隔内执行任务。TaskScheduler是一个抽象类,它提供了一个通用的方法来计划和执行任务。您可以使用TaskScheduler来调度多个任务,并且在多线程环境中控制它…

创建github个人博客

文章目录 安装Hexo安装git安装Node.js安装 Hexo git配置SSH key配置ssh 搭建个人博客新建博客生成静态网页 本文主要参考 【保姆级】利用Github搭建自己的个人博客,看完就会 安装Hexo 参考官方文档:https://hexo.io/zh-cn/docs/ Hexo 是一个快速、简洁且…

【STM32】USART串口通讯

1.USART简介 STM32芯片具有多个USART外设用于串口通讯,它是 Universal Synchronous Asynchronous Receiver and Transmitter的缩写, 即通用同步异步收发器可以灵活地与外部设备进行全双工数据交换。有别于USART, 它还有具有UART外设(Univers…

6.18 多态

多态相较于继承是更加重要的体现面向对象的特征。 多态: 同一个消息、同一种调用,在不同的场合,不同的情况下,执行不同的行为 。 背景需求:继承是实现可以在圆柱或者圆锥中复用圆的特征,多态是可以通过一…

东南亚本地化游戏

通常,亚洲电子游戏市场首先与中国联系在一起。但最近,分析人士越来越关注一个邻近地区:东南亚。而且有充分的理由。 该地区包括中南半岛、马来群岛和邻近岛屿上的十一个国家。1967年,其中10个国家(除东帝汶外&#xf…

反射及动态代理

反射 定义: 反射允许对封装类的字段,方法和构造 函数的信息进行编程访问 图来自黑马程序员 获取class对象的三种方式: 1)Class.forName("全类名") 2)类名.class 3) 对象.getClass() 图来自黑马程序员 pac…

2024广东省职业技能大赛云计算赛项实战——构建CICD

构建CI/CD 前言 题目如下: 构建CI/CD 编写流水线脚本.gitlab-ci.yml触发自动构建,具体要求如下: (1)基于镜像maven:3.6-jdk-8构建项目的drone分支; (2)构建镜像的名称&#xff1a…

Qt | 子类化 QStyle(Qt自带图标大全)

01、简介 1、把绘制自定义部件外观的步骤大致分为三大板块,如下: ①、样式元素:即指定需要绘制的图形元素(比如焦点框、按钮,工具栏等)。样式元素使 用 QStyle 类中的一系列枚举(共有 11 个枚举)进行描述。 ②、样式选项:包含了需要绘制的图形元素的所有信息,比如包含…

DDR3控制器(一)DDR3 IP调用

目录 一、DDR3 IP核简介 二、DDR3 IP核调用 在千兆以太网通信中用到了DDR3控制器,但是并没有对其做相关介绍。这次准备重新整理一下DDR3控制相关知识,复习巩固一下。 一、DDR3 IP核简介 MIG IP核(Memory Interface Generator)是…

【ajax基础04】form-serialize插件

目录 一:form-serialize插件 作用: 语法格式: 一:form-serialize插件 作用: 快速且大量的收集表单元素的值 例如上图对于多表单元素的情形,单靠通过”选择器获取节点.value”值的形式,获取…

QEMU + Vscode + Arm Arch‘s Linux调试小记

目录 下载QEMU 下载aarch64-gcc 下载BusyBox 编译linux 6.9.5的内核 启动! 链接到vscode进行远程调试 Reference 前几天看到了一篇讲授如何调试ARM Linux内核的文章,这里现在记录一下调试ARM Linux内核的办法 下载QEMU 对于Arch Linux用户而言&a…

如何集成CppCheck到visual studio中

1.CPPCheck安装 在Cppcheck官方网站下载最新版本1.70,官网链接:http://cppcheck.sourceforge.net/ 安装Cppcheck 2.集成步骤 打开VS,菜单栏工具->外部工具->添加,按照下图设置,记得勾选“使用输出窗口” 2.…

考研数学一有多难?130+背后的残酷真相

考研数学一很难 大家平时在网上上看到很多人说自己考了130,其实这些人只占参加考研数学人数的极少部分,有个数据可以展示出来考研数学到底有多难: 在几百万考研大军中,能考到120分以上的考生只有2%。绝大多数人的分数集中在30到…

回购注销高管减持,东软集团的“大手笔”有意义吗?

文:互联网江湖 作者:刘致呈 作为老牌软件巨头,东软集团这两年的业绩着实有些不够看。 看财报数据,22年东软集团营收94.66亿,净亏损3.47亿,扣非净利利润-5.30亿。23年,集团营收105.44亿&#x…

华为OD机试【高矮个子排队】(java)(100分)

1、题目描述 现在有一队小朋友,他们高矮不同,我们以正整数数组表示这一队小朋友的身高,如数组{5,3,1,2,3}。 我们现在希望小朋友排队,以“高”“矮”“高”“矮”顺序排列,每一个“高”位置的小朋友要比相邻的位置高或…

5000天后的世界:科技引领的未来之路

**你是否想过,5000天后的世界会是什么样子?** 科技日新月异,改变着我们的生活方式,也引领着人类文明的进程。著名科技思想家凯文凯利在他的著作《5000天后的世界》中,对未来进行了大胆的预测。 **这本书中&#xff0c…

SpringBoot学习04-[定制SpringMVC]

定制SpringMVC 定制SpringMvc的自动配置定制springmvc-configurePathMatch配置 定制SpringMvc的自动配置 SpringMvc的自动配置类:WebMvcAutoConfiguration 1、在大多数情况下,SpringBoot在自动配置中标记了很多ConditionalOnMissingBean,我们…

智慧互联:Vatee万腾平台展现科技魅力

随着科技的迅猛发展,我们的生活正逐渐变得智能化、互联化。在这个信息爆炸的时代,一个名为Vatee万腾的平台正以其独特的魅力,引领我们走向一个更加智能的未来。 Vatee万腾,这个名字本身就充满了对科技未来的憧憬与期待。作为一家专…

吴恩达揭秘:编程Agent如何革新软件开发行业

作为 AI 领域的杰出人物,吴恩达教授对编程 Agent 的兴起表示了极大的兴趣。他认为,编程 Agent 有潜力通过自动执行繁琐的任务、提高代码质量和加速开发周期来彻底改变软件开发行业。 本文将深入探讨吴恩达对编程 Agent 的见解, 多代理系统质…