使用自动模型

本文通过文本分类任务演示了HuggingFace自动模型使用方法,既不需要手动计算loss,也不需要手动定义下游任务模型,通过阅读自动模型实现源码,提高NLP建模能力。

一.任务和数据集介绍
1.任务介绍
前面章节通过手动方式定义下游任务模型,HuggingFace也提供了一些常见的预定义下游任务模型,如下所示:

说明:包括预测下一个词,文本填空,问答任务,文本摘要,文本分类,命名实体识别,翻译等。

2.数据集介绍
本文使用ChnSentiCorp数据集,不清楚的可以参考中文情感分类介绍。一些样例如下所示:

二.准备数据集
1.使用编码工具

def load_encode_tool(pretrained_model_name_or_path):"""加载编码工具"""tokenizer = BertTokenizer.from_pretrained(Path(f'{pretrained_model_name_or_path}'))return tokenizer
if __name__ == '__main__':# 测试编码工具pretrained_model_name_or_path = r'L:/20230713_HuggingFaceModel/bert-base-chinese'tokenizer = load_encode_tool(pretrained_model_name_or_path)print(tokenizer)

输出结果如下所示:

BertTokenizer(name_or_path='L:\20230713_HuggingFaceModel\bert-base-chinese', vocab_size=21128, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True)

2.定义数据集
直接使用HuggingFace数据集对象,如下所示:

def load_dataset_from_disk():pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\ChnSentiCorp'dataset = load_from_disk(pretrained_model_name_or_path)return dataset
if __name__ == '__main__':# 加载数据集dataset = load_dataset_from_disk()print(dataset)

输出结果如下所示:

DatasetDict({train: Dataset({features: ['text', 'label'],num_rows: 9600})validation: Dataset({features: ['text', 'label'],num_rows: 1200})test: Dataset({features: ['text', 'label'],num_rows: 1200})
})

3.定义计算设备

# 定义计算设备
device = 'cpu'
if torch.cuda.is_available():device = 'cuda'
# print(device)

4.定义数据整理函数

def collate_fn(data):sents = [i['text'] for i in data]labels = [i['label'] for i in data]#编码data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents, # 输入文本truncation=True, # 是否截断padding=True, # 是否填充max_length=512, # 最大长度return_tensors='pt') # 返回的类型#转移到计算设备for k, v in data.items():data[k] = v.to(device)data['labels'] = torch.LongTensor(labels).to(device)return data

5.定义数据集加载器

# 数据集加载器
loader = torch.utils.data.DataLoader(dataset=dataset['train'], batch_size=16, collate_fn=collate_fn, shuffle=True, drop_last=True)
print(len(loader))# 查看数据样例
for i, data in enumerate(loader):break
for k, v in data.items():print(k, v.shape)

输出结果如下所示:

600
input_ids torch.Size([16, 200])
token_type_ids torch.Size([16, 200])
attention_mask torch.Size([16, 200])
labels torch.Size([16])

三.加载自动模型
使用HuggingFace的AutoModelForSequenceClassification工具类加载自动模型,来实现文本分类任务,代码如下:

# 加载预训练模型
model = AutoModelForSequenceClassification.from_pretrained(Path(f'{pretrained_model_name_or_path}'), num_labels=2)
model.to(device)
print(sum(i.numel() for i in model.parameters()) / 10000)

四.训练和测试
1.训练
需要说明自动模型本身包括loss计算,因此在train()中就不再需要手工计算loss,如下所示:

def train():# 定义优化器optimizer = AdamW(model.parameters(), lr=5e-4)# 定义学习率调节器scheduler = get_scheduler(name='linear', # 调节器名称num_warmup_steps=0, # 预热步数num_training_steps=len(loader), # 训练步数optimizer=optimizer) # 优化器# 将模型切换到训练模式model.train()# 按批次遍历训练集中的数据for i, data in enumerate(loader):# print(i, data)# 模型计算out = model(**data)# 计算1oss并使用梯度下降法优化模型参数out['loss'].backward() # 反向传播optimizer.step() # 优化器更新scheduler.step() # 学习率调节器更新optimizer.zero_grad() # 梯度清零model.zero_grad() # 梯度清零# 输出各项数据的情况,便于观察if i % 10 == 0:out_result = out['logits'].argmax(dim=1)accuracy = (out_result == data.labels).sum().item() / len(data.labels)lr = optimizer.state_dict()['param_groups'][0]['lr']print(i, out['loss'].item(), lr, accuracy)

其中,out数据结构如下所示:

2.测试

def test():# 定义测试数据集加载器loader_test = torch.utils.data.DataLoader(dataset=dataset['test'],batch_size=32,collate_fn=collate_fn,shuffle=True,drop_last=True)# 将下游任务模型切换到运行模式model.eval()correct = 0total = 0# 按批次遍历测试集中的数据for i, data in enumerate(loader_test):# 计算5个批次即可,不需要全部遍历if i == 5:breakprint(i)# 计算with torch.no_grad():out = model(**data)# 统计正确率out = out['logits'].argmax(dim=1)correct += (out == data.labels).sum().item()total += len(data.labels)print(correct / total)

