BERT的代码实现

目录

1.BERT的理论

2.代码实现 

 2.1构建输入数据格式

 2.2定义BERT编码器的类

 2.3BERT的两个任务

2.3.1任务一:Masked Language Modeling MLM掩蔽语言模型任务 

2.3.2 任务二:next sentence prediction

3.整合代码 

 4.知识点个人理解


 

1.BERT的理论

BERT全称叫做Bidirectional Encoder Representations from Transformers, 论文地址: [1810.04805] BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (arxiv.org)

BERT是谷歌AI研究院在2018年10月提出的一种预训练模型. BERT本质上就是Transformer模型的encoder部分, 并且对encoder做了一些改进.

  • 官方代码和预训练模型 Github: https://github.com/google-research/bert

下图中编码器部分即BERT的基本结构.

  

2.代码实现 

import torch
from torch import nn
import dltools

 2.1构建输入数据格式

def get_tokens_and_segments(tokens_a, tokens_b=None):#classification 分类#BERT是两句话作为一对句子一同传入的,也可以单独传一句话,若序列长度长,可以补padding#假设先传一句话tokens_atokens = ['<cls>'] + tokens_a + ['<sep>']  #tokens_embedding层的处理segments = [0] * (len(tokens_a) + 2)  #判断词元属于哪一句话,加标记,0属于第一句话if tokens_b is not None:tokens += tokens_b + ['sep']segments += [1] * (len(tokens_b) + 1)return tokens, segments#测试上面的函数
get_tokens_and_segments([1, 2, 3], [4, 5, 6])

(['<cls>', 1, 2, 3, '<sep>', 4, 5, 6, 'sep'], [0, 0, 0, 0, 0, 1, 1, 1, 1])

 2.2定义BERT编码器的类

class BERTEncoder(nn.Module):#由于前馈网络的ffn_num_outputs = num_hiddens,没有初始化传入#__init__()里面的参数,是创建类的时候传入的参数def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout,max_len=1000, key_size=768, query_size=768, value_size=768, **kwargs):super().__init__(**kwargs)#token_embeddings层self.token_emdedding = nn.Embedding(vocab_size, num_hiddens)#segment_embedding层  (传入两个句子,所以第0维为2)self.segment_embedding = nn.Embedding(2, num_hiddens)#pos_embedding层  :位置嵌入层是可以学习的, 用nn.Parameter()定义可学习的参数self.pos_embedding = nn.Parameter(torch.randn(1, max_len, num_hiddens))#设置Encoder_block的数量self.blks = nn.Sequential()  #为使用的Encoder_block依次编号for i in range(num_layers):  #有几层网络循环几层self.blks.add_module(f'{i}', dltools.EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout))#__init__()里面的参数,是创建类的时候传入的参数#foward里面的参数是创建完类对象之后,调用类方法时传入的参数def forward(self, tokens, segments, valid_lens):#X = token_embedding + segment_embedding + pos_embedding#传入的token_embedding,segment_embedding两者的shape相同,可以直接相加X = self.token_emdedding(tokens) + self.segment_embedding(segments)#pos_embedding与前两层的数据shape不相同,不能直接相加#切片让self.pos_embedding的第1维度的数据切片到token_embedding,segment_embedding相加之后的数X = X + self.pos_embedding.data[:, :X.shape[1], :]for blk in self.blks:X = blk(X, valid_lens)return X  
#测试上面代码#创建BERTEncoder类对象
vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4
norm_shape, ffn_num_input, num_layers, dropout = [768], 768, 2, 0.2
encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)tokens = torch.randint(0, vocab_size, (2, 8)) #生成随机正整数
segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 0, 1, 1, 1, 1]])
#调用类方法
encoded_X = encoder(tokens, segments, None)encoded_X.shape
torch.Size([2, 8, 768])

#  nn.Sequential()是PyTorch中的一个类,它允许用户将多个计算层按照顺序组合成一个模型。在深度学习中,模型可以是由各种不同类型的层组成的,例如卷积层、池化层、全连接层等。nn.Sequential()方法可以将这些层组合在一起,形成一个整体模型。 

 2.3BERT的两个任务

2.3.1任务一:Masked Language Modeling MLM掩蔽语言模型任务 

