NLP深入学习:结合源码详解 BERT 模型(三)

文章目录

  • 1. 前言
  • 2. 预训练
    • 2.1 modeling.BertModel
      • 2.1.1 embedding_lookup
      • 2.1.2 embedding_postprocessor
      • 2.1.3 transformer_model
    • 2.2 get_masked_lm_output
    • 2.3 get_next_sentence_output
    • 2.4 训练
  • 3. 参考


1. 前言

前情提要:
《NLP深入学习:结合源码详解 BERT 模型(一)》
《NLP深入学习:结合源码详解 BERT 模型(二)》

之前已经详细说明了 BERT 模型的主要架构和思想,并且讲解了 BERT 源代码对于数据准备的流程,回顾下关键字段的含义:

# 以下是输出到文件的值,也是会作为后续预训练的输入值,重点看!
input_ids:tokens在字典的索引位置,不足max_seq_length(128)则补0
input_mask:初始化为1,不足max_seq_length(128)则补0
segment_ids: 句子A的token和句子B的token,按照0/1排列区分。不足max_seq_length(128)则补0
masked_lm_positions: 被选中 MASK 的token位置索引
masked_lm_ids:被选中 MASK 的token原始值在字典的索引位置
masked_lm_weights:初始化为1
next_sentence_labels:对应is_random_next,1表示随机选择,0表示正常语序

下面我们结合预训练代码详细讲解下 BERT 的预训练流程。

2. 预训练

预训练代码在 run_pretraing.py 文件中,注意我们需要把数据准备的结果作为预训练的输入:
在这里插入图片描述
那我们打上断点,继续开启 debug 吧!
在这里插入图片描述

2.1 modeling.BertModel

看预训练代码,大部分的核心代码集中在 modeling.BertModel 这个 class 的 __init__ 代码中:
在这里插入图片描述
解释下 modeling.BertModel 的参数:

  • config: BERT 的配置文件,后续的很多参数都来源于此。我放到路径 ./multi_cased_L-12_H-768_A-12/bert_config.json ,内容如下:
{"attention_probs_dropout_prob": 0.1, "directionality": "bidi", "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "max_position_embeddings": 512, "num_attention_heads": 12, "num_hidden_layers": 12, "pooler_fc_size": 768, "pooler_num_attention_heads": 12, "pooler_num_fc_layers": 3, "pooler_size_per_head": 128, "pooler_type": "first_token_transform", "type_vocab_size": 2, "vocab_size": 119547
}
  • is_training:True 表示训练,False 表示评估
  • input_ids:对应于数据准备的字段 input_ids,形状 [batch_size, seq_length],即 [32, 128]
  • input_mask:对应于数据准备的字段 input_mask,形状 [batch_size, seq_length],即 [32, 128]
  • token_type_ids:对应于数据准备的字段 segment_ids,形状 [batch_size, seq_length],即 [32, 128]
  • use_one_hot_embeddings:词嵌入是否用 one_hot 模式
  • scope:变量的scope,用于 tf.variable_scope(scope, default_name="bert") 默认是 bert

2.1.1 embedding_lookup

modeling.BertModel__init__ 代码中,第一个重要的方法是 embedding_lookup
在这里插入图片描述
我们看下具体的代码,返回值有两个:

  • out_put 是根据输入的 input_ids 在字典中找到对应的词,并且返回词对应的 embedding 向量,out_put 的形状是 [batch_size, seq_length, embedding_size]
  • embedding_table 是字典每一个词对应的向量,形状是 [vocab_size, embedding_size]

在这里插入图片描述
ps: 有些同学不清楚字典是什么?字典在项目的 ./multi_cased_L-12_H-768_A-12/vocab.txt 里,每一行对应一个词,里例如id=0则表示字典第一个对应的词[PAD],字典内容如下:

[PAD]
[unused1]
[unused2]
[unused3]
[unused4]
...
[unused99]
[UNK]
[CLS]
[SEP]
[MASK]
<S>
<T>
!
"
#
$
%
...
A
B
C
D
E
F
G
H

2.1.2 embedding_postprocessor

