使用 BERT 进行文本分类 (03/3)

一、说明

        在使用BERT(2)进行文本分类时,我们讨论了什么是PyTorch以及如何预处理我们的数据,以便可以使用BERT模型对其进行分析。在这篇文章中,我将向您展示如何训练分类器并对其进行评估。

二、准备数据的又一个步骤

        上次,我们使用train_test_split将数据拆分为测试和验证数据。接下来需要的一个重要步骤是将数据转换为值列表,以便稍后可以在我们的训练器方法中调用它们。此步骤在其他教程中经常被忽略,当您无法微调模型时,这通常是问题所在。

# This is a continuation from the code written in Text Classification with BERT (2)
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(df_balanced['Message'],df_balanced['Label'], stratify=df_balanced['Label'], test_size=.2)# Store everything in list of values
train_texts = X_train.to_list()
val_texts = X_val.to_list()
train_labels = y_train.to_list()
val_labels = y_val.to_list()

2.1 标记化

        现在我们已经准备好了我们的数据集,我们需要做一些标记化。我们将使用DistilBERT来实现这一点。引用拥抱脸的话:

DistilBERT是一种小型,快速,廉价和轻便的变压器模型,通过蒸馏Bert基础进行训练。它的参数比 bert-base-uncase 少 40%,运行速度快 60%,同时保留了 95% 以上的 Bert 性能,如 GLUE 语言理解基准测试所示。

        导入模型后,我们将文本传递给分词器。如果您已经忘记了填充和截断,请检查使用 BERT 进行文本分类 (01/3) 的  文

from transformers import DistilBertTokenizerFast
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)

2.2 格式化我们的数据集

        在这里,我们需要将输入数据转换为可用于使用 PyTorch 训练深度学习模型的格式。

import torchclass SmapDataset(torch.utils.data.Dataset):def __init__(self, encodings, labels):self.encodings = encodingsself.labels = labelsdef __getitem__(self, idx):item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}item['labels'] = torch.tensor(self.labels[idx])return itemdef __len__(self):return len(self.labels)train_dataset = SmapDataset(train_encodings, train_labels)
val_dataset = SmapDataset(val_encodings, val_labels)

        类的构造函数方法 () 通过将输入存储为 和 类属性来初始化数据集对象。__init__SmapDatasetencodingslabels

        此类的方法用于从给定索引处的数据集中检索单个项目。它返回一个包含两个元素的字典对象:__getitem__idxitem

  • 该元素是包含输入编码的字典对象,其中键是编码功能的名称,值是包含给定索引处的编码数据的 PyTorch 张量。encodingsidx
  • 该元素是一个 PyTorch 张量,其中包含给定索引处的标签数据。labelsidx

        此类的方法返回数据集中的样本总数。__len__

        最后,代码创建两个数据集对象,并使用类传入 、、 和 作为输入参数。这些数据集对象可用于在 PyTorch 模型中进行训练和验证。train_datasetval_datasetSmapDatasettrain_encodingstrain_labelsval_encodingsval_labels

2.3 使用培训师进行微调

        我们以培训师预期的方式准备了数据。现在我们需要根据数据微调预训练模型。默认情况下,trainer.train 方法将仅报告训练损失。我将定义自己的指标函数并将其传递给培训师。

from sklearn.metrics import accuracy_score, precision_recall_fscore_supportdef compute_metrics(pred):labels = pred.label_idspreds = pred.predictions.argmax(-1)precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)acc = accuracy_score(labels, preds)return {'accuracy': acc,'f1': f1,'precision': precision,'recall': recall}
  • 准确性: 这是正确分类的样本占数据集中样本总数的比例。换句话说,它衡量模型正确预测数据集中所有样本的类标签的能力。虽然准确性是一个常用的指标,但在某些情况下可能会产生误导,尤其是在处理类分布不相等的不平衡数据集时。
  • 精度:此指标度量真阳性预测(正确预测的正样本)在模型做出的所有正预测中的比例。换句话说,它衡量模型正确预测正样本的频率。当我们想要避免假阳性预测时,即当错误地将样本预测为阳性时,当样本实际上是负数时,精度非常有用。
  • 召回:此指标衡量数据集中所有真阳性样本中真正预测的比例。换句话说,它衡量模型找到所有正样本的能力。当我们想要避免假阴性预测时,即当错误地将样本预测为阴性时,当样本实际上是阳性时,召回率很有用。
  • F1比分:此指标是精度和召回率的调和平均值,并提供了一种平衡这两个指标的方法。它衡量精度和召回率之间的平衡,并且在假阳性和假阴性错误都有后果时很有用。
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArgumentstraining_args = TrainingArguments(output_dir='./results',          # output directorynum_train_epochs=3,              # total number of training epochsper_device_train_batch_size=16,  # batch size per device during trainingper_device_eval_batch_size=64,   # batch size for evaluationwarmup_steps=500,                # number of warmup steps for learning rate schedulerweight_decay=0.01,               # strength of weight decaylogging_dir='./logs',            # directory for storing logslogging_steps=10,evaluation_strategy="steps"
)model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")trainer = Trainer(model=model,                         # the instantiated 🤗 Transformers model to be trainedargs=training_args,                  # training arguments, defined abovetrain_dataset=train_dataset,         # training dataseteval_dataset=val_dataset,            # validation datasetcompute_metrics=compute_metrics     
)trainer.train()

