BERT模型中的嵌入后处理与注意力掩码

摘要

BERT(Bidirectional Encoder Representations from Transformers)是一种强大的预训练模型,广泛应用于自然语言处理任务。本文将详细介绍BERT模型中的两个重要组件:嵌入后处理和注意力掩码的创建。通过理解这些组件的工作原理,读者可以更好地掌握BERT模型的内部机制,并在实际应用中进行优化和调整。

1. 引言

BERT模型的核心在于其强大的嵌入表示能力和多头自注意力机制。在模型的输入阶段,嵌入后处理是一个重要的步骤,它包括词嵌入、段嵌入和位置嵌入的叠加。此外,注意力掩码的创建也是确保模型正确处理序列数据的关键。本文将详细介绍这两个组件的实现。

2. 嵌入后处理
2.1 函数定义
def embedding_postprocessor(input_tensor,use_token_type=False,token_type_ids=None,token_type_vocab_size=16,token_type_embedding_name="token_type_embeddings",use_position_embeddings=True,position_embedding_name="position_embeddings",initializer_range=0.02,max_position_embeddings=512,dropout_prob=0.1):"""Performs various post-processing on a word embedding tensor.Args:input_tensor: float Tensor of shape [batch_size, seq_length, embedding_size].use_token_type: bool. Whether to add embeddings for `token_type_ids`.token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].Must be specified if `use_token_type` is True.token_type_vocab_size: int. The vocabulary size of `token_type_ids`.token_type_embedding_name: string. The name of the embedding table variablefor token type ids.use_position_embeddings: bool. Whether to add position embeddings for theposition of each token in the sequence.position_embedding_name: string. The name of the embedding table variablefor positional embeddings.initializer_range: float. Range of the weight initialization.max_position_embeddings: int. Maximum sequence length that might ever beused with this model. This can be longer than the sequence length ofinput_tensor, but cannot be shorter.dropout_prob: float. Dropout probability applied to the final output tensor.Returns:float tensor with same shape as `input_tensor`.Raises:ValueError: One of the tensor shapes or input values is invalid."""input_shape = get_shape_list(input_tensor, expected_rank=3)batch_size = input_shape[0]seq_length = input_shape[1]width = input_shape[2]output = input_tensorif use_token_type:if token_type_ids is None:raise ValueError("`token_type_ids` must be specified if""`use_token_type` is True.")token_type_table = tf.get_variable(name=token_type_embedding_name,shape=[token_type_vocab_size, width],initializer=create_initializer(initializer_range))flat_token_type_ids = tf.reshape(token_type_ids, [-1])one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)token_type_embeddings = tf.reshape(token_type_embeddings,[batch_size, seq_length, width])output += token_type_embeddingsif use_position_embeddings:assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)with tf.control_dependencies([assert_op]):full_position_embeddings = tf.get_variable(name=position_embedding_name,shape=[max_position_embeddings, width],initializer=create_initializer(initializer_range))position_embeddings = tf.slice(full_position_embeddings, [0, 0],[seq_length, -1])num_dims = len(output.shape.as_list())position_broadcast_shape = []for _ in range(num_dims - 2):position_broadcast_shape.append(1)position_broadcast_shape.extend([seq_length, width])position_embeddings = tf.reshape(position_embeddings,position_broadcast_shape)output += position_embeddingsoutput = layer_norm_and_dropout(output, dropout_prob)return output
2.2 功能解析
  1. 输入张量形状检查:首先,函数检查输入张量的形状是否符合预期(即 [batch_size, seq_length, embedding_size])。
  2. 段嵌入:如果 use_token_type 为 True,则添加段嵌入。段嵌入用于区分不同句子的标记。
  3. 位置嵌入:如果 use_position_embeddings 为 True,则添加位置嵌入。位置嵌入用于编码每个标记在序列中的位置信息。
  4. 层归一化和dropout:最后,对输出张量进行层归一化和dropout处理,以提高模型的泛化能力。
3. 注意力掩码的创建
3.1 函数定义
def create_attention_mask_from_input_mask(from_tensor, to_mask):"""Create 3D attention mask from a 2D tensor mask.Args:from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].to_mask: int32 Tensor of shape [batch_size, to_seq_length].Returns:float Tensor of shape [batch_size, from_seq_length, to_seq_length]."""from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])batch_size = from_shape[0]from_seq_length = from_shape[1]to_shape = get_shape_list(to_mask, expected_rank=2)to_seq_length = to_shape[1]to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32)broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1], dtype=tf.float32)mask = broadcast_ones * to_maskreturn mask
3.2 功能解析
  1. 输入张量形状检查:首先,函数检查 from_tensor 和 to_mask 的形状是否符合预期。
  2. 重塑和类型转换:将 to_mask 重塑为 [batch_size, 1, to_seq_length] 并转换为浮点数。
  3. 广播和乘法:创建一个全1的张量 broadcast_ones,形状为 [batch_size, from_seq_length, 1]。然后将 broadcast_ones 与 to_mask 相乘,得到最终的注意力掩码。
