【复现DeepSeek-R1之Open R1实战】系列6:GRPO源码逐行深度解析(上)

目录

  • 4 GRPO源码分析
    • 4.1 数据类 `GRPOScriptArguments`
    • 4.2 系统提示字符串 `SYSTEM_PROMPT`
    • 4.3 奖励函数
      • 4.3.1 accuracy_reward函数
      • 4.3.2 verify函数
      • 4.3.3 format_reward函数
    • 4.4 将数据集格式化为对话形式
    • 4.5 初始化GRPO Trainer


【复现DeepSeek-R1之Open R1实战】系列3:SFT和GRPO源码逐行深度解析(上)
【复现DeepSeek-R1之Open R1实战】系列5:SFT和GRPO源码逐行深度解析(中)

4 GRPO源码分析

前面两篇博文已经详细介绍了一些基础知识和SFT源码,本文继续解读GRPO源码。与SFT源码差不多的部分,我们就不展开细说了,这里只解析GRPO独特的部分。

4.1 数据类 GRPOScriptArguments

该类使用了 Python 的 dataclass 装饰器,这是一种简化类定义的方式,特别是对于那些主要用来存储数据的类。它继承自 ScriptArguments 类。

  • reward_funcs: 这是一个列表,包含了一系列可能的奖励函数名称,默认值为 ["accuracy", "format"]。这些奖励函数可能是用于评估模型性能的不同标准。

    reward_funcs: list[str] = field(default_factory=lambda: ["accuracy", "format"],metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"},
    )
    
  • cosine_min_value_wrongcosine_max_value_wrong: 分别表示错误答案在余弦相似度尺度上的最小和最大奖励值,默认分别为 0.0-0.5

  • cosine_min_value_correctcosine_max_value_correct: 分别表示正确答案在余弦相似度尺度上的最小和最大奖励值,默认分别为 0.51.0

  • cosine_max_len: 表示余弦相似度尺度的最大长度,默认值为 1000

  • repetition_n_grams: 表示用于重复惩罚奖励的n-gram数量,默认值为 3

  • repetition_max_penalty: 表示重复惩罚奖励的最大负值,默认值为 -1.0

每个字段都使用了 field() 函数来定义其默认值和元数据(如帮助信息)。这有助于工具和库更好地理解和处理这些字段,例如生成命令行解析器时。

4.2 系统提示字符串 SYSTEM_PROMPT

SYSTEM_PROMPT = ("A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant ""first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning ""process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., ""<think> reasoning process here </think><answer> answer here </answer>"
)

字符串描述了一个对话场景,用户先提问,助手首先思考推理过程,然后提供答案。推理过程和答案分别用 <think><answer> 标签包裹,这种格式化有助于区分和识别不同的部分,和DeepSeek-R1的思考过程格式一致。

4.3 奖励函数

奖励函数的定义如下,GRPO默认用到了accuracy_reward和format_reward这两个函数。

# Get reward functionsREWARD_FUNCS_REGISTRY = {"accuracy": accuracy_reward,"format": format_reward,"reasoning_steps": reasoning_steps_reward,"cosine": get_cosine_scaled_reward(min_value_wrong=script_args.cosine_min_value_wrong,max_value_wrong=script_args.cosine_max_value_wrong,min_value_correct=script_args.cosine_min_value_correct,max_value_correct=script_args.cosine_max_value_correct,max_len=script_args.cosine_max_len,),"repetition_penalty": get_repetition_penalty_reward(ngram_size=script_args.repetition_n_grams,max_penalty=script_args.repetition_max_penalty,),"length": len_reward,}reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

这段代码定义了一个奖励函数注册表 REWARD_FUNCS_REGISTRY,并根据用户提供的配置动态生成一个奖励函数列表 reward_funcs。每个奖励函数用于评估模型输出的不同方面,如准确性、格式、推理步骤等。

  1. 注册表定义
  • accuracy: 使用 accuracy_reward 函数评估模型输出的准确性。
  • format: 使用 format_reward 函数评估模型输出的格式。
  • reasoning_steps: 使用 reasoning_steps_reward 函数评估模型输出的推理步骤。
  • cosine: 使用 get_cosine_scaled_reward 函数计算余弦相似度奖励,参数包括:
    • min_value_wrong: 错误情况下的最小值。
    • max_value_wrong: 错误情况下的最大值。
    • min_value_correct: 正确情况下的最小值。
    • max_value_correct: 正确情况下的最大值。
    • max_len: 最大长度。
  • repetition_penalty: 使用 get_repetition_penalty_reward 函数计算重复惩罚奖励,参数包括:
    • ngram_size: n-gram 的大小。
    • max_penalty: 最大惩罚值。
  • length: 使用 len_reward 函数评估模型输出的长度。
  1. 动态生成奖励函数列表
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
  • 根据 script_args.reward_funcs 中指定的奖励函数名称,从 REWARD_FUNCS_REGISTRY 中获取相应的奖励函数,并生成一个列表 reward_funcs

