[论文翻译]GLU Variants Improve Transformer

引言

今天带来一篇短小精悍的论文GLU Variants Improve Transformer笔记,作者提出了GLU1的一种变体。

GLU(Gated Linear Units,门控线性单元)由两个线性投影的逐元素乘积组成,其中一个首先经过sigmoid函数。GLU的变体是可能生效的,可以使用不同的非线性(甚至线性)函数来替代sigmoid。作者在Transformer序列到序列模型的前馈子层中测试了这些变体,并发现其中一些相对于通常使用的ReLU或GELU激活函数会带来质量改进。

总体介绍

Transformer序列到序列模型在多头注意力和位置感知前馈网络之间交替进行。FFN接受一个向量 x x x​,并通过两个学习得到的线性变换进行传递。在两个线性变换之间应用了ReLU激活函数。
FFN ( x , W 1 , W 2 , b 1 , b 2 ) = max ⁡ ( 0 , x W 1 + b 1 ) W 2 + b 2 (1) \text{FFN}(x,W_1,W_2,b_1,b_2) =\max(0,xW_1 +b_1)W_2 +b_2 \tag 1 FFN(x,W1,W2,b1,b2)=max(0,xW1+b1)W2+b2(1)
沿用T5 codebase的设定,作者使用无偏置的版本
FFN ( x , W 1 , W 2 ) = max ⁡ ( 0 , x W 1 ) W 2 (2) \text{FFN}(x,W_1,W_2) =\max(0,xW_1 )W_2 \tag 2 FFN(x,W1,W2)=max(0,xW1)W2(2)
后续的研究提出了将ReLU替换为其他非线性激活函数,例如GELU(Gaussian Error Linear Units)2 GELU ( x ) = x Φ ( x ) \text{GELU}(x) = x\Phi(x) GELU(x)=xΦ(x) Swish β ( x ) = x σ ( β x ) \text{Swish}_β(x) = xσ(βx) Swishβ(x)=xσ(βx)3
FFN GELU ( x , W 1 , W 2 ) = GELU ( x W 1 ) W 2 FFN Swish ( x , W 1 , W 2 ) = Swish 1 ( x W 1 ) W 2 (3) \text{FFN}_{\text{GELU}}(x,W_1,W_2) =\text{GELU}(xW_1 )W_2 \\ \text{FFN}_{\text{Swish}}(x,W_1,W_2) =\text{Swish}_1(xW_1 )W_2 \tag{3} FFNGELU(x,W1,W2)=GELU(xW1)W2FFNSwish(x,W1,W2)=Swish1(xW1)W2(3)

Swish 1 = x σ ( x ) \text{Swish}_1 = xσ(x) Swish1=xσ(x),即 β = 1 \beta=1 β=1

GLU及其变体

Dauphin等人1引入了GLU,一种神经网络层,被定义为输入的两个线性变换的逐元素乘积,其中一个经过了sigmoid激活。其作者还建议省略激活函数,称之为双线性(bilinear)层:
GLU ( x , W , V , b , c ) = σ ( x W + b ) ⊗ ( x V + c ) Bilinear ( x , W , V , b , c ) = ( x W + b ) ⊗ ( x V + c ) (4) \begin{aligned} \text{GLU}(x,W,V,b,c) &= \sigma(xW+b) \otimes (xV +c) \\ \text{Bilinear}(x,W,V,b,c) &= (xW+b) \otimes (xV +c) \end{aligned} \tag 4 GLU(x,W,V,b,c)Bilinear(x,W,V,b,c)=σ(xW+b)(xV+c)=(xW+b)(xV+c)(4)
我们也可以使用其他激活函数来定义GLU的变体:
ReGLU ( x , W , V , b , c ) = max ⁡ ( 0 , x W + b ) ⊗ ( x V + c ) GeGLU ( x , W , V , b , c ) = GELU ( x W + b ) ⊗ ( x V + c ) SwiGLU ( x , W , V , b , c ) = Swish β ( x W + b ) ⊗ ( x V + c ) (5) \begin{aligned} \text{ReGLU}(x,W,V,b,c) &= \max(0,xW+b) \otimes (xV +c) \\ \text{GeGLU}(x,W,V,b,c) &= \text{GELU}(xW+b) \otimes (xV +c) \\ \text{SwiGLU}(x,W,V,b,c) &= \text{Swish}_\beta(xW+b) \otimes (xV +c) \\ \end{aligned} \tag 5 ReGLU(x,W,V,b,c)GeGLU(x,W,V,b,c)SwiGLU(x,W,V,b,c)=max(0,xW+b)(xV+c)=GELU(xW+b)(xV+c)=Swishβ(xW+b)(xV+c)(5)

