使用 RLHF 训练 LLaMA 的实践指南:StackLLaMA

由于LLaMA没有使用RLHF,后来有一个初创公司 Nebuly AI使用LangChain agent生成的数据集对LLaMA模型使用了RLHF进行学习,得到了ChatLLaMA模型,详情请参考:Meta开源的LLaMA性能真如论文所述吗?如果增加RLHF,效果会提升吗?,其实RLHF未必是必须的,主要是高质量的标注数据获取成本比较高,RLHF是一个trade-off。

StackLLaMA模型介绍

今天分享的StackLLaMA是按照InstructGPT论文的方法获得的,它的目的是,在算法流程上和ChatGPT类似,大致流程如下:

  • 监督微调 (SFT)

  • 奖励/偏好建模 (RM)

  • 从人类反馈中强化学习 (RLHF)

主要区别在于

基础模型:ChatGPT使用的是GPT3.5,StackLLaMA使用的是LLaMA;

SFT阶段:StackLLaMA使用的是StackExchange 数据集

(https://huggingface.co/datasets/HuggingFaceH4/stack-exchange-preferences),而ChatGPT的简单数据没有公开;

StackLLaMA的主要共贡献

    • StackLLaMA模型开源了,并且在Huggingface Hub上可以使用,地址:https://huggingface.co/trl-lib/llama-7b-se-rl-peft;

    • 集成到Hugging Face TRL库,为广大朋友提供了基础库使用,地址:https://huggingface.co/docs/trl/index;

    • 开源了监督训练数据集StackExchange,地址:https://huggingface.co/datasets/HuggingFaceH4/stack-exchange-preferences;

    • 开源了数据集和处理笔记本https://huggingface.co/datasets/lvwerra/stack-exchange-paired;

    • 介绍了训练过程的细节以及解决方案;

    • RLHF库:https://github.com/lvwerra/trl

Stack Exchange数据集

该数据集包括来自 StackExchange 平台的问题及其相应的答案(包括针对代码和许多其他主题的 StackOverflow)

根据论文《https://arxiv.org/abs/2112.00861》介绍,给每个答案一个score,公式如下:

score = log2 (1 + upvotes) rounded to the nearest integer, plus 1 if the questioner accepted the answer (we assign a score of −1 if the number of upvotes is negative).

对于奖励模型,始终需要每个问题有两个答案进行比较。对每个问题最多采样 10 个答案对,以限制每个问题的数据点数量。最后,将 HTML 转换为 Markdown 来清理格式,使模型的输出更具可读性。

高效的训练策略

即使训练最小的 LLaMA 模型也需要大量内存,比如bf16,每个参数需要2个字节来存储,fp32需要4个字节,更多可以参考:https://huggingface.co/docs/transformers/perf_train_gpu_one#optimizer。对于一个 7B 参数的模型,只是参数就需要(2+8)*7B=70GB的内存,实际存储还包括计算注意力分数等中间值,可能需要更多内存。因此,即使是7B的模型也无法在单个 80GB A100 上训练模型。

使用参数高效微调 (PEFT) 技术,比如使用https://github.com/huggingface/peft库来实现,这种技术可以加载8-bit模型执行低秩自适应 (LoRA),如下图所示:

线性层的低秩自适应:在冻结层(蓝色)旁边添加额外参数(橙色),并将生成的编码隐藏状态与冻结层的隐藏状态一起添加。

8-bit模型,每个参数只需要一个字节(例如,7B LlaMa 的内存为 7GB)。LoRA 不是直接训练原始权重,而是在某些特定层(通常是注意力层)之上添加小的适配器层;因此,可训练参数的数量大大减少。

在这种情况下,根据batch大小和序列长度不同每十亿个参数大概需要1.2-1.4GB内存,这可以以低成本微调更大的模型(在 NVIDIA A100 80GB 上高达 50-60B 比例模型)。

这些技术可以在消费级设备和 Google Colab 上微调大型模型。比如:facebook/opt-6.7b(13GB in float16)和openai/whisper-largeGoogle Colab(15GB GPU RAM)。更多参考:https://github.com/huggingface/peft和https://huggingface.co/blog/trl-peft

将非常大的模型放入单个 GPU 中,如果训练速度仍然很慢,可以使用数据并行技术,如下图所示:

数据并行可以使用transformers.Trainer和accelerate,无需任何代码更改,只需在使用torchrunor调用脚本时传递参数即可accelerate launch。下面分别用accelerate和在单台机器上运行带有 8 个 GPU 的训练脚本torchrun

accelerate launch --multi_gpu --num_machines 1  --num_processes 8 my_accelerate_script.pytorchrun --nnodes 1  --nproc_per_node 8 my_torch_script.py

监督微调SFT

首先需要对监督数据做一些处理,传统方式通常是需要保证每个batch中序列长度是一样(采用填充或者截断),与传统的方式不同,GPT模型监督数据是把多个sentence通过EOS标记拼接到一起来使用,如下图所示:

使用peft加载模型之后就可以使用Trainer训练模型了。具体是:首先

首先导入int8模型,然后加入到训练准备,最后再添加LoRA adapters,代码如下所示:

# load model in 8bitmodel = AutoModelForCausalLM.from_pretrained(        args.model_path,        load_in_8bit=True,        device_map={"": Accelerator().local_process_index}    )model = prepare_model_for_int8_training(model)
# add LoRA to modellora_config = LoraConfig(    r=16,    lora_alpha=32,    lora_dropout=0.05,    bias="none",    task_type="CAUSAL_LM",)
model = get_peft_model(model, config)

Note:最终预测的时候需要把LoRA adapters的模型参与与LLaMA模型参数加起来使用。

通过运行脚本(https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py)将它们转换为 🤗 Transformers 格式。

奖励建模和人类偏好

奖励模型的目标是模仿人类如何评价文本。建立奖励模型有几种可能的策略:最直接的方法是预测注释(例如评分或“好”/“坏”的二进制值)。在实践中,更好的方法是预测两个例子的排名,奖励模型必须预测哪一个会被人类标注人员评分更高。

这可以转化为以下损失函数:

使用 StackExchange 数据集,我们可以根据分数推断用户更喜欢这两个答案中的哪一个。有了这些信息和上面定义的损失,就可以使用transformers.Trainer通过添加自定义损失函数来修改。​​​​​​​​​​​​​​

class RewardTrainer(Trainer):    def compute_loss(self, model, inputs, return_outputs=False):        rewards_j = model(input_ids=inputs["input_ids_j"],  attention_mask=inputs["attention_mask_j"])[0]        rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0]        loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean()        if return_outputs:            return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}        return loss

