LLM —— 强化学习(RLHF-PPO和DPO)学习笔记

强化学习整体流程

请添加图片描述
智能体执行动作与环境进行交互,根据奖励R的反馈结果不断进行更新。

价值函数

请添加图片描述
奖励将会考虑两个方面的奖励,一个当下的奖励,一个是未来的奖励(为了防止陷入局部最优解)。

LLM强化学习

请添加图片描述

强化学习模型分类

需要LLM在训练过程中做生成的方法是 On Policy,其余的为Off Policy

1、On Policy

(1)RLHF模型

组成部分
请添加图片描述
具有四个部分,演员模型、参考模型、评论家模型和奖励模型。
其中,演员模型和评论家模型是需要训练改变参数的,而奖励模型和参考模型在训练中不改变参数。

训练过程
请添加图片描述
首先,由第一个SFT阶段后,会得到SFT模型。然后,使用SFT模型作为参考模型和演员模型。使用偏好数据训练一个Reward模型作为PPO阶段的奖励模型和评论家模型。在第三部训练过程中,Reference模型和Reward模型的参数将不会变。Actor模型和Critic模型将会随着训练变化参数。

模型部分

  1. Actor模型
    请添加图片描述
    Actor模型是我们最终训练的目标模型,生成的reponse中每个token将作为生成的动作,后续会被奖励模型和评论家模型进行评判。
    请添加图片描述
    prompt为 S t S_t St,response为 A t A_t At

  2. Reference模型
    请添加图片描述
    Reference模型主要是为了防止Actor模型训练后与SFT阶段产生的模型差异过大,RLFH阶段只是为了让Actor模型生成符合人类偏好的数据,但并不希望与SFT阶段的生成效果偏差过大。因此,需要Reference模型去来 “纠正” 它。

  3. Reward模型
    请添加图片描述
    Reward模型用于计算生成token的当下收益,输入prompt+response,生成的Response中米哥token会作为动作输入给RW模型的最后一层Value Head得到 R t R_t Rt 奖励值。
    请添加图片描述
    其中,x为prompt, y w y_w yw为偏好正例数据response, y l y_l yl为偏好负例数据reponse。

  4. Critic模型
    在这里插入图片描述
    Critic模型会评价Actor模型未来的收益,因为Actor模型会随着训练不断变化,Critic模型也需要随着变化去更新,来更好的评判Actor模型未来会产生的收益。
    请添加图片描述
    R t R_t Rt是当下收益, V t + 1 V_{t+1} Vt+1是未来收益,两个结合后为对未来收益更综合的评估。

RLHF-PPO过程
5. prompt输入给Actor模型,会得到动作 A t A_t At。同时,也将prompt输入给Reference模型,得到参考的 A t A_t At
6. 将得到的动作 A t A_t At 输入给 Critic模型和 Reward模型分别得到 V t V_t Vt R t R_t Rt
7. 将Actor模型的 A t A_t At和Peference模型的 A t A_t At进行比较,得到一个KL散度分数。
8. 结合 R t R_t Rt V t V_t Vt和KL散度分数,对Actor模型和Critic模型进行联合优化更新Loss。
9. 更新 Actor模型和 Critic模型。

RLHF现存问题

主要集中在训练时的问题。

  1. 算力消耗大
    请添加图片描述

  2. 容易崩溃
    请添加图片描述
    因为牵扯模型多且需要协同配合,容易出现训练崩溃。

(2)deepspeed chat

请添加图片描述
1、如何高效生成答案?

DeepSpeed混合引擎
挑战1:在大多数高度优化的系统中,训练和推理通常使用两个不同的后端。

原因是这两个目标通常在不同的情况下使用——训练用于模型更新,推理用于模型部署。
在RLHF微调中,生成与训练是存在串行的方式,演员模型需要在每一步为每个查询生成答
案。因此,标准的训练模式可能是RLHF微调的瓶颈,因为它没有针对推理进行优化。

挑战2:因为模型分布在不同的GPU上。在生成过程中,生成步骤需要在GPU之间收集参数后进行推理,通信成本将会非常高

为了克服这两个挑战,引入了DeepSpeed混合引擎(DeepSpeed-HE)。
1.这个引擎可以自动在DeepSpeed提供的训练引擎和推理引擎之间切换
2.DeepSpeed-HE可以自动将ZeRO训练模式更改为张量并行推理,消除了重复参数收集的
需要。

2、如何处理多个模型使用的大量内存消耗?

第一,得益于DeepSpeed ZeRO优化,我们可以将模型参数和优化器分布在整个用于训练的GPU系统上。这显著减少了这些模型所需的内存消耗。

