ChatGLM-RLHF(六)-PPO(Proximal Policy Optimization)原理实现代码逐行注释

 一,前言

从open AI 的论文可以看到,大语言模型的优化,分下面三个步骤,SFT,RM,PPO,我们跟随大神的步伐,来学习一下这三个步骤和代码实现,本章介绍PPO代码实现。


上章我们介绍了PPO算法的公式,其形式如下:

$ L_{\text{CLIP}+\text{VF}+S}(\theta) = \hat{E}_t [ L_{\text{CLIP},t}(\theta) - c_1 L_{\text{VF},t}(\theta) + c_2 S[\pi_\theta](s_t)] $。      

其中(1)

L_{\text{CLIP}}(\theta) = \hat{E}_t \left[ \min \left( r_t(\theta)A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)A_t \right) \right]

对应下面优化

$J(\theta) = E_{\tau \sim p_{\mu}(\tau)}[\sum_{t=0}^T \rho_t A_t \log \pi_{\theta}(a_t|s_t)] +E_{\tau \sim p_{\theta_{old}}(\tau)}[\min(r_t(\theta)A_t, clip(r_t(\theta), 1-\epsilon, 1+\epsilon)A_t)]$

其中(2)

$L_{\text{VF},t}$是一个平方误差损失 $(V_{\theta}(s_t)-V_{\text{targ},t})^2$

其中(3)

$S$ 表示熵奖励,$c_1$$c_2$ 是系数,

下面是公式与代码的对应关系:

一,优势函数(Advantage)计算,对应公式里的At:

1. 优势函数用于衡量当前状态或行动相对于期望值的优劣程度。A(s,a) = Q(s,a) - V(s),

2. $A_t$ 表示时间步 t 的优势函数,然后使用 $A_t$$r_t(\theta)$ 计算出 $E_{\tau \sim p_{\theta_{old}}(\tau)}[\min(r_t(\theta)A_t, clip(r_t(\theta), 1-\epsilon, 1+\epsilon)A_t)]$部分的损失。 $r_t(\theta)$由当前策略 $\pi_{\theta}(a_t|s_t)$ 和旧策略 $\pi_{\theta_{\text{old}}}(a_t|s_t)$ 分别计算出,同时,$A_t$还用于计算E_{\tau \sim p_{\mu}(\tau)}[\sum_{t=0}^T \rho_t A_t \log \pi_{\theta}(a_t|s_t)] 部分的损失。优势函数的计算是一个重要的部分,它帮助我们估计当前策略相对于旧策略的性能提升。

3. 代码里计算$A_t$没有直接使用Q(s,a) - V(s),而是使用了GAE-Lambda 算法进行计算,GAE-Lambda 算法通过将多个时间步的奖励加权平均,计算出当前时间步的优势函数,GAE-Lambda 算法的优点在于,它不仅考虑了当前时间步的奖励,还考虑了未来时间步的奖励,并且通过 Lambda 参数进行加权平均,从而更好地估计当前状态或行动的价值。这使得 GAE-Lambda 算法在训练强化学习模型时具有更好的稳定性和收敛性。

GAE-Lambda 算法的计算过程如下:

  1. 对于每个时间步 t,计算 delta,即当前时间步的奖励加上折扣因子乘以下一个时间步的值减去当前时间步的值:

    delta = rewards[t] + gamma * values[t+1] - values[t]

  2. 对于每个时间步 t,计算 GAE-Lambda,即 delta 加上折扣因子乘以 Lambda 倍的上一个时间步的 GAE-Lambda:

    lastgaelam = delta + gamma * lam * lastgaelam

  3. 将计算得到的 GAE-Lambda 添加到 advantages_reversed 列表中。

  4. 将 advantages_reversed 列表转换为张量,并进行维度转置,得到最终的优势函数张量 advantages。

  5. 具体如下代码