这几个激活函数的图像如下所示:
在这里插入图片描述

在本篇工作中,作提出了对Transformer FFN层的额外变种,其中使用GLU或其变种代替第一个线性变换和激活函数。同样,省略了偏置项。
FFN GLU ( x , W , V , W 2 ) = ( σ ( x W ) ⊗ x V ) W 2 FFN Bilinear ( x , W , V , W 2 ) = ( x W ⊗ x V ) W 2 FFN ReGLU ( x , W , V , W 2 ) = ( max ⁡ ( 0 , x W ) ⊗ x V ) W 2 FFN GEGLU ( x , W , V , W 2 ) = ( GELU ( x W ) ⊗ x V ) W 2 FFN SwiGLU ( x , W , V , W 2 ) = ( Swish 1 ( x W ) ⊗ x V ) W 2 (6) \begin{aligned} \text{FFN}_{\text{GLU}}(x,W,V,W_2) &= (\sigma(xW) \otimes xV)W_2 \\ \text{FFN}_{\text{Bilinear}}(x,W,V,W_2) &= (xW \otimes xV)W_2 \\ \text{FFN}_{\text{ReGLU}}(x,W,V,W_2) &= (\max(0,xW) \otimes xV)W_2 \\ \text{FFN}_{\text{GEGLU}}(x,W,V,W_2) &= (\text{GELU}(xW) \otimes xV)W_2 \\ \text{FFN}_{\text{SwiGLU}}(x,W,V,W_2) &= (\text{Swish}_1(xW) \otimes xV)W_2 \\ \end{aligned} \tag 6 FFNGLU(x,W,V,W2)FFNBilinear(x,W,V,W2)FFNReGLU(x,W,V,W2)FFNGEGLU(x,W,V,W2)FFNSwiGLU(x,W,V,W2)=(σ(xW)xV)W2=(xWxV)W2=(max(0,xW)xV)W2=(GELU(xW)xV)W2=(Swish1(xW)xV)W2(6)
与原始的FFN层相比,所有这些层都有三个权重矩阵,而不是两个。为了保持参数量和计算量的恒定,当将这些层与原始的双矩阵版本进行比较时,作者将隐藏单元的数量 d f f d_{ff} dff W W W V V V的第二个维度以及 W 2 2 W_22 W22的第一个维度)减少了 2 3 \frac{2}{3} 32

实验

作者在T5的迁移学习设置上对所描述的FFN变种进行了测试。使用了一个编码器-解码器的Transformer模型,在预测缺失文本段的去噪目标上进行训练,并随后在各种语言理解任务上进行了微调。

模型架构

image-20240413211104044

使用与T5的基准模型相同的代码库、模型架构和训练任务。编码器和解码器各由12个层组成, d m o d e l = 768 d_{model} = 768 dmodel=768。对于注意力层, h = 12 , d k = d v = 64 h = 12,d_k = d_v = 64 h=12,dk=dv=64。FFN层的隐藏大小为 d f f = 3072 d_{ff} = 3072 dff=3072。如上所述,对于基于GLU变种的FFN层,它们具有三个权重矩阵而不是两个,将隐藏层减少到 d f f = 2048 d_{ff} = 2048 dff=2048,以保持与基准模型相同的参数和操作数量。

预训练和困惑度

与T5完全一致,在C4数据集上使用填充跨度任务进行了524288步的预训练。每个训练批次包含128个示例,每个示例的输入为512个标记,输出为114个标记,输出中包含从输入中删除的多个标记跨度。

类似于T5,使用Adafactor优化器和反平方根学习率调度。还在线性方式下在训练最后10%的步骤中衰减学习率。与T5的主要不同之处在于,在预训练期间不使用dropout。作者发现这样可以产生更好的结果。使用C4数据集中的一个保留分片计算训练目标的对数困惑度,祖宗认为这是模型质量的一个很好的指标。对于每个模型架构,还训练了四个模型进行较短的时间65536步的训练,以衡量不同运行之间的可变性。结果列在表1中。GEGLU和SwiGLU变种产生了最佳的困惑度。

微调

image-20240413211440445

然后,作者对每个完全训练的模型进行了一次微调,使用的是SQuAD和GLUE以及SuperGlue基准测试中的所有语言理解任务的例子按比例混合而成。微调共包含131072步,学习率为 1 0 − 3 10^{-3} 103​​。与训练过程类似,每一步的输入序列的总长度约为65536个标记。根据T5的建议,作者在层输出、前馈隐藏层和注意力权重上使用了0.1的dropout率。在微调期间,嵌入矩阵将被固定。