后续的该方法是用于加上位置编码!
在这里插入图片描述
我们进到函数内部查看具体细节:
在这里插入图片描述
上面代码中,token_type_ids 对应的是 segment_ids,即句子的表示(用0/1来表示),细节见《NLP深入学习:结合源码详解 BERT 模型(二)》 的 2.3章节。token_type_table 和上一节的 embedding_table 是一样的含义,这里就是向量化 segment_ids。由于 segment_ids 只用 0和1来表示,所以token_type_vocab_size=2,并且最终将 out_put 加上了 segment_ids 向量化的结果,就是图中的 TokenEmbeddings + SegmentEmbeddings
在这里插入图片描述
那么显而易见,下一段代码就是再加上 PositionEmbeddings 了!
在这里插入图片描述
注意,这里的 position_embeddings 实际就是词在句子中的位置对应的 embedding~

最后将输出加上了 layer_norm_and_dropout ,即层归一和dropout。

2.1.3 transformer_model

顺着代码debug下去,在准备好了数据之后,就是经典的 Transformer 模型了:
在这里插入图片描述
希望深入了解 Transformer 的,建议参考:
《NLP深入学习:大模型背后的Transformer模型究竟是什么?(一)》
《NLP深入学习:大模型背后的Transformer模型究竟是什么?(二)》

我们先回忆下 Transformer 的结构,因为下面的代码完全是对论文的编码器实现:
在这里插入图片描述
为了方便查看,我把代码的结构和论文的结构对比在一起:
在这里插入图片描述
transformer 结构构建完成之后,下面的self.sequence_out 是把最后一层的输出作为 transformer 的 encoding 结果输出。
在这里插入图片描述
此外,first_token_tensor 是取第一个 token 的输出结果,即 [CLS] 的结果。因为 [CLS] 已经带有上下文信息了,因此对于分类而言,用 [CLS] 的输出即可。这个论文中也有说明:
在这里插入图片描述
以上就是 BERT 模型的构建整体流程,下面来看 BERT 模型的评估流程,包含 Masked Language Model(MLM)和 Next Sentence Prediction(NSP)。

2.2 get_masked_lm_output

先来看 Masked Language Model(MLM)的评估,对应代码中的 get_masked_lm_out ,见下图:

首先看下 get_masked_lm_out 的输入参数:

  • bert_config : BERT 的配置文件,对应我的路径 ./multi_cased_L-12_H-768_A-12/bert_config.json
  • input_tensor:BERT 模型的输出,即上文的 self.sequence_out
  • output_weights:对应上文 embedding_lookup 的第二个输出,即字典每一个词对应的向量,形状是 [vocab_size, embedding_size]
  • positions:对应 features["masked_lm_positions"] ,即被选中 MASK 的 token 位置索引
  • label_ids:对应 features["masked_lm_ids"],即被选中 MASK 的 token 原始值在字典的索引位置
  • label_weights:对应 features["masked_lm_weights"]

下面是整体的代码,代码有些地方需要细细品味:

在这里插入图片描述
要看懂这里的代码,首先我们要知道 BERT 在 Masked Language Model(MLM)上要干啥。BERT 首先给句子的词打上了 [MASK] ,后续就要对 [MASK] 的词进行预测。预测,就是在词典中出现的词给出一个概率,看属于哪个词,本质上就是多分类问题。那么对于多分类问题,通常的做法是计算交叉熵。