4. 应用示例

假设我们有一个输入张量 input_tensor 和一个输入掩码 input_mask,我们可以使用上述函数进行嵌入后处理和注意力掩码的创建:

import tensorflow as tf# 假设的输入张量和掩码
input_tensor = tf.random.uniform([2, 10, 128])
input_mask = tf.constant([[1, 1, 1, 0, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 0, 0, 0, 0, 0]], dtype=tf.int32)# 嵌入后处理
output_tensor = embedding_postprocessor(input_tensor=input_tensor,use_token_type=True,token_type_ids=tf.zeros_like(input_mask),use_position_embeddings=True,initializer_range=0.02,max_position_embeddings=512,dropout_prob=0.1
)# 注意力掩码的创建
attention_mask = create_attention_mask_from_input_mask(input_tensor, input_mask)with tf.Session() as sess:sess.run(tf.global_variables_initializer())output_tensor_val, attention_mask_val = sess.run([output_tensor, attention_mask])print("Output Tensor Shape:", output_tensor_val.shape)print("Attention Mask Shape:", attention_mask_val.shape)
5. 结论

本文详细介绍了BERT模型中的嵌入后处理和注意力掩码的创建。通过这些组件,BERT模型能够有效地处理自然语言任务中的输入数据,并生成高质量的嵌入表示。希望本文能为读者在自然语言处理领域的研究和开发提供有益的参考。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/474182.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

台式电脑没有声音怎么办?台式电脑没有声音解决详解

台式电脑一般来说都是没有内置扬声器的,需要连接耳机或者是音响才可以播放音乐。那么如果遇到台式电脑没有声音的问题,我们也需要确认这些设备硬件有没问题,知道原因才可以进行处理。下面本文将为你介绍台式电脑没有声音的可能原因和解决方法…

一文速学---红黑树

文章目录 一、红黑树简介二、 红黑树特性三、红黑树插入3.1 红黑树为空3.2 父节点为黑色3.3 父节点为红色3.3.1 父亲和叔叔都是红色3.3.2 父节点为红色,叔叔节点为黑色3.3.2.1 父节点在左节点,插入节点在父亲左节点3.3.2.2 父节点在左节点,插…

gitlab容器的迁移(部署)并配置自动备份

gitlab容器的迁移(部署)并配置自动备份 本文背景为从Ubuntu服务器上迁移gitlab容器到windows并备份,若要直接拉取镜直接安装配置可直接从第二小标题参考 1、原Ubuntu的gitlab容器制作为镜像 2.1 将运行的容器制为镜像 #镜像:i…

Linux:调试器-gdb/cgdb

文章目录 一、编译成debug1、-g 选项 二、gdb调试命令1、在CentOS系统下检查安装gdb2、进入gdb模式3、quit 退出gdb4、list (简写 l)显示文件内容5、b 打断点6、 r / run运行程序7、c 让程序直接运行完 三、cgdb1、info b查看打的所有断点2、d 删除断点3…

基于差分、粒子群算法下的TSP优化对比

TSP问题,即旅行商问题(Traveling Salesman Problem),是数学领域中的一个著名问题。以下是对TSP问题的详细解释: 一、问题定义 假设有一个旅行商人要拜访n个城市,他必须选择所要走的路径,路径的…

17.100ASK_T113-PRO 配置QT运行环境(三)

前言 1.打开QT,新建项目. 做成以下效果,会QT都没有问题吧 编译输出: /home/book/LED_and_TempHumi/build-LED_and_TempHumi-100ask-Debug LED_and_TempHumi 2.下载程序与测试 设置运行环境 export QT_QPA_PLATFORMlinuxfb 这个地方还需要加字体,不然不会显示字体.

智慧社区平台系统提升物业管理效率与居民生活质量

内容概要 智慧社区平台系统是为应对现代城市管理挑战而诞生的重要工具。随着城市化进程的加快,传统的物业管理方式已经难以满足日益增长的居民需求和管理复杂性。因此,引入智能化管理手段显得尤为重要。这个系统不仅仅是一个简单的软件,它是…

