【AI大模型】Transformers大模型库(十一):Trainer训练类

 

目录

一、引言 

二、Trainer训练类

2.1 概述

2.2 使用示例

三、总结


一、引言 

 这里的Transformers指的是huggingface开发的大模型库,为huggingface上数以万计的预训练大模型提供预测、训练等服务。

🤗 Transformers 提供了数以千计的预训练模型,支持 100 多种语言的文本分类、信息抽取、问答、摘要、翻译、文本生成。它的宗旨是让最先进的 NLP 技术人人易用。
🤗 Transformers 提供了便于快速下载和使用的API,让你可以把预训练模型用在给定文本、在你的数据集上微调然后通过 model hub 与社区共享。同时,每个定义的 Python 模块均完全独立,方便修改和快速研究实验。
🤗 Transformers 支持三个最热门的深度学习库: Jax, PyTorch 以及 TensorFlow — 并与之无缝整合。你可以直接使用一个框架训练你的模型然后用另一个加载和推理。

本文重点介绍Trainer训练类

二、Trainer训练类

2.1 概述

2.2 使用示例

from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset# 1. 加载数据集
# 假设我们使用的是Hugging Face的内置数据集,例如SST-2
dataset = load_dataset('sst2')  # 或者使用你自己的数据集# 2. 数据预处理,可能需要根据模型进行Tokenization
# 以BERT为例,使用AutoTokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def tokenize_function(examples):return tokenizer(examples["sentence"], truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)# 3. 准备训练参数
training_args = TrainingArguments(output_dir='./results',          # 输出目录num_train_epochs=3,              # 总的训练轮数per_device_train_batch_size=16,  # 每个GPU的训练批次大小per_device_eval_batch_size=64,   # 每个GPU的评估批次大小warmup_steps=500,                # 预热步数weight_decay=0.01,               # 权重衰减logging_dir='./logs',            # 日志目录logging_steps=10,
)# 4. 准备模型
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")# 5. 创建Trainer并开始训练
trainer = Trainer(model=model,                         # 要训练的模型args=training_args,                  # 训练参数train_dataset=tokenized_datasets['train'],  # 训练数据集eval_dataset=tokenized_datasets['validation'], # 验证数据集
)# 开始训练
trainer.train()

整个流程是机器学习项目中的标准流程:数据准备、模型选择、参数设置、训练与评估。每个步骤都是为了确保模型能够高效、正确地训练,以解决特定的机器学习任务:

  • 加载数据集 (load_dataset('sst2')):这行代码是使用Hugging Face的datasets库加载SST-2数据集,这是一个情感分析任务的数据集。如果你使用自定义数据集,需要相应地处理和加载数据。
  • 数据预处理 (tokenizer(examples["sentence"], truncation=True)):在训练模型之前,需要将文本数据转换为模型可以理解的格式。这里使用AutoTokenizer对文本进行分词(Tokenization),truncation=True意味着如果句子超过模型的最大输入长度,将截断超出部分。这一步是将文本转换为模型输入的张量格式。
  • 训练参数 (TrainingArguments):这部分定义了训练过程的配置,包括训练轮数(num_train_epochs)、每个设备的训练和评估批次大小、预热步数(warmup_steps)、权重衰减(weight_decay)等。这些参数对训练效率和模型性能有重要影响。
  • 准备模型 (AutoModelForSequenceClassification.from_pretrained()):这里选择或初始化模型,AutoModelForSequenceClassification是用于序列分类任务的模型,from_pretrained方法加载预训练的模型权重。选择的模型(如BERT的“bert-base-uncased”)是基于任务需求的。
  • 创建Trainer (Trainer):Trainer是Transformers库中的核心类,它负责模型的训练和评估流程。它接收模型、训练参数、训练数据集和评估数据集作为输入。Trainer自动处理了训练循环、损失计算、优化器更新、评估、日志记录等复杂操作,使得训练过程更加简洁和高效。
  • 开始训练 (trainer.train()):调用此方法开始模型的训练过程。Trainer会根据之前设定的参数和数据进行模型训练,并在每个指定的步骤打印日志,训练完成后,模型的权重会保存到指定的输出目录。

三、总结

本文对transformers训练类Trainer进行讲述并赋予应用代码,希望可以帮到大家!

如果您还有时间,可以看看我的其他文章:

《AI—工程篇》

AI智能体研发之路-工程篇(一):Docker助力AI智能体开发提效

AI智能体研发之路-工程篇(二):Dify智能体开发平台一键部署

AI智能体研发之路-工程篇(三):大模型推理服务框架Ollama一键部署

AI智能体研发之路-工程篇(四):大模型推理服务框架Xinference一键部署

AI智能体研发之路-工程篇(五):大模型推理服务框架LocalAI一键部署

《AI—模型篇》

AI智能体研发之路-模型篇(一):大模型训练框架LLaMA-Factory在国内网络环境下的安装、部署及使用

AI智能体研发之路-模型篇(二):DeepSeek-V2-Chat 训练与推理实战

AI智能体研发之路-模型篇(三):中文大模型开、闭源之争

AI智能体研发之路-模型篇(四):一文入门pytorch开发

AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比

AI智能体研发之路-模型篇(六):【机器学习】基于tensorflow实现你的第一个DNN网络

AI智能体研发之路-模型篇(七):【机器学习】基于YOLOv10实现你的第一个视觉AI大模型

AI智能体研发之路-模型篇(八):【机器学习】Qwen1.5-14B-Chat大模型训练与推理实战

AI智能体研发之路-模型篇(九):【机器学习】GLM4-9B-Chat大模型/GLM-4V-9B多模态大模型概述、原理及推理实战

《AI—Transformers应用》

【AI大模型】Transformers大模型库(一):Tokenizer

【AI大模型】Transformers大模型库(二):AutoModelForCausalLM

【AI大模型】Transformers大模型库(三):特殊标记(special tokens)

【AI大模型】Transformers大模型库(四):AutoTokenizer

【AI大模型】Transformers大模型库(五):AutoModel、Model Head及查看模型结构

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/355394.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

千呼新零售2.0【更新日志】持续更新ing

千呼新零售2.0系统是零售行业连锁店一体化收银系统,包括线下收银线上商城连锁店管理ERP管理商品管理供应商管理会员营销等功能为一体,线上线下数据全部打通。 适用于商超、便利店、水果、生鲜、母婴、服装、零食、百货等连锁店使用。 详细介绍请查看下…

Mybatis中BaseEntity作用

新建各种对象的时候,一般来说,有几个属性是所有对象共有的,比如说id,is_del,is_enable这些,然后设置一个基础对象,以后新建所有对象的时候都继承它,就省的每次都要写这些共有的属性了

GPT-4系列模型,在文档理解中的多维度评测

著名云数据平台Snowflake的研究人员发布了一篇论文,主要对OpenAI的GPT-4系列模型进行了研究,查看其文本生成、图像理解、文档摘要等能力。 在DocVQA、InfographicsVQA、SlideVQA和DUDE数据集上对GPT-4、GPT-4 V、GPT-4 Turbo V OCR等进行了多维度测试。…

SD3发布,送你3个ComfyUI工作流

大家好,我是每天分享AI应用的萤火君! 这几天AI绘画界最轰动的消息莫过于Stable Diffusion 3(简称SD3)的发布。SD3是一个多模态的 Diffusion Transformer 模型,其在图像质量、排版、复杂提示理解和资源效率方面具有显著…

计算机考研|408第二轮复习是二刷王道还是先看强化课?

在基础复习完一轮后,大部分人会把前面的内容忘掉很多!这个时候不要着急进入强化,在强化阶段之前先把4本书再重新整理复习一遍,查缺补漏。然后再看王道强化课! 对于408这门具有大量知识需要学习的专业课,有…

【深度学习】智能手写数字识别系统

文章目录 一.实验课题背景说明1.1实验目的1.2实验环境1.2.1安装PyTorch1.2.2安装其他必要的库 二.模型说明2.1模型概述2.2模型结构 三.数据说明3.1 输入数据3.1.1输入数据特征3.1.2输入数据维度3.1.3输入数据预处理 3.2 数据格式3.2.1输出数据…

如何调用讯飞星火认知大模型的API以利用其卓越功能

摘要 讯飞星火认知大模型,作为科大讯飞精心打造的一款人工智能模型,在自然语言理解和生成方面展现出了卓越的能力。这款模型通过深度学习技术和大量数据的训练,具备了强大的语言理解、文本生成和对话交互等功能。 一、模型功能概述 讯飞星…