这里就不详细阐述交叉熵的来龙去脉了,直接说明交叉熵如何计算。我们假设真实分布为 y,而模型输出分布为 y ^ \widehat{y} y ,总的类别数为 n,交叉熵损失函数的计算方法为:
l o s s = ∑ i = 1 n [ − y l o g y ^ i − ( 1 − y ) l o g ( 1 − y ^ i ) ] loss = \sum_{i=1}^{n}[-ylog\widehat{y}_i-(1-y)log(1-\widehat{y}_i)] loss=i=1n[ylogy i(1y)log(1y i)]
好,我们来看代码中关键的几个步骤:

  • log_probs = tf.nn.log_softmax(logits, axis=-1) ,这个方法实际上计算的是:
    l o g _ p r o b s = [ l o g y ^ 1 , l o g y ^ 2 , . . . , l o g y ^ n ] log\_probs = [log\widehat{y}_1, log\widehat{y}_2,...,log\widehat{y}_n] log_probs=[logy 1,logy 2,...,logy n]
    其中 l o g y ^ i log\widehat{y}_i logy i 表达的是属于词典第 i 个词的概率的对数值。

  • one_hot_labels = tf.one_hot(label_ids, depth=bert_config.vocab_size, dtype=tf.float32),计算每个词的在字典的 one_hot 结果,形状是 [batch_size*seq_len, vocab_size]。例如,“animal” 在字典第18883位置,那么"animal"对应的 one_hot 就是 [0,0,…0,1,0,…,0],其中向量长度就是字典的大小,1排在向量的18883个。

  • per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) ,这个方法是用于交叉熵的。因为我们知道真实的分布情况,就是 one_hot_labels 对应的结果,那么对于某一个具体的词,其交叉熵的计算就是 − y l o g y ^ i − ( 1 − y ) l o g ( 1 − y ^ i ) -ylog\widehat{y}_i-(1-y)log(1-\widehat{y}_i) ylogy i(1y)log(1y i),将 y=1(即事先知道一定属于某个词)代入,即交叉熵为 − l o g y ^ i -log\widehat{y}_i logy i。所以事先计算了 log_probsper_example_loss 可以直接得到每个词的交叉熵的结果。

  • lossper_example_loss 得到的结果赋予权重进行加权平均,得到一个最终的 loss,实际上就相当于 l o s s = ∑ i = 1 n w i [ − y l o g y ^ i − ( 1 − y ) l o g ( 1 − y ^ i ) ] loss = \sum_{i=1}^{n}w_i[-ylog\widehat{y}_i-(1-y)log(1-\widehat{y}_i)] loss=i=1nwi[ylogy i(1y)log(1y i)]

2.3 get_next_sentence_output

再来看 Next Sentence Prediction(NSP)评估,预测句子的下一句:
在这里插入图片描述
首先看下 get_next_sentence_output 的输入参数:

  • bert_config: BERT 的配置文件,对应我的路径 ./multi_cased_L-12_H-768_A-12/bert_config.json
  • input_tensor[CLS] 的输出线性变换后的结果,简单理解为 [CLS] 的输出作为当前函数的输入
  • labels:对应 features["next_sentence_labels"] ,1表示下一个句子是随机选择的,0表示正常语序

由于下一个句子只有两种选择,要么是随机的,要么是原先正常的句子,所以其实就是一个二分类问题:
在这里插入图片描述
二分类的交叉熵:
l o s s = ∑ i = 1 n − y l o g y ^ i loss = \sum_{i=1}^{n}-ylog\widehat{y}_i loss=i=1nylogy i
上面的核心逻辑跟 get_masked_lm_output 一模一样。只不过这里的 loss 用的是平均值,没有用加权平均

2.4 训练

计算了 masked_lm_loss 以及 next_sentence_loss 之后,将两种 loss 相加,即是总的 loss
在这里插入图片描述
后续就训练模型降低 loss

3. 参考

《NLP深入学习:结合源码详解 BERT 模型(一)》
《NLP深入学习:结合源码详解 BERT 模型(二)》
《BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding》
《NLP深入学习:大模型背后的Transformer模型究竟是什么?(一)》
《NLP深入学习:大模型背后的Transformer模型究竟是什么?(二)》

欢迎关注本人,我是喜欢搞事的程序猿; 一起进步,一起学习;

欢迎关注知乎:SmallerFL;

也欢迎关注我的wx公众号:一个比特定乾坤

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/289176.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分享:vue3+OpenTiny UI+cesium 实现三维地球

效果图 使用vue3 OpenTiny UI cesium 实现三维地球 node.js > v16.0 opentiny vue3 ui安装指南 https://opentiny.design/tiny-vue/zh-CN/os-theme/docs/installation yarn add opentiny/vue3 项目依赖 "dependencies": {"opentiny/vue": "3…

