HuggingFace模型头的自定义

 

在线工具推荐:  Three.js AI纹理开发包 -  YOLO合成数据生成器 -  GLTF/GLB在线编辑 -  3D模型格式在线转换 -  可编程3D场景编辑器

在本文中我们将介绍如何使HuggingFace的模型适应你的任务,在Pytorch中建立自定义模型头并将其连接到HF模型的主体,并端到端地训练系统。

1、HF模型头和模型体

这是典型的HF模型的样子:

为什么我需要单独使用模型头(Model Head)和模型体(Model Body)?

一些HF的模型针对下游任务(例如提问或文本分类)训练,并包含有关其权重培训的数据的知识。

有时,尤其是当我们手头的任务包含很少的数据或领域特定(例如医学或运动特定任务)时,我们可以在HUB上使用其他任务训练的模型(不一定与我们的任务相同的任务 手但属于相同领域,例如运动或药物),并利用一些验证的知识来提高我们模型在我们自己任务的性能表现。

  • 一个非常简单的例子是,如果说我们有一个小数据集,比如分类某些财务报表是积极还是负面的。 但是,我们进入了HF,发现许多模型已经经过与金融相关的问答数据集的训练,那么 我们可以使用这些模型的某些层来改进自己的任务。
  • 另一个简单的示例是,某个特定领域的模型经过巨大数据集的训练学会了将文本从中分为5个类别。 假设我们有类似的分类任务,在同一域中的一个完全不同的数据集,只想将数据分类为2个类别而不是5。 这时我们也可以复用模型主体,添加自己的模型头来增强我们自己任务的特定领域知识。

这就是我们要做的事情的示意图:

2、自定义HF模型头

我们的任务是简单的,从Kaggle上的这个数据集进行讽刺检测。

你可以在此处查看完整的代码。 为了时间的考虑,我没有在下面包括预处理和一些训练的详细信息,因此请确保查看整个代码的笔记本。

我将使用一个在大量推文上训练的模型,有5个分类输出不同的情感类型。我们将提取模型体,在pytorch中添加自定义层(2个标签,讽刺/不讽刺),并训练新的模型。

注意:你可以在此示例中使用任何模型(不一定是对分类训练的模型),因为我们只会使用该模型主体并拆除模型头。

这就是我们的工作流程:

我将跳过数据预处理步骤,然后直接跳到主类,但是你可以在本节开头的链接中查看整个代码。

3、令牌化和动态填充

使用如下代码将文本转化为令牌并进行动态填充:

checkpoint = "cardiffnlp/twitter-roberta-base-emotion"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.model_max_len=512def tokenize(batch):return tokenizer(batch["headline"], truncation=True,max_length=512)tokenized_dataset = data.map(tokenize, batched=True)
print(tokenized_dataset)tokenized_dataset.set_format("torch",columns=["input_ids", "attention_mask", "label"])
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

结果如下:

DatasetDict({train: Dataset({features: ['headline', 'label', 'input_ids', 'attention_mask'],num_rows: 22802})test: Dataset({features: ['headline', 'label', 'input_ids', 'attention_mask'],num_rows: 2851})valid: Dataset({features: ['headline', 'label', 'input_ids', 'attention_mask'],num_rows: 2850})
})

4、提取模型体并添加我们自己的层

代码如下:

class CustomModel(nn.Module):def __init__(self,checkpoint,num_labels): super(CustomModel,self).__init__() self.num_labels = num_labels #Load Model with given checkpoint and extract its bodyself.model = model = AutoModel.from_pretrained(checkpoint,config=AutoConfig.from_pretrained(checkpoint, output_attentions=True,output_hidden_states=True))self.dropout = nn.Dropout(0.1) self.classifier = nn.Linear(768,num_labels) # load and initialize weightsdef forward(self, input_ids=None, attention_mask=None,labels=None):#Extract outputs from the bodyoutputs = self.model(input_ids=input_ids, attention_mask=attention_mask)#Add custom layerssequence_output = self.dropout(outputs[0]) #outputs[0]=last hidden statelogits = self.classifier(sequence_output[:,0,:].view(-1,768)) # calculate lossesloss = Noneif labels is not None:loss_fct = nn.CrossEntropyLoss()loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states,attentions=outputs.attentions)

如你所见,我们首先是继承Pytorch中的 nn.Module,使用AutoModel(来自transformers库)提取加载了指定检查点的模型主体。

请注意, forward() 方法返回 TokenClassifierOutput,从而确保我们输出的格式与HF预训练模型一致。

5、端到端训练新的模型

代码如下:

from tqdm.auto import tqdmprogress_bar_train = tqdm(range(num_training_steps))
progress_bar_eval = tqdm(range(num_epochs * len(eval_dataloader)))for epoch in range(num_epochs):model.train()for batch in train_dataloader:batch = {k: v.to(device) for k, v in batch.items()}outputs = model(**batch)loss = outputs.lossloss.backward()optimizer.step()lr_scheduler.step()optimizer.zero_grad()progress_bar_train.update(1)model.eval()for batch in eval_dataloader:batch = {k: v.to(device) for k, v in batch.items()}with torch.no_grad():outputs = model(**batch)logits = outputs.logitspredictions = torch.argmax(logits, dim=-1)metric.add_batch(predictions=predictions, references=batch["labels"])progress_bar_eval.update(1)print(metric.compute())model.eval()test_dataloader = DataLoader(tokenized_dataset["test"], batch_size=32, collate_fn=data_collator
)for batch in test_dataloader:batch = {k: v.to(device) for k, v in batch.items()}with torch.no_grad():outputs = model(**batch)logits = outputs.logitspredictions = torch.argmax(logits, dim=-1)metric.add_batch(predictions=predictions, references=batch["labels"])metric.compute()

结果如下:

  0%|          | 0/2139 [00:00<?, ?it/s]0%|          | 0/270 [00:00<?, ?it/s]
{'f1': 0.9335347432024169}
{'f1': 0.9360090874668686}
{'f1': 0.9274912756882513}

如你所见,我们使用此方法实现了不错的性能。 请记住,该博客的目的不是分析此特定数据集的性能,而是要学习如何使用预训练的身体并添加自定义头。

6、结束语

在本文中,我们看到了如何在HF预训练模型上添加自定义层。

一些收获:

  • 在我们拥有特定于域的数据集并希望利用在同一域(任务 - 努力的task-agnostic)上训练的模型以增强小型数据集中的性能的情况下,此技术特别有用。
  • 我们可以选择接受过与自己任务不同的下游任务训练的模型,并且仍然使用该模型主体的知识。
  • 如果你的数据集足够大且通用,那么这可能根本不需要,在这种情况下,你可以使用 AutoModeForSequenceCecrification或使用 BERT 解决的任何其他任务。 实际上,如果是这样,我强烈建议不要建立自己的模型头。

原文链接:HF自定义模型头 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/188650.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大数据Doris(二十一):数据导入演示

文章目录 数据导入演示 一、启动zookeeper集群(三台节点都启动) 二、启动hdfs集群

【Linux】语言层面缓冲区的刷新问题以及简易模拟实现

文章目录 前言一、缓冲区刷新方法分类a.无缓冲--直接刷新b.行缓冲--不刷新&#xff0c;直到碰到\n才刷新c.全缓冲--缓冲区满了才刷新 二、 缓冲区的常见刷新问题1.问题2.刷新本质 三、模拟实现1.Mystdio.h2.Mystdio.c3.main.c 前言 我们接下来要谈论的是我们语言层面的缓冲区&…

Spark的转换算子和操作算子

1 Transformation转换算子 1.1 Value类型 1&#xff09;创建包名&#xff1a;com.shangjack.value 1.1.1 map()映射 参数f是一个函数可以写作匿名子类&#xff0c;它可以接收一个参数。当某个RDD执行map方法时&#xff0c;会遍历该RDD中的每一个数据项&#xff0c;并依次应用f函…

idea 一直卡在maven正在解析maven依赖

修改maven Importing的jvm参数 -Xms1024m -Xmx2048m

srs webrtc推拉流环境搭建

官方代码https://github.com/ossrs/srs 拉取代码&#xff1a; git clone https://github.com/ossrs/srs.gitcd ./configure make ./objs/srs -c conf/rtc.confconf/rtc.conf中&#xff0c;当推拉流浏览器在本地时&#xff0c;如果srs也在本地&#xff0c;那么可以使用官网默认…

VUE页面导出PDF方案

1&#xff0c;技术方案为&#xff1a;html2canvas把页面生成canvas图片&#xff0c;再通过jspdf生成PDF文件&#xff1b; 2&#xff0c;安装依赖&#xff1a; npm i html2canvas -S npm i jspdf -S 3&#xff0c;封装导出pdf方法exportPdf.js: // 页面导出为pdf格式 //titl…

STM32基础--NVIC中断控制器

一、NVIC是什么&#xff1f; NVIC是一种中断控制器。当一个中断正在处理时&#xff0c;另一个更高优先级的中断可以打断当前中断的执行&#xff0c;并立即得到处理。这种机制使得处理器在高速运行的同时&#xff0c;能够及时响应不同优先级的中断请求。 二、有哪些优先级&…

19 异步通知

一、异步通知 1. 异步通知简介 阻塞和非阻塞两种方式都是需要应用程序去主动查询设备的使用情况。 异步通知类似于驱动可以主动报告自己可以访问&#xff0c;应用程序获取信号后会从驱动设备中读取或写入数据。 异步通知最核心的就是信号&#xff1a; #define SIGHUP 1 /* 终…

