4.1 文本相似度(二)

目录

1 文本相似度评估

2 代码

2.1 load_dataset 方法

2.2 AutoTokenizer、AutoModelForSequenceClassification


1 文本相似度评估

      对两个文本拼接起来,然后作为一个样本喂给模型,作为一个二分类的任务;

数据处理的方式以及训练的基本流程与上文相似。

2 代码

  1. 数据预处理,把需要对比的文本放置一起,作为一个样本; tokenizer: 输入的语句是两个。分类标签的类型必须是 int,不能是其他的类型;
  2. 加载模型。
  3. 输出结果; 


2.1 load_dataset 方法

datasets是抱抱脸开发的一个数据集python库,可以很方便的从Hugging Face Hub里下载数据,也可很方便的从本地加载数据集,本文主要对load_dataset方法的使用进行详细说明。

def load_dataset(
    path: str,
    name: Optional[str] = None,
    data_dir: Optional[str] = None,
    data_files: Union[Dict, List] = None,
    split: Optional[Union[str, Split]] = None,
    cache_dir: Optional[str] = None,
    features: Optional[Features] = None,
    download_config: Optional[DownloadConfig] = None,
    download_mode: Optional[GenerateMode] = None,
    ignore_verifications: bool = False,
    save_infos: bool = False,
    script_version: Optional[Union[str, Version]] = None,
    **config_kwargs,
) -> Union[DatasetDict, Dataset]:

path:参数path表示数据集的名字或者路径。可以是如下几种形式(每种形式的使用方式后面会详细说明)
数据集的名字,比如imdb、glue
数据集文件格式,比如json、csv、parquet、txt
数据集目录中的处理数据集的脚本(.py)文件,比如“glue/glue.py”
name:参数name表示数据集中的子数据集,当一个数据集包含多个数据集时,就需要这个参数,比如glue数据集下就包含"sst2"、“cola”、"qqp"等多个子数据集,此时就需要指定name来表示加载哪一个子数据集
data_dir:数据集所在的目录
data_files:数据集文件
cache_dir:构建的数据集缓存目录,方便下次快速加载。

2.2 AutoTokenizer、AutoModelForSequenceClassification

类名称介绍
AutoTokenizerAutoTokenizer 是 Hugging Face Transformers 库中的一个类,用于自动选择适合特定预训练模型的 tokenizer。该类可以根据指定的模型名称或路径,自动选择对应的 tokenizer 类型,无需手动指定。这样可以方便地在不同的预训练模型之间切换,而无需更改代码中的 tokenizer 类型。
AutoModelForSequenceClassificationAutoModelForSequenceClassification 是 Hugging Face Transformers 库中的一个类,用于自动选择适合特定预训练模型的用于序列分类任务的模型。这个类会根据指定的模型名称或路径自动选择对应的模型类型,无需手动指定。这样可以方便地在不同的预训练模型之间切换,而无需更改代码中的模型类型。
TrainerTrainer 是 Hugging Face Transformers 库中用于训练和评估模型的高级 API。它提供了一个简单而强大的接口,用于管理训练循环、验证循环、日志记录、保存模型等任务。使用 Trainer 可以方便地训练和微调预训练模型,同时还支持分布式训练和混合精度训练等功能。
TrainingArgumentsTrainingArguments 是 Hugging Face Transformers 库中用于配置训练参数的类。通过 TrainingArguments 类,可以指定训练过程中的各种参数,如训练轮数、学习率、批次大小、日志路径、模型保存路径等。这些参数可以帮助控制训练过程的行为,并对训练过程进行定制。

