Bert Encoder和Transformer Encoder有什么不同

前言:本篇文章主要从代码实现角度研究 Bert Encoder和Transformer Encoder 有什么不同?应该可以帮助你:

  • 深入了解Bert Encoder 的结构实现
  • 深入了解Transformer Encoder的结构实现

本篇文章不涉及对注意力机制实现的代码研究。

注:本篇文章所得出的结论和其它文章略有不同,有可能是本人代码理解上存在问题,但是又没有找到更多的文章加以验证,并且代码也检查过多遍。

观点不太一致的文章:bert-pytorch版源码详细解读_bert pytorch源码-CSDN博客 这篇文章中,存在 “这个和我之前看的transformers的残差连接层差别还挺大的,所以并不完全和transformers的encoder部分结构一致。” 但是我的分析是:代码实现上不太一样,但是本质上没啥不同,只是Bert Encoder在Attention之后多了一层Linear。具体分析过程和结论可以阅读如下文章。

如有错误或问题,请在评论区回复。

1、研究目标

这里主要的观察对象是BertModel中Bert Encoder是如何构造的?从Bert Tensorflow源码,以及transformers库中源码去看。

然后再看TransformerEncoder是如何构造的?从pytorch内置的transformer模块去看。

最后再对比不同。

2、tensorflow中BertModel主要代码如下

class BertModel(object):def __init__(...):...得到了self.embedding_output以及attention_mask# transformer_model就代表了Bert Encoder层的所有操作self.all_encoder_layers = transformer_model(input_tensor=self.embedding_output, attention_mask=attention_mask,...)# 这里all_encoder_layers[-1]是取最后一层encoder的输出self.sequence_output = self.all_encoder_layers[-1]...pooler层,对 sequence_output中的first_token_tensor,即CLS对应的表示向量,进行dense+tanh操作with tf.variable_scope("pooler"):first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)self.pooled_output = tf.layers.dense(first_token_tensor,config.hidden_size,activation=tf.tanh,kernel_initializer=create_initializer(config.initializer_range))def transformer_model(input_tensor, attention_mask=None,...):...for layer_idx in range(num_hidden_layers):# 如下(1)(2)(3)就是每一层Bert Encoder包含的结构和操作with tf.variable_scope("layer_%d" % layer_idx):# (1)attention层:主要包含两个操作,获取attention_output,对attention_output进行dense + dropout + layer_normwith tf.variable_scope("attention"):# (1.1)通过attention_layer获得 attention_outputattention_output# (1.2)output层:attention_output需要经过dense + dropout + layer_norm操作with tf.variable_scope("output"):attention_output = tf.layers.dense(attention_output,hidden_size,...)attention_output = dropout(attention_output, hidden_dropout_prob)# “attention_output + layer_input” 表示 残差连接操作attention_output = layer_norm(attention_output + layer_input)# (2)intermediate中间层:对attention_output进行dense+激活(GELU)with tf.variable_scope("intermediate"):intermediate_output = tf.layers.dense(attention_output,intermediate_size,activation=intermediate_act_fn,)# (3)output层:对intermediater_out进行dense + dropout + layer_normwith tf.variable_scope("output"):layer_output = tf.layers.dense(intermediate_output,hidden_size,kernel_initializer=create_initializer(initializer_range))layer_output = dropout(layer_output, hidden_dropout_prob)# "layer_output + attention_output"是残差连接操作layer_output = layer_norm(layer_output + attention_output)all_layer_outputs.append(layer_output)

3、pytorch的transformers库中的BertModel主要代码;

  • 其中BertEncoder对应要研究的目标