训练集100000个样本的子集,评估集50000个样本,batch大小为4,采用LoRA adapter(bf16混合精度的Adam优化器)fine-tuning LLaMA模型。LoRA 配置如下:​​​​​​​

peft_config = LoraConfig(    task_type=TaskType.SEQ_CLS,    inference_mode=False,    r=8,    lora_alpha=32,    lora_dropout=0.1,)

在 8个A100 GPU 训练了几个小时,训练的Weights & Biases记录地址:https://wandb.ai/krasul/huggingface/runs/wmd8rvq6?workspace=user-krasul

模型最终达到了67% 的准确率。虽然这听起来分数很低,但任务也非常艰巨,即使对于人工标注者也是如此。

如下一节所述,生成的适配器可以合并到冻结模型中并保存以供下游进一步使用。

从人类反馈中强化学习RLHF

有了经过微调的语言模型和奖励模型,我们现在可以运行 RL 循环了。大致分为三个步骤:

  1. 根据提示生成响应

  2. 使用奖励模型对响应进行评分

  3. 使用评级运行强化学习策略优化步骤

Query and Response模板如下:​​​​​​​

Question: <Query>
Answer: <Response>

SFT、RM 和 RLHF 阶段使用相同的模板。

使用 RL 训练语言模型的一个常见问题是,奖励模型为了得到高的reward,通常会生成一些完整的乱码序列。为了解决这个问题,在reward中增加了一个KL散度的惩罚,公式如下:

训练的时候,首先导入8-bit的SFT模型并且冻结参数,然后使用PPO优化LoRA参数,代码如下:​​​​​​​

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):    question_tensors = batch["input_ids"]
    # sample from the policy and generate responses    response_tensors = ppo_trainer.generate(        question_tensors,        return_prompt=False,        length_sampler=output_length_sampler,        **generation_kwargs,    )    batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
    # Compute sentiment score    texts = [q + r for q, r in zip(batch["query"], batch["response"])]    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)    rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs]
    # Run PPO step    stats = ppo_trainer.step(question_tensors, response_tensors, rewards)    # Log stats to WandB    ppo_trainer.log_stats(stats, batch, rewards)

在 3x8 A100-80GB GPU 上训练了 20 小时,也可以使用更少的资料(例如,在 8 个 A100 GPU 上训练约 20 小时后)。训练过程每个step的reward变化如下图所示:

模型的性能在大约 1000 步后达到稳定状态。

StackLLaMA模型效果

答案看起来连贯,甚至提供了一个谷歌链接。让我们来看看接下来的一些训练挑战。

StackLLaMA模型训练的一些坑以及解决方案

  • 更高的reward意味着更好的表现吗?

