大语言模型-RLHF(七)-PPO实践(Proximal Policy Optimization)原理实现代码逐行注释

从open AI 的论文可以看到,大语言模型的优化,分下面三个步骤,SFT,RM,PPO,我们跟随大神的步伐,来学习一下这三个步骤和代码实现,本章介绍PPO实践。

生活中,我们经常会遇到,希望chatgpt在指定内容范围内回答问题。目前的解决方案大致可以分为两大类,一类是知识库外挂,代表作如langchain。把chatgpt的结果转换为向量在知识库里检索。如下图,本质上最终还是一种向量检索,chatgpt的能力其实是打了一个大的折扣。

另外一类是扩展现有LLM模型的Context处理长度,把候选直接作为llm模型的Context。这里涉及到两个问题,一个是如何扩展Context长度,一个是如何让llm模型只在指定Context内回答问题。今天我们ppo优化主要解决llm模型只在指定Context内回答问题。


样本

我们在1000篇文章中随机选择30篇作为prompt,让模型从这30篇文章中选择出我们想要的文章。

        #随机选择30篇作为promptrandom_articles = df.sample(n=31)random_article = random_articles.iloc[0]cat = random_article['category']article_list = [title + ' (' + cat + ')' for title, cat in zip(random_articles['title'], random_articles['category'])]input_str = construct_input(article_list, cat)input_ids = tokenizer.encode(input_str, return_tensors='pt').to('cuda')

模型准确率判定

可以回答多篇结果,如果模型有我们希望的回答的结果,加1分,不符合减1分。

        #判断命中条数for ans in answer.split('\n'):similarity_threshold = 0.9  # 相似度阈值# 判断是否在input中且分类是否一致if is_similar(ans, article_list, similarity_threshold):positive_num = positive_num +1breakprint(i, 'accuracy:', positive_num / (i+1))

rm样本制作

第一种

正例:选择一条在prompt中符合条件的新闻为正例

负例:随机选择一条不在prompt中的新闻作为负例,        

第二种,

正例:sft一次预测多条,从预测的结果中,挑选出符合条件的为正

负例:sft一次预测多条,从预测的结果中,挑选出不符合条件的为负

比较的结果是第二种方案会好一些。

也可以参考这篇博文ChatGLM-RLHF(三)-RM(Reward Model)实现&代码逐行注释_Pillars-Creation的博客-CSDN博客

ppo训练预测

ppo原理前一章节已经讲了,传送门ChatGLM-RLHF(六)-PPO(Proximal Policy Optimization)原理&实现&代码逐行注释_Pillars-Creation的博客-CSDN博客

需要注意的就是,因为训练时候需要加载sft和rm两个模型, 你需要一个大一点显存的gpu,本例在A100,40G显存上跑通。如果显存小了容易报显存不足的错误。

训练结果

原始预测结果

sft预测结果

ppo预测结果

几点体会,

1,好的sft可以解决大部分的问题,从上面实验看简单sft训练后准确率就可以得到明显提升

2,要根据自身需要定制好的rm样本和loss。有时候单纯根据sft样本,模型可能很难总结出你真正的目的,rm可以帮助模型更好的理解人的期望。

3,rm单独使用效果不一定比sft效果更好,这也比较好理解,rm需要人工标注pair对,数量总是有限的,并且这个pair对,是否清晰表达给了模型用户的全部意图,容易顾此失彼。所以rm我们更多用在最后,结合ppo纠正模型。

4,rm过程可以进行多次,把自己的目标拆解成几个rm过程,更容易达到我们的目标

5,PPO过程确实帮助模型效果得到了提升,并且可以从比较粗劣的rm结果和sft模型对比中学到知识。

 完整代码可以参考:

GitHub - Pillars-Creation/ChatGLM-RLHF-LoRA-RM-PPO: ChatGLM-6B添加了RLHF的实现,以及部分核心代码的逐行讲解 ,实例部分是做了个新闻短标题的生成

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/95665.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