class BertModel(BertPreTrainedModel):def __init__(self, config, add_pooling_layer=True):self.embeddings = BertEmbeddings(config)self.encoder = BertEncoder(config)self.pooler = BertPooler(config) if add_pooling_layer else Nonedef forward(...):# 这是嵌入层操作embedding_output = self.embeddings(input_ids=input_ids,position_ids=position_ids,token_type_ids=token_type_ids,...)# 这是BertEncoder层的操作encoder_outputs = self.encoder(embedding_output,attention_mask=extended_attention_mask,...)# 这里encoder_outputs是一个对象,encoder_outputs[0]是指最后一层Encoder(BertLayer)输出sequence_output = encoder_outputs[0]# self.pooler操作是BertPooler层操作,是先取first_token_tensor(即CLS对应的表示向量),然后进行dense+tanh操作# 通常pooled_output用于做下游分类任务pooled_output = self.pooler(sequence_output) if self.pooler is not None else Noneclass BertEncoder(nn.Module):def __init__(self, config):...self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])...def forward(...):for i, layer_module in enumerate(self.layer):# 元组的append做法,将每一层的hidden_states保存到all_hidden_states;# 第一个hidden_states是BertEncoder的输入,后面的都是每一个BertLayer的输出if output_hidden_states:all_hidden_states = all_hidden_states + (hidden_states,)...# 执行BertLayer的forward方法,包含BertAttention层 + BertIntermediate中间层 + BertOutput层layer_outputs = layer_module(...)# 当前BertLayer的输出hidden_states = layer_outputs[0]# 添加到all_hidden_states元组中if output_hidden_states:all_hidden_states = all_hidden_states + (hidden_states,)class BertLayer(nn.Module):def __init__(self, config):self.attention = BertAttention(config)self.intermediate = BertIntermediate(config)self.output = BertOutput(config)def forward(...):# (1)Attention是指BertAttention# BertAttention包含:BertSelfAttention + BertSelfOutput# BertSelfAttention包括计算Attention+Dropout# BertSelfOutput包含:dense+dropout+LayerNorm,LayerNorm之前会进行残差连接self_attention_outputs = self.attention(...)# self_attention_outputs是一个元组,取[0]获取当前BertLayer中的Attention层的输出attention_output = self_attention_outputs[0]# (2)BertIntermediate中间层包含:dense+gelu激活# (3)BertOutput层包含:dense+dropout+LayerNorm,LayerNorm之前会进行残差连接# feed_forward_chunk的操作是:BertIntermediate(attention_output) + BertOutput(intermediate_output, attention_output)# BertIntermediate(attention_output)是:dense+gelu激活# BertOutput(intermediate_output, attention_output)是:dense+dropout+LayerNorm;# 其中LayerNorm(intermediate_output + attention_output)中的“intermediate_output + attention_output”是残差连接操作layer_output = apply_chunking_to_forward(self.feed_forward_chunk, ..., attention_output)

4、pytorch中内置的transformer的TransformerEncoderLayer主要代码

  • torch.nn.modules.transformer.TransformerEncoderLayer
class TransformerEncoderLayer(Module):'''Args:d_model: the number of expected features in the input (required).nhead: the number of heads in the multiheadattention models (required).dim_feedforward: the dimension of the feedforward network model (default=2048).dropout: the dropout value (default=0.1).activation: the activation function of intermediate layer, relu or gelu (default=relu).Examples::>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)>>> src = torch.rand(10, 32, 512)>>> out = encoder_layer(src)'''def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):super(TransformerEncoderLayer, self).__init__()self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)# Implementation of Feedforward modelself.linear1 = Linear(d_model, dim_feedforward)self.dropout = Dropout(dropout)self.linear2 = Linear(dim_feedforward, d_model)self.norm1 = LayerNorm(d_model)self.norm2 = LayerNorm(d_model)self.dropout1 = Dropout(dropout)self.dropout2 = Dropout(dropout)self.activation = _get_activation_fn(activation)def forward(...):# 过程:# (1)MultiheadAttention操作:src2 = self.self_attn# (2)Dropout操作:self.dropout1(src2)# (3)残差连接:src = src + self.dropout1(src2)# (4)LayerNorm操作:src = self.norm1(src)# 如下是FeedForword:做两次线性变换,为了更深入的提取特征# (5)Linear操作:src = self.linear1(src)# (6)RELU激活(默认RELU)操作:self.activation(self.linear1(src))# (7)Dropout操作:self.dropout(self.activation(self.linear1(src)))# (8)Linear操作:src2 = self.linear2(...)# (9)Dropout操作:self.dropout2(src2)# (10)残差连接:src = src + self.dropout2(src2)# (11)LayerNorm操作:src = self.norm2(src)src2 = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0]src = src + self.dropout1(src2)src = self.norm1(src)src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))src = src + self.dropout2(src2)src = self.norm2(src)return src

5、区别总结

        Transformer Encoder的结构如上图所示,代码也基本和上图描述的一致,不过代码中在Multi-Head Attention和Feed Forward之后都存在一个Dropout操作。(可以认为每层网络之后都会接一个Dropout层,是作为网络模块的一部分)

可以将Transformer Encoder过程表述为:

(1)MultiheadAttention + Dropout + 残差连接 + LayerNorm

(2)FeedForword(Linear + RELU + Dropout + Linear + Dropout) + 残差连接 + LayerNorm;Transformer默认的隐含层激活函数是RELU;