第二,参考模型与PPO训练循环中的演员模型大小相同,这需要相当数量的内存。然而,这个参考模型只在我们需要“旧行为概率“时才被调用。因此,参考模型的计算成本低于演员模型。为了减少内存压力,我们提供了一个单模型卸载选项,只将参考模型卸载到CPU。我们观察到,在相同的训练批量大小下,卸载参考模型(到CPU)与否的吞吐量影响很小。然而,如果演员模型卸载到CPU,训练速度会显著减慢。

第三,优化器的优化状态消耗了大量的训练内存。为了缓解这个问题,采用LoRA训练的方式,它只更新训练期间参数的一小部分。结果,与标准训练相比,优化状态要小得多

(4)实战案例
1、偏好数据集的构建

请添加图片描述
基于XuanYuan-6B进行RLHF落地实战,该模型主要应用于金融领域。

Prompt构建
重点关注两个方面:一方面是数据的丰富性和多样性,一方面是数据的质量。

  • 数据的多样性保证
    请添加图片描述
    把通用性、安全性和金融性进行了更细粒度的拆分,得到了多个子项,并按照一定的量级和比例收集每一子项的数据。这样可以使收集的prompt覆盖到不同的方面,同事具备合理的量级和配比。

  • 数据质量保证
    专业人员对数据进行清晰:删除或修改有明显错误的prompt或格式有瑕疵的prompt。经过清洗后,获得4W+高质量的prompt数据。

Response生成
为保证RM训练数据和测试数据分布的一致性,避免出现OOD(Out of distribution)问题,生成步骤:

  1. 使用XuanYuan-6B-SFT来产生response。在强化学习阶段,RM模型的输入是Actor模型的输出,Actor的初始状态为XuanYuan-6B-SFT。
  2. 使用XuanYuan-6B-SFT的采样参数,提高其采样参数中temperature和top_p的值,然后再生成response,以保证response的多样性,以及其包含的偏好信息的多样性。

偏好标注
当前业界主要有两种流行的标注方式:rank标注和pair标注。

  • rank标注
    一个Prompt包含多个response(一般为4个),标注者要对多个response进行排序,之后根据排序信息,可以将response两两组合,构建形如所示的偏好数据。Instruct-GPT即采用这类标注方式。
  • pair标注
    一条prompt仅生成两个reponse,标注者直接比较两个reponse,标出哪条response更符合偏好。此外,一些标注方法也要求标出偏好的强度。Anthropic和LlaMA2即采用pair形式的偏好标注。

实践

请添加图片描述

  • 放弃rank标注方式: 原因是标注速度慢且不同的标注人员对标注结果评判的一致性较低。
  • 采用pair标注方式: 直接比较两个response进行标注,并且要求标注出偏好的强度,以收集更多的偏好信息,来提升RM的泛化性能。
  • 具体标注步骤: 标注页面可以选择8个档位进行标注,从左到右依次命名为A3、A2、A1、A0、B0、B1、B2、B3。其中,A3表示A优于B的程度,B3表示B优于A的程度,其他档位依次类推。
  • 制订了一套完善的标注标准: 覆盖了实际中可能出现的大多数场景,并在标注过程中不断发现和解决新出现的问题,不断扩充完善我们的标注标准。
  • 对交付的标注结果进行严格的质检: 如果数据不合格会重新进行标注,直至满足验收标准。
  • 删除了偏好强度最低的数据(即A0和B0): 偏好强度低意味着两个response较为接近,未包含明显的偏好信息。这类数据歧义较大,会让模型感觉比较"困惑",不利于模型进行偏好建模。

最终的数据量:约6W+条偏好数据,其中90%用于训练,剩余10%用于测试。

2、RM训练

架构
请添加图片描述
XuanYuan-6B-SFT作为RM的基本架构,去掉最后的LM_head layer(softmax层,输出词表中每个token的概率),并将其替换为value_head layer。Value_head layer为一个线性层,输入是XuanYuan-6B-SFT次顶层的特征,输出为一个一维的reward分数。

