Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略

Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略

目录

trl的简介

1、亮点

2、PPO是如何工作的:PPO对语言模型微调三步骤,Rollout→Evaluation→Optimization

trl的安装

trl的使用方法

1、基础用法

(1)、如何使用库中的SFTTrainer

(2)、如何使用库中的RewardTrainer

(3)、如何使用库中的PPOTrainer

2、进阶用法

LLMs之BELLE:源码解读(ppo_train.py文件)训练一个基于强化学习的自动对话生成模型—解析命令行参数→加载数据集(datasets库)→初始化模型分词器和PPOConfig配置参数(trl库)→模型训练(accelerate分布式训练+DeepSpeed推理加速,生成对话→计算奖励【评估生成质量】→执行PPO算法更新【改善生成文本的质量】)→模型保存之详细攻略

LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略


trl的简介

          TRL - Transformer Reinforcement Learning使用强化学习的全栈Transformer语言模型。trl 是一个全栈库,其中我们提供一组工具,用于通过强化学习训练Transformer语言模型和稳定扩散模型,从监督微调步骤(SFT)到奖励建模步骤(RM)再到近端策略优化(PPO)步骤。该库建立在Hugging Face 的 transformers 库之上。因此,可以通过 transformers 直接加载预训练语言模型。目前,大多数解码器架构和编码器-解码器架构都得到支持。请参阅文档或示例/文件夹,以查看示例代码片段以及如何运行这些工具。

GitHub地址:GitHub - huggingface/trl: Train transformer language models with reinforcement learning.

1、亮点

>> SFTTrainer:一个轻量级且友好的围绕transformer Trainer的包装器,可以在自定义数据集上轻松微调语言模型或适配器。

>> RewardTrainer: transformer Trainer的一个轻量级包装,可以轻松地微调人类偏好的语言模型(Reward Modeling)。

>> potrainer:用于语言模型的PPO训练器,它只需要(查询、响应、奖励)三元组来优化语言模型。

>> AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead:一个转换器模型,每个令牌有一个额外的标量输出,可以用作强化学习中的值函数。

>> 示例:使用BERT情感分类器训练GPT2生成积极的电影评论,仅使用适配器的完整RLHF,训练GPT-j减少毒性,Stack-Llama示例等。

2、PPO是如何工作的PPO对语言模型微调三步骤,Rollout→Evaluation→Optimization

通过PPO对语言模型进行微调大致包括三个步骤:

Rollout

Rollout(展开):语言模型基于查询生成响应或继续,查询可以是句子的开头。

Evaluation

Evaluation(评估):使用一个函数、模型、人类反馈或它们的组合来评估查询和响应。重要的是,此过程应为每个查询/响应对产生一个标量值

Optimization

Optimization(优化):这是最复杂的部分。在优化步骤中,使用查询/响应对来计算序列中token的对数概率。这是通过训练的模型和一个参考模型(通常是微调之前的预训练模型)来完成的。两个输出之间的KL-散度被用作附加奖励信号,以确保生成的响应不会偏离参考语言模型太远。然后,使用PPO训练主动语言模型。

这个过程在下面的示意图中说明。

trl的安装

pip install trl

trl的使用方法

1、基础用法

(1)、如何使用库中的SFTTrainer

以下是如何使用库中的SFTTrainer的基本示例。SFTTrainer是用于轻松微调语言模型或适配器的transformers Trainer的轻量包装器。

# imports
from datasets import load_dataset
from trl import SFTTrainer# get dataset
dataset = load_dataset("imdb", split="train")# get trainer
trainer = SFTTrainer("facebook/opt-350m",train_dataset=dataset,dataset_text_field="text",max_seq_length=512,
)# train
trainer.train()

(2)、如何使用库中的RewardTrainer

以下是如何使用库中的RewardTrainer的基本示例。RewardTrainer是用于轻松微调奖励模型或适配器的transformers Trainer的包装器。

# imports
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer# load model and dataset - dataset needs to be in a specific format
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")...# load trainer
trainer = RewardTrainer(model=model,tokenizer=tokenizer,train_dataset=dataset,
)# train
trainer.train()

(3)、如何使用库中的PPOTrainer

以下是如何使用库中的PPOTrainer的基本示例。基于查询,语言模型创建响应,然后进行评估。评估可以是人工干预或另一个模型的输出。

