BERT训练环节(代码实现)

1.代码实现

#导包
import torch
from torch import nn
import dltools
#加载数据需要用到的声明变量
batch_size, max_len = 1, 64
#获取训练数据迭代器、词汇表
train_iter, vocab = dltools.load_data_wiki(batch_size, max_len)
#其余都是二维数组
#tokens, segments, valid_lens(一维), pred_position, mlm_weights, mlm, nsp(一维)对应每条数据i中包含的数据
for i in train_iter:  #遍历迭代器break   #只遍历一条数据
[tensor([[    3,    25,     0,  4993,     0,    24,     4,    26,    13,     2,158,    20,     5,    73,  1399,     2,     9,   813,     9,   987,45,    26,    52,    46,    53,   158,     2,     5,  3140,  5880,9,   543,     6,  6974,     2,     2,   315,     6,     8,     5,8698,     8, 17229,     9,   308,     2,     4,     1,     1,     1,1,     1,     1,     1,     1,     1,     1,     1,     1,     1,1,     1,     1,     1]]),tensor([[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),tensor([47.]),tensor([[ 9, 15, 26, 32, 34, 35, 45,  0,  0,  0]]),tensor([[1., 1., 1., 1., 1., 1., 1., 0., 0., 0.]]),tensor([[ 484, 1288,   20,    6, 2808,    9,   18,    0,    0,    0]]),tensor([0])]
#创建BERT网络模型
net = dltools.BERTModel(len(vocab), num_hiddens=128, norm_shape=[128], ffn_num_input=128, ffn_num_hiddens=256, num_heads=2, num_layers=2, dropout=0.2, key_size=128, query_size=128, value_size=128, hid_in_features=128, mlm_in_features=128, nsp_in_features=128)
#调用设备上的GPU
devices = dltools.try_all_gpus()
#损失函数对象
loss = nn.CrossEntropyLoss()   #多分类问题,使用交叉熵
#@save    #表示用于指示某些代码应该被保存或导出,以便于管理和重用
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y):#前向传播#获取遮蔽词元的预测结果、下一个句子的预测结果_, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X, valid_lens_x.reshape(-1), pred_positions_X)#计算遮蔽语言模型的损失mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1,1)mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)   #MLM损失函数的归一化版本   #加一个很小的数1e-8,防止分母为0,抵消上一行代码乘以的数值#计算下一个句子预测任务的损失nsp_l = loss(nsp_Y_hat, nsp_y)l = mlm_l + nsp_lreturn mlm_l, nsp_l, l  
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):  #文本词元样本量太多,全跑完花费的时间太多,若num_steps=1在BERT中表示,跑了1个batch_sizenet = nn.DataParallel(net, device_ids=devices).to(devices[0])  #调用设备的GPUtrainer = torch.optim.Adam(net.parameters(), lr=0.01)   #梯度下降的优化算法Adamstep, timer = 0, dltools.Timer()  #设置计时器#调用画图工具animator = dltools.Animator(xlabel='step', ylabel='loss', xlim=[1, num_steps], legend=['mlm', 'nsp'])#遮蔽语言模型损失的和, 下一句预测任务损失的和, 句子对的数量, 计数metric = dltools.Accumulator(4)  #Accumulator类被设计用来收集和累加各种指标(metric)num_steps_reached = False  #设置一个判断标志, 训练步数是否达到预设的步数while step < num_steps and not num_steps_reached:for tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y in train_iter:#将遍历的数据发送到设备上tokens_X = tokens_X.to(devices[0])segments_X = segments_X.to(devices[0])valid_lens_x = valid_lens_x.to(devices[0])pred_positions_X = pred_positions_X.to(devices[0])mlm_weights_X = mlm_weights_X.to(devices[0])mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])#梯度清零trainer.zero_grad()timer.start()  #开始计时mlm_l, nsp_l, l = _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)l.backward()  #反向传播trainer.step()  #梯度更新metric.add(mlm_l, nsp_l, tokens_X.shape[0], l)  #累积的参数指标timer.stop() #计时停止animator.add(step + 1, (metric[0] / metric[3], metric[1] / metric[3]))  #画图的step += 1  #训练完一个batch_size,就+1if step == num_steps:  #若步数与预设的训练步数相等num_steps_reached = True   #判断标志改为Truebreak  #退出while循环print(f'MLM loss {metric[0] / metric[3]:.3f}, 'f'NSP loss {metric[1] / metric[3]:.3f}')print(f'{metric[2]/ timer.sum():.1f} sentence pairs/sec on 'f'{str(devices)}')
train_bert(train_iter, net, loss, len(vocab), devices, 500)

 