损失函数
请添加图片描述
在实践中计算损失有两种方式:token-level的对比损失和sentence-level的对比损失

  • token-level的对比损失
    参考DeepSpeed-Chat中做法,使用token-level的对比损失来进行RM训练。
    训练阶段:

    • 步骤1:对于 y c y_c yc y r y_r yr,先找到他们第一个不相同的token所在的位置,作为起始位置
    • 步骤2:找到两个response结束的位置,并取两者中的最长长度,作为结束位置。
    • 计算从起始位置,到结束位置,相同位置上 y c y_c yc y r y_r yr之间的对比损失,最后求对比损失的均值作为该条件偏好样本的损失。(逐字对比)
    • 预测阶段:取response最后一个token对应的reward作为该response的reward。
  • sentence-level的对比损失
    为保证训练/测试的一致性,训练时应该取 y c y_c yc 最后一个token的reward和 y c y_c yc 最后一个token的reward来计算对比损失。

  • 上述两个在实践中的实际情况
    实验对比了两种损失函数的表现,结果表明sentence-level损失训练RM可获得更高的测试精度,但是RM不仅用于给reward打分,还用与强化训练阶段critic model的初始化。我们发现使用sentence-level损失训练的RM初始化critic model后,强化训练会变得不稳定,难以收敛。因此,我们仍使用token-level损失来进行RM训练,虽然精度会小幅度下降,但是强化训练的稳定性会有较高提升。

模型选择
在RM训练阶段,训练多个epoch,并在每个epoch结束后存储当前RM,之后选择合适的RM进行后续强化训练。在选择RM时,我们主要看以下几点:

  1. 测试精度:因为测试精度客观反应了RM打分合理性;
  2. RM输出的reward值:如果reward值过小或过大,在后续强化训练时会产生数值问题,导致训练无法正常进行;
  3. 接受和拒绝response奖励值之间的差距:具体做法是计算测试集中的接受reward的均值和拒绝reward均值,观察两个均值之间是否存在一定的差距。如果存在一定的差距,则说明RM有较强的鲁棒性。(差距越大,对比的体现就越明显)

最终选择RM测试精度是63%,输出尺度在[-1, 1]区间内,差距为0.5。

3、RLHF训练

模型结构
actor model和reference model:XuanYuan-6B-SFT
critic model和reward model:XuanYuan-6B-RM进行初始化

训练中actor model和critic model需要更新,而reference model和reward model保持不变。

数据
强化训练的数据为prompt数据。

数据组成:偏好数据的prompt,增加了额外的新prompt,比例为1:1。

  • 偏好数据中的prompt用于强化训练会使训练过程更为"容易",很大程度上可以避免RM打分不准而导致的一系列问题,如reward hacking、训练不收敛等。
  • 仅采用偏好数据中的prompt是不够的,这样模型见到的数据过于局限,不利于提升模型的泛化性能,因此增加了额外的新prompt一起用于强化训练。新prompt的构建方式和偏好数据中prompt构建方式相同。≈

训练

  • 训练参考
    训练过程参考了Instruct-GPT】LlaMA2以及Anthropic的做法。在实现上,参考了DeepSpeed-Chat框架。

  • 强化训练的目标
    请添加图片描述

  • 超参选择

    • KL的权重 β = 0.05 \beta=0.05 β=0.05
      β \beta β 是一个超参数,用于平衡探索和保持现状之间的权衡。过高的 β \beta β 会使模型接近初始模型 π 0 \pi_0 π0(不怎么探索),强化训练效果不明显;过低的 β \beta β 会过度优化 reward值(探索过头),容易造成reward hacking。

    • actor model和critic model的学习率设置为5e-7
      过高的学习率会让RM值快速上升,容易造成reward hacking;而过低的学习率会极大降低训练速度。

    • loss精度
      在计算loss时,使用fp32的数据精度,避免loss的数值问题引起的训练不稳定现象。

    • 训练了约300 PPO step
      训练中重点关注critic loss和RM reward值的变化,critic loss整体上应呈现下降趋势,而RM reward整体上应呈现上升趋势。
      注:RM reward上升的过高也是一种异常现象,此时大概率出现了reward hacking。

  • 模型选择

    • 每训练20个PPO step,存储当前的actor model。
    • 训练完成后,根据RM reward变化情况,挑选几个不同阶段的代表性模型进行快速的人工评估。
    • 人工评估时对比对象是强化训练前的SFT模型,即XuanYuan-6B-SFT。
      评估完成后统计good(actor response > SFT model response),same(actor response = SFT model response),bad(actor response < SFT model response)数量。然后,选择最有优势的actor model进行更正式的人工评估。
  • 人工评估

    • 聘请了专业的评估人员进行模型评估,评估题目覆盖通用型、安全性、金融垂类等不同范畴。
    • 每道题目均由三个不同的评估人员进行评估,来避免不同评估人员的洗好偏差。
    • 评估题目对其他人员完全封闭,避免研发人员同构构造类似的评估题目进行训练来获得更好的评估结果。

