PPO(Proximal Policy Optimization)算法介绍

引言

       Proximal Policy Optimization(PPO) 算法,是强化学习领域的一种先进算法,由OpenAI的研究人员在2017年提出。它以其高效性、稳定性和易于实现等优点,广泛应用于各类强化学习任务,尤其是在大规模模型的策略优化中。

一、背景与动机

       在策略优化的强化学习中,目标是找到一个策略,使得在与环境交互时获得的累计奖励最大。早期的策略梯度方法如REINFORCE,虽然概念简单,但在实践中可能出现学习效率低、收敛慢等问题。

       为了解决策略更新过程中可能出现的剧烈变化和不稳定性,研究人员提出了 信赖域策略优化(Trust Region Policy Optimization,TRPO) 算法。TRPO通过限制新旧策略的KL散度,保证每次更新不会偏离过远。但TRPO实现复杂,计算代价高,不易于大规模应用。

       PPO的提出旨在以一种更简单、高效的方式,实现类似于TRPO的效果,避免策略更新过大导致的性能下降,同时保持实现上的简洁性。

二、PPO的核心思想

       PPO的核心思想是在策略更新时,通过修改损失函数,限制新旧策略之间的差异,从而防止策略更新过大导致不稳定性。这种方法被称为“接近策略优化”(Proximal Policy Optimization),因为每次更新都使得新策略仅在“接近”于旧策略的范围内改进。

三、PPO的关键技术细节

3.1 概率比率(Probability Ratio)

       在策略梯度方法中,我们通常需要计算策略的梯度。PPO引入了概率比率来度量新旧策略在某个状态下采取某动作的概率之比:

r t ( θ ) = π θ ( a t ∣ s t ) π θ old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} rt(θ)=πθold(atst)πθ(atst)

π θ \pi_\theta πθ:参数为 θ \theta θ 的新策略。
π θ old \pi_{\theta_{\text{old}}} πθold:旧策略。

3.2 损失函数设计

       PPO定义了一个新型的剪辑(Clipped)损失函数,以限制策略更新的范围:

L CLIP ( θ ) = E t [ min ⁡ ( r t ( θ ) A ^ t , clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A ^ t ) ] L^{\text{CLIP}}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) \hat{A}_t, \ \text{clip}\left( r_t(\theta), 1 - \epsilon, 1 + \epsilon \right) \hat{A}_t \right) \right] LCLIP(θ)=Et[min(rt(θ)A^t, clip(rt(θ),1ϵ,1+ϵ)A^t)]

  • A ^ t \hat{A}_t A^t:优势函数的估计,表示在状态 s t s_t st 采取动作 a t a_t at 相对于某基准策略的优势。
  • ϵ \epsilon ϵ:一个很小的正数,通常取值如0.1或0.2,用于限制策略更新的幅度。
  • clip ( ⋅ ) \text{clip}(\cdot) clip():剪辑函数,将概率比率 r t ( θ ) r_t(\theta) rt(θ)限制在 [ 1 − ϵ , 1 + ϵ ] [1 - \epsilon, 1 + \epsilon] [1ϵ,1+ϵ] 范围内。

剪辑机制的作用:
       当 r t ( θ ) r_t(\theta) rt(θ) 偏离1的程度超过 ϵ \epsilon ϵ时,损失函数会被剪辑,以避免对策略参数的过度更新。这种机制在增大收敛速度的同时,保证了策略更新的稳定性。

3.3 优化目标

       PPO的优化目标是最大化上述剪辑损失函数,即:

θ new = arg ⁡ max ⁡ θ L CLIP ( θ ) \theta_{\text{new}} = \arg\max_\theta L^{\text{CLIP}}(\theta) θnew=argθmaxLCLIP(θ)

通过梯度上升方法,对策略参数 ( \theta ) 进行迭代更新。

四、PPO的工作流程

  • 1.采集数据: 在当前策略 π θ old \pi_{\theta_{\text{old}}} πθold 下,与环境交互,生成一系列状态、动作、奖励数据。

  • 2.计算优势函数 A ^ t \hat{A}_t A^t 利用时间差分(TD)方法或广义优势估计(GAE)来估计优势函数。

  • 3.计算概率比率 r t ( θ ) r_t(\theta) rt(θ) 根据新旧策略计算概率比率。

  • 4.更新策略参数 θ \theta θ 通过优化剪辑损失函数 L CLIP ( θ ) L^{\text{CLIP}}(\theta) LCLIP(θ),使用梯度上升或优化器(如Adam)更新策略参数。

  • 5.重复迭代: 更新后的策略作为新的旧策略,重复上述过程,直到收敛或达到预定的训练轮数。