def get_bert_encoding(net, tokens_a, tokens_b=None):tokens, segments = dltools.get_tokens_and_segments(tokens_a, tokens_b)token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0)  #unsqueeze(0)增加一个维度segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)  valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)endoced_X, _, _ = net(token_ids, segments, valid_len)return endoced_X
tokens_a = ['a', 'crane', 'is', 'flying']
encoded_text = get_bert_encoding(net, tokens_a)
# 词元:'<cls>','a','crane','is','flying','<sep>'
encoded_text_cls = encoded_text[:, 0, :]
encoded_text_crane = encoded_text[:, 2, :]
encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]
(torch.Size([1, 6, 128]),torch.Size([1, 128]),tensor([-0.5872, -0.0510, -0.7376], device='cuda:0', grad_fn=<SliceBackward0>))
encoded_text_crane

 

tensor([[-5.8725e-01, -5.0994e-02, -7.3764e-01, -4.3832e-02,  9.2467e-02,1.2745e+00,  2.7062e-01,  6.0271e-01, -5.5055e-02,  7.5122e-02,4.4872e-01,  7.5821e-01, -6.1558e-02, -1.2549e+00,  2.4479e-01,1.3132e+00, -1.0382e+00, -4.7851e-03, -6.3590e-01, -1.3180e+00,5.2245e-02,  5.0982e-01,  7.4168e-02, -2.2352e+00,  7.4425e-02,5.0371e-01,  7.2120e-02, -4.6384e-01, -1.6588e+00,  6.3987e-01,-6.4567e-01,  1.7187e+00, -6.9696e-01,  5.6788e-01,  3.2628e-01,-1.0486e+00, -7.2610e-01,  5.7909e-02, -1.6380e-01, -1.2834e+00,1.6431e+00, -1.5972e+00, -4.5678e-03,  8.8022e-02,  5.5931e-02,-7.2332e-02, -4.9313e-01, -4.2971e+00,  6.9757e-01,  7.0690e-02,-1.8613e+00,  2.0366e-01,  8.9868e-01, -3.4565e-01,  9.6776e-02,1.3699e-02,  7.1410e-01,  5.4820e-01,  9.7358e-01, -8.1038e-01,2.6216e-01, -5.7850e-01, -1.1969e-01, -2.5277e-01, -2.0046e-01,-1.6718e-01,  5.5540e-01, -1.8172e-01, -2.5639e-02, -6.0961e-01,-1.1521e-03, -9.2973e-02,  9.5226e-01, -2.4453e-01,  9.7340e-01,-1.7908e+00, -2.9840e-02,  2.3087e+00,  2.4889e-01, -7.2734e-01,2.1827e+00, -1.1172e+00, -7.0915e-02,  2.5138e+00, -1.0356e+00,-3.7332e-02, -5.6668e-01,  5.2251e-01, -5.0058e-01,  1.7354e+00,4.0760e-01, -1.2982e-01, -7.0230e-01,  3.1563e+00,  1.8754e-01,2.0220e-01,  1.4500e-01,  2.3296e+00,  4.5522e-02,  1.1762e-01,1.0662e+00, -4.0858e+00,  1.6024e-01,  1.7885e+00, -2.7034e-01,-1.6869e-01, -8.7018e-02, -4.2451e-01,  1.1446e-01, -1.5761e+00,7.6947e-02,  2.4336e+00,  4.5346e-02, -6.5078e-02,  1.4203e+00,3.7165e-01, -7.9571e-01, -1.3515e+00,  4.1511e-02,  1.3561e-01,-3.3006e+00,  1.4821e-01,  1.3024e-01,  1.9966e-01, -8.5910e-01,1.4505e+00,  7.6774e-02,  9.3771e-01]], device='cuda:0',grad_fn=<SliceBackward0>)
tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)
# 词元:'<cls>','a','crane','driver','came','<sep>','he','just', 'left','<sep>'
encoded_pair_cls = encoded_pair[:, 0, :]
encoded_pair_crane = encoded_pair[:, 2, :]
encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]

 