模型评估效果:
请添加图片描述
模型在通用性的各细分领域的评估结果:
请添加图片描述
从结果来看,在大多数子领域,经过强化训练后,模型的能力都有了明显的提升。在日常对话、逻辑推理、内容创作和安全性等子领域,强化带来的效果提升都很明显。

然后,在一些其他子领域,比如信息摘要、翻译等,强化训练并未带来明显的进步。在后续工作中,需要补充更多的偏好数据,同事提升偏好标注质量,来进一步补齐这些弱项的能力。

模型在金融子领域的各细分领域的评估结果:
请添加图片描述
在金融知识理解、金融业务分析连个子领域,强化训练带来了明显的能力提升。而在其他子领域,强化训练并未取得逾期的效果。对这些子领域,需要补充更多的高质量偏好数据,提高RM对这类prompt和response打分的准确性,进而提升强化训练的效果。

2、Off Policy

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/411057.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CTF—杂项学习

1 文件操作隐写 1.1 文件类型识别 1.1.1 File命令 当文件没有后缀名或有后缀名而无法打开时&#xff0c;根据识别出的文件类型来修改后缀名即可正常打开文件&#xff0c;file是Linux下的文件识别命令。 file 文件名 使用场景&#xff1a;不知道后缀名&#xff0c;无法打开文件…

【STM32开发笔记】STM32H7S78-DK上的CoreMark移植和优化--兼记STM32上的printf重定向实现及常见问题解决

【STM32开发笔记】STM32H7S78-DK上的CoreMark移植和优化--兼记STM32上的printf重定向实现及常见问题解决 一、CoreMark简介二、创建CubeMX项目2.1 选择MCU2.2 配置CPU时钟2.3 配置串口功能2.4 配置LED引脚2.5 生成CMake项目 三、基础功能支持3.1 支持记录耗时3.2 支持printf输出…

SEO之网站结构优化(十三-网站地图)

** 初创企业搭建网站的朋友看1号文章&#xff1b;想学习云计算&#xff0c;怎么入门看2号文章谢谢支持&#xff1a; ** 1、我给不会敲代码又想搭建网站的人建议 2、“新手上云”能够为你开启探索云世界的第一步 博客&#xff1a;阿幸SEO~探索搜索排名之道 网站无论大小&…

京存分布式赋能EDA应用

合抱之木&#xff0c;生于毫末&#xff1b;九层之台&#xff0c;起于累土&#xff1b;千里之行&#xff0c;始于足下。——《老子德经第六十四章》 EDA&#xff08;Electronic Design Automation 电子设计自动化&#xff09;是利用计算机&#xff0c;完成对VLSI &#xff08;V…

OpenCV绘图函数(8)填充凸多边形函数fillConvexPoly()的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 填充一个凸多边形。 函数 cv::fillConvexPoly 绘制一个填充的凸多边形。这个函数比 fillPoly 函数快得多。它可以填充的不仅仅是凸多边形&#…

25届最近5年自动化考研院校分析

哈尔滨工程大学 目录 一、学校学院专业简介 二、考试科目指定教材 三、近5年考研分数情况 四、近5年招生录取情况 五、最新一年分数段图表 六、初试大纲复试大纲 七、学费&奖学金&就业方向 一、学校学院专业简介 二、考试科目指定教材 1、考试科目介绍 2、指定…

C++ | Leetcode C++题解之第377题组合总和IV

题目&#xff1a; 题解&#xff1a; class Solution { public:int combinationSum4(vector<int>& nums, int target) {vector<int> dp(target 1);dp[0] 1;for (int i 1; i < target; i) {for (int& num : nums) {if (num < i && dp[i - …

《JavaEE进阶》----4.<SpringMVC①简介、基本操作>

本篇博客讲解 MVC思想、及Spring MVC&#xff08;是对MVC思想的一种实现&#xff09;。 Spring MVC的基本操作、学习了六个注解 RestController注解 RequestMappering注解 RequestParam注解 RequestBody注解 PathVariable注解 RequestPart注解 MVC View(视图) 指在应⽤程序中…

四大名著改编的ip大作,一个巨亏2亿,一个狂赚20亿!选择决定成败!

最近讨论热度比较高的当属《红楼梦》和《西游记》了 胡玫导演的《红楼梦之金玉良缘》耗费了18年的心血&#xff0c;投资了2个多亿 却仅仅只有600万票房&#xff0c;还被网友调侃称“一黛不如一黛” 而由《西游记》改编的游戏《黑神话悟空》&#xff0c;研发10年投资6亿&…