winform 封装unity web player 用户控件

环境: VS2015Unity 5.3.6f1 (64-bit) 目的: Unity官方提供的UnityWebPlayer控件在嵌入Winform时要求读取的.unity3d文件路径(Src)必须是绝对路径,如果移动代码到另一台电脑,需要重新修改src。于是考虑使…

阿里云100元预算可选的云服务器配置2核2G3M带宽

阿里云服务器100元可以买到哪些配置?如果是一年时长,轻量应用服务器2核2G3M带宽一年108元,系统盘为50GB高效云盘。以前阿里云服务器ECS卖过35元一年、69元、88元、89元和99元的都有过,但是现在整体费用上涨,入门级云服…

opencv直方图与模板匹配

import cv2 #opencv读取的格式是BGR import numpy as np import matplotlib.pyplot as plt#Matplotlib是RGB %matplotlib inline def cv_show(img,name):cv2.imshow(name,img)cv2.waitKey()cv2.destroyAllWindows() 直方图 cv2.calcHist(images,channels,mask,histSize,ran…

【ES6】箭头函数和普通函数的区别

它们之间的区别: (1)箭头函数没有自己的this。 (2)不可以当作构造函数,不可以对箭头函数使用new命令,否则抛出错误。 (3)不可以使用arguments对象,该对象在函…

【深度学习】PyTorch快速入门

【深度学习】学习PyTorch基础 介绍PyTorch 深度学习框架是一种软件工具,旨在简化和加速构建、训练和部署深度学习模型的过程。深度学习框架提供了一系列的函数、类和工具,用于定义、优化和执行各种深度神经网络模型。这些框架帮助研究人员和开发人员专注…

华为PPPOE配置实验

华为PPPOE配置实验 网络拓扑图拓扑说明电信ISP设备配置用户拨号路由器配置查看是否拨上号是否看不懂? 看不懂就对了,只是记录一下命令。至于所有原理,等想写了再写 网络拓扑图 拓扑说明 用户路由器用于模拟家用拨号路由器,该设备…

R语言处理缺失数据(1)-mice

#清空 rm(listls()) gc()###生成模拟数据### #生成100个随机数 library(magrittr) set.seed(1) asd<-rnorm(100, mean 60, sd 10) %>% round #平均60&#xff0c;标准差10 #将10个数随机替换为NA NA_positions <- sample(1:100, 10) asd[NA_positions] <- NA #转…

CentOS下MySQL的彻底卸载的几种方法

这里我为大家详细讲解下“CentOS下MySQL的彻底卸载的几种方法”的完整攻略。 一、关闭MySQL服务 在开始操作之前&#xff0c;需要先关闭MySQL服务。可以使用以下命令来关闭MySQL服务&#xff1a; systemctl stop mysqld 或者 service mysqld stop 二、使用yum命令卸载MySQL…

Unity制作一个简单的登入注册页面

1.创建Canvas组件 首先我们创建一个Canvas画布&#xff0c;我们再在Canvas画布底下创建一个空物体&#xff0c;取名为Resgister。把空物体的锚点设置为全屏撑开。 2.我们在Resgister空物体底下创建一个Image组件&#xff0c;改名为bg。我们也把它 的锚点设置为全屏撑开状态。接…

Flutter 测试小结

Flutter 项目结构 pubspec.yaml 类似于 RN 的 package.json&#xff0c;该文件分别在最外层及 example 中有&#xff0c;更新该文件后&#xff0c;需要执行的 Pub get lib 目录下的 dart 文件为 Flutter 插件封装后的接口源码&#xff0c;方便在其他 dart 文件中调用 example 目…

卷积神经网络全解!CNN结构、训练与优化全维度介绍!