五.深入自动模型源代码
1.加载配置文件过程
在执行AutoModelForSequenceClassification.from_pretrained(Path(f'{pretrained_model_name_or_path}'), num_labels=2)时,实际上调用了AutoConfig.from_pretrained(),该函数返回的config对象内容如下所示:

config对象如下所示:

BertConfig {"_name_or_path": "L:\\20230713_HuggingFaceModel\\bert-base-chinese","architectures": ["BertForMaskedLM"],"attention_probs_dropout_prob": 0.1,"classifier_dropout": null,"directionality": "bidi","hidden_act": "gelu","hidden_dropout_prob": 0.1,"hidden_size": 768,"initializer_range": 0.02,"intermediate_size": 3072,"layer_norm_eps": 1e-12,"max_position_embeddings": 512,"model_type": "bert","num_attention_heads": 12,"num_hidden_layers": 12,"pad_token_id": 0,"pooler_fc_size": 768,"pooler_num_attention_heads": 12,"pooler_num_fc_layers": 3,"pooler_size_per_head": 128,"pooler_type": "first_token_transform","position_embedding_type": "absolute","transformers_version": "4.32.1","type_vocab_size": 2,"use_cache": true,"vocab_size": 21128
}

(1)_name_or_path=bert-base-chinese:模型名字。
(2)attention_probs_DropOut_prob=0.1:注意力层DropOut的比例。
(3)hidden_act=gelu:隐藏层的激活函数。
(4)hidden_DropOut_prob=0.1:隐藏层DropOut的比例。
(5)hidden_size=768:隐藏层神经元的数量。
(6)layer_norm_eps=1e-12:标准化层的eps参数。
(7)max_position_embeddings=512:句子的最大长度。
(8)model_type=bert:模型类型。
(9)num_attention_heads=12:注意力层的头数量。
(10)num_hidden_layers=12:隐藏层层数。
(11)pad_token_id=0:PAD的编号。
(12)pooler_fc_size=768:池化层的神经元数量。
(13)pooler_num_attention_heads=12:池化层的注意力头数。
(14)pooler_num_fc_layers=3:池化层的全连接神经网络层数。
(15)vocab_size=21128:字典的大小。

2.初始化模型过程
BertForSequenceClassification类构造函数包括一个BERT模型和全连接神经网络,基本思路为通过BERT提取特征,通过全连接神经网络进行分类,如下所示:

def __init__(self, config):super().__init__(config)self.num_labels = config.num_labelsself.config = configself.bert = BertModel(config)classifier_dropout = (config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob)self.dropout = nn.Dropout(classifier_dropout)self.classifier = nn.Linear(config.hidden_size, config.num_labels)# Initialize weights and apply final processingself.post_init()

通过forward()函数可证明以上推测,根据问题类型为regression(MSELoss()损失函数)、single_label_classification(CrossEntropyLoss()损失函数)和multi_label_classification(BCEWithLogitsLoss()损失函数)选择损失函数。

参考文献:
[1]HuggingFace自然语言处理详解:基于BERT中文模型的任务实战
[2]https://github.com/ai408/nlp-engineering/blob/main/20230625_HuggingFace自然语言处理详解/第12章:使用自动模型.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/119259.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Java 动态数据统计图】动态数据统计思路案例(动态,排序,动态数组(重点推荐))七(129)

需求:前端根据后端的返回数据:画统计图; 说明: 1.X轴为地域,Y轴为地域出现的次数; 2. 动态展示(有地域展示,没有不展示,且高低排序) Demo案例: …

Python 分析HTTP的可靠性

在这篇文章中,我们将介绍如何使用 Python 来分析代理服务提供商的可靠性。代理服务在许多场景中都非常有用,例如突破地理限制、保护隐私和提高网络安全性。然而,并非所有的代理服务提供商都是可靠的。因此,我们将使用 Python 来测…

leetcode 941. 有效的山脉数组