image-20240413211609348

表2、表3和表4显示了在开发集上的结果。对于每个任务,作者报告了在微调过程中记录的任何检查点中的最佳得分。尽管结果有些噪音,但新的GLU变种在大多数任务上表现最佳。为了比较,在每个表的底部,作者列出了T5的结果。他们的模型与 FFN ReLU \text{FFN}_{\text{ReLU}} FFNReLU模型完全相同。值得注意的是,他们的结果明显较差,作者认为这是由于他们在预训练期间使用了dropout所导致的。还列出了由T5测量的运行间标准偏差。

image-20240413211617037

结论

作者扩展了GLU家族并将它们应用于Transformer模型中。在迁移学习的设置中,新的变种似乎在预训练中用于去噪目标的困惑度上表现更好,并在许多下游语言理解任务上取得了更好的结果。

总结

⭐ 作者用流行的激活函数(Swish,GeLU和ReLU等)替换GLU中的激活函数,得到了一个困惑度比较好的GLU变体——SwiGLU,但作者也无法解释效果好的原因。

参考


  1. GLU Variants Improve Transformer ↩︎ ↩︎

  2. GAUSSIAN ERROR LINEAR UNITS (GELUS) ↩︎

  3. SEARCHING FOR ACTIVATION FUNCTIONS ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/307773.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c语言多功能计算软件170

定制魏:QTWZPW,获取更多源码等 目录 题目 要求 主要代码片段 题目 设计一个计算器软件,具备如下功能提示界面。 要求 设计出界面,注意界面名称最后为自己的姓名;(20分)能够实现加、减、乘、…

Java 面试宝典:Redis 的线程模型是怎么样的?

大家好,我是大明哥,一个专注「死磕 Java」系列创作的硬核程序员。 本文已收录到我的技术网站:https://www.skjava.com。有全网最优质的系列文章、Java 全栈技术文档以及大厂完整面经 Redis 的线程模型其实是分两块的: Redis 6.0 …

(学习日记)2024.04.15:UCOSIII第四十三节:任务消息队列

写在前面: 由于时间的不足与学习的碎片化,写博客变得有些奢侈。 但是对于记录学习(忘了以后能快速复习)的渴望一天天变得强烈。 既然如此 不如以天为单位,以时间为顺序,仅仅将博客当做一个知识学习的目录&a…

文件上传【2】--靶场通关

1.前端禁用js绕过 上传文件,进行抓包,没有抓到,说明这里的验证是前端js验证跳出的弹窗 禁用js后,php文件上传成功。 2.文件上传.htaccess 上传png木马后连接不上 代码中存在.htaccess,判断此时应该就是需要用到.htac…

1111111111

c语言中的小小白-CSDN博客c语言中的小小白关注算法,c,c语言,贪心算法,链表,mysql,动态规划,后端,线性回归,数据结构,排序算法领域.https://blog.csdn.net/bhbcdxb123?spm1001.2014.3001.5343 给大家分享一句我很喜欢我话: 知不足而奋进,望远山而前行&am…

21 标准错误

标准输出重定向关闭无数据 下面的代码&#xff1a; #include <stdio.h> #include <string.h> #include <stdlib.h> #include <unistd.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h>int main() {close(1);i…

超级详细的JDBC和数据库连接池讲解

文章目录 JDBC简介概念本质好处 JDBC快速入门JDBC中API详解DriverManager驱动管理类作用注册驱动获取连接 Connection数据库连接对象作用获取执行SQL的对象事务管理 Statement作用执行SQL语句 ResultSet原理使用步骤 PreparedStatementSQL注入获取对象操作步骤 原理好处 JDBC工…

力扣刷题 二叉树层序遍历相关题目II

NO.116 填充每个节点的下一个右侧节点指针 给定一个 完美二叉树 &#xff0c;其所有叶子节点都在同一层&#xff0c;每个父节点都有两个子节点。二叉树定义如下&#xff1a; struct Node {int val;Node *left;Node *right;Node *next; } 填充它的每个 next 指针&#xff0c;…

redis的主从复制(docker方式快速入门和实战)

目录 一、主从复制简介 二、配置主从服务器 2.1使用配置文件的形式来主从复制 2.2使用纯代码的方式来进行主从复制&#xff1b; 2.3脱离主服务器 三、一些注意事项 一、主从复制简介 主从复制&#xff0c;是指将一台Redis服务器的数据&#xff0c;复制到其他的Redis服务器…

【论文阅读】MCTformer: 弱监督语义分割的多类令牌转换器

