基于transformers框架实践Bert系列4-文本相似度

本系列用于Bert模型实践实际场景,分别包括分类器、命名实体识别、选择题、文本摘要等等。(关于Bert的结构和详细这里就不做讲解,但了解Bert的基本结构是做实践的基础,因此看本系列之前,最好了解一下transformers和Bert等)
本篇主要讲解文本相似度应用场景。本系列代码和数据集都上传到GitHub上:https://github.com/forever1986/bert_task

目录

  • 1 环境说明
  • 2 前期准备
    • 2.1 了解Bert的输入输出
    • 2.1 了解Bert的输入输出
    • 2.2 数据集与模型
    • 2.3 任务说明
    • 2.4 实现关键
  • 3 关键代码
    • 3.1 数据集处理
    • 3.2 模型加载
    • 3.3 评估函数
  • 4 整体代码
  • 5 运行效果

1 环境说明

1)本次实践的框架采用torch-2.1+transformer-4.37
2)另外还采用或依赖其它一些库,如:evaluate、pandas、datasets、accelerate等

2 前期准备

2.1 了解Bert的输入输出

Bert模型是一个只包含transformer的encoder部分,并采用双向上下文和预测下一句训练而成的预训练模型。可以基于该模型做很多下游任务。

2.1 了解Bert的输入输出

Bert的输入:input_ids(使用tokenizer将句子向量化),attention_mask,token_type_ids(句子序号)、labels(结果)
Bert的输出:
last_hidden_state:最后一层encoder的输出;大小是(batch_size, sequence_length, hidden_size)
pooler_output:这是序列的第一个token(classification token)的最后一层的隐藏状态,输出的大小是(batch_size, hidden_size),它是由线性层和Tanh激活函数进一步处理的。(通常用于句子分类,至于是使用这个表示,还是使用整个输入序列的隐藏状态序列的平均化或池化,视情况而定)。(注意:这是关键输出,本次任务就需要获取该值,并进行相似度计算
hidden_states: 这是输出的一个可选项,如果输出,需要指定config.output_hidden_states=True,它也是一个元组,它的第一个元素是embedding,其余元素是各层的输出,每个元素的形状是(batch_size, sequence_length, hidden_size)
attentions:这是输出的一个可选项,如果输出,需要指定config.output_attentions=True,它也是一个元组,它的元素是每一层的注意力权重,用于计算self-attention heads的加权平均值。

2.2 数据集与模型

1)数据集来自:Chinese_Text_Similarity
2)模型权重使用:bert-base-chinese

2.3 任务说明

1)文本相似度任务就是判断2段文本的相似程度,可以理解为是否表达相同的意思。这时可能想到的最简单的方法就是将2段文本作为输入,label是0或1这样一个分类方法,可以采用系列1(情感分类)的方式。但是如果找是一段文本对应多个文本之间的相似度呢?或许你会想到系列3(选择题)的方式。但是如果是一段文本对应几十万的文本之间的相似度呢?虽然系列3(选择题)也能解决问题,但是会很慢,因为你要一一匹配。这里我们可以采用一个特征提取方式,先将文本输入到模型做特征,最后在通过相似度比较函数对2段文本的特征进行比较即可,虽然也是需要每段文本都做比较,但是好处是先将文本做好特征。
在这里插入图片描述
2)这时候我们需要做的是分别将数据放入同一个BERT模型进行特征提取,然后通过相似度和余弦相似度损失计算进行模型训练即可

2.4 实现关键

1)将数据处理成对放入模型中
2)自定义模型,在forward中对2个句子分别通过bert做特征提取,然后计算相似度和余弦相似度损失

3 关键代码

3.1 数据集处理

Chinese_Text_Similarity数据集是一个txt文件,每一行分别存储“句子1”、“句子2”、“相似度”。下面代码就是读取数据并处理为模型想要的类型

