【强化学习】DPO(Direct Preference Optimization)算法学习笔记

【强化学习】DPO(Direct Preference Optimization)算法学习笔记

  • RLHF与DPO的关系
  • KL散度
  • Bradley-Terry模型
  • DPO算法流程
  • 参考文献

RLHF与DPO的关系

  • DPO(Direct Preference Optimization)和RLHF(Reinforcement Learning from Human Feedback)都是用于训练和优化人工智能模型的方法,特别是在大型语言模型的训练中
  • DPO和RLHF都旨在通过人类的反馈来优化模型的表现,它们都试图让模型学习到更符合人类偏好的行为或输出
  • RLHF通常涉及三个阶段:全监督微调(Supervised Fine-Tuning)、奖励模型(Reward Model)的训练,以及强化学习(Reinforcement Learning)的微调
  • DPO是一种直接优化模型偏好的方法,不需要显式地定义奖励函数,而是通过比较不同模型输出的结果,选择更符合人类偏好的结果作为训练目标,主要是通过直接最小化或最大化目标函数来实现优化,利用偏好直接指导优化过程,而不依赖于强化学习框架
    在这里插入图片描述

KL散度

  • KL散度(Kullback-Leibler divergence),也被称为相对熵,是衡量两个概率分布P和Q差异的一种方法
  • 公式: K L ( P ∣ ∣ Q ) = ∑ x P ( x ) log ⁡ ( P ( x ) Q ( x ) ) \mathrm{KL}(P||Q)=\sum_xP(x)\log\left(\frac{P(x)}{Q(x)}\right) KL(P∣∣Q)=xP(x)log(Q(x)P(x))
  • KL散度是不对称的, K L ( P ∣ ∣ Q ) ! = K L ( Q ∣ ∣ P ) KL(P||Q)!=KL(Q||P) KL(P∣∣Q)!=KL(Q∣∣P)

在这里插入图片描述

Bradley-Terry模型

  • Bradley-Terry模型是一种用于比较成对对象并确定相对偏好或能力的方法。这种模型特别适用于对成对比较数据进行分析,从而对一组对象进行排序

  • P ( i > j ) = α i α i + α j P(i{>}j)=\frac{\alpha_i}{\alpha_i{+}\alpha_j} P(i>j)=αi+αjαi

  • α i \alpha_i αi表示第 i i i个元素的能力参数,且大于0。 P ( i > j ) P(i>j) P(i>j)表示第 i i i个元素战胜第 j j j个元素的概率

  • Bradley-Terry模型的参数通常通过最大似然估计(MLE)来确定
    在这里插入图片描述

  • sigmoid函数: σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1 + e^{-x}} σ(x)=1+ex1

  • loss函数的化简
    L o s s = − E ( x , y w , y l ) ∼ D [ ln ⁡ e x p ( r ( x , y w ) ) e x p ( r ( x , y w ) ) + e x p ( r ( x , y l ) ) ] = − E ( x , y w , y l ) ∼ D [ ln ⁡ 1 1 + e x p ( r ( x , y l ) − r ( x , y w ) ) ] = − E ( x , y w , y l ) ∼ D [ ln ⁡ σ ( r ( x , y w ) − r ( x , y l ) ) ] \begin{aligned}Loss &=-\mathbb{E}_{(x,y_{w},y_{l})\sim D}[\ln\frac{exp(r(x,y_{w}))}{exp(r(x,y_{w}))+exp(r(x,y_{l}))}] \\ &= -\mathbb{E}_{(x,y_{w},y_{l})\sim D}[\ln\frac{1}{1 + exp(r(x,y_{l})- r(x,y_{w}))}] \\ &= -\mathbb{E}_{(x,y_{w},y_{l})\sim D}[\ln \sigma(r(x,y_{w})-r(x,y_{l}))] \end{aligned} Loss=E(x,yw,yl)D[lnexp(r(x,yw))+exp(r(x,yl))exp(r(x,yw))]=E(x,yw,yl)D[ln1+exp(r(x,yl)r(x,yw))1]=E(x,yw,yl)D[lnσ(r(x,yw)r(x,yl))]

  • loss函数的目标是优化LLM输出的 y w y_w yw,经过reward计算的得分尽可能的大于 y w y_w yw经过reward计算的得分

在这里插入图片描述

DPO算法流程

  • DPO通过比较不同输出的偏好,构建一个目标函数,该函数直接反映人类的偏好,通常使用排序损失函数(例如Pairwise Ranking Loss),该函数用来衡量模型在用户偏好上的表现
  • DPO优化过程:使用梯度下降等优化算法,直接最小化或最大化目标函数。通过不断调整模型参数,使得模型生成的输出更加符合用户的偏好
    在这里插入图片描述
  • 基准模型一般指经过SFT有监督微调后的模型
  • DPO的目标是尽可能得到多的奖励,同时使得新训练的 模型尽可能与基准模型分布一致