(torch.Size([1, 10, 128]),torch.Size([1, 128]),tensor([-0.4637, -0.0569, -0.6119], device='cuda:0', grad_fn=<SliceBackward0>))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/431227.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OceanBase 3.X 高可用 (一)

OceanBase 3.X 高可用&#xff08;一&#xff09; 一、分布式核心 OceanBase 3.x 采用的是paxos 协议&#xff0c;与raft协议相比。其复杂程度高&#xff0c;实现技术难度大。 Paxos 协议允许事务日志乱序发送&#xff0c;顺序提交。raft允许事务顺序发送&#xff0c;顺序提…

关于 NLP 应用方向与深度训练的核心流程

文章目录 主流应用方向核心流程&#xff08;5步&#xff09;1.选定语言模型结构2.收集标注数据3.forward 正向传播4.backward 反向传播5.使用模型预测真实场景 主流应用方向 文本分类文本匹配序列标注生成式任务 核心流程&#xff08;5步&#xff09; 基本流程实现的先后顺序…

聊聊Thread Local Storage

聊聊ThreadLocal 为什么需要Thread Local StorageThread Local Storage的实现PThread库实现操作系统实现GCC __thread关键字实现C11 thread_local实现JAVA ThreadLocal实现 Thread Local Storage 线程局部存储&#xff0c;简称TLS。 为什么需要Thread Local Storage 变量分为全…

MySQL程序

目录 MySQL程序 常用的MySQL的程序 mysqld程序 mysql客户端 客户端命令的常用的选项 配置文件 配置文件语法 MySQL客户端命令 ​编辑 .sql 文件中执行SQL语句 mysqlcheck &#xff08;表维护程序&#xff09; Mysqldump&#xff08;数据库备份程序&#xff09; mysql…

[数据集][目标检测]基于yolov5增强数据集算法mosaic来扩充自己的数据集自动生成增强图片和对应标注无需重新标注

【算法介绍】 YOLOv5最引人注目的增强技术之一是马赛克增强&#xff0c;它将四张不同的图像拼接成一张图像。 思路&#xff1a;首先&#xff0c;从数据集中随机选择四张图像&#xff0c;然后将它们缩放、随机裁剪&#xff0c;并按马赛克模式拼接在一起。这种方式允许模型看到…

10. 排序

一、排序的概念及引用 1. 排序的概念 排序&#xff1a;所谓排序&#xff0c;就是使一串记录&#xff0c;按照其中的某个或某些关键字的大小&#xff0c;递增或递减的排列起来的操作。 稳定性&#xff1a;假定在待排序的记录序列中&#xff0c;存在多个具有相同的关键字的记录…

无人机之编程基础原理

无人机编程基础原理涉及多个方面&#xff0c;主要包括无人机的基本原理、飞行控制算法、编程语言及算法应用等。以下是对这些方面的详细阐述&#xff1a; 一、无人机基本原理 无人机的基本原理是理解其结构、飞行原理、传感器和控制系统等的基础。无人机通常由机身、动力系统&…

企业如何利用短视频平台做口碑塑造和品牌营销?

抖音和小红书作为短视频平台的代表&#xff0c;吸引了大量的用户和品牌。如何利用抖音、小红书等短视频平台进行品牌塑造和口碑营销呢&#xff1f;小马识途营销顾问分析&#xff0c;短视频平台的用户以年轻人为主&#xff0c;他们具有高度的社交性和消费意愿。短视频平台提供了…

fiddler抓包11_列表显示服务器IP (配置文件)

请求列表默认不显示服务器IP字段&#xff0c;也无法从定制列窗口添加&#xff0c;可以修改CustomRules.js实现。 ① 菜单栏“Rules”&#xff08;规则&#xff09; - “Customize Rules...”&#xff08;自定义规则&#xff09;&#xff0c;打开CustomRules.js文件。 &#xf…

Qt (17)【Qt 文件操作 读写保存】