class MaskLM(nn.Module):def __init__(self, vocab_size, num_inputs=768, **kwargs):super().__init__(**kwargs)self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),  #全连接层nn.ReLU(),  nn.LayerNorm(num_hiddens), nn.Linear(num_hiddens, vocab_size))  #输出层#X表示随机(15%概率)将一些词元换成mask#pred_positions表示已经处理好的80%概率将选中的词换成mask>, 10%概率换成随机词元,10%概率保持原有词元#pred_position是二维数据def forward(self, X, pred_positions):  num_pred_positions = pred_positions.shape[1]  #索引出80%、10%、10%三个概率选出的需要转换的词位置数量pred_positions = pred_positions.reshape(-1)  #变成一维数据batch_size = X.shape[0]  #获取批次batch_idx = torch.arange(0, batch_size) #获取批次的编号#将批次编号与元素数量对应起来#例如:batch_size = [0, 1]   -->   [0, 0, 0, 1, 1, 1]batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)  #将batch_idx中每个元素重复num_pred_positions次#把要预测位置的数据取出来masked_X = X[batch_idx, pred_positions]masked_X = masked_X.reshape(batch_size, num_pred_positions, -1)  #还原维度mlm_Y_hat = self.mlp(masked_X)return mlm_Y_hat
#测试代码mlm = MaskLM(vocab_size, num_hiddens)
mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])
mlm_Y_hat = mlm(encoded_X, mlm_positions)mlm_Y_hat.shape    #2:2个批次,   3:三个需要转换词元的位置     10000:计算的概率数量(在最后会用softmax函数计算分类结果),vocab_size有10000个,
torch.Size([2, 3, 10000])
mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])  #假设真实值
loss = nn.CrossEntropyLoss(reduction='none')
mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1))  # mlm_Y_hat的shape=(6, 10000)     mlm_Y的shape=(6)
mlm_l.shape

torch.Size([6])

2.3.2 任务二:next sentence prediction

class NextSentencePred(nn.Module):def __init__(self, num_inputs, **kwargs):super().__init__(**kwargs)self.output = nn.Linear(num_inputs, 2)  #预测输入的句子是否为下一个句子,预测目标值为“是/否”二分类问题def forward(self, X):#X的形状(batch_size, num_hiddens)return self.output(X)
#测试代码encoded_X = torch.flatten(encoded_X, start_dim=1)  #将数据展平,相当于reshape
nsp = NextSentencePred(encoded_X.shape[-1])
nsp_Y_hat = nsp(encoded_X)nsp_Y_hat.shape
torch.Size([2, 2])
#计算损失
nsp_y = torch.tensor([0, 1])   #假设真实值
nsp_1 = loss(nsp_Y_hat, nsp_y)
nsp_1.shape

torch.Size([2])

3.整合代码 

class BERTModel(nn.Module):def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout,max_len=1000, key_size=768, query_size=768, value_size=768,hid_in_features=768, mlm_in_features=768, nsp_in_features=768, **kwargs):super().__init__(**kwargs)#初始化编码器对象self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout,max_len=max_len, key_size=key_size, query_size=query_size, value_size=value_size)#掩蔽语言模型任务self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)#中间隐藏层的线性转换+激活函数self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens), nn.Tanh())#预测出下一句self.nsp = NextSentencePred(nsp_in_features)def forward(self, tokens, seqments, valid_lens=None, pred_position=None):encoded_X = self.encoder(tokens, seqments, valid_lens)if pred_position is not None:mlm_Y_hat = self.mlm(encoded_X, pred_position)else:pred_position = None#0表示<cls>标记的索引nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))return encoded_X, mlm_Y_hat, nsp_Y_hat

 4.知识点个人理解

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/429585.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习02-pytorch-08-自动微分模块

​​​​​​​ 其实自动微分模块&#xff0c;就是求相当于机器学习中的线性回归损失函数的导数。就是求梯度。 反向传播的目的&#xff1a; 更新参数&#xff0c; 所以会使用到自动微分模块。 神经网络传输的数据都是 float32 类型。 案例1: 代码功能概述&#xff1a; 该…

鸿蒙Harmony应用开发,数据驾驶舱 项目结构搭建

对于一个项目而言&#xff0c;在拿到我们的开发任务后&#xff0c;我们最重要的就是技术的选型。选型定下来了之后我们便开始脚手架的搭建&#xff0c;然后开始撸代码&#xff0c;开搞. 首先我们需要对一些常见依赖库的引入 我们需要再oh-package.json5的dependencies节点下面…

8--SpringBoot原理分析、注解-详解(面试高频提问点)

目录 SpringBootApplication 1.元注解 --->元注解 Target Retention Documented Inherited 2.SpringBootConfiguration Configuration Component Indexed 3.EnableAutoConfiguration&#xff08;自动配置核心注解&#xff09; 4.ComponentScan Conditional Co…

基于PHP的新闻管理系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于phpMySQL的新闻管理系统。…

JavaWeb--纯小白笔记03:servlet入门---动态网页的创建

笔记&#xff1a;index.html在tomcat中为默认的名字&#xff0c;html里面的语法不严谨。改配置文件要小心&#xff0c;不然容易删掉其他 Servlet&#xff1a;服务器端小程序&#xff0c;写动态网页需要用Servlet&#xff0c;普通的java类通过继承HttpServlet&#xff0c;可以响…

【GUI设计】基于Matlab的图像处理GUI系统(1),用matlab实现

博主简介&#xff1a;matlab图像代码项目合作&#xff08;扣扣&#xff1a;3249726188&#xff09; ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 本次案例是基于Matlab的图像处理GUI系统&#xff0c;用matlab实现。 本次内容主要分为两部分&a…

Why is OpenAI image generation Api returning 400 bad request in Unity?

题意&#xff1a;为什么 OpenAI 图像生成 API 在 Unity 中返回 400 Bad Request 错误&#xff1f; 问题背景&#xff1a; Im testing out dynamically generating images using OpenAI API in Unity. Amusingly, I actually generated most of this code from chatGPT. 我正在…

