中文完形填空

本文通过ChnSentiCorp数据集介绍了完型填空任务过程,主要使用预训练语言模型bert-base-chinese直接在测试集上进行测试,也简要介绍了模型训练流程,不过最后没有保存训练好的模型。

一.完形填空
完形填空应该大家都比较熟悉,就是把句子中的词挖掉,根据上下文推测挖掉的词是什么。

二.准备数据集
本文使用ChnSentiCorp数据集,不清楚的可以参考中文情感分类介绍。一些样例如下所示:

本文做法为将每句话截断为固定的30个词,同时将第15个词替换为[MASK],模型任务为根据上下文预测第15个词。

1.使用编码工具

def load_encode_tool(pretrained_model_name_or_path):token = BertTokenizer.from_pretrained(Path(f'{pretrained_model_name_or_path}'))return token
if __name__ == '__main__':# 测试编码工具pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\bert-base-chinese'token = load_encode_tool(pretrained_model_name_or_path)print(token)

输出结果如下所示:

BertTokenizer(name_or_path='L:\20230713_HuggingFaceModel\bert-base-chinese', vocab_size=21128, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True)

测试编码句子如下所示:

if __name__ == '__main__':# 测试编码工具pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\bert-base-chinese'token = load_encode_tool(pretrained_model_name_or_path)# 测试编码句子out = token.batch_encode_plus(batch_text_or_text_pairs=[('不是一切大树,', '都被风暴折断。'),('不是一切种子,', '都找不到生根的土壤。')],truncation=True,padding='max_length',max_length=18,return_tensors='pt',return_length=True, # 返回长度)# 查看编码输出for k, v in out.items():print(k, v.shape)print(token.decode(out['input_ids'][0]))print(token.decode(out['input_ids'][1]))

结果输出如下所示:

input_ids torch.Size([2, 18])
token_type_ids torch.Size([2, 18])
length torch.Size([2])
attention_mask torch.Size([2, 18])
[CLS] 不 是 一 切 大 树 , [SEP] 都 被 风 暴 折 断 。 [SEP] [PAD]
[CLS] 不 是 一 切 种 子 , [SEP] 都 找 不 到 生 根 的 土 [SEP]

第1个句子长度为17,补了1个[PAD],第2个句子长度为18。return_length=True表示返回句子真实长度,即不包括[PAD]填充部分长度。如下所示:

编码结果如下所示:

2.定义数据集

def load_dataset_from_disk():pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\ChnSentiCorp'dataset = load_from_disk(pretrained_model_name_or_path)# batched=True表示批量处理# batch_size=1000表示每次处理1000个样本# num_proc=8表示使用8个线程操作# remove_columns=['text']表示移除text列dataset = dataset.map(f1, batched=True, batch_size=1000, num_proc=8, remove_columns=['text', 'label'])return dataset
if __name__ == '__main__':# 加载数据集dataset = load_dataset_from_disk()print(dataset)

结果输出如下所示:

DatasetDict({train: Dataset({features: ['input_ids', 'token_type_ids', 'attention_mask', 'length'],num_rows: 9600})validation: Dataset({features: ['input_ids', 'token_type_ids', 'attention_mask', 'length'],num_rows: 1200})test: Dataset({features: ['input_ids', 'token_type_ids', 'attention_mask', 'length'],num_rows: 1200})
})

3.定义计算设备

# 定义计算设备
device = 'cpu'
if torch.cuda.is_available():device = 'cuda'
# print(device)

4.定义数据整理函数
本质是将每个句子第15个词替换为[MASK],同时将第15个词作为标签,即根据上下文要预测的词。如下所示:

# 数据整理函数
def collate_fn(data):# 取出编码结果input_ids = [i['input_ids'] for i in data]attention_mask = [i['attention_mask'] for i in data]token_type_ids = [i['token_type_ids'] for i in data]# 转换为Tensor格式input_ids = torch.LongTensor(input_ids)attention_mask = torch.LongTensor(attention_mask)token_type_ids = torch.LongTensor(token_type_ids)# 把第15个词替换为MASKlabels = input_ids[:, 15].reshape(-1).clone()input_ids[:, 15] = token.get_vocab()[token.mask_token]# 移动到计算设备input_ids = input_ids.to(device)attention_mask = attention_mask.to(device)token_type_ids = token_type_ids.to(device)labels = labels.to(device)return input_ids, attention_mask, token_type_ids, labels

5.定义数据集加载器

# 数据集加载器
loader = torch.utils.data.DataLoader(dataset=dataset['train'], batch_size=16, collate_fn=collate_fn, shuffle=True, drop_last=True)
print(len(loader)) #600=9600/16

查看样例数据如下所示:

# 查看数据样例
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):break
print(input_ids.shape, attention_mask.shape, token_type_ids.shape, labels)

输出结果如下所示:

torch.Size([16, 30])
torch.Size([16, 30])
torch.Size([16, 30])
tensor([4638, 8024, 3198, 6206, 6392, 4761, 3449, 2128, 3341,  119, 3315, 2697,2523, 2769, 6814, 1086], device='cuda:0')

三.定义模型
1.加载预训练模型
加载模型并移动到device(CPU或GPU)中,如下所示:

pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\bert-base-chinese'
# 加载预训练模型
pretrained = BertModel.from_pretrained(Path(f'{pretrained_model_name_or_path}'))
pretrained.to(device) 

2.定义下游任务模型
下游任务模型将BERT提取第15个词的特征(16×768),输入到全连接神经网络(768×21128),得到16×21128,即把第15个词的特征投影到全体词表空间中,还原为词典中的某个词。

class Model(torch.nn.Module):def __init__(self):super().__init__()self.decoder = torch.nn.Linear(in_features=768, out_features=token.vocab_size, bias=False)# 重新将decode中的bias参数初始化为全oself.bias = torch.nn.Parameter(data=torch.zeros(token.vocab_size))self.decoder.bias = self.bias# 定义 Dropout层,防止过拟合self.Dropout = torch.nn.Dropout(p=0.5)def forward(self, input_ids, attention_mask, token_type_ids):# 使用预训练模型抽取数据特征with torch.no_grad():out = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)# 把第15个词的特征投影到全字典范围内out = self.Dropout(out.last_hidden_state[:, 15])out = self.decoder(out)return out

四.训练和测试
1.训练
定义了AdamW优化器、loss损失函数(交叉熵损失函数)和线性学习率调节器,如下所示:

def train():# 定义优化器optimizer = AdamW(model.parameters(), lr=5e-4, weight_decay=1.0)# 定义1oss函数criterion = torch.nn.CrossEntropyLoss()# 定义学习率调节器scheduler = get_scheduler(name='linear', num_warmup_steps=0, num_training_steps=len(loader) * 5, optimizer=optimizer)# 将模型切换到训练模式model.train()# 共训练5个epochfor epoch in range(5):# 按批次遍历训练集中的数据for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):# 模型计算out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)# 计算loss并使用梯度下降法优化模型参数loss = criterion(out, labels)loss.backward()optimizer.step()scheduler.step()optimizer.zero_grad()# 输出各项数据的情况,便于观察if i % 50 == 0:out = out.argmax(dim=1)accuracy = (out == labels).sum().item() / len(labels)lr = optimizer.state_dict()['param_groups'][0]['lr']print(epoch, 1, loss.item(), lr, accuracy)

输出部分结果如下所示:

0 1 10.123428344726562 0.0004998333333333334 0.0
0 1 8.659417152404785 0.0004915 0.0625
0 1 7.431852340698242 0.0004831666666666667 0.0625
0 1 7.261701583862305 0.00047483333333333335 0.0625
0 1 6.693362236022949 0.0004665 0.125
0 1 4.0811614990234375 0.00045816666666666667 0.375
0 1 7.034963607788086 0.00044983333333333334 0.1875