五、PPO在RLHF中的应用

       在RLHF中,PPO被用于微调预训练语言模型的策略,使其生成的内容更符合人类偏好。

  1. 策略模型 π θ \pi_\theta πθ
  • 预训练的语言模型,被视为生成文本的策略,其参数为 θ \theta θ
  1. 奖励模型( R )
  • 奖通过人类反馈训练得到的模型,用于评估策略模型生成的文本质量,输出一个奖励值。
  1. 优化过程
  • 生成文本:策略模型根据输入提示生成回复。
  • 计算奖励:奖励模型评估生成的回复,得到奖励值 R ( s t , a t ) R(s_t, a_t) R(st,at)
  • 优势估计:计算优势函数 A ^ t \hat{A}_t A^t,通常基于奖励和价值函数估计。
  • 策略更新:使用PPO算法,优化策略模型的参数 θ \theta θ,最大化预期奖励。
  1. 优势
  • 效率:PPO的高效性使得在大规模模型上进行策略优化成为可能。
  • 稳定性:剪辑机制保证了策略更新的稳定性,避免了生成质量的大幅波动。
  • 可控性:通过奖励模型,结合PPO,可以精细地调控模型生成的行为,使其更符合人类偏好。

六、PPO的优点与局限

优点

  • 易于实现:相比于TRPO等算法,PPO的实现更加简单明了。
  • 高效稳定:在保证策略更新稳定性的同时,保持了较高的样本效率。
  • 适用广泛:适用于离散和连续动作空间的任务。

局限

  • 超参数敏感:需要仔细调节超参数 ϵ \epsilon ϵ 以及学习率等。
  • 样本相关性:需要足够多的样本来估计优势函数,否则可能影响收敛性能。

七、总结

       PPO通过引入剪辑损失函数,有效地限制了策略更新的幅度,兼顾了学习效率和更新稳定性。在RLHF中,PPO作为微调预训练语言模型的关键算法,发挥了重要作用,使得模型能够从人类反馈中高效学习,生成更符合人类期望的内容。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/465823.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

esp32cam+Arduino IDE在编译时提示找不到 esp_camera.h 的解决办法

多半是因为你的ESP32库升级了,不再是 1.02版本,或者根本就没有 ESp32 库。如果被升级了,还原为1.02版本就可以了。如果没有,按照下述方法添加: 首先,在"文件"->"首选项"->"…

基于物联网设计的地下煤矿安全监测与预警

文章目录 一、前言1.1 项目介绍【1】项目开发背景【2】设计实现的功能【3】项目硬件模块组成 1.2 设计思路1.3 系统功能总结1.4 开发工具的选择【1】设备端开发【2】上位机开发 1.5 模块的技术详情介绍【1】NBIOT-BC26模块【2】MQ5传感器【4】DHT11传感器【5】红外热释电人体检…

第8章 利用CSS制作导航菜单作业

1.利用CSS技术&#xff0c;结合链接和列表&#xff0c;设计并实现“山水之间”页面。 浏览效果如下&#xff1a; HTML代码如下&#xff1a; <!DOCTYPE html> <html><head><meta charset"utf-8" /><title>山水之间</title><…

32单片机HAL库的引脚初始化

在使用HAL库时&#xff0c;GPIO初始化函数定义在stm32f4xx_hal_gpio.c文件中&#xff0c;如下&#xff1a; void HAL_GPIO_Init(GPIO_TypeDef *GPIOx, GPIO_InitTypeDef *GPIO_Init); 由这个函数可以看出&#xff0c;在初始化GPIO时&#xff0c;需要向函数传入2个结构体&…

Django安装

在终端创建django项目 1.查看自己的python版本 输入对应自己本机python的版本&#xff0c;列如我的是3.11.8 先再全局安装django依赖包 2.在控制窗口输入安装命令&#xff1a; pip3.11 install django 看到Successflully 说明我们就安装成功了 python的Scripts文件用于存…

网络层5——IPV6

目录 一、IPv6 vs IPv4 1、对IPv6主要变化 2、IPv4 vs IPv6 二、IPv6基本首部 1、版本——4位 2、通信量类——8位 3、流标号——20位 4、有效载荷长度——16位 5、下一个首部——8位 6、跳数限制——8位 7、源 、 目的地址——128位 8、扩展首部 三、IPv6地址 1…

怎么样鉴定疾病相关稀有细胞群?二值化精细模型标签,这个刚发的顶刊单细胞算法值得一学!

生信碱移 HiDDEN&#xff1a;抽丝剥茧 在具有病例和对照单细胞RNA测序研究中&#xff0c;样本级标签通常被直接赋予单个细胞&#xff0c;假设所有病例细胞都受影响。这种传统方法在受影响细胞比例较小或扰动强度较弱时&#xff0c;难以有效识别关键细胞及其标记基因&#xff…

三周精通FastAPI:33 在编辑器中调试

官方文档&#xff1a;https://fastapi.tiangolo.com/zh/tutorial/debugging/ 调试 你可以在编辑器中连接调试器&#xff0c;例如使用 Visual Studio Code 或 PyCharm。 调用 uvicorn 在你的 FastAPI 应用中直接导入 uvicorn 并运行&#xff1a; import uvicorn from fast…