# 计算优势函数
for t in reversed(range(gen_len)):nextvalues = values[:,t + 1] if t < gen_len - 1 else last_values  # 获取下一个时间步的值,如果当前时间步是最后一个时间步,则使用 last_valuesdelta = rewards[:, t] + self.config.gamma * nextvalues - values[:,t]  # 计算 delta,即当前时间步的奖励加上折扣因子乘以下一个时间步的值减去当前时间步的值lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam  # 计算 GAE-Lambda,即 delta 加上折扣因子乘以 Lambda 倍的上一个时间步的 GAE-Lambdaadvantages_reversed.append(lastgaelam)  # 将计算得到的 GAE-Lambda 添加到 advantages_reversed 列表中
advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)  # 将 advantages_reversed 列表转换为张量,并进行维度转置

二,策略函数的损失(Policy Function Loss)的计算:

这部分对应公式E_{\tau \sim p_{\mu}(\tau)}[\sum_{t=0}^T \rho_t A_t \log \pi_{\theta}(a_t|s_t)]

在PPO算法中,我们采用两种不同的方式计算策略损失,即pg_losses和pg_losses2。这两种方式分别对应目标函数中的两个部分。

pg_losses表示使用原始比率计算得到的损失,即:

$ L^{PG}_1(\theta) = -\frac{1}{N} \sum_{i=1}^N \sum_{t=0}^{T_i} \rho_{i,t} A_{i,t} \log \pi_{\theta}(a_{i,t}|s_{i,t}) $

其中,N表示采样轨迹的数量,$\rho_{i,t}$ 表示第 i 条轨迹在时间步 t 的重要性采样比例,$A_{i,t}$表示第 i 条轨迹在时间步 t 的优势函数。

pg_losses2表示使用限制后的比率计算得到的损失,即:

$ L^{PG}_2(\theta) = -\frac{1}{N} \sum_{i=1}^N \sum_{t=0}^{T_i} \min(r_{i,t}(\theta)A_{i,t}, \text{clip}(r_{i,t}(\theta), 1-\epsilon, 1+\epsilon)A_{i,t}) \log \pi_{\theta}(a_{i,t}|s_{i,t}) $

其中,$r_{i,t}(\theta)$ 表示第i条轨迹在时间步t的比率,$\epsilon$表示剪切幅度。

最终,将两种方式计算得到的损失取较大值,即:

pg_loss = \max(pg_losses, pg_losses2)

            # 策略函数的损失logprobs = F.log_softmax(logits, dim=1)ratio = torch.exp(logprobs - old_logprobs)pg_losses = -advantages * ratiopg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange)pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), masks)pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).double(), masks)

总损失计算

            # 总损失loss = pg_loss + self.config.vf_coef * vf_loss

 三,值函数的损失(Value Function Loss)的计算

值函数  $L_{\text{VF},t}$的损失公式通常使用均方差(Mean Squared Error,MSE)来衡量值函数的预测误差。值函数的损失公式可以表示为:

优化公式为 L(θ) = 0.5 * E[(V(s) - R)^2]

其中,L(θ)表示值函数的损失,θ表示值函数的参数,V(s)表示值函数对状态s的预测值,R表示实际的回报值。

这个公式的含义是,首先,通过 clip_by_value 函数将当前状态的价值函数 values 限制在一个区间内,得到 vpredclipped。然后,分别计算使用原始价值函数和限制后的价值函数计算得到的损失,即 vf_losses1 和 vf_losses2。通过计算值函数对状态的预测值与实际回报值之间的差异的平方,来衡量值函数的预测误差。然后取这些差异的平方的期望值,再乘以0.5,得到最终的损失值。最终,将两者的较大值作为值函数的损失,通过 masked_mean 函数计算期望。

            # 值函数的损失vpredclipped = clip_by_value(values, values - self.config.cliprange_value, values + self.config.cliprange_value)vf_losses1 = (values - returns) ** 2vf_losses2 = (vpredclipped - returns) ** 2vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), masks)vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).double(), masks)

四, 完整代码可以参考:

以上我们就完成了PPO公式的优化。完整代码参考下面git   /ptuning/utils/ppo_trainer.py文件

GitHub - Pillars-Creation/ChatGLM-RLHF-LoRA-RM-PPO: ChatGLM-6B添加了RLHF的实现,以及部分核心代码的逐行讲解 ,实例部分是做了个新闻短标题的生成

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/83290.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