linux端口被占用 关闭端口

系列文章目录 文章目录 系列文章目录一、linux端口被占用 关闭端口1.参考链接2.具体命令 二、【linux关闭进程命令】fuser -k 和 kill -9 的区别1.参考链接2.具体命令 一、linux端口被占用 关闭端口 1.参考链接 linux端口被占用 关闭端口 2.具体命令 1.查看端口是否被占用 …

商超仓库管理系统

摘要 随着全球经济和互联网技术的快速发展,依靠互联网技术的各种管理系统逐渐应用到社会的方方面面。各行业的有识之士都逐渐开始意识到过去传统的人工管理模式已经逐渐成为企业发展的绊脚石,不再适应现代企业的发展需要。企业想要得到更好的发展&#…

FreeRtos-13资源管理

一、临界资源是什么 要独占式地访问临界资源,有3种方法: 1.公平竞争:比如使用互斥量,谁先获得互斥量谁就访问临界资源,这部分内容前面讲过。 谁要跟我抢,我就灭掉谁: 2.中断要跟我抢?我屏蔽中断 3.其他任务要跟我抢?我禁止调度器,不运行任务切换 二、暂停调度器…

【C语言】自定义类型

目录 一、结构体: 1、结构体的声明: 2、结构体的自引用: 3、结构体变量的定义和初始化: 4、结构体内存对齐: 5、结构体传参: 6、位段: 二、枚举类型: 三、联合体&#xff1a…

网络安全:什么是SQL注入

文章目录 网络安全:什么是SQL注入引言SQL注入简介工作原理示例代码 攻击类型为什么SQL注入危险结语 网络安全:什么是SQL注入 引言 在数字化时代,数据安全成为了企业和个人最关心的问题之一。SQL注入(SQL Injection)是…

【LLM之RAG】RAT论文阅读笔记

研究背景 近年来,大型语言模型(LLMs)在各种自然语言推理任务上取得了显著进展,尤其是在结合大规模模型和复杂提示策略(如链式思维提示(CoT))时。然而,LLMs 在推理的事实…

C++的智能指针 RAII

目录 产生原因 RAII思想 C11的智能指针 智能指针的拷贝与赋值 shared_ptr的拷贝构造 shared_ptr的赋值重置 shared_ptr的其它成员函数 weak_ptr 定制删除器 简单实现 产生原因 产生原因:抛异常等原因导致的内存泄漏 int div() {int a, b;cin >> a…

手机usb共享网络电脑没反应的方法

适用于win10电脑,安卓手机上可以 开启usb网络共享选择,如果选择后一直跳,让重复选择usb选项的话,就开启 开发者模式,进到 开发者模式 里设置 默认usb 共享网络 选项 ,就不会一直跳让你选。 1.先用数据线 连…

八大经典排序算法

前言 本片博客主要讲解一下八大排序算法的思想和排序的代码 💓 个人主页:普通young man-CSDN博客 ⏩ 文章专栏:排序_普通young man的博客-CSDN博客 若有问题 评论区见📝 🎉欢迎大家点赞👍收藏⭐文章 目录 …

HTTP详细总结

概念 HyperText Transfer Protocol,超文本传输协议,规定了浏览器和服务器之间数据传输的规则。 特点 基于TCP协议: 面向连接,安全 TCP是一种面向连接的(建立连接之前是需要经过三次握手)、可靠的、基于字节流的传输层通信协议,在…

Linux管道与重定向

管道 是进程通信的方法之一,在Linux中用命令1|命令2的形式表示,将前一个命令的结果作为后续命令的参数进行输入,也有tee管道,可以进行多次筛选,即多次使用|过滤命令。 重定向 文件描述符FD Linux中输入输出分为三种…

C语言变量、指针的内存关系

1. type p ? 表示从内存地址p开始,开辟一段内存,内存大小为类型type规定的字节数,然后把等号右边的值写入到这段内存中。 因此,这块内存起点位置是p,结束是ptype字节数-1。 2. type* p ?表示从内存地址p开始&…

SpingBoot快速入门下

响应HttpServietResponse 介绍 将ResponseBody 加到Controller方法/类上 作用:将方法返回值直接响应,如果返回值是 实体对象/集合,将会自动转JSON格式响应 RestController Controller ResponseBody; 一般响应 统一响应 在实际开发中一般…