阅读导航 引言一、Qt文件概述二、输入输出设备类三、文件读写类四、文件和目录信息类五、自定义“记事本” 引言 在上一篇文章中&#xff0c;我们学习了Qt的事件处理机制&#xff0c;知道了如何响应用户的操作。但应用程序常常还需要处理文件&#xff0c;比如读写数据。所以&a…

CVPR最牛图像评价算法!

本文所涉及所有资源均在 传知代码平台可获取。 目录 概述 一、论文思路 1.多任务学习框架&#xff1a; 2.视觉-语言对应关系&#xff1a; 3.动态损失权重&#xff1a; 4.模型优化和评估&#xff1a; 二、模型介绍 三、详细实现方法 1.图像编码器和语言编码器&#xff08;Image…

大语言模型的发展-OPENBMB

一、自然语言处理的基础 1、图灵测试 就是验证人工智能程序有多智能 让计算机像人一样&#xff0c;能够听懂问题&#xff0c;然后给出答案&#xff1b; 自然语言发展历史&#xff1a; advances in Natural Lannguage Processing --论文 2、自然语言处理的基本任务和应用 …

MES系统如何提升制造企业的运营效率和灵活性

参考拓展&#xff1a;苏州稳联-西门子MES系统-赋能智能制造的核心引擎 制造执行系统(MES)在提升制造企业运营效率和灵活性方面发挥着关键作用。 一、MES系统的基本概念和功能 MES系统是连接企业管理层与生产现场的重要桥梁。它主要负责生产调度、资源管理、质量控制等多个方…

【重学 MySQL】三十一、字符串函数

【重学 MySQL】三十一、字符串函数 函数名称用法描述ASCII(S)返回字符串S中的第一个字符的ASCII码值CHAR_LENGTH(s)返回字符串s的字符数&#xff0c;与CHARACTER_LENGTH(s)相同LENGTH(s)返回字符串s的字节数&#xff0c;和字符集有关CONCAT(s1,s2,…,sn)连接s1,s2,…,sn为一个字…

低代码可视化工具--vue条件判断v-if可视化设置-代码生成器

在Vue UniApp中&#xff0c;条件判断通常是通过指令v-if、v-else-if、v-else来实现的。这些机制允许你根据表达式的真假值来决定是否渲染某个元素或元素组&#xff0c;或者执行特定的逻辑。 条件判断说明 v-if 是惰性的&#xff1a;如果在初始渲染时条件为假&#xff0c;则什么…

如何使用ssm实现基于Java web的高校学生课堂考勤系统的设计与实现+vue

TOC ssm686基于Java web的高校学生课堂考勤系统的设计与实现vue 第一章 课题背景及研究内容 1.1 课题背景 信息数据从传统到当代&#xff0c;是一直在变革当中&#xff0c;突如其来的互联网让传统的信息管理看到了革命性的曙光&#xff0c;因为传统信息管理从时效性&#x…

BUUCTF [SCTF2019]电单车详解两种方法(python实现绝对原创)

使用audacity打开&#xff0c;发现是一段PT2242 信号 PT2242信号 有长有短&#xff0c;短的为0&#xff0c;长的为1化出来 这应该是截获电动车钥匙发射出的锁车信号 0 01110100101010100110 0010 0前四位为同步码0 。。。中间这20位为01110100101010100110为地址码0010为功…

Leetcode 反转链表

使用递归 /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode() {}* ListNode(int val) { this.val val; }* ListNode(int val, ListNode next) { this.val val; this.next next; }* }*/ class S…

Java基础知识扫盲

目录 Arrays.sort的底层实现 BigDecimal(double)和BigDecimal(String)有什么区别 Char可以存储一个汉字吗 Java中的Timer定时调度任务是咋实现的 Java中的序列化机制是咋实现的 Java中的注解是干嘛的 Arrays.sort的底层实现 Arrays.sort是Java中提供的对数组进行排序的…

动态规划11,完全背包模板

NC309 完全背包 问题一&#xff1a;求这个背包至多能装多大价值的物品&#xff1f; 状态表示&#xff1a;经验题目要求 dp[i][j] 表示 从前i个物品中挑选&#xff0c;总体积不超过j&#xff0c;所有选法中&#xff0c;能选出来的最大价值。 状态转移方程 根据最后一步的状态&a…