【图像合成】基于DCGAN典型网络的MNIST字符生成(pytorch)

关于 近年来&#xff0c;基于卷积网络&#xff08;CNN&#xff09;的监督学习已经 在计算机视觉应用中得到了广泛的采用。相比之下&#xff0c;无监督 使用 CNN 进行学习受到的关注较少。在这项工作中&#xff0c;我们希望能有所帮助 缩小了 CNN 在监督学习和无监督学习方面的成…

FPGA时钟资源详解(2)——Clock-Capable Inputs

FPGA时钟系列文章总览&#xff1a;FPGA原理与结构&#xff08;14&#xff09;——时钟资源https://ztzhang.blog.csdn.net/article/details/132307564 目录 一、概述 1.1 为什么使用CC 1.2 如何使用CC 二、Clock-Capable Inputs 2.1 SRCC 2.2 MRCC 2.3 其他用途 2.3.1…

element-plus中的日期时间选择器el-date-picker;日期选择面板中选定起始与结束的日期只能改具体的时刻,日期默认是一个月没法动态修改问题

目前遇到一个问题&#xff0c;在使用element-plus中的日期时间选择器el-date-picker&#xff0c;type为datetimerange时&#xff0c;展示的日期选择面板有两个输入框&#xff0c;开始时间和结束时间&#xff0c;element-plus只提供了default-time 使用datetimerange进行范围选择…

我们是如何测试人工智能的(八)包含大模型的企业级智能客服系统拆解与测试方法 -- 大模型 RAG

大模型的缺陷 -- 幻觉 接触过 GPT 这样的大模型产品的同学应该都知道大模型的强大之处&#xff0c; 很多人都应该调戏过 GPT&#xff0c;跟 GPT 聊很多的天。 作为一个面向大众的对话机器人&#xff0c;GPT 明显是鹤立鸡群&#xff0c;在世界范围内还没有看到有能跟 GPT 扳手腕…

五款会让你爱不离手的编程工具,用了都说好,加班变得少。

作为一名“CV工程师” 勤勤恳恳地复制粘贴 没想到AI来了之后 连搬运都不用了&#xff01; 融入了AI代码生成能力的工具 真的能代替程序员的位置吗&#xff1f; 看完这5个AI工具 咱们再来说结论吧&#xff01; aiXcoder 在平时写代码的过程中&#xff0c;经常需要通过上…

flutter3_douyin:基于flutter3+dart3短视频直播实例|Flutter3.x仿抖音

flutter3-dylive 跨平台仿抖音短视频直播app实战项目。 全新原创基于flutter3.19.2dart3.3.0getx等技术开发仿抖音app实战项目。实现了类似抖音整屏丝滑式上下滑动视频、左右滑动切换页面模块&#xff0c;直播间进场/礼物动效&#xff0c;聊天等模块。 运用技术 编辑器&#x…

吴恩达2022机器学习专项课程(一) 4.2 梯度下降实践

问题预览/关键词 本节内容梯度下降更新w的公式梯度下降更新b的公式的含义α的含义为什么要控制梯度下降的幅度&#xff1f;导数项的含义为什么要控制梯度下降的方向&#xff1f;梯度下降何时结束&#xff1f;梯度下降算法收敛的含义正确更新梯度下降的顺序错误更新梯度下降的顺…

网络编程之流式套接字

流式套接字&#xff08;SOCK_STREAM&#xff09;是一种网络编程接口&#xff0c;它提供了一种面向连接的、可靠的、无差错和无重复的数据传输服务。这种服务保证了数据按照发送的顺序被接收&#xff0c;使得数据传输具有高度的稳定性和正确性。通常用于那些对数据的顺序和完整性…

【vue3学习笔记(一)】vue3简介;使用vue-cli创建工程;使用vite创建工程;分析工程结构;安装开发者工具

尚硅谷Vue2.0Vue3.0全套教程丨vuejs从入门到精通 对应课程136-140节 课程 P136节 《vue3简介》笔记 课程 P137节 《使用vue-cli创建工程》笔记 官方文档&#xff1a; https://cli.vuejs.org/zh/guide/creating-a-project.html#vue-create官方文档地址 查看vue-cli版本&#x…