【笔记】第二节 轧制、热处理和焊接工艺

2.2 钢轨的轧制工艺 坯料进厂按标准验收, 然后装加热炉加热, 加热好的钢坯经高压水除鳞后进行轧制。轧出的钢轨经锯切、打印到中央冷床冷却, 然后装缓冷坑进行缓冷。缓冷后的钢轨进行矫直、轨端加工和端头淬火。钢轨入库前逐根进行探伤和外观检查。 钢轨的轧制 #mermaid-svg-…

foreach,for in和for of的区别

forEach 不能使用break return 结束并退出循环 for in 和 for of 可以使用break return&#xff1b; for in 遍历的是数组的索引&#xff08;即键名&#xff09;&#xff0c;而for of遍历的是数组元素值。 for of 遍历的只是数组内的元素&#xff0c;而不包括数组的原型属性…

后端-navicat查找语句(单表与多表)

表格字段设置如图 语句&#xff1a; 1.输出 1.输出name和age列 SELECT name,age from student 1.2.全部输出 select * from student 2.where子语句 1.运算符&#xff1a; 等于 >大于 >大于等于 <小于 <小于等于 ! <>不等于 select * from stude…

JdbcTemplate常用方法一览AG网页参数绑定与数据寻址实操

JdbcTemplate是Spring框架中的一个重要组件&#xff0c;主要用于简化JDBC数据库操作。它提供了许多常用的方法&#xff0c;如查询、插入、更新、删除等。本文将介绍JdbcTemplate的常用方法及其使用方式&#xff0c;以及参数绑定和删除数据的方法。 一、JdbcTemplate常用方法 查…

钉钉与MySQL对接集成获取部门列表2.0打通EXECUTE语句

钉钉与MySQL对接集成获取部门列表2.0打通EXECUTE语句 接入系统&#xff1a;钉钉 钉钉是阿里巴巴集团打造的企业级智能移动办公平台&#xff0c;是数字经济时代的企业组织协同办公和应用开发平台。钉钉将IM即时沟通、钉钉文档、钉闪会、钉盘、Teambition、OA审批、智能人事、钉工…

828华为云征文|华为Flexus云服务器搭建Cloudreve私人网盘

一、华为云 Flexus X 实例&#xff1a;开启高效云服务新篇&#x1f31f; 在云计算的广阔领域中&#xff0c;资源的灵活配置与卓越性能犹如璀璨星辰般闪耀。华为云 Flexus X 实例恰似一颗最为耀眼的新星&#xff0c;将云服务器技术推向了崭新的高度。 华为云 Flexus X 实例基于…

基于SpringBoot+Vue的商城积分系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、SSM项目源码 精品专栏&#xff1a;Java精选实战项目源码、Python精…

我的AI工具箱Tauri版-MicrosoftTTS文本转语音

本教程基于自研的AI工具箱Tauri版进行MicrosoftTTS文本转语音服务。 MicrosoftTTS文本转语音服务 是自研的AI工具箱Tauri版中的一款功能模块&#xff0c;专为实现高效的文本转语音操作而设计。通过集成微软TTS服务&#xff0c;用户可以将大量文本自动转换为自然流畅的语音文件…

物理学基础精解【9】

文章目录 直线与二元一次方程两直线夹角直线方程斜率两点式方程截距式方程将不同形式的直线方程转换为截距方程直线的一般方程直线一般方程的系数有一个或两个为零的直线 参考文献 直线与二元一次方程 两直线夹角 两直线 y 1 k 1 x b 1 , y 2 k 2 x b 2 形成夹角 a 1 和 a…

关于字节 c++

字节的介绍 字节是计算机中最小的存储单位&#xff0c;通常由8个二进制位组成&#xff0c;用来存储一个字符。在C中&#xff0c;字节也是基本数据类型之一&#xff0c;用关键字"byte"来表示。字节主要用于存储一些较小的数据&#xff0c;如整数、字符等。字节的大小…

【Delphi】通过 LiveBindings Designer 链接控件示例

本教程展示了如何使用 LiveBindings Designer 可视化地创建控件之间的 LiveBindings&#xff0c;以便创建只需很少或无需源代码的应用程序。 在本教程中&#xff0c;您将创建一个高清多设备应用程序&#xff0c;该应用程序使用 LiveBindings 绑定多个对象&#xff0c;以更改圆…

python - self 调用父类方法

Python 子类继承父类构造函数说明 | 菜鸟教程如果在子类中需要父类的构造方法就需要显式地调用父类的构造方法&#xff0c;或者不重写父类的构造方法。 子类不重写 __init__&#xff0c;实例化子类时&#xff0c;会自动调用父类定义的 __init__。 实例 [mycode3 typepython] cl…

Linux基础---13三剑客及正则表达式

一.划水阶段 首先我们先来一个三剑客与正则表达式混合使用的简单示例&#xff0c;大致了解是个啥玩意儿。下面我来演示一下如何查询登录失败的ip地址及次数。 1.首先&#xff0c;进入到 /var/log目录下 cd /var/log效果如下 2.最后&#xff0c;输入如下指令即可查看&#xf…