无线液位传感器—简介

近年来&#xff0c;随着无线传感网络技术的愈发成熟和稳定&#xff0c;无线传感器因其安装、维护方便&#xff0c;不用布线、节约成本&#xff0c;监测方便&#xff0c;使用灵活&#xff0c;可适用于多种工业领域等优点&#xff0c;正在逐步替代部分传统有线传感器&#xff0c;…

LabVIEW对并行机器人结构进行建模仿真

LabVIEW对并行机器人结构进行建模仿真 为了对复杂机器人结构的数学模型进行建模、搜索、动画和验证&#xff0c;在工业机器人动态行为实验室中&#xff0c;设计并实现了具有五个自由度的单臂型机器人。在研究台上可以区分以下元素&#xff1a;带有直流电机和编码器的机器人;稳…

Linux基础学习

文章目录 Linux命令学习Linux环境准备Linux命令行学习Linux命令行格式与文件系统linux实用命令笔记Linux文件权限查看 Linux命令学习 理解Linux命令是什么 &#xff08;图形化的操作&#xff0c;文件查看&#xff0c;浏览器打开&#xff09; 你打开一个谷歌浏览器&#xff0c;…

【考研复习】24王道数据结构课后习题代码|第3章栈与队列

文章目录 3.1 栈3.2 队列3.3 栈和队列的应用 3.1 栈 int symmetry(linklist L,int n){char s[n/2];lnode *pL->next;int i;for(i0;i<n/2;i){s[i]p->data;pp->next;}i--;if(n%21) pp->next;while(p&&s[i]p->data){i--;pp->next;}if(i-1) return 1;…

Godot 4 源码分析 - 碰撞

碰撞功能应该是一个核心功能&#xff0c;它能自动产生相应的数据&#xff0c;比如目标对象进入、离开本对象的检测区域。 基于属性设置&#xff0c;能碰撞的都具备这样的属性&#xff1a;Layer、Mask. 在Godot 4中&#xff0c;Collision属性中的Layer和Mask属性是用于定义碰撞…

8个值得收藏的在线3D建模工具

如今&#xff0c;许多设计师、艺术家和建筑师尝试学习进行 3D 建模来表达他们的想法。 但 3D 建模并不总是看起来那样。 我们所有人都很难找到合适的工具&#xff0c;尤其是在学习阶段。 但不要害怕&#xff01; 你可以学习仅使用浏览器进行建模。 有许多基于浏览器的 3D 建模…

成功解决Android设备adb连接后显示device unauthorized

一、提出问题 在电脑通过USB连接新的Android设备&#xff0c;想要通过adb来进行一些操作时&#xff0c;却发现命令提示符上在输入下面命令后显示设备未授权的信息也就是"unauthorized" adb devices二、不可行的解决方案 有人提出的解决方案是打开Android设备的开发…

selenium获取b站视频标题

一、下载selenium 1. 下载对应版本的浏览器驱动 2. 安装selenium 3.把浏览器驱动放到使用的python内核的script目录中 二、测试效果模拟登录b站 from selenium import webdriver from selenium.webdriver.common.by import By import timebrowser webdriver.Chrome() # 打…

C++/Qt 读写文件

之前写过两篇跟文件操作相关的博客&#xff0c;有兴趣也可以看一下&#xff1a; C语言读写文件 Qt关于文件路径的处理 先讲一些关于基础文本文件和二进制文件的读写操作&#xff0c;后续将会整理C/Qt关于ini、xml、json、xlsx相关文件的读写操作。 C 相比于C语言使用FILE文…

冠达管理:股票注册制通俗理解?

目前我国A股商场正在进行股票注册制变革&#xff0c;相较之前的发行准则&#xff0c;股票注册制在理念上更为商场化&#xff0c;这意味着公司发行股票的门槛将下降&#xff0c;公司数量将添加&#xff0c;而股票流通的方式也将有所改变。那么股票注册制指的是什么&#xff0c;它…

ChatGPT会取代搜索引擎吗?BingChat、GoogleBard与ChatGPT区别

目前暂时不会&#xff0c;ChatGPT为代表的聊天机器人很可能会直接集成到搜索中&#xff0c;而不是取代它。微软已经通过Bing Chat和Bing做到了这一点&#xff0c;它将“聊天”选项卡直接放入Bing搜索的菜单中。Google、百度也分别开始尝试通过其AI生成技术将Google Bard、文心一…

栈和队列(二) 队列操作详解及栈与队列的相互实现

文章目录 四、队列1、什么是队列2、队列的基本操作Queue.hQueue.c初始化队列队尾入队列队头出队列获取队列头部元素获取队列队尾元素获取队列中有效元素个数检测队列是否为空&#xff0c;如果为空返回非零结果&#xff0c;如果非空返回0销毁队列 五、设计循环队列六、栈与队列的…

【Linux的开胃小菜】Linux系统安装后初始化配置操作

我们刚接手一台刚安装好服务器系统之后&#xff0c;可以对系统进行一些基础优化&#xff1a; 常规设定&#xff1a; centos: 1.关闭 iptables 2.关闭 selinux 3.设定 ChronyUbuntu: 4. /etc/security/limits.conf 5. /etc/sysctl.conf1.首先使用国内阿里云的yum源&#xff08…

Electron学习1 安装环境与第一个程序

Electron学习1 安装环境与第一个程序 一、 Electron 简介二、安装 nvm三、安装nodejs四、安装nrm五、安装electron1. npm 初始化2. 创建 package.json3. 安装electron4. 创建一个页面5. 创建文件main.js6. 创建预加载器文件 preload.js7. 启动程序 六、打包 一、 Electron 简介…

windows .gitignore 加入文件名后 依然可以从git status中看到文件问题

最近在学git&#xff0c;对着b站的视频操作&#xff0c;结果很简单的添加.gitignore文件操作&#xff0c;up主的正常隐藏&#xff0c;我的却一直出问题。 百思不得其解&#xff0c;网上各种啥啥啥清缓存都没讲到点上。 最后发现是.gitignore文件有问题&#xff0c;windows默认…

uniapp 实现滑动视图切换 顶部滚动导航栏

无论小程序的时候一般有这个功能,在页面处于首页时候,滑动视图,切换视图顶部滚动导航也跟着切换 1.想要实现这个功能就需要实现顶部导航栏,首先实现顶部滚导航栏 点击高亮颜色显示 模板代码 <scroll-view scroll-x"true" class"scroll-content" > …

IDEA离线安装插件

一、背景 有时&#xff0c;在ideal中我们无法获取到插件&#xff0c;可能是因为内网或者无法访问插件库等原因&#xff0c;此时我们需要离线安装插件 IDEA离线仓库&#xff1a;https://plugins.jetbrains.com/ 二、步骤 2.1 下载插件&#xff1a;https://plugins.jetbrains.…

护网行动 | AD360 在网络安全中的重要作用

随着数字化时代的来临&#xff0c;网络已经成为了人们生活和工作中不可或缺的一部分。然而&#xff0c;随之而来的是网络安全问题日益突出。为了应对这些安全威胁&#xff0c;护网行动应运而生&#xff0c;其中AD360在保障网络安全方面扮演着至关重要的角色。 AD360是一个集成的…

nginx 负载均衡

1.环境准备 我使用的说centos7的系统 1.20版本的nginx 另外还有3台虚拟机 主机&#xff1a;192.168.163.142 两台服务器&#xff1a;服务器A--192.168.163.140 服务器B---192.168.163.141 2.配置服务器A和B 找到nginx下的html目录&#xff0c;编辑其中的index.html(在此…

FreeRTOS(任务管理的创建、删除、挂起、恢复)

目录 一、任务的基本概念 二、任务状态的概念 1、Running—运行态&#xff1a; 2、Ready—就绪态 3、Blocked—阻塞态 4、Suspended—挂起态 三、任务状态的切换 四、系统启动 1、vTaskStartScheduler()函数 1.1 作用 1.2 启动函数介绍 2、空闲任务 2.1 空闲任务的作…