不要盲目开抖店,这才是开店的正确流程,2024全新版教程

我是王路飞。 抖音小店和视频号小店&#xff0c;我更建议没有经验的新手去做抖音小店。 虽然现在抖音小店不属于是一个蓝海项目了&#xff0c;但它依旧是我们普通人借助抖音流量变现非常重要的一个渠道&#xff0c;甚至没有之一。 至于视频号小店&#xff0c;可以说是当下最…

【JSON2WEB】11 基于 Amis 角色功能权限设置页面

【JSON2WEB】01 WEB管理信息系统架构设计 【JSON2WEB】02 JSON2WEB初步UI设计 【JSON2WEB】03 go的模板包html/template的使用 【JSON2WEB】04 amis低代码前端框架介绍 【JSON2WEB】05 前端开发三件套 HTML CSS JavaScript 速成 【JSON2WEB】06 JSON2WEB前端框架搭建 【J…

油缸位置传感器871D-DW2NP524-N4

概述 油缸位置传感器是一种使用电感原理来检测物体接近的开关装置。它通过感应物体的电磁场来判断物体的位置&#xff0c;并将信号转化为电信号输出。当物体靠近或远离电感式接近开关时&#xff0c;物体的电磁场会改变&#xff0c;从而使接近开关产生不同的信号输出。电感式接…

Chrome 插件 tabs API 解析

Chrome.tabs API 解析 使用 chrome.tabs API 与浏览器的标签页系统进行交互&#xff0c;可以使用此 API 在浏览器中创建、修改和重新排列标签页 Tabs API 不仅提供操作和管理标签页的功能&#xff0c;还可以检测标签页的语言、截取屏幕截图&#xff0c;以及与标签页的内容脚本…

MySQL面试汇总(一)

MySQL 如何定位慢查询 如何优化慢查询 索引及其底层实现 索引是一个数据结构&#xff0c;可以帮助MySQL高效获取数据。 聚簇索引和非聚簇索引 覆盖索引 索引创建原则 联合索引

6. 学习方法和Java概述

文章目录 1&#xff09;学习方法2&#xff09;Java是什么&#xff1f; 1&#xff09;学习方法 作为一个0基础入门的同学&#xff0c;在刚开始学习的时候&#xff0c;我们不要追求知识点的深度&#xff0c;而是要追求知识点的广度。简单来说&#xff0c;学一个知识点不要想的太…

TCP和UDP分别是什么?TCP和UDP的区别

在计算机网络通信中&#xff0c;TCP&#xff08;Transmission Control Protocol&#xff09;和UDP&#xff08;User Datagram Protocol&#xff09;是两种重要的传输层协议&#xff0c;它们在数据传输过程中发挥着关键作用。本文将深入探讨TCP和UDP的定义、特点以及它们之间的区…

【Qt】QDialog对话框

目录 一、概念 二、对话框的分类 2.1 模态对话框 2.2 非模态对话框 2.3 混合属性对话框 三、消息对话框QMessageBox 四、颜色对话框QColorDialog 五、文件对话框QFileDialog 六、字体对话框QFontDialog 七、输入对话框QInputDialog 一、概念 对话框是GUI程序中不可或…

MrDoc寻思文档 个人wiki搭建

通过Docker快速搭建个人wiki&#xff0c;开源wiki系统用于知识沉淀&#xff0c;教学管理&#xff0c;技术学习 部署步骤 ## 拉取 MrDoc 代码 ### 开源版&#xff1a; git clone https://gitee.com/zmister/MrDoc.git### 专业版&#xff1a; git clone https://{用户名}:{密码…

「媒体宣传」如何针对不同行业制定媒体邀约方案

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 针对不同行业制定媒体邀约方案时&#xff0c;需要考虑行业特点、目标受众、媒体偏好以及市场趋势等因素。 一、懂行业 先弄清楚你的行业是啥样&#xff0c;有啥特别之处。 了解行业的热…