DPO训练目标的化简

在这里插入图片描述
上图中第一步利用的是KL散度的定义,之所以式子中没有KL散度中的 P ( π ( y ∣ x ) ) P(\pi(y|x)) P(π(yx)),是因为KL散度可以理解成是一个概率比值的log的期望,在这里这个概率以期望的形式放到式子左边的期望中了

  • 求最大值 通过在式中加上负号转化为求最小值,并同时除以 β \beta β
  • DPO原论文中的推导过程

在这里插入图片描述

  • 继续推导

在这里插入图片描述
在这里插入图片描述

  • 求解reward函数的表达式,将reward函数的表达式代入loss函数中

在这里插入图片描述

  • DPO loss损失函数的表达形式

在这里插入图片描述

  • logZ(x)项被抵消,于是可以转而用最大似然估计MLE直接在这个概率模型上直接优化LM,去得到希望的最优的π*
    个人理解的一知半解 有时间还是得去看看原论文

参考文献

  1. DPO (Direct Preference Optimization) 算法讲解
  2. Direct Preference Optimization(DPO)学习笔记
  3. DPO原论文 Direct Preference Optimization: Your Language Model is Secretly a Reward Model

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/337221.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

根据状态转移图实现时序电路 (三段式状态机)

看图编程 * ** 代码 module seq_circuit(input C ,input clk ,input rst_n,output wire Y ); reg [1:0] current_stage ; reg [1:0] next_stage ; reg Y_reg; //输出//第一段 : 初始化当前状态和…

【再探】设计模式—访问者模式、策略模式及状态模式

访问者模式是用于访问复杂数据结构的元素,对不同的元素执行不同的操作。策略模式是对于具有多种实现的算法,在运行过程中可动态选择使用哪种具体的实现。状态模式是用于具有不同状态的对象,状态之间可以转换,且不同状态下对象的行…

【Gradle】Gradle的本地安装和使用

目录 1、Gradle 的安装 2、集成 IntelliJ IDEA 3、使用 Gradle Gradle 完全兼容 Maven 和 Ivy 仓库,你可以从中检索依赖也可以发布你的文件到仓库中,Gradle 提供转换器能把 Maven 的构建逻辑转换成 Gradle 的构建脚本。 1、Gradle 的安装 Gradle 的…

数据结构的快速排序(c语言版)

一.快速排序的概念 1.快排的基本概念 快速排序是一种常用的排序算法,它是基于分治策略的一种高效排序算法。它的基本思想如下: 从数列中挑出一个元素作为基准(pivot)。将所有小于基准值的元素放在基准前面,所有大于基准值的元素放在基准后面。这个过程称为分区(partition)操作…

微信小程序-页面导航

一、页面导航 页面导航指的是页面之间的相互跳转&#xff0c;例如&#xff1a;浏览器中实现页面导航的方式有如下两种&#xff1a; 1.<a>链接 2.location.href 二、小程序中实现页面导航的两种方式 1.声明式导航 在页面上声明一个<navigator>导航组件 通过点击…

AIGC微短剧轻量化制作

AIGC&#xff08;生成式AI&#xff09;和微短剧作为影视领域的发展趋势&#xff0c;热度持续不减。行业巨头纷纷涉足其中&#xff0c;头部公司也陆续宣布加入竞争。这一趋势带来了新一轮的资本狂欢。 事实上&#xff0c;无论是被视为未来发展的风向标&#xff0c;还是像只是一…

【核心动画-关键帧动画-CAKeyframeAnimation Objective-C语言】

一、接下来,我们来说这个关键帧动画, 1.我们把之前的基本动画,这一坨代码,备份到test1方法里边, 然后,开始说我们的关键帧动画,步骤都是一样的,都是三大步: // 关键帧动画 // 1.做什么动画 // 2.怎么做动画 // 3.对谁做动画 1)做什么动画 第一,我们现在要创建…

证件/文书类日期中文大写js/ts插件

说明 证件/文书类落款日期中文大写往往会将“零”写作“〇”&#xff0c;而数字依然使用简体“一二三”&#xff0c;而不是“壹贰叁”。 如下&#xff1a; 针对这一点&#xff0c;写了如下转换插件。 代码 function DateToUpperCase(date: Date new Date()) {const chStr …

【移动端】商场项目路由设计

1&#xff1a;路由设计配置&#xff1a; 一级路由配置 分析项目&#xff0c;设计路由&#xff0c;配置一级路由 一级路由&#xff1a;单个页面独立展示的&#xff0c;都是一级路由&#xff0c;例如&#xff1a;登录页面&#xff0c;首页架子&#xff0c;报错页面 二级路由&…