from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset,load_from_disk
import traceback
from sklearn.model_selection import train_test_split#dataset = load_dataset("json", data_files="../data/train_pair_1w.json", split="train")
dataset = load_dataset("csv", data_files="/Users/user/studyFile/2024/nlp/text_similar/data/Chinese_Text_Similarity.csv", split="train")datasets = dataset.train_test_split(test_size=0.2,shuffle=True)import torchtokenizer = AutoTokenizer.from_pretrained("../chinese_macbert_base")
def process_function(examples):tokenized_examples = tokenizer(examples["sentence1"], examples["sentence2"], max_length=128, truncation=True)#  注意int(label)tokenized_examples["labels"] = [int(label) for label in examples["label"]]return tokenized_examplestokenized_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)
#tokenized_datasets# 创建模型
from transformers import BertForSequenceClassification 
model = AutoModelForSequenceClassification.from_pretrained("../chinese_macbert_base")import evaluate
acc_metric = evaluate.load("./metric_accuracy.py")
f1_metirc = evaluate.load("./metric_f1.py")
# 
# acc_metric = evaluate.load("accuracy")
# f1_metirc = evaluate.load("f1")
def eval_metric(eval_predict):predictions, labels = eval_predict#print(predictions,labels)predictions = predictions.argmax(axis=-1)#predictions = [int(p > 0.5) for p in predictions]labels = [int(l) for l in labels]# predictions = predictions.argmax(axis=-1)acc = acc_metric.compute(predictions=predictions, references=labels)f1 = f1_metirc.compute(predictions=predictions, references=labels)acc.update(f1)return acc
train_args = TrainingArguments(output_dir="./cross_model",      # 输出文件夹per_device_train_batch_size=32,  # 训练时的batch_sizeper_device_eval_batch_size=32,  # 验证时的batch_sizelogging_steps=10,                # log 打印的频率evaluation_strategy="epoch",     # 评估策略save_strategy="epoch",           # 保存策略save_total_limit=3,              # 最大保存数learning_rate=2e-5,              # 学习率weight_decay=0.01,               # weight_decaymetric_for_best_model="f1",      # 设定评估指标load_best_model_at_end=True)     # 训练完成后加载最优模型
train_args
from transformers import DataCollatorWithPadding
trainer = Trainer(model=model, args=train_args, train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["test"], data_collator=DataCollatorWithPadding(tokenizer=tokenizer),compute_metrics=eval_metric)
trainer.train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/326163.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024最新版JavaScript逆向爬虫教程-------基础篇之无限debugger的原理与绕过

目录 一、无限debugger的原理与绕过1.1 案例介绍1.2 实现原理1.3 绕过debugger方法1.3.1 禁用所有断点1.3.2 禁用局部断点1.3.3 替换文件1.3.4 函数置空与hook 二、补充2.1 改写JavaScript文件2.2 浏览器开发者工具中出现的VM开头的JS文件是什么? 三、实战 一、无限…

一篇文章掌握所有国债期货的基本交易策略介绍

国债期货是一种基本的利率衍生品,根据交易者交易目的不同,可以将期货交易行为分为三类:套期保值、套利交易和投机交易。套期保值是投资者为了避免现有或将来预期的投资组合价值受市场利率变动的影响,而在国债期货市场上采取抵消性…

2023年30米分辨率土地利用遥感监测数据

改革开放以来,中国经济的快速发展对土地利用模式产生了深刻的影响。同时,中国又具有复杂的自然环境背景和广阔的陆地面积,其土地利用变化不仅对国家发展,也对全球环境变化产生了深刻的影响。为了恢复和重建我国土地利用变化的现代…

六、Redis五种常用数据结构-zset

zset是Redis的有序集合数据类型,但是其和set一样是不能重复的。但是相比于set其又是有序的。set的每个数据都有一个double类型的分数,zset正是根据这个分数来进行数据间的排序从小到大。有序集合中的元素是唯一的,但是分数(score)是可以重复的…

LeetCode416:分割等和子集

题目描述 给你一个 只包含正整数 的 非空 数组 nums 。请你判断是否可以将这个数组分割成两个子集,使得两个子集的元素和相等。 解题思想 [1,5,11,5] 和为22,其中一半为 11。如果能寻找到若干数的和为11则成立可以抽象为一个0-1背包问题:容…

浮点数的由来及运算解析

数学是自然科学的皇后,计算机的设计初衷是科学计算。计算机的最基本功能是需要存储整数、实数,及对整数和实数进行算术四则运算。 但是在计算机从业者的眼中,我们知道的数学相关的基本数据类型通常是整型、浮点型、布尔型。整型又分为int8&a…

给centos机器打个样格式化挂载磁盘(新机器)

文章目录 一、先安装lvm2二、观察磁盘三、磁盘分区四、建PV五、建VG六、创建LV七、在LV上创建文件系统八、挂载到/home(1)临时挂载(2)永久挂载 九、最后reboot一下 一、先安装lvm2 yum install lvm2二、观察磁盘 三、磁盘分区 四…

Springboot整合 Spring Cloud Alibaba Seata

1.事务简介 事务是访问并可能更新数据库中各种数据项的一个程序执行单元。在关系型数据库中,一个事务由一组sql语句组成。事务具有 原子性,一致性,隔离性,持久性四个属性(ACID)。 原子性:事务是一个不可分割的工作单位…

ThreadLocal 源码详解

概述 ThreadLocal是一个java提供的本地线程副本变量工具类。主要用于将私有线程和该线程存放的副本对象做一个映射,各个线程之间的变量互不干扰,在高并发场景下,可以实现无状态的调用,特别适用于各个线程依赖不通的变量值完成操作…

美国政府首次发布《国家网络安全态势报告》

报告提到,不断演变的关键基础设施风险、勒索软件、供应链利用、商业间谍软件和AI是主要趋势;国家网络总监办公室同时公布了第二版《国家网络安全战略实施计划》,新增了31项倡议。 前情回顾美国深化网络安全战略 美国发布国家网络安全战略实施…

快团团怎么做帮卖团长/供货大团长(如何从小白到优质团长)?

一名小白想要成长为快团团的优质团长,可以遵循以下步骤和策略: 了解平台与注册成为团长: 首先,熟悉快团团平台的操作流程和规则。快团团是一个基于微信的小程序,专注于社区团购业务。通过微信扫描团长资源二维码或在快…

【爬虫基础1.1课】——requests模块上

目录索引 requests模块的作用:实例引入: 特殊情况:锦囊1:锦囊2: 这一个栏目,我会给出我从零开始学习爬虫的全过程。感兴趣的小伙伴可以关注一波,用于复习和新学都是不错的选择。 那么废话不多说&#xff0c…

​​​​【收录 Hello 算法】5.2 队列

目录 5.2 队列 5.2.1 队列常用操作 5.2.2 队列实现 1. 基于链表的实现 2. 基于数组的实现 5.2.3 队列典型应用 5.2 队列 队列(queue)是一种遵循先入先出规则的线性数据结构。顾名思义,队列模拟了排队现象,即…

利用香港多IP服务器进行大数据分析的潜在优势?

利用香港多IP服务器进行大数据分析的潜在优势? 在当今数据驱动的时代,大数据分析已经成为企业获取竞争优势的不二选择。而香港作为一个拥有世界级通信基础设施的城市,提供了理想的环境来部署多IP服务器,从而为大数据分析提供了独特的优势。…

2024中国(重庆)人工智能展览会8月举办

2024中国(重庆)人工智能展览会8月举办 邀请函 主办单位: 中国航空学会 重庆市南岸区人民政府 招商执行单位: 重庆港华展览有限公司 【报名I59交易会 2351交易会 9466】 展会背景: 2024中国航空科普大会暨第八届全国青少年无人机大赛在…

Spring Framework-简介

Spring Framework Java Spring是一个开源的Java应用框架,它的主要目的是简化企业级应用开发的复杂性。Spring框架为开发者提供了许多基础功能,使得开发者能够更专注于业务逻辑的实现,而不是底层的细节。 主要特点和功能: 控制反…

数据库脚本编写规范(SQL编写规范)

编写本文档的目的是保证在开发过程中产出高效、格式统一、易阅读、易维护的SQL代码。 1 编写目 2 SQL书写规范 3 SQL编写原则 软件开发全文档获取:点我获取

全域运营是割韭菜吗?看完再下结论!

随着流量时代的到来,各大公私域平台中的流量争夺战日益激烈,商家和品牌实现流量变现的难度值也不断提高,运营人员的压力也逐渐增大。在此背景下,全域运营的兴起或许是一个契机,能够将所有人从内卷的状态中解救出来。而…

网络安全之OSPF进阶

该文针对OSPF进行一个全面的认识。建议了解OSPF的基础后进行本文的一个阅读能较好理解本文。 OSPF基础的内容请查看:网络安全之动态路由OSPF基础-CSDN博客 OSPF中更新方式中的触发更新30分钟的链路状态刷新。是因为其算法决定的,距离矢量型协议是边算边…

力扣刷题 day1

移动零 283. 移动零 - 力扣(LeetCode) 看到这道题,是否第一个想法就是双指针,下面我们就来使用双指针来完成这道题. 图: 代码: java: // 移动 0 双指针public void moveZeroes(int[] nums) {for (int current 0, dest -1; …