【drools】Rulesengine构建及intelj配置

7.57.0.FinalRulesengineApplication 使用maven构建 intelj 打开文件资源管理器实在是太慢了所以直接把pom 扔到其主页识别为maven项目,自动下载maven包管理器 然后解析依赖: 给maven加一个代理 -DproxyHost=127.0.0.1 -DproxyPort=7890 还是卡主

机器人拾取系统关节机械臂通过NY-PN-EIPZ进行命令控制

关节机械臂是一种精密的机器&#xff0c;旨在模拟人类手臂在订单拣选操作中的运动。这些多功能机器人由多个关节组成&#xff0c;通常有 4 到 7 个轴&#xff0c;使它们能够高度自由地移动&#xff0c;并在仓库内以各种方向和位置接触物品。 制造工厂智能仓库系统中的关节机械臂…

Mobile-Agent项目部署与学习总结(DataWhale AI夏令营)

前言 你好&#xff0c;我是GISer Liu&#xff0c;一名热爱AI技术的GIS开发者&#xff0c;本文是DataWhale 2024 AI夏令营的最后一期——Mobile-Agent赛道&#xff0c;这是作者的学习文档&#xff0c;这里总结一下&#xff0c;和作者一起学习这个多模态大模型新项目吧&#x1f6…

AMEYA360 :“Radisol”,一款可改善智能手机Wi-Fi天线性能的村田电子新产品

株式会社村田制作所开发了村田首款(1)天线抗干扰器件‘Radisol’。Radisol是一款可配备到天线上来抑制无线性能下降的新产品&#xff0c;该产品已于2024年6月开始量产&#xff0c;并已用在Motorola Mobility LLC 2024年8月开始销售的智能手机“Edge系列”新机型。摩托罗拉通过采…

【Qt】垂直布局管理器QVBoxLayout

垂直布局管理器QVBoxLayout 在之前学习Qt的过程中&#xff0c;将控件放在界面上&#xff0c;都是依靠“手动”的方式来布局的&#xff0c;但是手动调整的方式是不科学的。 手动布局的方式非常复杂&#xff0c;而且不精确无法对窗口大小进行自适应 因此Qt引入布局管理器来解决…

缓存Mybatis一级缓存与二级缓存

缓存 为什么使用缓存 缓存(cache)的作用是为了减去数据库的压力,提高查询性能,缓存实现原理是从数据库中查询出来的对象在使用完后不销毁,而是存储在内存(缓存)中,当再次需要获取对象时,直接从内存(缓存)中提取,不再向数据库执行select语句,从而减少了对数据库的查询次数,因此…

无法启动此程序,因为计算机中丢失dll,整理了7种解决方法!

当电脑出现“无法启动此程序&#xff0c;因为计算机中丢失dll”的错误弹窗时&#xff0c;这通常意味着系统中的DLL文件出现了缺失或错误。DLL文件是动态链接库文件&#xff0c;它们在软件运行中起着至关重要的作用。 造成dll文件缺失和错误的原因有很多&#xff0c;大部分问题都…

python爬虫,使用pyppeteer异步,爬取,获得指定标签内容

获得指定 #pip install pyppeteer,使用 Pyppeteer&#xff08;异步方案&#xff09; import asyncio from pyppeteer import launch async def main():browser await launch()page await browser.newPage()await page.goto(http://xxx/#/login)# 等待页面加载完成await page…

算法-容斥原理

venn图&#xff1a; 如何求三个圆圈的面积之和&#xff1f; 此时&#xff0c;||不代表绝对值&#xff0c;代表集合的个数 解题思路&#xff1a; 实际上&#xff0c;我们不需要知道每个集合中的元素具体是什么&#xff0c;只需要知道每个集合的大小 例如 &#xff0c;表示10以…

Golang小项目(1)

Golang小项目(1) 前言 本项目适合Golang初学者,通过简单的项目实践来加深对Golang的基本语法和Web开发的理解。 建议前往 torna.top 查阅效果更佳 项目结构 . ├── main.go └── static├── form.html└── index.html项目流程图 定义三个路由: /:首页,显示static…

Windows隐藏起你的秘密文件以及文件夹工具

我们都知道&#xff0c;在 Windows 中可以右键文件夹&#xff0c;选择”属性“&#xff0c;勾选”隐藏“来实现隐藏某个文件夹。 我们还知道&#xff0c;在 Windows 中可以选择勾选 ”显示隐藏的项目和文件夹“&#xff0c;来使上述方法变得形同虚设。 本工具就是用于解决以上…