ADuM1201可使用π121U31间接替换π122U31直接替换

ADuM1201可使用π121U31间接替换π122U31直接替换 一般低速隔离通信150Kbps电路可使用π121U31&#xff0c;价格优势较大。速度快的有其它型号可达10M,200M,600M。 本文主要介绍ADUM1201,替换芯片π121U31简单资料请访问下行链接 只要0.74元的双通道数字隔离器&#xff0c;1T1…

【VSCode】快捷方式log去掉分号

文章目录 一、引入二、解决办法 一、引入 我们使用 log 快速生成的 console.log() 都是带分号的 但是我们的编程习惯都是不带分号&#xff0c;每次自动生成后还需要手动删掉分号&#xff0c;太麻烦了&#xff01; 那有没有办法能够生成的时候就不带分号呢&#xff1f;自然是有…

ubuntu 18.04 ros1学习

总结了一下&#xff0c;学习内容主要有&#xff1a; 1.ubuntu的基础命令 pwd: 获得当前路径 cd: 进入或者退出一个目录 ls:列举该文件夹下的所有文件名称 mv 移动一个文件到另一个目录中 cp 拷贝一个文件到另一个目录中 rm -r 删除文件 gedit sudo 给予管理员权限 sudo apt-…

开源硬件初识——Orange Pi AIpro(8T)

开源硬件初识——Orange Pi AIpro&#xff08;8T&#xff09; 大抵是因为缘&#xff0c;妙不可言地就有了这么一块儿新一代AI开发板&#xff0c;乐于接触新鲜玩意儿的小火苗噌一下就燃了起来。 还没等拿到硬件&#xff0c;就已经开始在Orange Pi AIpro 官网上查阅起资料&…

基于安卓的虫害识别软件设计--(1)模型训练与可视化

引言 简介&#xff1a;使用pytorch框架&#xff0c;从模型训练、模型部署完整地实现了一个基础的图像识别项目计算资源&#xff1a;使用的是Kaggle&#xff08;每周免费30h的GPU&#xff09; 1.创建名为“utils_1”的模块 模块中包含&#xff1a;训练和验证的加载器函数、训练…

如何使用Spring Cache优化后端接口?

Spring Cache是Spring框架提供的一种缓存抽象,它可以很方便地集成到应用程序中,用于提高接口的性能和响应速度。使用Spring Cache可以避免重复执行耗时的方法,并且还可以提供一个统一的缓存管理机制,简化缓存的配置和管理。 本文将详细介绍如何使用Spring Cache来优化接口,…

【前端】Mac安装node14教程

在macOS上安装Node.js版本14.x的步骤如下&#xff1a; 打开终端。 使用Node Version Manager (nvm)安装Node.js。如果你还没有安装nvm&#xff0c;可以使用以下命令安装&#xff1a; curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.1/install.sh | bash 然后关…

基于NANO 9K 开发板加载PICORV32软核,并建立交叉编译环境

目录 0. 环境准备 1. 安装交叉编译器 2. 理解makefile工作机理 3. 熟悉示例程序的代码结构&#xff0c;理解软核代码的底层驱动原理 4. 熟悉烧录环节的工作机理&#xff0c; 建立下载环境 5. 编写例子blink&#xff0c; printf等&#xff0c; 加载运行 6. 后续任务 0.…

卷积网络迁移学习:实现思想与TensorFlow实践

摘要&#xff1a;迁移学习是一种利用已有知识来改善新任务学习性能的方法。 在深度学习中&#xff0c;迁移学习通过迁移卷积网络&#xff08;CNN&#xff09;的预训练权重&#xff0c;实现了在新领域或任务上的高效学习。 下面我将详细介绍迁移学习的概念、实现思想&#xff0c…

【成品设计】基于STM32单片机的饮水售卖机

基于STM32单片机的饮水售卖机 所需器件&#xff1a; STM32最小系统板。RFID&#xff1a;MFRC-522用于IC卡检测。OLED屏幕&#xff1a;用于显示当前水容量、系统状态等。水泵软管&#xff1a;用于抽水。水位传感器&#xff08;3个&#xff09;&#xff1a;用于分别标定&#x…

低代码赋能企业数字化转型:数百家软件公司的成功实践

本文转载于葡萄城公众号&#xff0c;原文链接&#xff1a;https://mp.weixin.qq.com/s/gN8Rq9TDmkMpCtNMMsBUXQ 导读 在当今的软件开发时代&#xff0c;以新技术助力企业数字化转型已经成为一个热门话题。如何快速适应技术变革&#xff0c;构建符合时代需求的技术能力和业务模…