# 读取数据
df = pd.read_csv(data_path, sep='\s+')
df = df.sample(n=5000)  # 取其中5000条
datasets = Dataset.from_pandas(df)
# 划分训练集和测试集
datasets = datasets.train_test_split(test_size=0.1, shuffle=True, seed=42)
# 划分训练集和验证集
train_datasets = datasets["train"].train_test_split(test_size=0.05, shuffle=True, seed=42)
datasets["train"] = train_datasets["train"]
datasets["validation"] = train_datasets["test"]
tokenizer = BertTokenizerFast.from_pretrained(model_path)# 数据处理函数
def process_function(datas):sentences = []labels = []for sentence1, sentence2, label in zip(datas["句子1"], datas["句子2"], datas["相似度"]):sentences.append(sentence1)sentences.append(sentence2)labels.append(1 if int(label) == 1 else -1)tokenized_datas = tokenizer(sentences, max_length=256, truncation=True, padding="max_length")# 关键点:这里将2条数据合并为一组,也就是reshape,从(2倍datas数量 * max_length),变成(datas数量 * 2 * max_length)tokenized_datas = {k: [v[i: i + 2] for i in range(0, len(v), 2)] for k, v in tokenized_datas.items()}tokenized_datas["labels"] = labelsreturn tokenized_datasnew_datasets = datasets.map(process_function, batched=True)

3.2 模型加载

自定义模型,模仿transformers中的其它BERT模型,继承BertPreTrainedModel(为了方便使用XXX.from_pretrained()获取模型),参照其它BERT模型写法,重新init和forward方法

class SimilarityModel(BertPreTrainedModel):# 不需要增加其它层def __init__(self, config: PretrainedConfig, *inputs, **kwargs):super().__init__(config, *inputs, **kwargs)self.bert = BertModel(config)self.post_init()# 在forward中对2个句子分别通过bert做特征提取,然后计算相似度和余弦相似度损失def forward(self,input_ids: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,token_type_ids: Optional[torch.Tensor] = None,position_ids: Optional[torch.Tensor] = None,head_mask: Optional[torch.Tensor] = None,inputs_embeds: Optional[torch.Tensor] = None,labels: Optional[torch.Tensor] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,):return_dict = return_dict if return_dict is not None else self.config.use_return_dict# 分别获取sentenceA 和 sentenceB的输入senA_input_ids, senB_input_ids = input_ids[:, 0], input_ids[:, 1]senA_attention_mask, senB_attention_mask = attention_mask[:, 0], attention_mask[:, 1]senA_token_type_ids, senB_token_type_ids = token_type_ids[:, 0], token_type_ids[:, 1]# 分别获取sentenceA 和 sentenceB的向量表示senA_outputs = self.bert(senA_input_ids,attention_mask=senA_attention_mask,token_type_ids=senA_token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)# 获得pooler_outputsenA_pooled_output = senA_outputs[1]senB_outputs = self.bert(senB_input_ids,attention_mask=senB_attention_mask,token_type_ids=senB_token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)senB_pooled_output = senB_outputs[1]  # [batch, hidden]# 计算相似度cos = CosineSimilarity()(senA_pooled_output, senB_pooled_output)# 计算lossloss = Noneif labels is not None:loss_fct = CosineEmbeddingLoss(0.3)loss = loss_fct(senA_pooled_output, senB_pooled_output, labels)output = (cos,)return ((loss,) + output) if loss is not None else outputmodel = SimilarityModel.from_pretrained(model_path)

3.3 评估函数

这里采用evaluate库加载accuracy准确度计算方式来做评估,本次实验将accuracy和f1的计算py文件下载下来,因此也是本地加载

# 评估函数:此处的评估函数可以从https://github.com/huggingface/evaluate下载到本地
acc_metric = evaluate.load("./evaluate/metric_accuracy.py")
f1_metric = evaluate.load("./evaluate/metric_f1.py")def evaluate_function(eval_predict):predictions, labels = eval_predictpredictions = [int(p > 0.7) for p in predictions]labels = [int(l > 0) for l in labels]acc = acc_metric.compute(predictions=predictions, references=labels)f1 = f1_metric.compute(predictions=predictions, references=labels)acc.update(f1)return acc

4 整体代码