2.4 结果

        正如我们所看到的,我们的 F1 分数达到了 98% 左右,这表明我们的模型在判断邮件在我们的验证数据集中是垃圾邮件还是正常邮件方面表现良好。请记住,真正的测试数据集是野外未标记的消息。在本案例研究中,我们没有特权测试它在现实世界中的执行方式。

三、总结

        在这篇文章中,我们学习了如何微调BERT模型以进行文本分类,并定义了自己的函数来评估我们的自定义模型。达门·

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/118950.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

IDEA中Run/Debug Configurations添加VM options和Program arguments

1. 现象描述 我在我的IDEA当中打开配置模板后,发现没有VM options和Program arguments,也就是虚拟机选项和程序实参这两项,导致我不能配置系统属性参数和命令行参数!!!!!&#xff0…

【vue2第十一章】v-model的原理详解 与 如何使用v-model对父子组件的value绑定 和修饰符.sync

v-model的原理详解 v-model的本质就是一个语法糖,实际上就是 :value"msg" 与 input"msg $event.target.value" 的简写。 :value"msg" 从数据单向绑定到input框,当data数据中的msg内容一旦改变,而input框数据…

Matlab 基本教程

1 清空环境变量及命令 clear all % 清除Workspace 中的所有变量 clc % 清除Command Windows 中的所有命令 2 变量命令规则 (1)变量名长度不超过63位 (2)变量名以字母开头, 可以由字母、数字和下划线…

自定义类型:结构体、枚举、联合