2023.9.2 可以用双指针法来做,left指向数组起点,right指向数组终点,left满足条件则左移,right满足条件则右移,最终两指针重合则返回true。 期间任一条件不满足则返回false。 代码如下: class Solution { p…

【80天学习完《深入理解计算机系统》】第十二天3.6数组和结构体

专注 效率 记忆 预习 笔记 复习 做题 欢迎观看我的博客,如有问题交流,欢迎评论区留言,一定尽快回复!(大家可以去看我的专栏,是所有文章的目录)   文章字体风格: 红色文字表示&#…

OJ练习第159题——消灭怪物的最大数量

消灭怪物的最大数量 力扣链接:1921. 消灭怪物的最大数量 题目描述 你正在玩一款电子游戏,在游戏中你需要保护城市免受怪物侵袭。给你一个 下标从 0 开始 且长度为 n 的整数数组 dist ,其中 dist[i] 是第 i 个怪物与城市的 初始距离&#…

筑牢数据隐私安全底线,ADSCOPE通过ISO隐私信息管理体系认证!

数字时代,信息安全尤其是数据隐私信息保护已经成为社会共识。近日,ADSCOPE(上海倍孜网络技术有限公司)已通过相关组织机构评审,符合ISO/IEC 27701:2019标准,获得隐私信息管理体系认证证书&#…

Jmeter(三十一):制造大批量的用户数据数据

需求:因测试需要,要造100w用户数据,通过用户名、手机号、密码可新增用户,其中用户名和电话号码要求100w用户不能重复 要点: 1、通过Bean shell Sampler实现用户名和手机号的足够随机。 符合我们常用规则的手机号&#…

torch.bmm功能解读

bmm 是 batched matrix multiple 的简写,即批量矩阵乘法,矩阵是二维的,加上batch一个维度,因此该函数的输入必须是两个三维的 tensor,三个维度代表的含义分别是:(批量,行&#xff0c…

JVM学习(五)--方法区

概念: 方法区就是存和类相关的东西,成员方法,方法参数,成员变量,构造方法,类加载器等,逻辑上存在于堆中,但是不同的虚拟机对它的实现不同,oracle的hotsport vm在1.6的时…

腾讯云-对象存储服务(COS)的使用总结-JavaScript篇

简介 对象存储(Cloud Object Storage,COS)是腾讯云提供的一种存储海量文件的分布式存储服务,具有高扩展性、低成本、可靠安全等优点。通过控制台、API、SDK 和工具等多样化方式,用户可简单、快速地接入 COS&#xff0…

部署单点elasticsearch

部署elasticsearch 创建网络 因为我们还需要部署kibana容器,因此需要让es和kibana容器互联。这里先创建一个网络 docker network create es-net 拉取镜像 我们采用elasticsearch的7.12.1版本的镜像 docker pull elasticsearch:7.12.1 运行 运行docker命令&a…

苹果启动2024年SRDP计划:邀请安全专家使用定制iPhone寻找漏洞

苹果公司昨天(8月30日)正式宣布开始接受2024 年iPhone安全研究设备计划的申请,iOS 安全研究人员可以在 10 月底之前申请安全研究设备 SRD。 SRD设备是专门向安全研究人员提供的iPhone14Pro,该设备具有专为安全研究而设计的特殊硬…

如何在java中做基准测试

最近公司在搞新项目,由于是实验性质,且不会直接面对客户的项目,这次的技术选型非常激进,如,直接使用了Java 17。 作为公司里练习两年半的个人练习生,我自然也是深度的参与到了技术选型的工作中。不知道大家…

敏捷开发、V模型开发、瀑布模型

在软件开发领域,敏捷开发和V模型开发是两种主要的开发方法。它们之间的差异主要体现在开发过程的结构和组织方式上。在以下讨论中,我们将深入探讨这两种方法的特点和差异。 敏捷开发 敏捷开发是一种迭代和增量的软件开发方法,它强调灵活性和…

docker笔记7:Docker微服务实战

1.通过IDEA新建一个普通微服务模块 建Module docker_boot 改POM <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi…

解决Apache Tomcat “Request header is too large“ 异常 ‍

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

正则表达式 之 断言详解

正则表达式的先行断言和后行断言一共有 4 种形式&#xff1a; (?pattern) 零宽正向先行断言(zero-width positive lookahead assertion)(?!pattern) 零宽负向先行断言(zero-width negative lookahead assertion)(?<pattern) 零宽正向后行断言(zero-width positive lookb…

rrweb录制用户的操作过程,并上传服务端

1、客户端 准备工作&#xff0c;需要使用到的包有rrweb&#xff08;录制&#xff09; rrwebPlayer&#xff08;播放&#xff09; pako&#xff08;压缩&#xff09; 1.1、录制&#xff1a;1.2、pako 压缩工具的使用方式 import * as rrweb from rrweblet dispose null let rr…

机械制图(CAD)

目录 第一课&#xff08;80分钟&#xff09; 第二课&#xff08;80分钟&#xff09; 力啥学机械制图&#xff1f;我们的工厂要加工机械&#xff0c;而加工机械零件前&#xff0c;我们要先给工人图纸来看,工人才知道该怎样加工&#xff0c;所以我们今天就来学习下怎样画出符何…

el-table中点击跳转到详情页的两种方法

跳转的两种写法: 1.使用keep-alive使组件缓存,防止刷新时参数丢失 keep-alive 组件用于缓存和保持组件的状态&#xff0c;而不是路由参数。它可以在组件切换时保留组件的状态&#xff0c;从而避免重新渲染和加载数据。 keep-alive 主要用于提高页面性能和用户体验&#xff0c;而…