2.测试
使用测试数据集进行测试,如下所示:

def test():# 定义测试数据集加载器loader_test = torch.utils.data.DataLoader(dataset=dataset['test'],  batch_size=32, collate_fn=collate_fn, shuffle=True, drop_last=True)# 将下游任务模型切换到运行模式model.eval()correct = 0total = 0# 按批次遍历测试集中的数据for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):# 计算15个批次即可,不需要全部遍历if i == 15:breakprint(i)# 计算with torch.no_grad():out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)# 统计正确率out = out.argmax(dim=1)correct += (out == labels).sum().item()total += len(labels)print(correct / total)

参考文献:
[1]HuggingFace自然语言处理详解:基于BERT中文模型的任务实战
[2]https://github.com/ai408/nlp-engineering/blob/main/20230625_HuggingFace自然语言处理详解/第8章:完形填空.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/115365.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Cloud Alibaba-Sentinel规则

1 流控规则 流量控制,其原理是监控应用流量的QPS(每秒查询率) 或并发线程数等指标,当达到指定的阈值时 对流量进行控制,以避免被瞬时的流量高峰冲垮,从而保障应用的高可用性。 第1步: 点击簇点链路,我们就可以看到访…

现代化畜牧业行业分析 - 商品猪养殖

改革开放以来,中国畜牧业生产基础条件不断改善、生产方式快速转变,畜牧业综合生产能力和保障市场有效供应能力不断加强。中国肉类、禽蛋产量均居世界第一位,奶类产量居世界第三位。随着产量的增长,中国人均畜产品占有量也持续上升…

联网智能实时监控静电离子风机的工作流程

联网智能实时监控静电离子风机是通过将静电离子风机与互联网连接,实现对其状态和性能的远程监控和管理。 具体实现该功能的方法可以包括以下几个步骤: 1. 传感器安装:在静电离子风机上安装适当的传感器,用于感知相关的参数&…

微信开发之一键创建标签的技术实现

简要描述: 添加标签 请求URL: http://域名地址/addContactLabel 请求方式: POST 请求头Headers: Content-Type:application/jsonAuthorization:login接口返回 参数: 参数名必选类型说明…

Python面试:什么是GIL

1. GIL (Global Interpreter lock)可以避免多个线程同时执行字节码。 import threadinglock threading.Lock()n [0]def foo():with lock:n[0] n[0] 1n[0] n[0] 1threads [] for i in range(5000):t threading.Thread(targetfoo)threads.append(t)for t in threads:t.s…

python conda实践 sanic框架gitee webhook实践

import subprocess import hmac import hashlib import base64 from sanic.response import text from sanic import Blueprint from git import Repo# 路由蓝图 hook_blue Blueprint(hook_blue)hook_blue.route(/hook/kaifa, methods["POST"]) async def kaifa(req…

UI自动化之关键字驱动

关键字驱动框架:将每一条测试用例分成四个不同的部分 测试步骤(Test Step):一个测试步骤的描述或者是测试对象的一个操作说明测试步骤中的对象(Test Object):指页面的对象或者元素对象执行的动…

rtsp 拉流 gb28181 收流 经AI 算法 再生成 rtsp server (一)

1、 rtsp 工具 1 vlc 必备工具 2 wireshark 必备工具 3 自己制作的工具 player 使用tcp 拉流,不自己写的话,使用ffmpeg 去写一个播放器就行 4 live555 编译好live555, 将live555的参数修改以下,主要是缓存大小 文章使用c 来写一…

C++面试题(叁)---操作系统篇

目录 操作系统篇 1 Linux中查看进程运行状态的指令、查看内存使用情况的指令、 tar解压文件的参数。 2 文件权限怎么修改 3 说说常用的Linux命令 4 说说如何以root权限运行某个程序。 5 说说软链接和硬链接的区别。 6 说说静态库和动态库怎么制作及如何使用,区…