目录 结构体 结构体的基础知识 结构的声明 特殊的声明 结构体的自引用 结构体变量的定义和初始化 结构体内存对齐 修改默认对齐数 结构体传参 位段 什么是位段 位段的内存分配 位段的跨平台问题 位段的应用 枚举 枚举类型的定义 枚举的优点 联合体(共…

CLFS信息泄露漏洞CVE-2023-28266分析

引用 这篇文章的目的是介绍今年4月发布的CLFS信息泄露漏洞CVE-2023-28266分析. 文章目录 引用简介CVE-2023-28266漏洞分析CVE-2023-28266调试过程漏洞复现相关引用参与贡献 简介 文章结合了逆向代码和调试结果分析了CVE-2023-28266漏洞利用过程和漏洞成因. CVE-2023-28266漏洞…

stm32之27.iic协议oled显示

屏幕如果无法点亮,需要用GPIO_OType_PP推挽输出,加并上拉电阻 1.显示字符串代码 2.显示图片代码(unsigned强制转换(char*)) 汉字显示

友元(个人学习笔记黑马学习)

1、全局函数做友元 #include <iostream> using namespace std; #include <string>//建筑物类 class Building {//goodGay全局函数是 Building好朋友 可以访问Building中私有成员friend void goodGay(Building* building);public:Building() {m_SittingRoom "…

SpringMVC

学习流程图&#xff1a; 四个学习模块&#xff1a; 1、SpringMVC入门 2、请求与响应 3、rest风格 4、ssm整合 5、拦截器 第一章、SpringMVC入门 1、简介 SpringMVC 是一种基于 Java 的实现 MVC 设计模型的请求驱动类型的轻量级 Web 框架。 SpringMVC的开发步骤&#xff1a; …

vscode调教配置:快捷修复和格式化代码

配置vscode快捷键&#xff0c;让你像使用idea一样使用vscode&#xff0c;我们最常用的两个功能就是格式化代码和快捷修复&#xff0c;所以这里修改一下快捷修复和格式化代码的快捷键。 在设置中&#xff0c;找到快捷键配置&#xff1a; 然后搜索&#xff1a;快捷修复 在快捷键…

即时通讯开发中的性能优化技巧

即时通讯开发在如今的数字化社会中扮演着重要角色&#xff0c;然而&#xff0c;随着用户对即时通讯应用的需求不断增长&#xff0c;开发者们面临着使其应用保持高性能和可靠性的挑战。本文将探讨即时通讯开发中关键的性能优化技巧&#xff0c;帮助开发者们提升应用的用户体验和…

基于Java的OA办公管理系统,Spring Boot框架,vue技术,mysql数据库,前台+后台,完美运行,有一万一千字论文。

基于Java的OA办公管理系统&#xff0c;Spring Boot框架&#xff0c;vue技术&#xff0c;mysql数据库&#xff0c;前台后台&#xff0c;完美运行&#xff0c;有一万一千字论文。 系统中的功能模块主要是实现管理员和员工的管理&#xff1b; 管理员&#xff1a;个人中心、普通员工…

c语言每日一练(13)

前言&#xff1a;每日一练系列&#xff0c;每一期都包含5道选择题&#xff0c;2道编程题&#xff0c;博主会尽可能详细地进行讲解&#xff0c;令初学者也能听的清晰。每日一练系列会持续更新&#xff0c;上学期间将看学业情况更新。 五道选择题&#xff1a; 1、程序运行的结果…

图文详解PhPStudy安装教程

版权声明 本文原创作者&#xff1a;谷哥的小弟作者博客地址&#xff1a;http://blog.csdn.net/lfdfhl 官方下载 请在PhPStudy官方网站下载安装文件&#xff0c;官方链接如下&#xff1a;https://m.xp.cn/linux.html&#xff1b;图示如下&#xff1a; 请下载PhPStudy安装文件…

SpringCloudAlibaba Gateway(一)简单集成

SpringCloudAlibaba Gateway(一)简单集成 随着服务模块的增加&#xff0c;一定会产生多个接口地址&#xff0c;那么客户端调用多个接口只能使用多个地址&#xff0c;维护多个地址是很不方便的&#xff0c;这个时候就需要统一服务地址。同时也可以进行统一认证鉴权的需求。那么服…

golong基础相关操作--一

package main//go语言以包作为管理单位&#xff0c;每个文件必须先声明包 //程序必须有一个main包 // 导入包&#xff0c;必须要要使用 // 变量声明了&#xff0c;必须要使用 import ("fmt" )/* * 包内部的变量 */ var aa 3var ss "kkk"var bb truevar …

使用Python对数据的操作转换

1、列表加值转字典 在Python中&#xff0c;将列表的值转换为字典的键可以使用以下代码&#xff1a; myList ["name", "age", "location"] myDict {k: None for k in myList} print(myDict) 输出&#xff1a; {name: None, age: None, loca…

EasyExcel导出模板实现下拉选(解决下拉超过50个限制)

学习地址&#xff1a;https://d9bp4nr5ye.feishu.cn/wiki/O3obweIbgi2Rk1ksXJncpClTnAfB站视频&#xff1a;https://www.bilibili.com/video/BV1H34y1T7Lm 先来看看最终实现效果&#xff0c;如果效果是你想要的&#xff0c;再看看实现逻辑。 EasyExcel本身是支持设置下拉校验的…

小程序中如何给会员卡设置到期时间

通过设置会员卡到期时间&#xff0c;可以有效地管理会员卡的使用周期&#xff0c;提供更好的会员服务体验。下面将介绍一种常见的给会员卡设置到期时间的方法。 1. 找到指定的会员卡。在管理员后台->会员管理处&#xff0c;找到需要设置到期时间的会员卡。也支持对会员卡按…

OS 内存分区和分页 多级页表与快表

每个进程的PCB都有一个LDT 内存紧缩不实用&#xff0c;所需时间太长 类似于段表&#xff0c;存在页表 但是不连续需要的空间太多了&#xff0c;太麻烦了 多级页表&#xff1a;类比于书的章目录和节目录 构建页目录 每个页目录号指向4M的地址 快表是寄存器&#xff0c;很昂…

《Python入门到精通》webbrowser模块详解,Python webbrowser标准库,Python浏览器控制工具

「作者主页」&#xff1a;士别三日wyx 「作者简介」&#xff1a;CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者 「推荐专栏」&#xff1a;小白零基础《Python入门到精通》 webbrowser模块详解 1、常用操作2、函数大全webbrowser.open() 打开浏览器webbro…