Spring Boot关闭时,如何确保内存里面的mq消息被消费完?

1.背景 之前写一篇文章Spring Boot集成disruptor快速入门demo&#xff0c;有网友留言如下图&#xff1a; 针对网友的留言&#xff0c;那么我们如何解决这个问题呢 Spring-Boot应用停机时&#xff0c;如何保证其内存消息都处理完成&#xff1f; 2.解决方法 方法其实挺简单的&…

vue3+vite搭建脚手架项目使用eletron打包成桌面应用+可以热更新

当前Node版本&#xff1a;18.12.0&#xff0c;npm版本&#xff1a;8.19.2 1.搭建脚手架项目 搭建Vue3ViteTs脚手架-CSDN博客 可删掉index.html文件的title标签 2.配置package.json {"name": "my-vite-project","private": true,"versi…

【Golang】validator库的使用

package mainimport ("fmt""github.com/go-playground/validator" )// MyStruct .. validate:"is-awesome"是一个结构体标签&#xff0c;它告诉验证器使用名为is-awesome的验证规则来验证String字段。 type MyStruct struct {String string vali…

Linux(CentOS)安装 MySQL

CentOS版本&#xff1a;CentOS 7 MySQL版本&#xff1a;MySQL Community Server 8.4.3 LTS 1、下载 MySQL 打开MySQL官网&#xff1a;https://www.mysql.com/ 直接下载网址&#xff1a;https://dev.mysql.com/downloads/mysql/ 其他版本 2、上传 MySQL 文件到 CentOS 使用F…

Pytorch实现transformer语言模型

转载自&#xff1a;| 03_language_model/02_Transformer语言模型.ipynb | 从头训练Transformer语言模型 |Open In Colab | Transformer语言模型 本节训练一个 sequence-to-sequence 模型&#xff0c;使用pytorch的 nn.Transformer <https://pytorch.org/docs/master/nn.ht…

<Project-20 YT-DLP> 给视频网站下载工具 yt-dlp/yt-dlp 加个页面 python web

介绍 yt-dlp Github 项目&#xff1a;https://github.com/yt-dlp/yt-dlp A feature-rich command-line audio/video downloader 一个功能丰富的视频与音频命令行下载器 原因与功能 之前我用的 cobalt 因为它不再提供Client Web功能&#xff0c;只能去它的官网使用。 翻 redd…

Sqli-Labs

目录 解题思路 题目设计原理 总结 解题思路 什么&#xff1f;sqli-labs&#xff1f;让我看看。还真是。想起了当初刚学被支配的恐惧。 悄咪咪点开第一关看看能不能秒了。测试闭合老样子&#xff0c;单引号闭合&#xff0c;双引号等都成功。这里 and 11 和 # 都不能通过检测&…

【基于Zynq FPGA对雷龙SD NAND的测试】

一、SD NAND 特征 1.1 SD 卡简介 雷龙的 SD NAND 有很多型号&#xff0c;在测试中使用的是 CSNP4GCR01-AMW 与 CSNP32GCR01-AOW。芯片是基于 NAND FLASH 和 SD 控制器实现的 SD 卡。具有强大的坏块管理和纠错功能&#xff0c;并且在意外掉电的情况下同样能保证数据的安全。 …

【NOIP提高组】引水入城

【NOIP提高组】引水入城 &#x1f490;The Begin&#x1f490;点点关注&#xff0c;收藏不迷路&#x1f490; 在一个遥远的国度&#xff0c;一侧是风景秀美的湖泊&#xff0c;另一侧则是漫无边际的沙漠。该国的行政 区划十分特殊&#xff0c;刚好构成一个N行M列的矩形&#xff…

鸿蒙开发:arkts 如何读取json数据

为了支持ArkTS语言的开发&#xff0c;华为提供了完善的工具链&#xff0c;包括代码编辑器、编译器、调试器、测试工具等。开发者可以使用这些工具进行ArkTS应用的开发、调试和测试。同时&#xff0c;华为还提供了DevEco Studio这一一站式的开发平台&#xff0c;为运行在Harmony…

OpenCV视觉分析之目标跟踪(11)计算两个图像之间的最佳变换矩阵函数findTransformECC的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 根据 ECC 标准 78找到两幅图像之间的几何变换&#xff08;warp&#xff09;。 该函数根据 ECC 标准 ([78]) 估计最优变换&#xff08;warpMatri…

【2024最新版Kotlin教程】Kotlin第一行代码系列第五课-类继承,抽象类,接口

【2024最新版Kotlin教程】Kotlin第一行代码系列第五课-类继承&#xff0c;抽象类&#xff0c;接口 为什么要有继承呢&#xff0c;现实中也是有继承的&#xff0c;对吧&#xff0c;你继承你爸的遗产&#xff0c;比如你爸建好了一个房子&#xff0c;儿子继承爸&#xff0c;就得了…