可以将 Bert Encoder过程表述为:

(1)BertSelfAttention: MultiheadAttention + Dropout

(2)BertSelfOutput:Linear+ Dropout + 残差连接 + LayerNorm; 注意:这里的残差连接是作用在BertSelfAttention的输入上,不是Linear的输入。

(3)BertIntermediate:Linear + GELU激活

(4)BertOutput:Linear + Dropout + 残差连接 + LayerNorm;注意:这里的残差连接是作用在BertIntermediate的输入上,不是Linear的输入;

进一步,把(1)(2)合并,(3)(4)合并:

(1)MultiheadAttention + Dropout + Linear + Dropout + 残差连接 + LayerNorm

(2)FeedForword(Linear + GELU激活 + Linear + Dropout) + 残差连接 + LayerNorm;Bert默认的隐含层激活函数是GELU;

所以,Bert Encoder和Transformer Encoder最大的区别是,Bert Encoder在做完Attention计算后,还会用一个线性层去提取特征,然后才进行残差连接。其次,是FeedForword中的默认激活函数不同。Bert Encoder图结构如下:

Bert 为什么要这么做?或许是多一个线性层,特征提取能力更强,模型表征能力更好。

GELU和RELU:GELU是RELU的改进版,效果更好。

Reference

  • GeLU、ReLU函数学习_gelu和relu-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/271502.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构之单链表详解(C语言手撕)

​ 🎉个人名片:🐼作者简介:一名乐于分享在学习道路上收获的大二在校生 🙈个人主页🎉:GOTXX 🐼个人WeChat:ILXOXVJE 🐼本文由GOTXX原创,首发CSDN…

前端实现单点登录

简单概括就是&#xff0c;一个系统登录&#xff0c;跳转多个系统&#xff0c;其他系统不需要再登录&#xff0c;直接进入页面 登录的系统 <template><div><div class"content"><div class"item" v-for"(item,index) in list&q…

使用Echarts绘制中国七大区地图

先上效果图&#xff08;文字是否显示&#xff0c;显示什么字&#xff0c;各种颜色之类的&#xff0c;都能随便改&#xff09; 直接上完整代码 <!DOCTYPE html> <html style"height: 100%"><head><meta charset"utf-8" /></hea…

windows安装Chocolatey方法注意事项,以及安装openssl方法

chock是一个很强大的软件包管理工具官方&#xff1a;Chocolatey Software | Installing Chocolatey 使用管理员打开powershell工具&#xff1a; 必须以管理员打开&#xff0c;不然安装失败&#xff0c;提示没有权限 然后输入&#xff1a; Get-ExecutionPolicy 如果返回&…

20个Python函数程序实例

前面介绍的函数太简单了&#xff1a; 以下是 20 个不同的 Python 函数实例 下面深入一点点&#xff1a; 以下是20个稍微深入一点的&#xff0c;使用Python语言定义并调用函数的示例程序&#xff1a; 20个函数实例 简单函数调用 def greet():print("Hello!")greet…

JVM-对象创建与内存分配机制深度剖析 3

JVM对象创建过程详解 类加载检查 虚拟机遇到一条new指令时&#xff0c;首先将去检查这个指令的参数是否能在常量池中定位到一个类的符号引用&#xff0c;并且检查这个 符号引用代表的类是否已被加载、解析和初始化过。如果没有&#xff0c;那必须先执行相应的类加载过程。 new…

“2024杭州智慧城市及安防展会”将于4月在杭州博览中心盛大召开

2024杭州国际智慧城市及安防展览会&#xff0c;将于4月24日在杭州国际博览中心盛大开幕。这场备受瞩目的盛会&#xff0c;不仅汇集了全球智慧城市与安防领域的顶尖企业&#xff0c;更是展示最新技术、交流创新理念的重要平台。近日&#xff0c;从组委会传来消息&#xff0c;展会…

Qt+FFmpeg+opengl从零制作视频播放器-1.项目介绍

1.简介 学习音视频开发&#xff0c;首先从做一款播放器开始是比较合理的&#xff0c;每一章节&#xff0c;我都会将源码贴在最后&#xff0c;此专栏你将学习到以下内容&#xff1a; 1&#xff09;音视频的解封装、解码&#xff1b; 2&#xff09;Qtopengl如何渲染视频&#…

PythonWeb