Clickhouse学习笔记(5)—— ClickHouse 副本

Data Replication | ClickHouse Docs 副本的目的主要是保障数据的高可用性&#xff0c;即使一台 ClickHouse 节点宕机&#xff0c;那么也可以从其他服务器获得相同的数据 注意&#xff1a; clickhouse副本机制的实现要基于zookeeperclickhouse的副本机制只适用于MergeTree f…

API SIX系列-服务搭建(一)

APIsix简介 APISIX是一个微服务API网关&#xff0c;具有高性能、可扩展性等优点。它基于nginx&#xff08;openresty&#xff09;、Lua、etcd实现功能&#xff0c;借鉴了Kong的思路。和传统的API网关相比&#xff0c;APISIX具有较高的性能和较低的资源消耗&#xff0c;并且具有…

Lua更多语法与使用

文章目录 目的错误处理元表和元方法垃圾回收协程模块面向对象总结 目的 在前一篇文章&#xff1a; 《Lua入门使用与基础语法》 中介绍了一些基础的内容。这里将继续介绍Lua一些更多的内容。 同样的本文参考自官方手册&#xff1a; https://www.lua.org/manual/ 错误处理 下…

嵌入式养成计划-48----QT--信息管理系统:百川仓储管理

一百二十二、信息管理系统&#xff1a;百川仓储管理 122.1 UI界面 122.2 思路 客户端&#xff1a; 用户权限有两种类型&#xff0c;一种是用户权限&#xff0c;一种是管理员权限&#xff0c;登录时服务器端会根据数据库查询到的此用户名的权限返回不同的结果&#xff0c;客户…

学习c#的第五天

目录 C# 运算符 算术运算符 关系运算符 逻辑运算符 位运算符 赋值运算符 其他运算符 C# 中的运算符优先级 C# 运算符 算术运算符 下表显示了 C# 支持的所有算术运算符。假设变量 A 的值为 10&#xff0c;变量 B 的值为 20&#xff0c;则&#xff1a; 运算符描述实例…

JVM在线分析-解决问题的工具一(jinfo,jmap,jstack)

1. jinfo (base) PS C:\Users\zishi\Desktop> jinfo Usage:jinfo <option> <pid>(to connect to a running process)where <option> is one of:-flag <name> to print the value of the named VM flag #输出对应名称的参数-flag [|-]<n…

ElasticSearch7.x - HTTP 操作 - 文档操作

创建文档(添加数据) 索引已经创建好了,接下来我们来创建文档,并添加数据。这里的文档可以类比为关系型数 据库中的表数据,添加的数据格式为 JSON 格式 向 ES 服务器发 POST 请求 :http://192.168.254.101:9200/shopping/_doc 请求体内容为: {"title":"小…

【Python 千题 —— 基础篇】账号登录

题目描述 题目描述 简易登录系统。你的账号密码分别是 “student”&#xff0c;“123456”&#xff1b;请使用 if-else 设计一个简易登录系统&#xff0c;输入账号密码。登陆成功输出 “Welcome !”&#xff0c;登录失败输出 “Login failed !” 输入描述 输入账号和密码。…

idea配置tomcat参数,防止nvarchar保存韩文、俄文、日文等乱码

描述下我的场景&#xff1a; 数据库服务器在远程机器上&#xff0c;数据库使用的Oracle&#xff0c;字符集是ZHS16GBK&#xff0c;但保存韩文、俄文、日文等字段A的数据类型是nvarchar(120)&#xff0c;而nvarchar使用的是Unicode 编码&#xff0c;有点乱。。 遇到的问题&…

【机器学习范式】监督学习,无监督学习,强化学习, 半监督学习,自监督学习,迁移学习,对比分析+详解与示例代码

目录 1. 监督学习 (Supervised Learning): 2. 无监督学习 (Unsupervised Learning): 3. 强化学习 (Reinforcement Learning): 4. 半监督学习 (Semi-Supervised Learning): 5. 自监督学习 (Self-Supervised Learning): 6. 迁移学习 (Transfer Learning): 7 机器学习范式应…

AI:67-基于深度学习的脱机手写汉字识别

🚀 本文选自专栏:AI领域专栏 从基础到实践,深入了解算法、案例和最新趋势。无论你是初学者还是经验丰富的数据科学家,通过案例和项目实践,掌握核心概念和实用技能。每篇案例都包含代码实例,详细讲解供大家学习。 📌📌📌在这个漫长的过程,中途遇到了不少问题,但是…

『Linux升级路』基础开发工具——vim篇

&#x1f525;博客主页&#xff1a;小王又困了 &#x1f4da;系列专栏&#xff1a;Linux &#x1f31f;人之为学&#xff0c;不日近则日退 ❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ 目录 一、vim的基本概念 &#x1f4d2;1.1命令模式 &#x1f4d2;1.2插入模式 &…