"""
基于BERT做文本相似度
1)数据集来自:Chinese_Text_Similarity
2)模型权重使用:bert-base-chinese
"""# step 1 引入数据库
import torch
from torch.nn import CosineSimilarity, CosineEmbeddingLossimport evaluate
import pandas as pd
from typing import Optional
from datasets import Dataset
from transformers import TrainingArguments, Trainer, BertTokenizerFast, BertPreTrainedModel, PretrainedConfig, BertModelmodel_path = "./model/tiansz/bert-base-chinese"
data_path = "./data/Chinese_Text_Similarity.txt"# step 2 数据集处理
df = pd.read_csv(data_path, sep='\s+')
df = df.sample(n=5000)  # 取其中5000条
datasets = Dataset.from_pandas(df)
# 划分训练集和测试集
datasets = datasets.train_test_split(test_size=0.1, shuffle=True, seed=42)
# 划分训练集和验证集
train_datasets = datasets["train"].train_test_split(test_size=0.05, shuffle=True, seed=42)
datasets["train"] = train_datasets["train"]
datasets["validation"] = train_datasets["test"]
tokenizer = BertTokenizerFast.from_pretrained(model_path)def process_function(datas):sentences = []labels = []for sentence1, sentence2, label in zip(datas["句子1"], datas["句子2"], datas["相似度"]):sentences.append(sentence1)sentences.append(sentence2)labels.append(1 if int(label) == 1 else -1)tokenized_datas = tokenizer(sentences, max_length=256, truncation=True, padding="max_length")# 这里将2条数据合并为一组,也就是reshape,从(2倍datas数量 * max_length),变成(datas数量 * 2 * max_length)tokenized_datas = {k: [v[i: i + 2] for i in range(0, len(v), 2)] for k, v in tokenized_datas.items()}tokenized_datas["labels"] = labelsreturn tokenized_datasnew_datasets = datasets.map(process_function, batched=True)# step 3 加载模型
class SimilarityModel(BertPreTrainedModel):def __init__(self, config: PretrainedConfig, *inputs, **kwargs):super().__init__(config, *inputs, **kwargs)self.bert = BertModel(config)self.post_init()def forward(self,input_ids: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,token_type_ids: Optional[torch.Tensor] = None,position_ids: Optional[torch.Tensor] = None,head_mask: Optional[torch.Tensor] = None,inputs_embeds: Optional[torch.Tensor] = None,labels: Optional[torch.Tensor] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,):return_dict = return_dict if return_dict is not None else self.config.use_return_dict# Step1 分别获取sentenceA 和 sentenceB的输入senA_input_ids, senB_input_ids = input_ids[:, 0], input_ids[:, 1]senA_attention_mask, senB_attention_mask = attention_mask[:, 0], attention_mask[:, 1]senA_token_type_ids, senB_token_type_ids = token_type_ids[:, 0], token_type_ids[:, 1]# Step2 分别获取sentenceA 和 sentenceB的向量表示senA_outputs = self.bert(senA_input_ids,attention_mask=senA_attention_mask,token_type_ids=senA_token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)senA_pooled_output = senA_outputs[1]  # [batch, hidden]senB_outputs = self.bert(senB_input_ids,attention_mask=senB_attention_mask,token_type_ids=senB_token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)senB_pooled_output = senB_outputs[1]  # [batch, hidden]# step3 计算相似度cos = CosineSimilarity()(senA_pooled_output, senB_pooled_output)  # [batch, ]# step4 计算lossloss = Noneif labels is not None:loss_fct = CosineEmbeddingLoss(0.3)loss = loss_fct(senA_pooled_output, senB_pooled_output, labels)output = (cos,)return ((loss,) + output) if loss is not None else outputmodel = SimilarityModel.from_pretrained(model_path)# step 4 评估函数:此处的评估函数可以从https://github.com/huggingface/evaluate下载到本地
acc_metric = evaluate.load("./evaluate/metric_accuracy.py")
f1_metric = evaluate.load("./evaluate/metric_f1.py")def evaluate_function(eval_predict):predictions, labels = eval_predictpredictions = [int(p > 0.7) for p in predictions]labels = [int(l > 0) for l in labels]acc = acc_metric.compute(predictions=predictions, references=labels)f1 = f1_metric.compute(predictions=predictions, references=labels)acc.update(f1)return acc# step 5 创建TrainingArguments
# train是4275条数据,batch_size=32,因此每个epoch的step=134,总step=402
train_args = TrainingArguments(output_dir="./checkpoints",      # 输出文件夹per_device_train_batch_size=32,  # 训练时的batch_sizeper_device_eval_batch_size=32,    # 验证时的batch_sizenum_train_epochs=3,              # 训练轮数logging_steps=50,                # log 打印的频率evaluation_strategy="epoch",     # 评估策略save_strategy="epoch",           # 保存策略save_total_limit=3,              # 最大保存数load_best_model_at_end=True      # 训练完成后加载最优模型)# step 6 创建Trainer
trainer = Trainer(model=model,args=train_args,train_dataset=new_datasets["train"],eval_dataset=new_datasets["validation"],compute_metrics=evaluate_function,)# step 7 训练
trainer.train()# step 8 模型评估
evaluate_result = trainer.evaluate(new_datasets["test"])
print(evaluate_result)# step 9 模型预测
class SentenceSimilarityPipeline:def __init__(self, model, tokenizer) -> None:self.model = model.bertself.tokenizer = tokenizerself.device = model.devicedef preprocess(self, senA, senB):return self.tokenizer([senA, senB], max_length=128, truncation=True, return_tensors="pt", padding=True)def predict(self, inputs):inputs = {k: v.to(self.device) for k, v in inputs.items()}return self.model(**inputs)[1]  # [2, 768]def postprocess(self, logits):cos = CosineSimilarity()(logits[None, 0, :], logits[None,1, :]).squeeze().cpu().item()return cosdef __call__(self, senA, senB):inputs = self.preprocess(senA, senB)logits = self.predict(inputs)result = self.postprocess(logits)if result >= 0.7:return "相似"return "不相似"pipe = SentenceSimilarityPipeline(model, tokenizer)
print(pipe("广东哪里最好玩啊?", "广东最好玩的地方在哪?"))

5 运行效果

在这里插入图片描述

注:本文参考来自大神:https://github.com/zyds/transformers-code

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/330012.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

两数交换,数组查找奇数个数的数(位运算)

文章目录 一、异或运算:1.1 Demo1.2 面试题 一、异或运算: 1.1 Demo 0和N进行异或运算都等于N 任何一个数和自己异或运算都等于0 且异或运算满足交换率 a^b b^a eg: a 甲 , b 已 那么则有 a a^b ​ b a^b ​ a a^b 故有&am…

ACM实训第十七天

Is It A Tree? 问题 考试时应该做不出来,果断放弃 树是一种众所周知的数据结构,它要么是空的(null, void, nothing),要么是一个或的集合满足以下属性的节点之间有向边连接的节点较多。 •只有一个节点,称为根节点,它…

Skywalking快速介绍

(01)SkyWalking简介 SkyWalking专为微服务,云原生架构和基于容器(Docker,k8s,Mesos等)的架构设计的应用程序性能监控工具,用于收集、分析、聚合和可视化来自服务和云原生基础设施的数…

Varjo XR-4功能详解:由凝视驱动的XR自动对焦相机系统

Varjo是XR市场中拥有领先技术的虚拟现实设备供应商,其将可变焦距摄像机直通系统带入到虚拟和混合现实场景中。在本篇文章中,Varjo的技术工程师维尔蒂莫宁详细介绍了这项在Varjo XR-4焦点版中投入应用的技术。 对可变焦距光学系统的需求 目前所有其他XR头…

Android 构建时:Manifest merger failed : Attribute application@name value

在AndroidStudio 构建时发现此问题&#xff1a; Manifest merger failed : Attribute applicationname value解决方案&#xff1a;在主Manifest中增加replace <applicationandroid:name".MyApp"android:allowBackup"false"tools:replace"android…

Linux:Ubuntu修改root密码

Linux&#xff1a;Ubuntu修改root密码 修改默认grub配置文件 rootshanxin:~# vim /etc/default/grub# 主要修改内容如下&#xff1a;GRUB_DEFAULT0 #GRUB_TIMEOUT_STYLEhidden 注释这一行 GRUB_TIMEOUT5 # 将这一行的时间改为5秒进行开启启动的grub文件的复写 rootshanxin:~…

不懂平面设计,这篇文章教你制作商业画册

​商业画册不仅是企业展示形象、推广产品的重要工具&#xff0c;也是设计师展现创意的平台。因此&#xff0c;制作一本高质量的画册对于企业来说至关重要。 那要怎么着手制作呢&#xff1f;以下是关于制作商业画册的步骤。 1.要制作电子杂志,首先需要选择一款适合自己的软件。…

Linux - 整理工作中常用的 Linux 命令(目录、文件、系统、进程、网络)持续更新~

目录 一、Linux 目录结构 二、Linux 中的常用指令 2.1、目录命令 cd 切换目录 pwd 打印当前所在目录 ls 展示当前目录内容 mkdir 创建目录 du 统计每个目录下的文件字节数 2.2、文件命令 which 查找 命令字 所在位置 find 查找文件 touch 创建一个空文件 cp 复制文…

设计软件有哪些?数据交换和导入导出工具篇,渲染100邀请码1a12

设计师制作的项目通常要在各种软件里导入导出&#xff0c;互相交换格式&#xff0c;这次我们介绍一些数据交换和导入导出工具。 1、OBJ OBJ&#xff08;Object File Format&#xff09;是一种常用的3D模型文件格式&#xff0c;用于存储和交换三维模型数据。它由一系列文本行组…

正点原子[第二期]Linux之ARM(MX6U)裸机篇学习笔记-18讲 高精度延时定时器GPT

前言&#xff1a; 本文是根据哔哩哔哩网站上“正点原子[第二期]Linux之ARM&#xff08;MX6U&#xff09;裸机篇”视频的学习笔记&#xff0c;在这里会记录下正点原子 I.MX6ULL 开发板的配套视频教程所作的实验和学习笔记内容。本文大量引用了正点原子教学视频和链接中的内容。…

前端 CSS 经典:元素倒影

前言&#xff1a;好看的元素倒影&#xff0c;可以通过-webkit-box-reflect 实现。但有兼容问题&#xff0c;必须是 webkit 内核的浏览器&#xff0c;不然没效果。但是好看啊。 效果图&#xff1a; 代码实现&#xff1a; <!DOCTYPE html> <html lang"en"&g…

VUE3好看的酒网站模板源码

文章目录 1.设计来源1.1 首页界面1.2 十大名酒界面1.3 名酒新闻界面1.4 联系我们界面1.5 在线留言界面 2.效果和结构2.1 动态效果2.2 代码结构 3.VUE框架系列源码4.源码下载 作者&#xff1a;xcLeigh 文章地址&#xff1a;https://blog.csdn.net/weixin_43151418/article/detai…

【C++初阶】—— 类和对象 (下)

&#x1f4dd;个人主页&#x1f339;&#xff1a;EterNity_TiMe_ ⏩收录专栏⏪&#xff1a;C “ 登神长阶 ” &#x1f339;&#x1f339;期待您的关注 &#x1f339;&#x1f339; 类和对象 1. 运算符重载运算符重载赋值运算符重载前置和后置重载 2. 成员函数的补充3. 初始化列…

Java中String类常用方法

写笔记记录自己的学习记录以及巩固细节 目录 1.String类的常用方法 1.1 字符串构造 1.2 String对象的比较 1.2.1 比较两个字符串是否相等 1.2.2 比较两个字符串的大小 1.3 字符串查找 1.4 字符串的转化 1.4.1 字符串转整数 1.4.2 字符串转数字 1.4.3 大小写的转换 1…

IT革命浪潮:技术革新如何改变我们的生活与工作

一、技术革新与行业应用 当前的IT行业正处于前所未有的技术革新阶段。其中&#xff0c;量子计算和虚拟现实是两项引人注目的技术。 量子计算&#xff1a;量子计算以其超越传统计算的潜力&#xff0c;正在逐步从理论走向实践。在材料科学、药物研发和气候模型等复杂计算领域&a…

利用kubeadm安装k8s集群 以及跟harbor私有仓库下载镜像

目录 环境准备 master&#xff08;2C/4G&#xff09; 192.168.88.3 docker、kubeadm、kubelet、kubectl、flannel node01&#xff08;2C/2G&#xff09; 192.168.88.4 docker、kubeadm、kubelet、kubectl、flannel node02&#xff08;…

Ansible自动化运维中的file文件模块模块应用详解

作者主页&#xff1a;点击&#xff01; Ansible专栏&#xff1a;点击&#xff01; 创作时间&#xff1a;2024年5月21日15点21分 &#x1f4af;趣站推荐&#x1f4af; 前些天发现了一个巨牛的&#x1f916;人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xf…

向npm发布自己写的vue组件,使用vite创建项目

向npm发布自己写的vue组件&#xff0c;使用vite创建项目 创建项目 pnpm create vite输入项目名称 由于我的组件是基于 ant-design-vue和vue的&#xff0c;需要解析.vue文件&#xff0c;我又安装了下面4个。 然后执行 pnpm i安装依赖 vite.config.ts import { defineC…

linux系统——ps命令的两种参数模式

ps命令后面接参数时&#xff0c;有“—”符号与无此符号&#xff0c;在具体实现功能上有很大区别 能够清晰表达进程之间层级关系

前端菜鸡,对于35+程序员失业这个事有点麻了

“经常看到30岁程序员失业的新闻&#xff0c;说实话&#xff0c;有点麻。目前程序员供求关系并未失衡&#xff0c;哪怕是最基础的前端或者后台、甚至事务型的岗位也是足够的。 事实上&#xff0c;现在一个开出的岗位要找到一位尽职尽责能顺利完成工作的程序员并不是一件那么容…