4.3.1 accuracy_reward函数

该函数用于计算模型生成的补全与真实答案之间的准确性奖励。它通过解析和验证生成的内容与真实答案来确定奖励值。

def accuracy_reward(completions, solution, **kwargs):"""Reward function that checks if the completion is the same as the ground truth."""contents = [completion[0]["content"] for completion in completions]rewards = []for content, sol in zip(contents, solution):gold_parsed = parse(sol,extraction_mode="first_match",extraction_config=[LatexExtractionConfig()],)if len(gold_parsed) != 0:# We require the answer to be provided in correct latex (no malformed operators)answer_parsed = parse(content,extraction_config=[LatexExtractionConfig(normalization_config=NormalizationConfig(nits=False,malformed_operators=False,basic_latex=True,equations=True,boxed="all",units=True,),# Ensures that boxed is tried firstboxed_match_priority=0,try_extract_without_anchor=False,)],extraction_mode="first_match",)# Reward 1 if the content is the same as the ground truth, 0 otherwisereward = float(verify(answer_parsed, gold_parsed))else:# If the gold solution is not parseable, we reward 1 to skip this examplereward = 1.0print("Failed to parse gold solution: ", sol)rewards.append(reward)return rewards
  • completions (list): 包含多个补全结果的列表,每个补全结果是一个包含内容的字典列表。
  • solution (list): 真实答案的列表。
  • kwargs: 其他可选参数(在本函数中未使用)。
  1. 提取补全内容

    contents = [completion[0]["content"] for completion in completions]
    
    • completions 列表中提取每个补全的第一个内容(假设每个补全是单个元素的列表),形成一个新的 contents 列表。
  2. 初始化奖励列表

    rewards = []
    
  3. 遍历每个补全和对应的真实答案

    for content, sol in zip(contents, solution):gold_parsed = parse(sol,extraction_mode="first_match",extraction_config=[LatexExtractionConfig()],)
    
    • 使用 zip 函数将 contentssolution 配对。
    • 对于每一对补全内容和真实答案,首先解析真实答案 sol,使用 parse 函数提取其中的信息。
  4. 处理解析结果

    if len(gold_parsed) != 0:answer_parsed = parse(content,extraction_config=[LatexExtractionConfig(normalization_config=NormalizationConfig(nits=False,malformed_operators=False,basic_latex=True,equations=True,boxed="all",units=True,),# Ensures that boxed is tried firstboxed_match_priority=0,try_extract_without_anchor=False,)],extraction_mode="first_match",)
    
    • 如果解析得到的真实答案 gold_parsed 非空,则继续解析生成的补全内容 content
    • 使用 LatexExtractionConfigNormalizationConfig 进行详细配置,确保解析过程中考虑了各种格式要求(如方程、单位等)。
  5. 计算奖励

    reward = float(verify(answer_parsed, gold_parsed))
    
    • 使用 verify 函数比较生成的补全解析结果和真实答案的解析结果。
    • 如果两者匹配,则返回 1.0,否则返回 0.0
  6. 处理无法解析的情况

    else:reward = 1.0print("Failed to parse gold solution: ", sol)
    
    • 如果真实答案无法解析,则默认给予奖励 1.0 并打印警告信息。
  7. 添加奖励到列表

    rewards.append(reward)
    
  8. 返回所有奖励

    return rewards
    

4.3.2 verify函数

该函数用于验证目标表达式是否与参考表达式匹配,它通过多种比较策略来处理不同的数学对象(如数字、表达式、集合、矩阵等),并提供灵活的配置选项以适应不同的需求。

def verify(gold: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, target: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, float_rounding: int=6,numeric_precision: int=15,strict: bool=True,timeout_seconds: int=3
) -> bool:
  • gold: 参考或正确的表达式,可以是单个 SymPy 表达式(BasicMatrixBase)、字符串或这些类型的列表。
  • target: 需要验证的表达式,类型同 gold
  • float_rounding: 浮点数舍入的小数位数,默认为 6。
  • numeric_precision: 数值比较时考虑的小数位数,默认为 15。
  • strict: 是否启用严格比较模式,默认为 True
    • 在严格模式下:变量很重要,集合不可与元组比较。
    • 在非严格模式下:变量按位置匹配,集合可与元组比较。
  • timeout_seconds: 单次比较操作的最大超时时间(秒),默认为 3 秒。
  1. 定义内部比较函数 compare_single_extraction

    @timeout(timeout_seconds=timeout_seconds)
    def compare_single_extraction(gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str) -> bool:...
    
    • 使用装饰器 @timeout 设置超时保护,默认超时时间为 timeout_seconds
    • 比较两个表达式:
      • 如果两者都是 SymPy 表达式(BasicMatrixBase),则调用 sympy_expr_eq 进行比较。
      • 如果两者都是字符串,则进行简单的字符串比较。
  2. 定义包装函数 compare_single_extraction_wrapper

    def compare_single_extraction_wrapper(g, t):try:return compare_single_extraction(g, t)except Exception as e:logger.exception(f"Error comparing {g} and {t}")return False
    
    • 包装 compare_single_extraction,捕获并记录任何异常,返回 False 以避免程序中断。
  3. 处理输入列表

    if not isinstance(gold, list):gold = [gold]
    if not isinstance(target, list):target = [target]
    
    • 如果 goldtarget 不是列表,则将其转换为单元素列表,以便统一处理。
  4. 组合所有可能的比较

    return any(compare_single_extraction_wrapper(g, t) for g, t in product(gold, target))
    
    • 使用 itertools.product 生成所有可能的 goldtarget 组合。
    • 对每个组合调用 compare_single_extraction_wrapper,如果任意一对匹配成功,则返回 True

4.3.3 format_reward函数

函数用于检查给定的完成文本是否符合特定的格式,它验证完成文本是否包含 <think><answer> 标签,并且这两个标签的内容是非空的。

def format_reward(completions, **kwargs):"""Reward function that checks if the completion has a specific format."""pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"completion_contents = [completion[0]["content"] for completion in completions]matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]return [1.0 if match else 0.0 for match in matches]
  • completions: 这是一个列表,其中每个元素都是一个包含完成内容的对象(通常是字典)。假设每个完成对象的第一个元素包含一个键 "content",其值是需要检查的文本。
  • kwargs: 其他关键字参数,这里没有使用,但可以为未来的扩展提供灵活性。
  1. 正则表达式模式定义

    pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
    
    • 这个正则表达式用于匹配字符串是否以 <think> 开始,紧接着是任意字符(非贪婪匹配),然后是 </think>,接着可能有任意数量的空白字符(包括换行符),最后是以 <answer> 开始并以 </answer> 结束。
    • .*? 是非贪婪匹配,确保尽可能少地匹配字符。
    • \s* 匹配零个或多个空白字符(包括换行符)。
    • re.DOTALL | re.MULTILINE 标志允许点号 . 匹配所有字符(包括换行符),并且使多行文本中的每一行都可以独立匹配。
  2. 提取完成内容

    completion_contents = [completion[0]["content"] for completion in completions]
    
    • 这里通过列表推导式从 completions 列表中提取每个完成对象的第一个元素的 "content" 字段,形成一个新的列表 completion_contents
  3. 匹配正则表达式

    matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
    
    • 使用 re.match 函数对 completion_contents 中的每个内容应用正则表达式模式。
    • matches 列表将包含 re.Match 对象(如果匹配成功)或 None(如果匹配失败)。
  4. 生成奖励分数

    return [1.0 if match else 0.0 for match in matches]
    
    • 最后一步是根据匹配结果生成奖励分数。如果匹配成功(即 match 不是 None),则返回 1.0;否则返回 0.0

示例代码:

completions = [[{"content": "<think>This is reasoning.</think><answer>This is answer.</answer>"}],[{"content": "<think>This is reasoning.</think>"}],[{"content": "<answer>This is answer.</answer>"}],[{"content": "This does not match."}]
]reward_scores = format_reward(completions)
print(reward_scores)  # 输出: [1.0, 0.0, 0.0, 0.0]

在这个例子中:

  • 第一个完成内容完全匹配正则表达式,因此得分为 1.0
  • 后三个完成内容不符合要求,因此得分均为 0.0

4.4 将数据集格式化为对话形式

# Format into conversationdef make_conversation(example):return {"prompt": [{"role": "system", "content": SYSTEM_PROMPT},{"role": "user", "content": example["problem"]},],}dataset = dataset.map(make_conversation)for split in dataset:if "messages" in dataset[split].column_names:dataset[split] = dataset[split].remove_columns("messages")

将一个数据集中的每个示例转换为对话格式,并确保数据集中没有多余的列(如 messages)。

  • 输入example 是一个字典,包含单个数据样本的信息,其中 "problem" 键对应的值是用户的问题或任务描述。
  • 输出:返回一个新的字典,包含一个 "prompt" 键,其值是一个对话列表:
    • 第一条消息是系统消息,内容由 SYSTEM_PROMPT 定义。
    • 第二条消息是用户消息,内容是 example["problem"]
  • dataset.map(make_conversation):使用 map 方法将 make_conversation 函数应用到数据集的每个示例上,生成新的对话格式。
  • 移除多余列:遍历数据集的每个拆分(split),如果存在 "messages" 列,则将其移除。

4.5 初始化GRPO Trainer

trainer = GRPOTrainer(model=model_args.model_name_or_path,reward_funcs=reward_funcs,args=training_args,train_dataset=dataset[script_args.dataset_train_split],eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,peft_config=get_peft_config(model_args),callbacks=get_callbacks(training_args, model_args),)

篇幅有限,训练部分的代码我们放到下一篇博文详细解读!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/20460.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于SpringBoot+Vue的在线电影购票系统的设计与实现

获取源码&#xff1a;SpringBootVue的在线电影购票系统: 用户&#xff1a;登录、注册、忘记密码、主页、猜你想看、电影详情、选座购票、正在热映、即将上映、影院、排行榜、影视快报、平台公告、个人中心、我的收藏、想看、改签、评论、排行等功能管理员&#xff1a;登录、首页…

LabVIEW无刷电机控制器检测系统

开发了一种基于LabVIEW的无刷电机控制器检测系统。由于无刷电机具有高效率、低能耗等优点&#xff0c;在电动领域有取代传统电机的趋势&#xff0c;而无刷电机的核心部件无刷电机控制器产量也在不断增长。然而&#xff0c;无刷电机控制器的出厂检测仍处于半自动化状态&#xff…

C#功能测试

List 内部元素为引用 src[0]为"11" List<Source> src new List<Source>(); src.Add(new Source() { Name "1", Age 1, Description "1" }); src.Add(new Source() { Name "2", Age 2, Description "2"…

八种单例模式详解

亲爱的朋友们&#xff0c;大家好&#xff01; 今天是 CSDN博客之星 投票的开始日&#xff01;非常感谢你们在过去的时间里给予我无私的支持和鼓励&#xff0c;这一路走来&#xff0c;正是因为有你们的帮助&#xff0c;我才能不断进步&#xff0c;走得更远。 目前&#xff0c;…

STM32 PWM脉冲宽度调制介绍

目录 背景 PWM 模式 影子寄存器和预装载寄存器 PWM对齐模式 PWM 边沿对齐模式 向上计数配置 向下计数的配置 PWM 中央对齐模式 程序 第一步、使能GPIOB组、AFIO、TIM3外设时钟 第二步、输出通道端口配置​编辑 第三步、定时器配置产生频率 第四步、PWM输出配置 第…

Java面试第二山!《计算机网络》!

在 Java 面试里&#xff0c;计算机网络知识是高频考点&#xff0c;今天就来盘点那些最容易被问到的计算机网络面试题&#xff0c;帮你轻松应对面试&#xff0c;也方便和朋友们一起探讨学习。 一、HTTP 和 HTTPS 的区别 1. 面试题呈现 HTTP 和 HTTPS 有什么区别&#xff1f;在…

deepseek-v3在阿里云和腾讯云的使用中的差异

随着deepseek在各大云商上线&#xff0c;试用了下阿里云和腾讯云的deepseek服务&#xff0c;在回答经典数学问题9.9和9.11谁大时&#xff0c;发现还是有差异的。将相关的问题记录如下。 1、问题表现 笔者使用的openai的官方sdk go-openai。 因本文中测验主要使用阿里云和腾讯…

蓝桥杯单片机基础部分——单片机介绍部分

前言 这个部分是额外的&#xff0c;我看我有的学弟学妹基础比较差&#xff0c;对板子上面的模块不太熟悉&#xff0c;这里简单的介绍一下 蓝桥杯单片机 这个就是蓝桥杯单片机的板子&#xff0c;它的主控芯片是&#xff08;IAP15F2K61S2&#xff09;&#xff0c;这里就对他常用…

百度搜索和文心智能体接入DeepSeek满血版——AI搜索的新纪元

在当今数字化时代&#xff0c;搜索引擎作为互联网信息获取的核心工具&#xff0c;正经历着前所未有的变革。据悉&#xff0c;2025年2月16日&#xff0c;百度搜索和文心智能体平台宣布全面接入DeepSeek和文心大模型的最新深度搜索功能&#xff0c;搜索用户可免费使用DeepSeek和文…

redis解决高并发看门狗策略

当一个业务执行时间超过自己设定的锁释放时间&#xff0c;那么会导致有其他线程进入&#xff0c;从而抢到同一个票,所有需要使用看门狗策略&#xff0c;其实就是开一个守护线程&#xff0c;让守护线程去监控key&#xff0c;如果到时间了还未结束&#xff0c;就会将这个key重新s…

【koa】05-koa+mysql实现数据库集成:连接和增删改查

前言 前面我们已经介绍了第二阶段的第1-4点内容&#xff0c;本篇介绍第5点内容&#xff1a;数据库集成&#xff08;koamysql&#xff09; 也是第二阶段内容的完结。 一、学习目标 在koa项目中正常连接数据库&#xff0c;对数据表进行增删改查的操作。 二、操作步骤 本篇文章…

aws(学习笔记第二十八课) aws eks使用练习(hands on)

aws(学习笔记第二十八课) 使用aws eks 学习内容&#xff1a; 什么是aws eksaws eks的hands onaws eks的创建applicationeks和kubernetes简介 1. 使用aws eks 什么是aws eks aws eks的概念 aws eks是kubernetes在aws上包装出来 的新的方式&#xff0c;旨在更加方便结合aws&…

IM聊天系统架构实现

一、IM系统整体架构 二、企业级IM系统如何实现心跳与断线重连机制&#xff1b; 1、重连机制&#xff08;服务端下线&#xff09; 服务端下线&#xff0c;客户端netty可以感知到&#xff0c;在感知的方法中进行重连的操作&#xff0c;注意重连可能连接到旧的服务器继续报错&…

Kubeadm+Containerd部署k8s(v1.28.2)集群(非高可用版)

KubeadmContainerd部署k8s(v1.28.2)集群&#xff08;非高可用版&#xff09; KubeadmContainerd部署k8s高可用版本 文章目录 KubeadmContainerd部署k8s(v1.28.2)集群&#xff08;非高可用版&#xff09;一.环境准备1.服务器准备2.环境配置3.设置主机名4.修改国内镜像源地址5.配…

HarmonyOS进程通信及原理

大家好&#xff0c;我是学徒小z&#xff0c;最近在研究鸿蒙中一些偏底层原理的内容&#xff0c;今天分析进程通信给大家&#xff0c;请用餐&#x1f60a; 文章目录 进程间通信1. 通过公共事件&#xff08;ohos.commonEventManager&#xff09;公共事件的底层原理 2. IPC Kit能…

移动通信发展史

概念解释 第一代网络通信 1G 第二代网络通信 2G 第三代网络通信 3G 第四代网络通信 4G 4g网络有很高的速率和很低的延时——高到500M的上传和1G的下载 日常中的4G只是用到了4G技术 运营商 移动-从民企到国企 联通-南方教育口有人 电信 铁通&#xff1a;成立于 2000 年…

CAS单点登录(第7版)10.多因素身份验证

如有疑问&#xff0c;请看视频&#xff1a;CAS单点登录&#xff08;第7版&#xff09; 多因素身份验证 概述 多因素身份验证 &#xff08;MFA&#xff09; 多因素身份验证&#xff08;Multifactor Authentication MFA&#xff09;是一种安全机制&#xff0c;要求用户提供两种…

#渗透测试#批量漏洞挖掘#Fastjson 1.2.24 远程命令执行漏洞

免责声明 本教程仅为合法的教学目的而准备,严禁用于任何形式的违法犯罪活动及其他商业行为,在使用本教程前,您应确保该行为符合当地的法律法规,继续阅读即表示您需自行承担所有操作的后果,如有异议,请立即停止本文章读。 目录 Fastjson 1.2.24 远程命令执行漏洞综合分析…

【设计模式】 代理模式(静态代理、动态代理{JDK动态代理、JDK动态代理与CGLIB动态代理的区别})

代理模式 代理模式是一种结构型设计模式&#xff0c;它提供了一种替代访问的方法&#xff0c;即通过代理对象来间接访问目标对象。代理模式可以在不改变原始类代码的情况下&#xff0c;增加额外的功能&#xff0c;如权限控制、日志记录等。 静态代理 静态代理是指创建的或特…

动态规划

简介 动态规划最核心两步&#xff1a; 状态表示&#xff1a;dp[i]代表什么状态转移方程&#xff1a;如何利用已有的dp求解dp[i] 只要这两步搞对了&#xff0c; 就完成了动态规划的%95 剩下的就是细节问题&#xff1a; dp初始化顺序&#xff08;有时是倒序&#xff09;处理边…