目录 一、引言1.1 背景和重要性1.2 卷积神经网络概述 二、卷积神经网络层介绍2.1 卷积操作卷积核与特征映射卷积核大小多通道卷积 步长与填充步长填充 空洞卷积&#xff08;Dilated Convolution&#xff09;分组卷积&#xff08;Grouped Convolution&#xff09; 2.2 激活函数R…

Wlan安全——认证与加密方式(WPA/WPA2)

目录 终端认证技术 WEP认证 PSK认证 802.1x认证与MAC认证 Portal认证 数据加密技术 WEP加密 TKIP加密 CCMP加密 TKIP和CCMP生成密钥所需要的密钥信息 802.11安全标准 WEP共享密钥认证、加密工作原理 WEP共享密钥认证 WEP加解密过程 PSK认证以及生成动态密钥的工…

【数据结构与算法——TypeScript】图结构(Graph)

【数据结构与算法——TypeScript】 图结构(Graph) 认识图结构以及特性 什么是图? 在计算机程序设计中&#xff0c;图结构 也是一种非常常见的数据结构。 但是&#xff0c;图论其实是一个非常大的话题 认识一下关于图的一些内容 图的抽象数据类型一些算法实现。 什么是图?…

Can‘t find end of central directory : is this a zip file ? at XMLHttpRequest

导出woed出现这个报错,原因其实很简单,路径写错了, 这个word首先必须是docx格式,然后必须放在public文件包下 如果放在public文件包下还没有用,则放在public包下 参考帖子: https://www.cnblogs.com/hejun26/p/13647927.html

Android Studio实现解析HTML获取图片URL,将URL存到list,进行瀑布流展示

目录 效果展示build.gradle(app)添加的依赖(用不上的可以不加)AndroidManifest.xml错误代码activity_main.xmlitem_image.xmlMainActivityImage适配器ImageModel 接收图片URL效果展示 build.gradle(app)添加的依赖(用不上的可以不加) dependencies {implementation co…

安防监控视频云存储平台EasyNVR出现内核报错的情况该如何解决?

安防视频监控汇聚EasyNVR视频集中存储平台&#xff0c;是基于RTSP/Onvif协议的安防视频平台&#xff0c;可支持将接入的视频流进行全平台、全终端分发&#xff0c;分发的视频流包括RTSP、RTMP、HTTP-FLV、WS-FLV、HLS、WebRTC等格式。 近期有用户联系到我们&#xff0c;EasyNVR…

[LitCTF 2023]Ping

因为直接ping会有弹窗。这里在火狐f12,然后f1选禁用javascript,然后ping 然后输入127.0.0.1;cat /flag 得到flag&#xff0c; 查看其他大佬的wp &#xff0c;这里还可以抓包。但是不知道为什么我这里的burp 用不了

「新整理」战略定位理论发展史

1954年&#xff0c;彼得德鲁克——现代管理学之父 代表作《管理的实践》 提出具有划时代意义的概念——目标管理&#xff0c;将管理划分为战略管理、组织管理和自我管理。将生产力从生产部门扩大到组织的所有职能部门&#xff0c;以知识精英为代表的职业经理人开始代替资本家走…

python ORM框架 sqlAlchemy

背景 最近在研究mysql的ORM框架&#xff0c;忽然看到了一个pip的包sqlalchemy&#xff0c;让我觉得很神奇&#xff0c;用下来的感觉和java的hibernate差不多&#xff0c;后边的链式查询又让我觉得和我很喜欢用的mybatis plus差不多&#xff0c;于是抱着好奇加上学习的态度&…

Docker中为RabbitMQ安装rabbitmq_delayed_message_exchange延迟队列插件

1、前言 rabbitmq_delayed_message_exchange是一款向RabbitMQ添加延迟消息传递&#xff08;或计划消息传递&#xff09;的插件。 插件下载地址&#xff1a;https://www.rabbitmq.com/community-plugins.html 1、下载插件 首先需要确定我们当前使用的RabbitMQ的版本&#xff0c…