Training language models to follow instructions with human feedback 论文阅读

论文原文:https://arxiv.org/pdf/2203.02155

论文简介

语言模型越大并不意味着它能更好的理解用户的意图,因此在这篇论文中,展示了根据人的反馈对模型进行微调,使得语言模型能够在各种人物上更好的理解用户的意图。在评估中,1.3B参数的InstructGPT模型的输出比175B GPT-3的输出更受欢迎,尽管参数少了100倍。此外,InstructGPT模型虽然在公共的数据上的效果有所降低,但是真实性和减少有害方面生成的能力提升。论文表明,尽管InstructGPT仍然会犯一些简单的错误,但根据人类反馈进行微调是能够理解人类意图的一个有效的方式和方向。
**相当于是,OpenAI提出了”align“的概念,希望模型的输出与人类的意图”对齐“,其用的方法是RLHF(Reinforcement Learning from Human Feedback)基于人类反馈的强化学习。**

方法和实验细节

在这里插入图片描述

Collect demonstration data, and train a supervised policy. (收集范例数据,并以有监督方式训练)

我们的打标签者提供了输入提示分布(prompt distribution)上所需行为的范例(有关此分布的详细信息,请参阅第 3.2 节)。 然后,我们使用有监督学习在该数据集上微调预训练的 GPT-3 模型。这部分就是根据prompts,也就是写的各种问题,进行标注,将prompts和标注的对话作为人工标注的数据集,对预训练的GPT-3进行有监督微调

Collect comparison data, and train a reward model. (收集比较数据,训练奖励模型)

我们收集了模型输出之间比较的数据集,其中打标记者根据输入标明了他们更喜欢的输出。 然后我们训练奖励模型来预测人类偏好的输出。用上一步得到的SFT模型生成各种问题的答案,再对这些答案进行比较(排序式)标注,如D>C>A=B,基于这个标注数据集,在去掉最后的嵌入层的SFT模型基础上进行有监督学习训练一个RM(reward model),这样使用模型来模仿标注者进行打分

Optimize a policy against the reward model using PPO. (使用PPO针对奖励模型优化策略)

我们使用RM奖励模型的输出作为标量奖励。 我们使用 PPO 算法微调监督策略以优化此奖励。
步骤2和步骤3可以不断迭代; 收集当前最佳策略的更多比较数据,用于训练新的 RM,然后训练新的策略。 在实践中,我们的大部分比较数据来自监管的学习,也有一些来自我们的PPO学习。用上一步的RM模型进行打分,然后分数就可以用强化学习来对SFT模型进行优化

数据集

打标签者提供了输入提示分布(prompt distribution)上所需行为的范例,根据论文所说,为了训练第一个InstructGPT模型,打标签者需要自己编写提示,分为三种:

  • Plain:只是要求标记者提出一个任意的任务,同时确保任务具有足够的多样性。
  • Few-shot:要求标注者提出一条指令,以及针对该指令的多个查询/相应对。
  • User-based:在OpenAI API的候补名单申请中陈述了许多用例,要求标注者提出与这些用例相对应的提示。
    根据这些提示,生成了三个用于微调过程的不同数据集:(1)SFT数据集,带有用于训练SFT模型的打标签者范例数据,(2)RM数据集,带有用于训练的模型已被打标签者分了等级的数据,(3)PPO数据集,没有任何人工标签,用于RLHF微调的输入。SFT数据集包含大约13k个训练提示数据(来自API和标记者编写),RM数据集有33k个训练提示数据(来自API和打标记者编写),PPO数据集有31k个训练提示数据(仅来自API)。
    在这里插入图片描述
    上表显示了API提示(特别是RM数据集)的用例类别的分布,大多数用例都是生成的,而不是分类或QA。在表二中展示了一些说明性提示(由研究人员编写,以模仿提交给InstructGPT模型的提示类型)。

任务