# imports
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch# get models
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)tokenizer = AutoTokenizer.from_pretrained('gpt2')# initialize trainer
ppo_config = PPOConfig(batch_size=1,
)# encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")# get model response
response_tensor  = respond_to_batch(model, query_tensor)# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

2、进阶用法

LLMs之BELLE:源码解读(ppo_train.py文件)训练一个基于强化学习的自动对话生成模型—解析命令行参数→加载数据集(datasets库)→初始化模型分词器和PPOConfig配置参数(trl库)→模型训练(accelerate分布式训练+DeepSpeed推理加速,生成对话→计算奖励【评估生成质量】→执行PPO算法更新【改善生成文本的质量】)→模型保存之详细攻略

https://yunyaniu.blog.csdn.net/article/details/133865725

LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略

https://yunyaniu.blog.csdn.net/article/details/133873621

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/160953.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

云汉芯城一站式电子制造平台启想智联顺利通过IATF16949:2016质量管理体系认证

近日,云汉芯城旗下一站式电子制造服务平台上海启想智能科技有限公司(以下简称“启想智联”)顺利通过IATF16949:2016质量管理体系认证,并获得由URS颁发的认证证书。通过此项认证,标志着启想智联在全球汽车行业的技术规范…

【Linux初阶】多线程3 | 线程同步,生产消费者模型(普通版、BlockingQueue版)

文章目录 ☀️一、线程同步🌻1.条件变量🌻2.同步概念与竞态条件🌻3.条件变量函数🌻4.条件变量使用规范🌻5.代码案例 ☀️二、生产者消费者模型🌻1.为何要使用生产者消费者模型🌻2.生产者消费者模…

【iOS】计算器仿写

文章目录 前言一、构建View界面二、Model中进行数据处理三、Controller层实现View与Model交互总结 前言 在前两周组内进行了计算器的仿写,计算器仿写主要用到了MVC框架的思想以及数据结构中用栈进行四则运算的思想,还有就是对OC中的字符串进行各种判错操…

C++stack和queue模拟实现以及deque的介绍

stack和queue介绍以及模拟实现 1.stack1.1stack的介绍1.2stack的使用 2.queue2.1queue的介绍2.2queue的使用 3.容器适配器3.1什么是适配器 4.stack模拟实现5.queue的模拟实现6.deque(双端队列) 1.stack 1.1stack的介绍 stack的文档介绍 stack是一种容…

Springboot整合WebSocket实现浏览器和服务器交互

Websocket定义 代码实现 引入maven依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId></dependency>配置类 import org.springframework.context.annotation.Bean;i…

2023 编程资料合集汇总

资源合集 名称链接Rabbitmq精讲&#xff0c;项目驱动落地&#xff0c;分布式事务拔高资料https://www.aliyundrive.com/s/5VwmhTCPBNa程序员书籍大全https://www.aliyundrive.com/s/Kz5UiijQB7i后端Java教程&#xff08;学完直接去BAT&#xff09;https://www.aliyundrive.com…

机器学习笔记 - 3D 对象跟踪极简概述

一、简述 大多数对象跟踪应用程序都是 2D 的。但现实世界是 3D 的,无论您是跟踪汽车、人、直升机、导弹,还是进行增强现实,您都需要使用 3D。在 CVPR 2022(计算机视觉和模式识别)会议上,已经出现了大量3D目标检测论文。 二、什么是 3D 对象跟踪? 对象跟踪是指随着时间的…

【环境搭建】linux docker-compose安装seata1.6.1,使用nacos注册、db模式

新建目录&#xff0c;挂载用 mkdir -p /data/docker/seata/resources mkdir -p /data/docker/seata/logs 给权限 chmod -R 777 /data/docker/seata 先在/data/docker/seata目录编写一个使用file启动的docker-compose.yml文件&#xff08;seata包目录的script文件夹有&#…

项目管理软件排行榜:点赞榜TOP5揭晓!

通过项目管理软件企业可以快速、高效地管理项目、整合团队成员以及资源。现如今市场上各类项目管理软件层出不穷&#xff0c;因此选择一款适合自身企业需求的软件显得尤为重要。本文将为大家介绍项目管理软件排行榜点赞榜&#xff0c;为大家选购提供一些参考。 1.Zoho Project…

【来点小剧场--项目测试报告】个人博客项目自动化测试