例题一 from flask import Flask app Flask(__name__) app.route(/index) def index():return f<h1>这是首页&#xff01;</h1> def second():return f<h1>这是第二页&#xff01;</h1> if __name__ __name__:app.run(host"0.0.0.0",port…

TypeScript 基础(一)

目录 一、概述 二、开发环境 三、数据类型 1.boolean 2.number 3.string 4.Array 5.type 6.tuple 7.enum 8.any 9.null / undefined 10.never 11.object 结束 一、概述 TypeScript 是一种由微软开发的开源编程语言。它是 JavaScript 的一个超集&#xff0c;这意…

HTML静态网页成品作业(HTML+CSS)——电影网首页网页设计制作(1个页面)

&#x1f389;不定期分享源码&#xff0c;关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 &#x1f3f7;️本套采用HTMLCSS&#xff0c;未使用Javacsript代码&#xff0c;共有1个页面。 二、作品演示 三、代…

场景问题: VisualVM工具Profiler JDBC不是真实执行的SQL

1. 问题 诡异的问题表象&#xff1a; 前端反馈分页接口的Total字段一直为0 使用Visualvm中的 Profiler 注入到应用后&#xff0c;查看JDBC监控得到了分页接口执行的SQL&#xff0c;复制出来执行是55. 此时还没有注意到 IN 的范围中有一个特别的值 NULL &#x1f928; 2. 排查…

vscode如何远程到linux python venv虚拟环境开发?(python虚拟环境、vscode远程开发、vscode远程连接)

文章目录 1. 安装VSCode2. 安装扩展插件3. 配置SSH连接4. 输入用户名和密码5. 打开远程文件夹6. 创建/选择Python虚拟环境7. 安装Python插件 Visual Studio Code (VSCode) 提供了一种称为 Remote Development 的功能&#xff0c;允许用户在远程系统、容器或甚至 Windows 子系统…

【微信小程序】基本语法

目录 一、列表渲染&#xff08;包括wx:for改变默认&#xff09; 二、事件冒泡和事件捕获 三、生命周期 一、列表渲染&#xff08;包括wx:for改变默认&#xff09; 1、列表渲染(wx-for、block 改变默认wx:for item等) <view> {{msg}} </view> //渲染跟普通vu…

泛型 --java学习笔记

什么是泛型 定义类、接口、方法时&#xff0c;同时声明了一个或者多个类型变量&#xff08;如&#xff1a;<E>&#xff09;&#xff0c;称为泛型类、泛型接口&#xff0c;泛型方法、它们统称为泛型 可以理解为扑克牌中的癞子&#xff0c;给它什么类型它就是什么类型 如…

设计模式—适配器模式

概念: 适配器模式&#xff08;有时候也称包装样式或者包装&#xff09;将一个类的接口适配成用户所期待的。一个适配允许通常因为接口不兼容而不能在一起工作的类工作在一起&#xff0c;做法是将类自己的接口包裹在一个已存在的类中。 本章代码: 小麻雀icknn/设计模式练习 - Gi…

开源的前端思维导图库介绍

在开源社区中&#xff0c;有许多优秀的思维导图库可供开发者使用。这些库通常具有丰富的功能和灵活的API&#xff0c;可以满足不同需求的前端开发。以下是一些流行的开源前端思维导图库&#xff0c;以及它们的特点和区别。 1. **MindMap** 特点&#xff1a; - 基于原生…

【数据结构】二、线性表:4.循环链表的定义及其基本操作(循环单链表,循环双链表的初始化、判空、判断头结点、尾结点、插入、删除)

文章目录 4.循环链表4.1循环单链表4.1.1初始化4.1.2判断单链表是否为空4.1.3判断p结点是否为循环单链表的表尾结点 4.2循环双链表4.2.1初始化4.2.2判断循环链表是否为空4.2.3判断结点p是否为循环双链表的表尾结点4.2.4双链表的插入4.2.5双链表的删除 4.循环链表 4.1循环单链表…

十堰网站建设公司华想科技具有10年的网站制作经验

2018年已经结束了。 华翔科技收到了很多客户的咨询&#xff0c;他们都有一个共同的问题&#xff1a;建一个网站需要多少钱&#xff1f; 但是&#xff0c;我们都会问&#xff1a;您有什么具体需求吗&#xff1f; 大多数人的答案是否定的&#xff0c;他们只是想打听一下价格。 十…

《JAVA与模式》之模板方法模式

系列文章目录 文章目录 系列文章目录前言一、模板方法模式的结构二、模板方法模式中的方法三、使用场景四、模板方法模式在Servlet中的应用前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了…