这里我们尝试通过RM训练让模型学会从给定上下文中提取信息,来进行RM模型的实践。你可以从下面链接获取代码
GitHub - Pillars-Creation/ChatGLM-RLHF-LoRA-RM: ChatGLM-6B添加了RLHF的实现,以及部分核心代码的逐行讲解 ,实例部分是做了个新闻短标题的生成
首先我们看例子,从随机给的50条新闻中提取我们希望的信息。
# 输入候选文章prompt和query
Prompt = "工程师舍身救人 各界纷纷伸出援手 妻子:将更多时间陪伴家人 (新闻)', '男生为恢复妈妈听力,考上了协和博士研究生!这应该是这位妈妈最骄傲的时候 (生活)', '三盘鏖战晋级!小花郑钦文险胜3号种子,首进红土巡回赛决赛 (体育)', '内蒙古锡林浩特降雨引发洪水,人员车辆被困 (新闻)', '海南师范大学回应录取唐尚珺:已录名单中未看到其报考信息 (生活)', '34国联合对华施压,要求加强领海船只识别,防止朝鲜规避制裁 (军事)', '郭峰演唱《让世界充满爱》,歌词耳熟能详,满满的回忆 (音乐)', '这是韩国裁判?韩国组合动作完全不同步,竟然还能拿58.50分 (体育)', '黑海成冲突新焦点,乌船无视俄警告要出海,普京敲打波兰不要冒险 (军事)', '30秒 | 成都大运村“开村” 执行副村长苏波介绍大运村亮点 (新闻)', '继续加油!中国女足世界杯首战失利,终场前丢球遭绝杀 (体育)', '当“社恐”在演唱会上被求婚,女生第一反应不要太搞笑 (生活)', '落夜前瞻预测:小虎帮前队友复仇成功!WBG轻松拿下NIP (游戏)', '这球精彩!费利佩凌空侧钩破门 划出美丽的弧线 (体育)', '公安部:专项行动侦办“网络水军”案件130余起,抓获620余人 (新闻)', '女子清洗空调外机中暑被困4楼外,丈夫营救也被困,消防紧急施救 (新闻)', '19岁男生给旧空调标价1000元,回收师傅多给1000元,男生痛哭跪谢 (新闻)', '美官员:乌克兰或于年底前接收F-16战机 (军事)', '女足门将“飞”起来多可怕:哥斯达黎加门前长城,单场10扑“锁”西班牙 (体育)', '【赛后采访】苏州KSG.啊泽:当时是想把小兵推走,没想到小兵比我还猛 (游戏)', '女子拒绝给弟弟买房,被父母起诉了!! (新闻)', '美国一客机舱内近44℃,等待起飞期间多人热晕 (新闻)', '小品《特殊劝导》:王小欠撞车却没钱赔,包袱十足笑点一个接一个 (综艺)', '拒绝盲人携带导盲犬入住酒店,甚至拒绝退款,被嫌弃的导盲犬究竟招惹了谁? (新闻)', '大暑炎炎,体验亲手制作汽水的乐趣|节节的外拍VLOG (新闻)', '希腊:罗得岛火势难控,20余艘船只协助疏散人员 (新闻)', '别买最贵的!iPad终于不是唯一真神了?安卓平板横评选购推荐 (数码)', '泽连斯基隔空对俄方撂下狠话:克里米亚大桥必须清除 (军事)', '锋芝有望合体?谢霆锋张柏芝都在澳洲,男方暖心带儿子滑雪 (娱乐)', '航拍新疆赛里木湖,蓝宝石湖面非常治愈,满目所及都是夏日清凉 (旅游)', '曝伍德和湖人的讨论正在升温!可能签底薪合同辅佐詹眉 (体育)', '朱婧佳演唱《陪我看日出》,句句醉人心扉,百听不厌 (音乐)', '英国铁路工作人员再次举行大规模罢工 (新闻)', '\\xa0骑着赃车去作案,这个小偷胆贼大 (新闻)', '山东省2023年选派216名“业务院长”下基层 (新闻)', '后续来了!逆行司机下车还讨要说法,西安交警迅速核查处罚 (新闻)', '江南韵、国际范、智能化!感受南大苏州校区的双面魅力 (生活)', '世界上只有一种英雄主义,就是在看清生活真相后,依然热爱生活 (生活)', '中国选了个好时候,普京此行或将决定俄罗斯未来 (新闻)', '练书法、学射箭、看锡剧!在东林书院体验惬意“书式生活” (新闻)', '鳞鲀警惕地守护着自己的领地,当入侵者出现时,筑巢的鳞鲀会及时回击 (纪录片)', '欧盟提出四年内向乌提供200亿欧元军事支持,匈牙利外长:此举意味着延长冲突 (新闻)', '各种馅料的派对酥皮馅饼是怎样制作的?圣诞节上的美味小点心! (纪录片)', '【视频】守护孩子健康从“口”开始!武汉太平社区开展爱牙护齿小课堂 (新闻)', '埃及媳妇回中国后胖了4斤,直言宁夏美食太诱人,想不胖都难 (旅游)', '巴西:民众拥有及携带枪支数量将受限 (新闻)', '因挪车发生纠纷 酒驾司机弃车逃跑 (新闻)', '关注苏丹武装冲突:谈判前景仍不明朗,民众期盼战争尽快结束 (军事)', '山高水长物象千万!百人王屋山同书《上阳台帖》 (新闻)"query = "从上面文章中找到旅游相关的新闻"
未训练前,大模型更倾向从过去学习的知识中获取答案,比如我们用glm和llama直接从候选文章中选择答案,模型倾向于旧有知识,大概率会给出一些不在候选池中的内容。
所以我们制定目标和lable的时候需要让模型能区分出从prompt里获取答案能获得奖励,不从prompt里获取答案会受到惩罚
2,lable定义
生成的lable数据如图,instruction是我们的输入,包含随机50条新闻。output是模型的lable,包含两个部分一个正例一个负例,比如我们希望查找一条军事新闻,正例是从50条新闻中获取的军事新闻标题,负例是随机给了一条军事新闻标题。
3,loss计算定义
通常我们利用
对应代码如下:
r_accept, r_reject = values[-1].split(batch_size, dim=0) # 将输出值按照 batch_size 进行拆分,得到 r_accept 和 r_rejectloss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() # 计算损失函数,这里使用二元交叉熵损失函数outputs = {"r_accept": r_accept, "r_reject": r_reject} # 将 r_accept 和 r_reject 保存到 outputs 中return (loss, outputs) if return_outputs else loss # 返回损失函数和 outputs,或者只返回损失函数,取决于 return_outputs 的值
4,