一般来说,在 RL 中希望获得最高的reward,但是在 RLHF 中,使用了一个不完美的奖励模型,PPO 算法将利用这些不完美,这可能表现为奖励的突然增加,但是当我们从策略中查看文本生成时,它们主要包含字符串 ``` 的重复,因为奖励模型发现包含代码块的stack exchange答案reward分数是最高的。幸运的是,这个问题很少被观察到,一般来说,KL 惩罚应该可以抵消这种攻击。

  • KL 总是一个正值,不是吗?

通常,KL 散度衡量两个分布之间的距离,并且始终为正数。然而,在trl库使用 KL 的估计值时,希望期望分布与实际分布尽可能接近。

显然,当从概率低于 SFT 模型的策略中采样token时,这将导致负 KL 惩罚,但平均而言它将是正的,否则您将无法从策略中正确抽样。然而,一些生成策略可以强制生成一些token或者可以抑制一些token。例如,当批量生成完成的序列时,会填充序列,当设置最小长度时,EOS 令牌会被抑制。该模型可以为那些导致负 KL 的token分配非常高或低的概率。当 PPO 算法针对奖励进行优化时,它会追逐这些负面惩罚,从而导致不稳定,如下图所示:

在生成答案的时候,建议先使用简单的采用策略,后面再提高采用的复杂程度。

  • 持续的问题

还有一些问题需要更好地理解和解决。例如,损失偶尔会出现峰值,这可能会导致进一步的不稳定性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/30286.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于人类反馈的强化学习(RLHF) 理论

gpt 进程 GPT-1 用的是无监督预训练 有监督微调&#xff0c;只有简单的单向语言模型任务&#xff1b;GPT-2用的是纯无监督预训练&#xff0c;使用更多的数据&#xff0c;更大的模型&#xff0c;又新增了几个辅助的训练任务&#xff1b;GPT-3 沿用了 GPT-2 的纯无监督预训练&a…

【疑难杂症】overleaf公式显示异常并且被重复添加至正文内,正文内$符号消失,编译报错Missing $ inserted.inserted text。

问题描述 此问题困扰本人许久&#xff0c;搜索了许多相关情况都没有我这样的。每次编译后&#xff0c;overleaf中的公式会编译错误&#xff0c;并且被莫名其妙地添加到正文中&#xff0c;而且原来引用公式的dollar符号$$也异常消失。 问题举例 原始文本&#xff1a; 编译后…

【ChatGPT】从人类反馈 (RLHF) 中进行强化学习 | Illustrating Reinforcement Learning from Human Feedback (RLHF)

从人类反馈 (RLHF) 中进行强化学习 | Illustrating Reinforcement Learning from Human Feedback (RLHF) 目录

python 用 xlwings 处理 Excel 中的重复数据

xlwings 简介 xlwings 是一个 Python 库。简化了 Python 和 Excel 通信。 xlwings - 让Excel跑得飞快! 本文写作背景 & 需求 & 方案 因前几个月帮在医院工作的朋友现学现卖用VBA写了段程序&#xff0c;处理2个excel文档的数据到第3个Excel文档上&#xff0c;有模板数据…

解决Chrome网页编码显示乱码的问题

解决Chrome网页编码显示乱码的问题 记得在没多久以前&#xff0c;Google Chrome上面出现编码显示问题时&#xff0c;可以手动来调整网页编码问题&#xff0c;可是好像在Chrome 55.0版以后就不再提供手动调整编码&#xff0c;所以如果现在遇到big 5被误判为UTF8的网页问题时&…

全网最详细中英文ChatGPT-GPT-4示例文档-从0到1快速入门语法纠正应用——官网推荐的48种最佳应用场景(附python/node.js/curl命令源代码,小白也能学)

从0到1快速入门语法纠正应用场景 Introduce 简介setting 设置Prompt 提示Sample response 回复样本API request 接口请求python接口请求示例node.js接口请求示例curl命令示例json格式示例 其它资料下载 ChatGPT是目前最先进的AI聊天机器人&#xff0c;它能够理解图片和文字&…

ChatGPT其实无法获得法学保研资格

ChatGPT通过了美国明尼苏达大学法学院4门课程的考试&#xff0c;95个选择题、12个论述题&#xff0c;平均分为C&#xff1b;也通过了宾夕法尼亚大学沃顿商学院的考试&#xff0c;成绩也不错。但是在当下内卷的情形下&#xff0c;ChatGPT的考试成绩不会获得保研资格&#xff0c;…

chatgpt赋能python:Python摄像头运用介绍

Python摄像头运用介绍 Python是一种广泛应用于各种领域的高级编程语言。其中&#xff0c;Python摄像头应用越来越受欢迎&#xff0c;尤其是在计算机视觉和机器学习领域。本文将介绍Python摄像头运用的相关知识。 什么是Python摄像头运用&#xff1f; Python摄像头运用是使用…

NAT技术之NAT server

技术背景&#xff1a; 在很多场景中&#xff0c;比如企业、学校、甚至家里都有一些对外访问的业务提供&#xff0c;比如门户网址、NAS、ERP等&#xff0c;在实际部署中&#xff0c;这些提供访问的服务器都属于内网内&#xff0c;配置的是内网地址&#xff0c;导致的情况是公网…

在群晖NAS上搭建导航页_通过Web Station搭建

一、业务需求 1.1、需求说明 我们在使用群晖NAS的过程中&#xff0c;随着时间的推移会安装各种各样的软件内容和管理工具&#xff0c;而这些内容又都是一些网页界面&#xff08;特别是一些在Docker中搭建的工具&#xff09;时间久了我们也记不住那么多工具的Web界面地址&#…

[NAS] QNAP/威联通 常用设置和操作

&#x1f341;简介 QNap 产品是一种可扩展的数据存储解决方案。它们包括具有 1 到 30 个驱动器托架的设备&#xff0c;并提供 HDMI、Thunderbolt 2 和 USB 3.1 等连接选项&#xff0c;以及 802.11ac/a/n Wi-Fi 和高达每秒 40 Gb 的以太网。内置软件提供基本服务&#xff0c;例如…

详解央行数字货币和数字票据交易平台架构(多图)

独家披露&#xff1a;详解央行数字货币和数字票据交易平台架构(多图) 暴走时评&#xff1a;央行推动的基于区块链的数字票据交易平台已测试成功&#xff0c;由央行发行的法定数字货币已在该平台试运行。作为一种创新的货币和全新的支付体系架构&#xff0c;央行数字货币具有长远…

国家队入场,中国数字资产交易市场或将迎来新一轮“洗牌”

‍ ‍数据智能产业创新服务媒体 ——聚焦数智 改变商业 数字化已经成为中国文化产业的催化剂&#xff0c;一大批文化资源在数字技术的赋能下焕发了崭新的生机。 随着数字化的升级与科技进步&#xff0c;数字经济正在成为改变全球竞争格局的关键力量&#xff0c;各国家都争先出…

浅谈数字人民币什么时候正式推出DCEP钱包在哪里下载

11月23日,澎湃新闻从苏州一位知情人士处独家获悉,继深圳后,苏州将于双十二推出数字人民币红包测试。 上述知情人士告诉澎湃新闻记者,目前苏州相城区已有很多商家已经安装NFC(Near Field Communication,近场通信)二维码,只是支付载体还在测试员中,目前已有测试员体验过数字人民币…

Facebook 数字货币:缘起、意义和后果

来源 | 孟岩的区块链思考 作者 | 孟岩、邵青 出品 | 区块链大本营&#xff08;blockchain_camp&#xff09; 6 月 18日&#xff0c;Facebook 位于瑞士的子公司 Libra Network (天秤座网络&#xff09;将发布其加密数字货币项目白皮书。此前 BBC 报道说这个数字货币叫做 GlobalC…

数字货币钱包基础

我在前面3篇文章讲了区块链基础知识、普通人如何购买以及如何在imtoken里参与ICO。一个核心的问题其实是没有讲到的&#xff0c;我们这些数字货币到底怎么保存&#xff0c;因为之前讲的都是在交易市场上购买比特币、以太币&#xff0c;这些货币被保存在交易市场&#xff0c;本质…

大家知道微信个人收款码限额多少吗

大家知道微信个人收款码限额多少吗 随着移动支付的普及&#xff0c;微信、支付宝等平台已经成为了人们日常生活中不可或缺的支付工具。二维码收款作为这些平台的重要功能之一&#xff0c;可以方便快捷地完成转账和付款操作&#xff0c;受到了越来越多用户的广泛关注和使用。 对…

用户授信额度管理中,会运用到哪些策略?

关注“金科应用研院”&#xff0c;回复“CSDN” 领取风控资料合集 01、授信额度与贷款额度 授信额度是指金融机构能够为借款人提供的最大贷款金额。贷款额度一般是指借款人在金融机构给予的最大贷款金额范围内&#xff0c;实际借贷的金额。 授信额度和贷款额度的主要区别是授…

ChatGPT在做什么?为什么它有效?

2023 年 2 月 14 日 它只是一次添加一个词 ChatGPT可以自动生成一些表面上看起来像人类书写文本的东西&#xff0c;这是非常了不起的&#xff0c;也是出乎意料的。但是它是怎么做到的呢&#xff1f;为什么它有效&#xff1f;我在这里的目的是粗略概述 ChatGPT 内部发生的事情&…

智能车浅谈——手把手让车跑起来(电磁篇)

文章目录 前言材料准备备赛组车模硬件 练习组车模硬件方案 整车原理赛道信息获取及转向原理工字电感运放模块转向原理元素判断 电机及舵机控制原理 代码实现效果欣赏总结17届完赛代码智能车系列文章汇总 前言 电磁寻迹小车 之前智能车系列已经做了一个比较详细的解析&#xff0…