训练任务来自两个来源:(1)由标注者编写的提示数据集和(2)提交给API上的早期InstructGPT模型的提示数据集。这些提示非常多样化,包括生成、问答、对话、摘要、提取和其他自然语言任务。数据集超过96%是英语。
对于每个自然语言提示,任务通常是通过自然语言指令直接指定的(例如”写一个关于聪明青蛙的故事“),但也可以通过少数例子间接指定(例如给出两个青蛙故事的例子,并提示模型生成一个新的)或隐含的连续(例如提供一个关于青蛙的故事的开始)。在每种情况下,我们都要求标注者尽最大努力推断出写提示的用户的意图,并要求他们跳过任务非常不清楚的输入(相当于当任务非常不清楚的时候,可以跳过回答,避免答非所问)。此外,在我们提供给他们的指示和他们的最佳判断的指导下,标注者还需考虑到隐含的意图,如回应的真实性,以及潜在的有害输出,如有偏见或有毒的语言。

模型

我们从GPT-3预训练语言模型开始。这些模型是在广泛分布的互联网数据上进行训练的,可以适应广泛的下游任务,但行为特征不佳。从这些模型开始,我们用三种不同的技术训练模型:

  • 有监督微调(SFT——Supervised fine-tuning),我们使用监督学习对标记器演示中的GPT-3进行微调。我们训练了16个epoch,使用余弦学习率衰减,0.2的残差dropout。我们根据验证集上的RM分数进行最终的SFT模型选择。我们发现SFT模型在1个epoch后对验证损失上过拟合,然而我们发现尽管存在过拟合,但更多epochs的训练有助于RM分数和人类偏好评级。(尽管这个SFT模型训练更多的epoch会产生过拟合,但是这是为了得到后续的RM模型的初始化模型,对RM模型有帮助,并不是直接使用这个SFT模型,所以过拟合没关系
  • 奖励建模(RM——Reward model),从移除了最后的非嵌入层的SFT模型开始(GPT模型最后的softmax层是用于得到每个词的概率,去掉softmax层以后,增加一个线性层来投影,将所有词的输出投影到一个值上面,即输出一个标量的分数),我们训练了一个模型来接收提示和相应,并输出标量奖励。在本文中,我们只使用6B RM,这样可以节省大量计算,而且我们发现175B RM训练可能不稳定,因此不太适合用作RL(Reinforcement learning)中的值函数。RM在同一输入的两个模型输出之间进行比较的数据集上训练。他们使用交叉熵损失,将比较作为标签——奖励的差异代表了人类标记者更喜欢一种反应的对数几率。
    为了加速分等级数据的收集,我们向标签提供者提供 K = 4 K=4 K=4 K = 9 K=9 K=9之间的任何排名相应。这会为显示给标签者的每个提示生成 ( K 2 ) = C K 2 \binom{K}{2}=C_K^2 (2K)=CK2比较。由于分等级数据在每个标记任务中都非常相关,我们发现,如果我们简单地将分等级数据混洗到一个数据集中,在数据集上的一次遍历会导致奖励模型过拟合。相反,我们将每个提示的所有 ( K 2 ) \binom{K}{2} (2K)比较数据作为单个批处理元素进行训练。这在计算上要高效得多,因为它只需要每次完成一次RM的前向传递(而不是超过 ( K 2 ) \binom{K}{2} (2K)次前向传递),而且因为它不在过拟合,大大提高了验证准确性和日志损失。
    具体来说,奖励模型的损失函数为(这里使用的是排序中最常见的pairwise ranking loss,成对排名损失):
    在这里插入图片描述
    这里的 r θ ( x , y ) r_{\theta}(x,y) rθ(x,y)是表示prompt x x x和相应 y y y在参数为 θ \theta θ的奖励模型下的奖励值, y w y_w yw是在prompt x x x下生成的一对响应 y w y_w yw y l y_l yl中更受欢迎的那一个, D D D是比较的数据集。每一个排名对 y i , y j y_i,y_j yi,yj的损失是 − l o g ( σ ( y i − y j ) ) -log(\sigma(y_i-y_j)) log(σ(yiyj)),换成奖励函数就是 − l o g ( σ ( r θ ( x , y w ) ) − r θ ( x , y l ) ) -log(\sigma(r_{\theta}(x,y_w))-r_{\theta}(x,y_l)) log(σ(rθ(x,yw))rθ(x,yl)),然后共 C K 2 C_K^2 CK2个排序对,所以期望除以它。
    目标是最小化这个loss,也就是最大化这两个奖励的差值, l o g ( σ ) log(\sigma) log(σ)最开始的时候是把生成的每个输出对都作为单独的数据混洗到数据集中,这样的话就需要超过 ( K 2 ) \binom{K}{2} (2K)次前向传递,而且输出对之间有重复,这样容易过拟合,所以将所有的输出对都统一作为单个批处理元素进行训练,这样的话就只需要 K K K次前向传递,因为奖励模型只需要算出9个奖励。之所以取 K = 9 K=9 K=9,是因为考虑到人工标注的时候,很大一部分是花在读懂这个prompt,所以在 K = 4 K=4 K=4 K = 9 K=9 K=9之间,只多了不到一倍的时间,但是标注的数据由6变成了36,多了6倍
    最后,由于RM损失对于奖励的变化是不变的,我们使用偏差对奖励模型进行归一化,以便在进行RL之前,标记器演示的平均得分为0。
  • 强化学习(RL——Reinforcement learning),我们使用PPO在我们的环境中微调了SFT模型。该模型是一个bandit环境,它呈现随机的客户提示并期望对提示的响应。给定提示和相应,它会产生由奖励模型确定的奖励并结束情节。此外,我们在每个token上上添加了SFT模型的每个token的KL惩罚,以减轻奖励模型的过度优化。从RM初始化值函数。我们称这些模型为PPO。
    我们还尝试将预训练梯度混合到PPO梯度中,以修复公共NLP数据集上的性能回归。我们称这些模型为”PPO-ptx“。我们在RL训练中最大化以下组合目标函数:
    在这里插入图片描述
    其中 π Θ R L \pi_{\Theta}^{RL} πΘRL是学习到的RL策略, π S F T \pi^{SFT} πSFT是有监督训练的模型, D p r e t r a i n D_{pretrain} Dpretrain是预训练分布。KL奖励系数 β \beta β,预训练损失系数 γ \gamma γ分别控制KL惩罚和预训练梯度的强度。对于”PPO“模型, γ \gamma γ被设置为0,除非另有说明,本文中的InstructGPT指的是PPO-ptx模型。对于上面说的31k个prompts数据集 D D D,都使用当前的RL模型,也就是RL策略 π θ R L \pi_{\theta}^{RL} πθRL,输出 y y y,然后用RM模型得到分数 r θ ( x , y ) ,目标函数是希望这个分数最大化 r_{\theta}(x,y),目标函数是希望这个分数最大化 rθ(x,y),目标函数是希望这个分数最大化然后根据这个目标函数,更新RL模型,然后再用RM模型计算得分,反复迭代。
    目标函数中还有两项,在此分别解释一下, β l o g ( π Θ R L ( y ∣ x ) / π S F T ( y ∣ x ) ) \beta log(\pi_{\Theta}^{RL}(y|x)/\pi^{SFT}(y|x)) βlog(πΘRL(yx)/πSFT(yx))是正则项,这是PPO的主要思想,随着模型的更新,RL产生的输出 y y y和原始的 S F T SFT SFT模型输出的 y y y会逐渐不一样,即数据分布( y ∣ x y|x yx)的差异会越来越大, R L RL RL的输出可能会不准,所以论文在loss里加入了一个KL散度 KL ( P ∥ Q ) = ∑ x P ( x ) log ⁡ ( P ( x ) Q ( x ) ) = ∫ P ( x ) log ⁡ ( P ( x ) Q ( x ) ) d x \text{KL}(P \parallel Q) = \sum_{x} P(x) \log \left(\frac{P(x)}{Q(x)}\right)= \int P(x) \log \left(\frac{P(x)}{Q(x)}\right)\, dx KL(PQ)=xP(x)log(Q(x)P(x))=P(x)log(Q(x)P(x))dx,用于描述一个概率分布相对于另一个概率分布的非对称性差异,相当于用这个散度来正则,希望RLSFT的输出分布不要偏太远,因为是最大化目标函数,所以要最小化KL散度需要在前面加一个负号。
    γ E x D p r e t r a i n [ l o g ( π Θ R L ( x ) ) ] \gamma E_x ~ D_{pretrain}[log(\pi_{\Theta}^{RL}(x))] γEx Dpretrain[log(πΘRL(x))],由于前两项目标函数只和人类排序部分有关,所以训练出来会导致模型仅仅对排序的结果较好,而在最终任务通用NLP任务上性能会下降,所以论文在loss中加入了GPT-3预训练模型的目标函数, D p r e t r a i n D_{pretrain} Dpretrain表示从训练GPT-3的预训练数据中采样 x x x,然后输入RL模型得到输出概率 π Θ R L ( x ) \pi_{\Theta}^{RL}(x) πΘRL(x),这样相当于是GPT-3本身的损失函数。

    总的来说,如果 γ = 0 \gamma=0 γ=0就是一个PPO函数,否则就是一个PPO加上一个GPT-3的目标函数的结合成为RL模型的目标函数,也就是PPO-ptx
    在这里插入图片描述

讨论

论文提出,本文使用的”对齐技术“——RLHF,是用于对齐人类系统的一个重要方法。与预训练相比,增加模型对齐的成本是适中的(仅仅标注几万条prompt数据),与训练GPT-3的花费相比(海量的各种数据),只占一小部分。上述结果也表明,RLHF在使语言模型更加helpful(真实性和无害性是被隐式优化了)方面非常有效,甚至比模型增加100倍更有效。所以,在自然语言领域,研究alignment可能比训练更大规模的模型更具性价比。
align也有争议,就是到底要align人类到什么地步,是用户让做什么就做什么,还是要理解用户更深层的、内在的一些东西。此外最后的RL模型也不是必要的,如果在第一步多标数据,在GPT-3微调,步骤会变得简单,可能更加实用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/351920.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C++】模板进阶(特化)

🌈个人主页:秦jh_-CSDN博客🔥 系列专栏:https://blog.csdn.net/qinjh_/category_12575764.html?spm1001.2014.3001.5482 目录 非类型模板参数 数组越界检查 按需实例化 模板的特化 函数模板特化 类模板特化 全特化 ​…

LabVIEW故障预测

在LabVIEW故障预测中,振动信号特征提取的关键技术主要包括以下几个方面: 时域特征提取:时域特征是直接从振动信号的时间序列中提取的特征。常见的时域特征包括振动信号的均值、方差、峰值、峰-峰值、均方根、脉冲指数等。这些特征能够反映振动…

【文末附gpt升级秘笈】AI热潮降温与AGI场景普及的局限性

AI热潮降温与AGI场景普及的局限性 摘要: 随着人工智能(AI)技术的迅猛发展,AI热一度席卷全球,引发了广泛的关注和讨论。然而,近期一些学者和行业专家对AI的发展前景提出了质疑,认为AI热潮将逐渐…

如何警用root用户登录ssh

使用tail指令,可以动态查看日志信息。 (tail -f /var/log/secure或messages) 使用>符号,可以清空日志内容,不删除文件本身。 禁用root用户为以下步骤: 首先使用useradd创建用户(可以修改为其…

路由器虚拟服务器有什么作用

现如今在IPv4时代,由于公网IP地址的匮乏,约有70%的电脑都处于内网中,上网需要通过路由器。如果反过来想要访问身处内网的电脑,我们就需要在路由器里开放相应的端口才能实现。而这开放端口的功能,在路由器里就叫做虚拟服…

俄罗斯Yandex推广投放如何开户?Yandex广告开户和代运营推广流程详解_俄罗斯_受众_搜索引擎

在俄罗斯进行Yandex广告推广是一种有效的在线营销方式,特别是针对俄罗斯市场。Yandex是俄罗斯最受欢迎的搜索引擎,类似于Google在全球范围内的地位。以下是通过Yandex广告推广的一般步骤,以及如何通过上海上弦进行广告开户和代运营。 1. Yan…

GPT_AI高速发展中什么是Prompt提示词?

提示词(Prompt)是给大语言模型(以下简称模型)的输入文本,用于指定模型应该执行什么样的任务并生成什么样的输出。 提示词发挥了“提示” 模型 应该做什么的作用。设计高质量的提示词需要根据目标任务和模型能力进行精…

49.Python-web框架-Django解决多语言redirect时把post改为get的问题

目录 1.背景 2.思路 3.寻找 Find and Replace 4.再次运行程序,POST来了 5.小结 1.背景 昨天在练习一个Django功能时,把form的method设置为POST,但是实际提交时,一直是GET方法。最后发现这是与多语言相关,django前面…

架构设计 - MySQL 插入数据性能优化策略

mysql 数据库提高数据插入效率主要可以考虑以下方面: 使用批量插入数据的 SQL 语句,避免使用 for 循环逐条记录插入。 所有插入语句共用一个事务,避免1条SQL语句开1个事务,所有操作都完成后再提交事务。 尽量按照索引递增顺序插入…

【JVM】JVisualVM的介绍、使用和GC过程

VisualVM介绍 VisualVM 是Netbeans的profile子项目,已在JDK6.0 update 7 中自带,能够监控线程,内存情况,查看方法的CPU时间和内存中的对 象,已被GC的对象,反向查看分配的堆栈(如100个String对象分别由哪几…

【Python/Pytorch - 网络模型】-- SVD算法

文章目录 文章目录 00 写在前面01 基于Pytorch版本的SVD算代码02 理论知识 00 写在前面 (1)矩阵的奇异值分解在最优化问题、特征值问题、最小二乘方问题、广义逆矩阵问题及统计学等方面都有重要应用; (2)应用&#…

看穿人性!现货白银交易的一些博弈心得

很多投资者认为现货白银交易最应该讲求的是交易技巧,但交易的技巧和套路是“死”的,行情走势却是“活”的,投资者需要在实践中不断地累积经验和总结心得,才能更加灵活地面对行情走势的变化,逐步达至盈利的理想彼岸。 无…

Linux设置静态ip

Linux配置静态ip 把win的Vmnet8改成静态ip,和你虚拟机虚拟网络编辑器自动分配的ip网段设置成一样的,最后一位不是虚拟网络编辑器里的网关和子网ip一样就行,掩码点一下就自动出来了 桌面环境配置静态ip 虚拟机内填写,点击你的网络连接配置&a…

CentOS 5(CentOS 6、Redhat 6)服务器配置VNC

一、配置服务器yum源 yum源(本地、华为云、阿里云、网易) 二、使用yum安装vnc服务 1、检查系统是否安装了vnc 和 vncserver, rpm -qa | grep vnc如果没有安装那就行自行下载安装(我这里用yum安装了,vncserver安装需…

值得推荐的品牌维权控价方法

数据调查 全面了解线上各渠道(如淘宝、天猫、拼多多、京东、抖音、快手等)的低价情况,包括哪些是授权店低价、窜货或假货,为后续针对性治理提供依据。人工排查适用于链接不多的情况,链接数量庞大时利用系统监测更高效…

黑龙江等保测评的流程和注意事项

黑龙江等保测评(信息安全级别保护评估),是根据国家信息安全等级保护的有关标准,以保证信息系统的安全性,对信息系统所做的一种安全性评价。下面是对等保进行评估的具体过程和说明: 一、黑龙江等保测评流程 …

Ms08067安全实验室成功实施多家业务系统渗透测试项目

点击星标,即时接收最新推文 近日,Ms08067安全实验室针对多家公司重要系统实施渗透测试项目。公司网络信息系统的业务应用和存储的重要信息资产均较多,存在网络系统结构的复杂性和庞杂等特点,使得公司网络信息系统面临一定风险。项…

【安装笔记-20240616-Windows-Gpg4win 证书管理器】

安装笔记-系列文章目录 安装笔记-20240616-Windows-Gpg4win 证书管理器 文章目录 安装笔记-系列文章目录安装笔记-20240616-Windows-Gpg4win 证书管理器 前言一、软件介绍名称:Gpg4win主页官方介绍 二、安装步骤测试版本:Gpg4win 4.3.1下载链接安装界面…

C语言之常用字符串函数总结、使用和模拟实现

文章目录 目录 一、strlen 的使用和模拟实现 二、strcpy 的使用及模拟实现 三、strcat 的使用和模拟实现 四、strcmp 的使用和模拟实现 五、strncpy 的使用和模拟实现 六、strncat 的使用和模拟实现 七、strncmp 的使用和模拟实现 八、strstr 的使用和模拟实现 九、st…

海成蜘蛛池广州官网下载

baidu搜索:如何联系八爪鱼SEO? baidu搜索:如何联系八爪鱼SEO? baidu搜索:如何联系八爪鱼SEO? 当我们给自己的泛目录设置仅蜘蛛抓取生成缓存的时候,我们需要模拟蜘蛛抓取测试我们的设置是否成功。绝大部分时候我们都使用网页蜘蛛模拟抓取测…