【论文阅读】MCTformer: 弱监督语义分割的多类令牌转换器 文章目录 【论文阅读】MCTformer: 弱监督语义分割的多类令牌转换器一、介绍二、联系工作三、方法四、实验结果 Multi-class Token Transformer for Weakly Supervised Semantic Segmentation 本文提出了一种新的基于变换…

尝试在手机上运行google 最新开源的gpt模型 gemma

Gemma介绍 Gemma简介 Gemma是谷歌于2024年2月21日发布的一系列轻量级、最先进的开放语言模型&#xff0c;使用了与创建Gemini模型相同的研究和技术。由Google DeepMind和Google其他团队共同开发。 Gemma提供两种尺寸的模型权重&#xff1a;2B和7B。每种尺寸都带有经过预训练&a…

【动手学深度学习】15_汉诺塔问题

注&#xff1a; 本系列仅为个人学习笔记&#xff0c;学习内容为《算法小讲堂》&#xff08;视频传送门&#xff09;&#xff0c;通俗易懂适合编程入门小白&#xff0c;需要具备python语言基础&#xff0c;本人小白&#xff0c;如内容有误感谢您的批评指正 汉诺塔&#xff08;To…

人员抽烟AI检测算法原理介绍及实际场景应用

抽烟检测AI算法是一种基于计算机视觉和深度学习技术的先进工具&#xff0c;旨在准确识别并监测个体是否抽烟。该算法通过训练大量图像数据&#xff0c;使模型能够识别出抽烟行为的关键特征&#xff0c;如烟雾、手部动作和口部形态等。 在原理上&#xff0c;抽烟检测AI算法主要…

[lesson22]对象的销毁

对象的销毁 对象的销毁 生活中的对象都是被初始化后才上市的 生活中的对象被销毁前会做一些清理工作 一般而言&#xff0c;需要销毁的对象都应该做清理 解决方案 为每个类都提供一个public的free函数对象不在需要时立即调用free函数进行清理 存在的问题 free只是一个普通…

稀碎从零算法笔记Day44-LeetCode:整数转罗马数字

题型&#xff1a;贪心、模拟 链接&#xff1a; 12. 整数转罗马数字 - 力扣&#xff08;LeetCode&#xff09; 来源&#xff1a;LeetCode 题目描述 罗马数字包含以下七种字符&#xff1a; I&#xff0c; V&#xff0c; X&#xff0c; L&#xff0c;C&#xff0c;D 和 M。 …

淘宝批量采集商品详情数据(属性丨详情图丨sku丨价格等)

淘宝批量采集商品详情数据&#xff08;包括属性、详情图、SKU、价格等&#xff09;可以通过以下几种方式实现&#xff1a; 使用淘宝数据抓取工具&#xff1a;这类工具&#xff0c;如某鱼等&#xff0c;能够自动化采集淘宝商品数据&#xff0c;并将其转换成CSV、Excel等格式&am…

【PyQt5】环境配置

PyQt5 环境配置 一、前言1.1 PyQt5介绍1.2 PyCharm集成Pyqt5 二、pyqt5安装三、PyQt5-tools工具包安装四、常用工具环境配置4.1、环境变量配置4。2、验证是否安装成功 五、pycharm中设置Qt工具&#xff08;Qt Designer、PyUIC、PyRcc&#xff09;5.1、配置Qt Designer5.2、配置…

C++11 设计模式4. 抽象工厂(Abstract Factory)模式

问题的提出 从前面我们已经使用了工厂方法模式 解决了一些问题。 现在 策划又提出了新的需求&#xff1a;对于各个怪物&#xff0c;在不同的场景下&#xff0c;怪物的面板数值会发生变化&#xff0c; //怪物分类&#xff1a;亡灵类&#xff0c;元素类&#xff0c;机械类 …

【数据交换格式】网络socket编程温度采集智能存储与上报项目技术------JSON、TLV

作者简介&#xff1a; 一个平凡而乐于分享的小比特&#xff0c;中南民族大学通信工程专业研究生在读&#xff0c;研究方向无线联邦学习 擅长领域&#xff1a;驱动开发&#xff0c;嵌入式软件开发&#xff0c;BSP开发 作者主页&#xff1a;一个平凡而乐于分享的小比特的个人主页…

蓝桥杯物联网竞赛_STM32L071KBU6_全部工程及国赛省赛真题及代码

包含stm32L071kbu6全部实验工程、源码、原理图、官方提供参考代码及国、省赛真题及代码 链接&#xff1a;https://pan.baidu.com/s/1pXnsMHE0t4RLCeluFhFpAg?pwdq497 提取码&#xff1a;q497