前述 针对个人博客项目进行测试&#xff0c;个人博客主要由七个页面构成&#xff1a;注册页、登录页、个人博客列表页、博客发布页、博客修改页、博客列表页、博客详情页&#xff0c;主要功能包括&#xff1a;注册、登录、编辑并发布博客、修改已发布的博客、查看详情、删除博…

解惑Android Scoped Storage

原文链接 Android Scoped Storage Puzzles 安卓对于文件存储这块&#xff0c;其实是相当混乱的&#xff0c;在早期的版本中对存储甚至是没有所谓的管理的&#xff0c;有多种方法可以操作文件存储&#xff0c;比如通过Java原生的方式(File/InputStream/OutputStream)&#xff0…

【微服务 SpringCloudAlibaba】实用篇 · Nacos注册中心

微服务&#xff08;5&#xff09; 文章目录 微服务&#xff08;5&#xff09;1. 认识和安装Nacos2. 服务注册到nacos和拉取服务1&#xff09;引入依赖2&#xff09;配置nacos地址3&#xff09;重启 3. 服务分级存储模型3.1 给user-service配置集群3.2 同集群优先的负载均衡 4. …

php74 安装sodium

下载编译安装libsodium wget https://download.libsodium.org/libsodium/releases/libsodium-1.0.18-stable.tar.gz tar -zxf libsodium-1.0.18-stable.tar.gz cd libsodium-stable ./configure --without-libsodium make && make check sudo make install下载编译安装…

【Python搜索算法】广度优先搜索(BFS)算法原理详解与应用,示例+代码

目录 1 广度优先搜索 2 应用示例 2.1 迷宫路径搜索 2.2 社交网络中的关系度排序 2.3 查找连通区域 1 广度优先搜索 广度优先搜索&#xff08;Breadth-First Search&#xff0c;BFS&#xff09;是一种图遍历算法&#xff0c;用于系统地遍历或搜索图&#xff08;或树…

Linux Zabbix企业级监控平台+cpolar实现远程访问

文章目录 前言1. Linux 局域网访问Zabbix2. Linux 安装cpolar3. 配置Zabbix公网访问地址4. 公网远程访问Zabbix5. 固定Zabbix公网地址 前言 Zabbix是一个基于WEB界面的提供分布式系统监视以及网络监视功能的企业级的开源解决方案。能监视各种网络参数&#xff0c;保证服务器系…

HTML三叉戟,标签、元素、属性各个的意义是什么?

&#x1f31f;&#x1f31f;&#x1f31f; 专栏详解 &#x1f389; &#x1f389; &#x1f389; 欢迎来到前端开发之旅专栏&#xff01; 不管你是完全小白&#xff0c;还是有一点经验的开发者&#xff0c;在这里你会了解到最简单易懂的语言&#xff0c;与你分享有关前端技术和…

【Java基础面试十二】、说一说你对面向对象的理解

文章底部有个人公众号&#xff1a;热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享&#xff1f; 踩过的坑没必要让别人在再踩&#xff0c;自己复盘也能加深记忆。利己利人、所谓双赢。 面试官&#xff1a; 说一说你对面向对象的理…

thinkphp5.1 获取缓存cache(‘cache_name‘)特别慢,php 7.0 unserialize 特别慢

thinkphp5.1 获取缓存cache(‘cache_name’)特别慢&#xff0c;php 7.0 unserialize 特别慢 场景&#xff1a; 项目中大量使用了缓存&#xff0c;本地运行非常快&#xff0c;二三百毫秒&#xff0c;部署到服务器后 一个表格请求就七八秒&#xff0c;最初猜想是数据库查询慢&am…

消息队列学习分享

消息队列学习 消息队列来解决问题 &#xff08;1&#xff09;异步处理 消息通知、日志管理、更新统计数据等步骤 &#xff08;2&#xff09;流量控制 如何避免过多的请求压垮我们的系统&#xff1f; 比如一个秒杀系统&#xff0c;网关在收到请求后&#xff0c;将请求放入…

Kotlin中的数值类型

在Kotlin中&#xff0c;Byte、Short、Int、Long、Float和Double是基本数据类型&#xff0c;用于表示不同范围和精度的数值。 Byte&#xff08;字节&#xff09;&#xff1a;Byte类型是8位有符号整数类型&#xff0c;取值范围为-128到127。在Kotlin中&#xff0c;可以使用字面值…