界面控件DevExpress WinForms(v23.2)下半年发展路线图

本文主要概述了官方在下半年(v23.2)中一些与DevExpress WinForms相关的开发计划,重点关注的领域将是可访问性支持和支持.NET 8。 DevExpress WinForms有180组件和UI库,能为Windows Forms平台创建具有影响力的业务解决方案。同时能…

自然语言处理-NLP

目录 自然语言处理-NLP 致命密码:一场关于语言的较量 自然语言处理的发展历程 兴起时期 符号主义时期 连接主义时期 深度学习时期 自然语言处理技术面临的挑战 语言学角度 同义词问题 情感倾向问题 歧义性问题 对话/篇章等长文本处理问题 探索自然语言…

四、高并发内存池整体框架设计

四、高并发内存池整体框架设计 现代很多的开发环境都是多核多线程,在申请内存的场景下,必然存在激烈的锁竞争问题。malloc本身其实已经很优秀,那么我们项目的原型TCmalloc就是在多线程高并发的场景下更胜一筹,所以这次我们实现的…

centos 7的超详细安装教程

打开虚拟机,创建一个新电脑 我们选择经典,然后选择下一步 我们选择稍后安装,我们在后面进行改设备 因为centos系统是linux系统的一个版本,所有我们选择linux,版本选择centos 7 64位,然后就是点击下一步 这一…

HTML <template> 标签

实例 使用 <template> 保留页面加载时隐藏的内容。使用 JavaScript 来显示: <button οnclick="showContent()">显示被隐藏的内容</button><template><h2>Flower</h2><img src="img_white_flower.jpg" width=&q…

2023年03月 C/C++(五级)真题解析#中国电子学会#全国青少年软件编程等级考试

第1题&#xff1a;拼点游戏 C和S两位同学一起玩拼点游戏。有一堆白色卡牌和一堆蓝色卡牌&#xff0c;每张卡牌上写了一个整数点数。C随机抽取n张白色卡牌&#xff0c;S随机抽取n张蓝色卡牌&#xff0c;他们进行n回合拼点&#xff0c;每次两人各出一张卡牌&#xff0c;点数大者获…

Word导出创建Adobe PDF其中emf图片公式马赛克化及文字缺失

软件版本 Word 2021 Visio 2019 Adobe Acrobat Pro 2020 问题描述 公式马赛克化&#xff0c;是指在Word中使用MathType编辑的公式&#xff0c;然后在Visio中使用图片(增强型图元文件)形式得到的粘贴对象&#xff0c;效果如下 文字缺失&#xff0c;是指Word导出→创建Adobe P…

源码安装cv_bridge

1. 下载源码 1去github上下载GitHub - ros-perception/vision_opencv&#xff0c;进去后注意选择与自己的ros对应的版本&#xff1a;&#xff08;我的为noetic&#xff09; 如果你直接使用 git clone https://github.com/ros-perception/vision_opencv.git 来拉取的话cmake的…

MySQL 8 数据清洗总结

MySQL 8 数据清洗三要素&#xff1a; 库表拷贝和数据备份数据清洗SQL数据清洗必杀技-存储过程 前提&#xff1a;数据库关联库表初始化和基础数据初始化&#xff1a; -- usc.t_project definitionCREATE TABLE t_project (id varchar(64) NOT NULL COMMENT 主键,tid varchar(…

网络基础之重中之重

目录 IP协议 ​编辑 基本概念&#xff1a; 协议头格式&#xff1a; ​编辑 网段划分 DHCP &#xff1a; CIDR&#xff1a; 特殊的IP地址&#xff1a; IP地址的数量限制&#xff1a; 私有IP和公网IP 路由 路由的过程&#xff1a; 数据链路层 认识以太网&#x…

GAN原理 代码解读

模型架构 代码 数据准备 import os import time import matplotlib.pyplot as plt import numpy as np import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision import datasets import torch.nn as nn import torch# 创建文…