远程jupyter lab的配置

打开虚拟环境 conda activate test 在环境下安装ipykernel软件包,这个软件包允许jupyter notebookl使用特定环境的python版本。 conda install ipykernel 将该环境添加到Jupyter Notebook中 python -m ipykernel install --user --nametest --display-name&quo…

python+Django+MySQL+echarts+bootstrap制作的教学质量评价系统,包括学生、老师、管理员三种角色

项目介绍 该教学质量评价系统基于Python、Django、MySQL、ECharts和Bootstrap技术,旨在为学校或教育机构提供一个全面的教学质量评估平台。系统主要包括三种角色:学生、老师和管理员,每个角色有不同的功能权限。 学生角色:学生可…

找不到vcruntime140.dll怎么办,彻底解决vcruntime140.dll丢失的5种方法

当计算机系统中无法找到vcruntime140.dll这个特定的动态链接库文件时,可能会引发一系列运行问题,具体表现形式多样且影响范围较广。对于依赖于该文件运行的各类软件应用来说,缺失vcruntime140.dll将直接导致程序无法正常启动或执行&#xff0…

设计模式-Adapter(适配器模式)GO语言版本

前言 个人感觉Adapter模式核心就在于接口之间的转换。将已有的一些接口转换成其他接口形式。并且一般用于对象上,而不是系统上 问题 就用一个简单的问题,懂数据结构的同学可能知道双端队列。那么就用双端队列实现一个栈(stack)或…

表格的选择弹窗,选中后返显到表格中

项目场景: 提示:这里简述项目相关背景: 表格的下拉框可以直接显示选项,那如果选择框不是下拉的,而是弹窗,那么在表格中如何返显呢? 问题描述 如上图所示,点击表格中的选择&#xf…

4.STM32之通信接口《精讲》之USART通信---实验串口发送程序

本节将进行实战,基础了解请查看第1,2,3节(Whappy) 开始背!! USART ---》全双工 异步/同步 点对点 C语言基础printf用法,这节将用到printf的重定向,来打印到串口助手上…

搭建MC服务器

局域网中玩MC,直接自己创建房间开启局域网就可以了。如果想开一个24小时不关机的服务器呢?其实最开始我是想在windows云服务器,图形化界面运行一个开启局域网即可。可能是云服务器上没有显卡,还是其他什么原因,游戏打开…

css 使用图片作为元素边框

先看原始图片 再看效果 边框的四个角灭有拉伸变形,但是图片的中部是拉伸的 代码 border-style: solid;/* 设置边框图像的来源 */border-image-source: url(/static/images/mmwz/index/bk_hd3x.png);/* 设置如何切割图像 */border-image-slice: 66;/* 设置边框的宽度 */border…

通用定时器---输出比较功能

目录 一、概念 二、输出比较的8种模式 三、输出比较输出PWM波形的基本结构 配置步骤 四、示例代码 一、概念 OC(OutPut Compare)输出比较。输出比较可以通过比较CNT与CCR寄存器的关系,来对输出电平进行置1/置0/翻转的操作,可…

【网页设计】CSS3 进阶(动画篇)

1. CSS3 2D 转换 转换(transform)是CSS3中具有颠覆性的特征之一,可以实现元素的位移、旋转、缩放等效果 转换(transform)你可以简单理解为变形 移动:translate旋转:rotate缩放&#xf…

探索 HTML 和 CSS 实现的 3D旋转相册

效果演示 这段HTML与CSS代码创建了一个包含10张卡片的3D旋转效果&#xff0c;每张卡片都有自己的边框颜色和图片。通过CSS的3D变换和动画&#xff0c;实现了一个动态的旋转展示效果 HTML <div class"wrapper"><div class"inner" style"-…

WTV芯片在智能电子锁语音留言上的应用方案解析

一、概述 电子锁的留言功能允许用户通过语音或文字方式给其他家庭成员留下信息。这项功能可以增强家庭成员之间的沟通&#xff0c;特别是在忙碌的家庭生活中提供便利。 WTV是一款功能强大的高品质语音芯片&#xff0c;采用了高性能32位处理器、最高频率可达120MHz。具有低成本、…

Ajax的相关内容

一、Ajax的使用步骤 1.创建XML对象 const xhrnew XMLHttpRequest(); 2.监听事件&#xff0c;处理响应 3.准备发送请求 true表示异步 ajax中永远是异步&#xff0c;永远是true 4.发送请求 二、GET和POST请求 三、JSON的三种形式 四、JSON的方法 五、跨域 六、XHR的属性和方法…