π0源码解析——一个模型控制7种机械臂:对开源VLA sota之π0源码的全面分析,含我司的部分落地实践

前言 

ChatGPT出来后的两年多,也是我疯狂写博的两年多(年初deepseek更引爆了下),比如从创业起步时的15年到后来22年之间 每年2-6篇的,干到了23年30篇、24年65篇、25年前两月18篇,成了我在大模型和具身的原始技术积累

如今一转眼已到25年3月初,时光走得太快,近期和团队接了好几个大客户订单,使得3月起 不得不全力加速落地,自己也得每天抠paper、搞代码

虽然今年可能没法像去年24年那样干65篇,不过,我还是争取保持月月更新

  1. 一方面,有些文章是之前既定计划中的,比如如此文《π0开源了且推出自回归版π0-FAST——打造机器人动作专用的高效Tokenizer:比扩散π0的训练速度快5倍但效果相当》最后所说的,对π0源码的解读
    至于什么是π0,详见此文《π0——用于通用机器人控制的VLA模型:一套框架控制7种机械臂(基于PaliGemma和流匹配的3B模型)》
  2. 二方面,我司「七月在线」在做一系列工厂落地场景的过程中,我们也希望团结到可以和我们一块做的朋友,而若想团结,便需要对外分享我们每个季度在重点做的业务场景

比如过去一周,我把lerobot、reflect vlm、π0的仿真环境都在我自己本地电脑上跑了下(过程中,GitHub copilot这种AI编程工具在环境的安装上帮了我很大的忙——各种环境 只要几句命令,直接帮我装好,真心不错)

如此硬着头皮冥思苦想、摸索了好几天,随后使得我自己知道怎么带队完成『太多工厂希望实现的一个生产线任务』了,3月初先仿真训练,2-3个月内部署到真机

当然了,也不单纯只是「这几天的想」就能想出来的,​这几天之前

  1. 有把过去一年当三年用的具身技术积累
  2. 有一年多来,和同事们 如姚博士,以及朋友们许多的讨论
  3. 有去年十几个工厂对我们的支持与信任

我们正在不断壮大队伍

  • 有我司内部同事,亦有我带的北理、中南等985的具身研究生,及一块合作开发的朋友,很快会把多个生产线任务并行开发起来
  • 且无论哪个项目,都是不断长期迭代的,故过程中少不了科研层面的突破,欢迎更多伙伴加入我们(全、兼、实习皆可,有意者,敬请私我),和我们一块开发

话休絮烦,本文便按照如下图所示的源码结构,重点解读一下π的整个源码

  1. π0的源码结构非常清晰、可读性高,不愧是成熟的商业化公司,是我司七月的学习榜样之一
  2. 我身边的很多朋友目前都在做π0的微调及二次开发,相信本文无论对我身边的朋友,还是对更多人的学习与工作,都会起到比较大的提升


目录

前言 

第一部分 examples、packages、scripts等结构的分析

1.1 examples :各种机器人平台的示例实现

1.2 packages

1.3 scripts:包含数据处理、模型训练/推理的多个脚本

1.3.1 __init__.py

1.3.2 compute_norm_stats.py:计算数据的归一化统计信息

1.3.3 serve_policy.py:启动策略服务,用于模型推理

1.3.4 train_test.py:训练和测试模型

1.3.5 train.py:训练模型

1.3.6 scripts/docker

第二部分 核心模块src下models的全面分析与解读

2.1 models/pi0.py的实现

2.1.1 make_attn_mask:注意力掩码生成函数

2.1.2 posemb_sincos:位置编码函数

2.1.3 class Pi0Config:含inputs_spec、get_freeze_filter

2.1.3.1 模型配置参数的定义

2.1.3.2 inputs_spec:定义了π0模型本身接收的输入数据格式​编辑

2.1.3.3 get_freeze_filter:针对是否LoRA的处理

2.1.4 class Pi0:初始化、特征嵌入、损失函数、推理(去噪生成动作)

2.1.4.1 初始化方法 `__init__`

2.1.4.2 特征嵌入方法:embed_prefix(图像和文本输入)、embed_suffix(状态和动作信息)​编辑

2.1.4.3 损失函数 `compute_loss`

2.1.4.4 推理函数 `sample_actions`:基于扩散模型逆向采样,生成机器人动作序列

第一部分 examples、packages、scripts等结构的分析

1.1 examples :各种机器人平台的示例实现

根据π0对应examples模块的结构

其涉及以下模块

  1. aloha_real/:真实机器人ALOHA的示例
  2. aloha_sim/:ALOHA模拟器的示例
  3. droid/:DROID机器人的示例
  4. libero/:LIBERO基准测试的示例
  5. simple_client/:简单客户端的示例
  6. ur5/:UR5机器人的示例
  7. inference.ipynb:推理示例的Jupyter Notebook
  8. policy_records.ipynb:策略记录示例的Jupyter Notebook

1.2 packages

该模块的目录结构如下

1.3 scripts:包含数据处理、模型训练/推理的多个脚本

根据下图

可知,scripts 目录包含多个 Python 脚本,这些脚本用于数据处理、模型训练和服务部署等任务,每个脚本通常对应一个特定的功能或任务

  1. __init__.py
  2. compute_norm_stats.py: 计算数据的归一化统计信息
  3. serve_policy.py: 启动策略服务,提供模型推理接口
  4. train_test.py: 训练和测试模型
  5. train.py: 训练模型

1.3.1 __init__.py

1.3.2 compute_norm_stats.py:计算数据的归一化统计信息

1.3.3 serve_policy.py:启动策略服务,用于模型推理

  1. 在这个代码片段中,首先导入了一些必要的模块和库,包括 `policy`、`policy_config`、`websocket_policy_server` 和 `config`,这些模块来自 `openpi` 项目
    from openpi.policies import policy as _policy       # 导入 openpi.policies.policy 模块并重命名为 _policy
    from openpi.policies import policy_config as _policy_config  # 导入 openpi.policies.policy_config 模块并重命名为 _policy_config
    from openpi.serving import websocket_policy_server  # 导入 openpi.serving.websocket_policy_server 模块
    from openpi.training import config as _config       # 导入 openpi.training.config 模块并重命名为 _config
    接下来定义了一个枚举类 `EnvMode`,它表示支持的环境类型,包括 `ALOHA`、`ALOHA_SIM`、`DROID` 和 `LIBERO`
    class EnvMode(enum.Enum):"""支持的环境。"""ALOHA = "aloha"              # ALOHA 环境ALOHA_SIM = "aloha_sim"      # ALOHA 模拟环境DROID = "droid"              # DROID 环境LIBERO = "libero"            # LIBERO 环境
  2. 然后定义了几个数据类
    `Checkpoint` 类用于从训练好的检查点加载策略,包含两个字段:`config`(训练配置名称)和 `dir`(检查点目录)
    `Default` 类表示使用默认策略
    `Args` 类定义了脚本的参数,包括环境类型、默认提示、端口、是否记录策略行为以及如何加载策略
  3. 接下来定义了一个字典 `DEFAULT_CHECKPOINT`,它为每个环境类型指定了默认的检查点配置
    # 每个环境应使用的默认检查点
    DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {EnvMode.ALOHA: Checkpoint(config="pi0_aloha",dir="s3://openpi-assets/checkpoints/pi0_base",),EnvMode.ALOHA_SIM: Checkpoint(config="pi0_aloha_sim",dir="s3://openpi-assets/checkpoints/pi0_aloha_sim",),EnvMode.DROID: Checkpoint(config="pi0_fast_droid",dir="s3://openpi-assets/checkpoints/pi0_fast_droid",),EnvMode.LIBERO: Checkpoint(config="pi0_fast_libero",dir="s3://openpi-assets/checkpoints/pi0_fast_libero",),
    }
    `create_default_policy` 函数根据环境类型创建默认策略,如果环境类型不支持,则抛出异常
    def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy:"""为给定环境创建默认策略 """if checkpoint := DEFAULT_CHECKPOINT.get(env):              # 获取环境对应的默认检查点return _policy_config.create_trained_policy(_config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt)  # 创建训练好的策略raise ValueError(f"Unsupported environment mode: {env}")   # 如果环境不支持,抛出异常
    `create_policy` 函数根据传入的参数创建策略,如果参数中指定了检查点,则从检查点加载策略,否则使用默认策略
    def create_policy(args: Args) -> _policy.Policy:"""根据给定的参数创建策略 """match args.policy:          # 匹配策略类型case Checkpoint():      # 如果是 Checkpoint 类型return _policy_config.create_trained_policy(_config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt)      # 创建训练好的策略case Default():          # 如果是 Default 类型return create_default_policy(args.env, default_prompt=args.default_prompt)      # 创建默认策略
  4. `main` 函数是脚本的入口点,它首先调用 `create_policy` 函数创建策略,然后记录策略的元数据
    def main(args: Args) -> None:policy = create_policy(args)           # 创建策略policy_metadata = policy.metadata      # 获取策略的元数据
    如果参数中指定了记录策略行为,则使用 `PolicyRecorder` 包装策略
        # 记录策略的行为if args.record:# 使用 PolicyRecorder 记录策略行为policy = _policy.PolicyRecorder(policy, "policy_records")  
    接着获取主机名和本地 IP 地址
        hostname = socket.gethostname()              # 获取主机名local_ip = socket.gethostbyname(hostname)    # 获取本地 IP 地址logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip)  # 记录服务器创建信息
    并创建一个 WebSocket 服务器来提供策略服务,最后调用 `serve_forever` 方法启动服务器
        server = websocket_policy_server.WebsocketPolicyServer(policy=policy,host="0.0.0.0",port=args.port,metadata=policy_metadata,)  # 创建 WebSocket 策略服务器server.serve_forever()      # 启动服务器,永远运行
  5. 在脚本的最后,使用 `logging` 模块配置日志记录,并调用 `main` 函数启动脚本,参数通过 `tyro.cli` 解析

1.3.4 train_test.py:训练和测试模型

1.3.5 train.py:训练模型

1.3.6 scripts/docker

好的,下面是对 `openpi-main/scripts/docker` 目录的详细分析。这个目录通包含与 Docker 相关的脚本和配置文件,用于构建和管理 Docker 容器,具体而言,包含以下文件和子目录:

主要文件和功能如下所示

  1. docker/compose.yml
  2. docker/install_docker_ubuntu22.sh
  3. docker/install_nvidia_container_toolkit.sh
  4. docker/serve_policy.Dockerfile

// 待更

第二部分 核心模块src下models的全面分析与解读

接下来,我们来看核心src下的各个模块

首先是其中的src/openpi/models

2.1 models/pi0.py的实现

它结合了多模态输入(图像和文本)来生成机器人动作序列。下面是对代码的详细解析:

2.1.1 make_attn_mask:注意力掩码生成函数

这个函数生成transformer中使用的注意力掩码,控制 token 之间的注意力流动方式

def make_attn_mask(input_mask, mask_ar):"""从big_vision项目改编的注意力掩码生成函数Token可以关注那些累积mask_ar小于等于自己的有效输入token。这样`mask_ar` bool[?B, N]可用于设置几种类型的注意力,例如:[[1 1 1 1 1 1]]: 纯因果注意力。[[0 0 0 1 1 1]]: 前缀语言模型注意力。前3个token之间可以互相关注,后3个token有因果注意力。第一个条目也可以是1,不改变行为。[[1 0 1 0 1 0 0 1 0 0]]: 4个块之间的因果注意力。一个块的token可以关注所有之前的块和同一块内的所有token。参数:input_mask: bool[B, N] 如果是输入的一部分则为true,如果是填充则为falsemask_ar: bool[?B, N] 如果前面的token不能依赖于它则为true,如果它共享与前一个token相同的注意力掩码则为false"""# 将mask_ar广播到与input_mask相同的形状mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)  # 计算mask_ar在序列维度上的累积和cumsum = jnp.cumsum(mask_ar, axis=1)  # 创建注意力掩码:当目标位置的累积值<=查询位置的累积值时,允许注意力流动attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]  # 创建有效掩码:只有有效的输入位置之间才能有注意力valid_mask = input_mask[:, None, :] * input_mask[:, :, None]  # 结合注意力掩码和有效掩码return jnp.logical_and(attn_mask, valid_mask)  

它支持多种注意力模式:

  1. 纯因果注意力(每个 token 只能关注自己和之前的 token)
  2. 前缀语言模型注意力(允许前缀内部自由注意,后缀部分使用因果注意力)
  3. 块状因果注意力(在块内自由注意,块之间是因果的)

2.1.2 posemb_sincos:位置编码函数

使用正弦余弦函数实现位置编码

def posemb_sincos(pos: at.Real[at.Array, Any], embedding_dim: int, min_period: float, max_period: float
) -> at.Float[at.Array, f"b {embedding_dim}"]:"""计算标量位置的正弦余弦位置嵌入向量"""if embedding_dim % 2 != 0:      # 检查嵌入维度是否为偶数raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)  # 创建均匀分布的分数值period = min_period * (max_period / min_period) ** fraction  # 计算周期值,对数空间中均匀分布sinusoid_input = jnp.einsum("i,j->ij",pos,1.0 / period * 2 * jnp.pi,                      # 计算角频率precision=jax.lax.Precision.HIGHEST,            # 使用最高精度进行计算)# 连接sin和cos值,形成完整的位置编码return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)

2.1.3 class Pi0Config:含inputs_spec、get_freeze_filter

2.1.3.1 模型配置参数的定义

首先,这个类定义了模型的配置参数,比如PaLI-Gemma 变体:`gemma_2b

class Pi0Config(_model.BaseModelConfig):dtype: str = "bfloat16"  # 设置数据类型为bfloat16paligemma_variant: _gemma.Variant = "gemma_2b"          # 设置PaLI-Gemma变体为2B参数版本action_expert_variant: _gemma.Variant = "gemma_300m"    # 设置动作专家变体为300M参数版本# 设置模型特定的默认值action_dim: int = 32          # 设置动作维度为32action_horizon: int = 50      # 设置动作序列长度为50步max_token_len: int = 48       # 设置最大token长度为48
2.1.3.2 inputs_spec:定义了π0模型本身接收的输入数据格式

其次,通过inputs_spec函数定义了π0模型本身接收的输入数据格式,函数采用关键字参数 `batch_size`(默认为1),返回一个包含观察规格和动作规格的元组

def inputs_spec(self, *, batch_size: int = 1) -> Tuple[Type[_model.Observation], Type[_model.Actions]]
  1. 其支持多种输入,比如
    视觉输入(三个不同视角的RGB图像)语言输入(分词后的文本prompt)状态输入(当前机器人状态)
  2. 输出上
    则是一个时序动作序列(包含50个连续的动作向量,每个动作向量有32个维度,可能对应关节角度或其他控制信号)

具体而言该函数先
创建图像规格

        image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)

其中的

  1. `[batch_size, *_model.IMAGE_RESOLUTION, 3]` 定义了图像张量的形状:比如
    \rightarrow  批次大小
    \rightarrow  图像分辨率(从 `_model.IMAGE_RESOLUTION` 获取,可能是如 [224, 224] 这样的值
    \rightarrow  3 个颜色通道 (RGB)
  2. `jnp.float32` 指定了数据类型为 32 位浮点数

创建图像掩码规格

        image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)

其定义了图像掩码规格,每个批次中的每个图像都有一个布尔值,这个掩码用于指示哪些图像是有效的(`True`)或无效的(`False`)

创建观察规格:包含视觉输入、机器人状态、指令输入
`at.disable_typechecking()` 临时禁用类型检查,可能是因为这里创建的是类型规格而不是实际的数据,且观察规格包含多个组件:

  1. 多视角图像
    base_0_rgb: 机器人底座/身体视角的RGB图像
    left_wrist_0_rgb: 左手腕视角的RGB图像
    right_wrist_0_rgb: 右手腕视角的RGB图像
            with at.disable_typechecking():observation_spec = _model.Observation(images={"base_0_rgb": image_spec,"left_wrist_0_rgb": image_spec,"right_wrist_0_rgb": image_spec,},
  2. 图像掩码
    对应每个视角图像的有效性掩码
  3. 机器人状态:
    形状为 `[batch_size, self.action_dim]` 的浮点数张量
    `self.action_dim` 默认为32,表示状态向量的维度
                    state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
  4. 分词后的文本prompt
    形状为 `[batch_size, self.max_token_len]` 的整数张量
    `self.max_token_len` 默认为48,表示最大token数量
    数据类型为 `jnp.int32`,表示token ID
  5. 提示掩码
    与分词提示相同形状的布尔张量,用于指示哪些位置有有效的token
                    state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),)

创建动作规格

        action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)

其定义了动作数据的形状和类型:

  • `batch_size`: 批次大小
  • `self.action_horizon`: 动作序列长度,默认为50
  •  `self.action_dim`: 每个动作的维度,默认为32
  • `jnp.float32` 指定了数据类型为32位浮点数

然后返回

        return observation_spec, action_spec
2.1.3.3 get_freeze_filter:针对是否LoRA的处理

此外,该配置类还实现了get_freeze_filter这个函数作用是如果选择LoRA微调(冻结原始预训练模型的参数,只更新新添加的低秩适应层参数),则需要对模型中的某些参数做冻结

三种可能的情况:

  1. 只对 PaLI-Gemma 使用 LoRA:冻结 Gemma 参数(但排除动作专家参数)
  2. 只对动作专家使用 LoRA:冻结动作专家参数
  3. 对两者都使用 LoRA:冻结两者的基础参数

如此,可以选择性地微调模型的特定部分(语言部分或动作预测部分)

具体而言

  1. 首先,定义函数
        def get_freeze_filter(self) -> nnx.filterlib.Filter:"""返回基于模型配置的冻结过滤器"""
  2. 其次,初始化变量
            filters = []      # 初始化过滤器列表has_lora = False  # 初始化LoRA标志
  3. 接着,创建参数过滤器
            # 匹配所有LLM参数的正则表达式,用于选择 Gemma 语言模型的参数gemma_params_filter = nnx_utils.PathRegex(".*llm.*")  # 匹配动作专家参数的正则表达式action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*")  
  4. 接下来是对PaLI-Gemma变体的处理
            # 如果PaLI-Gemma使用LoRAif "lora" in self.paligemma_variant:filters.append(gemma_params_filter,  # 添加Gemma参数过滤器)if "lora" not in self.action_expert_variant:# 如果只冻结Gemma参数,排除动作专家参数filters.append(nnx.Not(action_expert_params_filter),)has_lora = True
  5. 再下来是对动作专家变体的处理
            elif "lora" in self.action_expert_variant:# 如果动作专家使用LoRAfilters.append(action_expert_params_filter,)has_lora = True

2.1.4 class Pi0:初始化、特征嵌入、损失函数、推理(去噪生成动作)

核心模型类,继承自 `_model.BaseModel`,实现了:

  1. 多模态输入处理
    处理多视角图像(基础视角、左手腕视角、右手腕视角)
    处理文本提示(如指令)
    处理机器人当前状态
  2. 扩散过程
    训练时:将干净动作添加噪声,让模型学习去噪
    推理时:从纯噪声开始,逐步降噪生成动作序列
  3. 注意力机制
    使用精心设计的注意力掩码控制信息流动
    前缀(图像和文本)内部使用全注意力
    后缀(状态和动作)使用特殊的注意力模式
2.1.4.1 初始化方法 `__init__`
class Pi0(_model.BaseModel):def __init__(self, config: Pi0Config, rngs: nnx.Rngs):# 初始化基类super().__init__(config.action_dim, config.action_horizon, config.max_token_len)# 获取PaLI-Gemma和动作专家配置paligemma_config = _gemma.get_config(config.paligemma_variant)action_expert_config = _gemma.get_config(config.action_expert_variant)

其组合了多个核心组件:

一个是PaLI-Gemma 模型:结合了 Gemma 语言模型和 SigLIP 视觉模型

  1. 先是对语言模型的初始化
            # 创建并初始化语言模型# TODO: 用NNX重写Gemma,目前使用桥接llm = nnx_bridge.ToNNX(_gemma.Module(configs=[paligemma_config, action_expert_config],  # 配置两个Gemma模型embed_dtype=config.dtype,          # 设置嵌入数据类型))llm.lazy_init(rngs=rngs, method="init")    # 延迟初始化LLM
  2. 然后是对视觉模型的初始化
            # 创建并初始化图像模型img = nnx_bridge.ToNNX(_siglip.Module(num_classes=paligemma_config.width,  # 设置图像特征维度与语言模型宽度相匹配variant="So400m/14",  # 使用400M参数SigLIP模型pool_type="none",  # 不使用池化,保留所有图像标记scan=True,  # 启用扫描优化dtype_mm=config.dtype,  # 设置矩阵乘法数据类型))# 使用假观察中的图像初始化图像模型img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
  3. 最后,把语言模型和视觉模型组合成PaLI-Gemma多模态模型
            # 组合LLM和图像模型为PaLI-Gemma多模态模型self.PaliGemma = nnx.Dict(llm=llm, img=img)

另一个是线性投影层:用于

  1. 状态投影
            # 状态投影层:将机器人状态投影到模型维度self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
  2. 动作投影
            # 动作输入投影层:将动作投影到模型维度self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
  3. 时间-动作混合等
            # 动作-时间MLP输入层:将连接的动作和时间特征投影到模型维度self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs)# 动作-时间MLP输出层self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)# 动作输出投影层:将模型输出投影回动作维度self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs)
2.1.4.2 特征嵌入方法:embed_prefix(图像和文本输入)、embed_suffix(状态和动作信息)
  • `embed_prefix`:处理图像和文本输入(图像通过SigLip模型编码,文本通过Gemma LLM编码),创建前缀 token,皆为双向注意力,用ar_mask = false表示
  • `embed_suffix`:处理机器人状态信息q_t、噪声化的动作信息noise(状态和噪声动作经过线性投影和MLP处理),创建后缀 token
    其中
    状态为单个token,和第一个动作token均设置为单向注意力,用ar_mask = true表示
    其余动作tokens之间设置为双向注意力,用ar_mask = false表示

对于前者embed_prefix

    def embed_prefix(self, obs: _model.Observation) -> Tuple[at.Float[at.Array, Any], at.Bool[at.Array, Any], at.Bool[at.Array, Any]]:"""嵌入前缀部分(图像和文本)"""input_mask = []           # 初始化输入掩码列表ar_mask = []              # 初始化自回归掩码列表tokens = []               # 初始化token列表

其工作流程为

  1. 图像处理:说白了,就是把图像token化
    使用SigLip视觉模型处理每个图像,生成图像tokens序列
            # 嵌入图像for name in obs.images:# 通过图像模型获取图像tokenimage_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)tokens.append(image_tokens)      # 添加图像token
  2. 图像掩码扩展
    将图像掩码扩展到与图像tokens相同的序列长度,使用einops.repeat进行形状变换,这些掩码会指示哪些图像是有效的,而哪些是填充的
                # 重复图像掩码以匹配token维度input_mask.append(einops.repeat(obs.image_masks[name],"b -> b s",               # 调整形状:批次维度保持不变,添加序列维度s=image_tokens.shape[1],  # 序列长度等于图像token数))
  3. 自回归掩码设置
    设置图像tokens之间的注意力为双向(False表示双向注意力),原因在于图像内容通常是非时序性的数据
                # 图像token之间互相关注(非自回归)ar_mask += [False] * image_tokens.shape[1]
  4. 文本处理
    使用LLM模型对文本输入tokenized_inputs进行嵌入
            # 添加语言(即分词后的输入)if obs.tokenized_prompt is not None:# 通过语言模型嵌入分词后的提示tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")tokens.append(tokenized_inputs)                  # 添加文本tokeninput_mask.append(obs.tokenized_prompt_mask)     # 添加提示掩码
    且同样设置为双向注意力,相当于语言token可以关注图像token,图像token反过来亦可关注语言token,最终实现多模态融合
                # 图像和语言输入之间完全关注(非自回归)ar_mask += [False] * tokenized_inputs.shape[1]
  5. 最后,连接所有token和掩码,其中包含了
    \rightarrow  多模态信息的融合表示tokens——图像token和语言token
    \rightarrow  以及指示哪些token是有效信息的input_mask
    \rightarrow  和如何在这些token之间进行注意力计算规则的ar_mask——相当于控制信息流动的方向
            # 连接所有token和掩码tokens = jnp.concatenate(tokens, axis=1)    # 在序列维度上连接tokeninput_mask = jnp.concatenate(input_mask, axis=1)  # 在序列维度上连接输入掩码ar_mask = jnp.array(ar_mask)                # 转换自回归掩码为数组return tokens, input_mask, ar_mask          # 返回token、输入掩码和自回归掩码

顺便,再回顾下此图

对于后者embed_suffix

    def embed_suffix(self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, Any]) -> Tuple[at.Float[at.Array, Any], at.Bool[at.Array, Any], at.Bool[at.Array, Any]]:"""嵌入后缀部分(状态和动作)"""input_mask = []           # 初始化输入掩码列表ar_mask = []              # 初始化自回归掩码列表tokens = []               # 初始化token列表

其工作流程为

  1. 状态处理
    将状态信息投影到embedding空间
            # 添加单个状态tokenstate_token = self.state_proj(obs.state)[:, None, :]  # 投影状态并添加序列维度tokens.append(state_token)                            # 添加状态token# 添加状态掩码(全为1),表示这个状态token是有效的input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_))  
    并设置为单向注意力(True),表明图像和语言输入不能关注状态信息,因为image/language do not attend to state or actions
            # 图像/语言输入不关注状态或动作(自回归)ar_mask += [True]
  2. 时间步嵌入,使用正弦-余弦位置编码生成时间步嵌入
           # 使用正弦余弦位置编码嵌入时间步,敏感度范围为[0, 1]time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)
  3. 动作和时间信息融合
            # 混合时间步+动作信息,使用MLPaction_tokens = self.action_in_proj(noisy_actions)  # 投影带噪声的动作# 重复时间嵌入以匹配动作序列长度time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon)# 连接动作和时间tokenaction_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
  4. MLP处理
    使用两层MLP和swish激活函数对「动作和时间的组合表示」进行非线性变换,以进一步融合:动作和时间信息
            # 通过MLP处理action_time_tokens = self.action_time_mlp_in(action_time_tokens)   # 输入层action_time_tokens = nnx.swish(action_time_tokens)                 # Swish激活函数action_time_tokens = self.action_time_mlp_out(action_time_tokens)  # 输出层
  5. 注意力掩码设置
    第一个动作token设置为单向注意力「上面说过了的,单向注意力,用ar_mask = true表示」,其余动作tokens之间设置为双向注意力
            # 添加动作时间tokentokens.append(action_time_tokens)# 添加掩码(全为1),表示所有动作token都是有效的input_mask.append(jnp.ones(action_time_tokens.shape[:2], dtype=jnp.bool_))  # 图像/语言/状态输入不关注动作token(动作第一个是自回归的——单向,其余不是——双向)ar_mask += [True] + ([False] * (self.action_horizon - 1))
  6. 最后连接所有token和掩码
            # 连接所有token和掩码tokens = jnp.concatenate(tokens, axis=1)          # 在序列维度上连接tokeninput_mask = jnp.concatenate(input_mask, axis=1)  # 在序列维度上连接输入掩码ar_mask = jnp.array(ar_mask)        # 转换自回归掩码为数组return tokens, input_mask, ar_mask  # 返回token、输入掩码和自回归掩码
2.1.4.3 损失函数 `compute_loss`

实现了扩散模型的训练损失计算

  1. 对输入观察进行预处理,其中
    preprocess_rng用于观察预处理(比如图像增强等)
    noise_rng用于生成噪声
    time_rng用于从beta分布采样时间步
        def compute_loss(self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False) -> at.Float[at.Array, Any]:"""计算扩散模型的损失函数"""# 分割随机数生成器为三部分,用于不同的随机操作preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
  2. 生成随机噪声并采样时间点 t
            # 获取动作的批次形状batch_shape = actions.shape[:-2]# 生成与动作相同形状的高斯噪声noise = jax.random.normal(noise_rng, actions.shape)# 从Beta分布采样时间点,范围为[0.001, 1],Beta(1.5, 1)偏向较低的值time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001# 扩展时间维度以匹配动作形状time_expanded = time[..., None, None]
  3. 创建带噪动作序列 x_t,相当于x_t是噪声化的动作,随着时间从0到1,原始动作逐渐加噪,变为纯噪声
    而u_t代表所加的真实噪声,而咱们就是要预测所添加的噪声(而所添加的噪声即等于加满噪声的动作 - 原始动作)
            # 创建带噪声的动作:t*noise + (1-t)*actionsx_t = time_expanded * noise + (1 - time_expanded) * actions# 计算真实噪声减去动作的差异,这是模型需要预测的目标u_t = noise - actions
    扩散策略diffusion policy的灵感来源于图像生成中的扩散模型DDPM,通过逐步去除噪声来生成目标数据(比如机器人的动作序列),如果对DDPM原理不太明白的,详见此文《图像生成发展起源:从VAE、扩散模型DDPM、DDIM到DETR、ViT、Swin transformer》
  4. 嵌入前缀和后缀
            # 一次性前向传递前缀+后缀# 嵌入前缀(图像和文本)prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)# 嵌入后缀(状态和带噪声的动作)suffix_tokens, suffix_mask, suffix_ar_mask = self.embed_suffix(observation, x_t, time)
  5. 构建注意力掩码和位置编码
    根据下图

    可得
            # 连接掩码:通过链接前缀和后缀的掩码,从而创建完整的输入掩码input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1)ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0)# 创建注意力掩码make_attn_mask,从而控制不同token之间的可见性attn_mask = make_attn_mask(input_mask, ar_mask)# 计算位置编码positions = jnp.cumsum(input_mask, axis=1) - 1
  6. 模型前向传播,即使用PaliGemma进行推理,处理前缀和后缀token
    当然了,输出中我们只关注与后缀相关的部分,因为其中包含了我们想要的动作预测的部分
            # 通过PaLI-Gemma模型处理token_, suffix_out = self.PaliGemma.llm([prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions)
  7. 预测噪声v_t
            # 将模型输出投影回动作空间v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
  8. 计算预测噪声与实际噪声间的均方误差
            # 返回预测噪声和真实噪声之间的均方误差return jnp.mean(jnp.square(v_t - u_t), axis=-1)
2.1.4.4 推理函数 `sample_actions`:基于扩散模型逆向采样,生成机器人动作序列

sample_actions函数是Pi0模型的核心推理方法,实现了基于扩散模型的逆向采样过程——说白了 就是去噪,它从纯噪声开始,通过多步骤逐渐"去噪",最终生成符合条件分布的机器人动作序列

函数的核心是一个基于while循环的迭代过程,每一步都使用训练好的神经网络预测从当前噪声化动作到目标动作的方向——从噪声到目标的方向 代表速度场,毕竟咱们去噪的方向得对 不然就去歪了

总之,这个函数将观察数据(图像和可选的文本提示)转换为具体的动作轨迹,是模型部署时的主要接口,简言之,其包含以下流程

  1. 首先从纯噪声开始 (t=1)
  2. 通过重复迭代降噪步骤,逐步将噪声转化为有意义的动作序列
  3. 使用KV缓存优化推理速度
  4. 实现了一个迭代降噪过程:
  5. 最终返回完全降噪后的动作序列 x_0

具体而言,包含如下步骤

第一,初始化

首先,函数对输入观察数据进行预处理,包括标准化图像大小等操作

def sample_actions(self,rng: at.KeyArrayLike,               # 随机数生成器observation: _model.Observation,    # 观察输入,包含图像和文本等*,num_steps: int = 10,                # 扩散过程的步数,默认为10步
) -> _model.Actions:                    # 返回生成的动作序列# 对观察数据进行预处理,不进行训练时的数据增强observation = _model.preprocess_observation(None, observation, train=False)

然后设置时间步长`dt`为负值(因为我们是从t=1向t=0方向演化),生成初始随机噪声作为起点,且时间上约定:"t=1是噪声,t=0是目标分布",这是扩散文献中常见的约定,不过与Pi0论文相反

    # 注意:这里使用扩散模型文献中更常见的约定,t=1是噪声,t=0是目标分布# 这与pi0论文相反dt = -1.0 / num_steps                       # 计算时间步长,从1到0batch_size = observation.state.shape[0]     # 获取批次大小# 生成初始噪声,形状为[批次大小, 动作序列长度, 动作维度]noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))

第二,Key-Value缓存初始化(预计算并存储前缀表示,减少冗余计算)

处理观察数据,得到前缀表示和相关掩码

    # 首先通过前缀的前向传递填充KV缓存# 获取前缀的token表示和掩码prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)# 创建前缀的注意力掩码prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)# 计算位置编码positions = jnp.cumsum(prefix_mask, axis=1) - 1

然后使用PaliGemma语言模型进行一次前向传递,生成Key-Value缓存(`kv_cache`)——这是一个性能优化:因为前缀部分在整个采样过程中保持不变,预先计算并缓存它们的表示可以避免重复计算

    # 进行前向传递,获取KV缓存_, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)

第三,通过step函数构建注意力掩码系统并让PaliGemma做推理

核心迭代通过 `jax.lax.while_loop` 实现

根据源码

可知,该class Pi0(_model.BaseModel)类的最后两行是

    # 使用while循环进行迭代采样,从t=1(噪声)开始x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))# 返回最终的去噪结果(生成的动作序列)return x_0

具体而言,包含 `step` 函数和 `cond` 函数,其中,`step` 函数是每次迭代的核心

首先,step函数通过 `embed_suffix` 处理当前状态,包括状态信息嵌入、噪声化动作、时间步编码

    def step(carry):"""定义单步去噪函数"""x_t, time = carry  # carry数组包含当前状态和时间# 将时间广播到批次维度,并嵌入后缀(状态和动作)suffix_tokens, suffix_mask, suffix_ar_mask = self.embed_suffix(observation, x_t, jnp.broadcast_to(time, batch_size))

其次,构建复杂的注意力掩码系统,处理前缀-后缀之间的注意力关系——这个复杂的掩码系统允许后缀token(包括状态和动作)有选择地关注前缀token(图像和文本),实现了条件生成,具体而言,其构建了三层注意力掩码:

  • 后缀内部注意力掩码,控制后缀token(状态和动作)之间的注意力关系
        # 创建后缀内部的注意力掩码,形状为(批次, 后缀长度, 后缀长度)suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
  • 前缀-后缀注意力掩码,控制后缀token如何关注前缀token(图像和文本输入)
        # 创建后缀对前缀的注意力掩码,形状为(批次, 后缀长度, 前缀长度)prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
  • 完整注意力掩码,将前两个掩码组合,形成完整的注意力控制机制
        # 组合掩码,形状为(批次, 后缀长度, 前缀长度+后缀长度)# 控制后缀token(生成查询)如何关注完整序列(生成键和值)full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)

当然了,过程中还做了形状检查,确保张量维度正确

        # 验证掩码形状正确assert full_attn_mask.shape == (batch_size,suffix_tokens.shape[1],prefix_tokens.shape[1] + suffix_tokens.shape[1],)

接着,计算位置编码,为后缀token计算其在完整序列中的位置,这对于Transformer模型理解序列顺序很重要

        # 计算后缀token的位置编码positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1

之后,模型推理,使用PaliGemma语言模型进行推理,利用缓存的前缀信息(`kv_cache`)提高效率

        # 使用KV缓存进行高效的前向传递(prefix_out, suffix_out), _ = self.PaliGemma.llm([None, suffix_tokens], mask=full_attn_mask, positions=positions, kv_cache=kv_cache)# 且确保前缀输出为None(因为使用了KV缓存)assert prefix_out is None

第四,step函数中做最后的速度预测与动作更新(去噪)

在每一步中,模型预测速度场 `v_t`(从噪声到目标的方向),并通过类欧拉法更新动作表示——使用简单而有效的欧拉方法x_{t+1}=x_{t}+v_{t} \cdot d t进行轨迹采样

具体而言

  • 一方面,提取模型输出并预测速度场`v_t`——相当于本质是通过PaliGemma模型预测去噪方向 `v_t`
        # 预测噪声v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
  • 二方面,使用欧拉法更新动作状态和时间步
        # 使用欧拉方法更新状态和时间return x_t + dt * v_t, time + dt

至于cond函数确定何时停止迭代,通过检查时间是否接近零(当然,要考虑浮点精读可能存在的误差)

    def cond(carry):"""定义循环终止条件"""x_t, time = carry# 考虑浮点误差,当时间接近0时停止return time >= -dt / 2

// 待更

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/29378.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Dify+DeepSeek | Excel数据一键可视化(创建步骤案例)(echarts助手.yml)(文档表格转图表、根据表格绘制图表、Excel绘制图表)

Dify部署参考&#xff1a;Dify Rag部署并集成在线Deepseek教程&#xff08;Windows、部署Rag、安装Ragan安装、安装Dify安装、安装ollama安装&#xff09; DifyDeepSeek - Excel数据一键可视化&#xff08;创建步骤案例&#xff09;-DSL工程文件&#xff08;可直接导入&#x…

由麻省理工学院计算机科学与人工智能实验室等机构创建低成本、高效率的物理驱动数据生成框架,助力接触丰富的机器人操作任务

2025-02-28&#xff0c;由麻省理工学院计算机科学与人工智能实验室&#xff08;CSAIL&#xff09;和机器人与人工智能研究所的研究团队创建了一种低成本的数据生成框架&#xff0c;通过结合物理模拟、人类演示和基于模型的规划&#xff0c;高效生成大规模、高质量的接触丰富型机…

OpenCV计算摄影学(11)色调映射算法类cv::TonemapDrago

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 cv::TonemapDrago 是 OpenCV 中实现的基于 Paul Debevec 和 Jorge Moraleda 以及后来由 Rogier van de Weijer 和 Theo Drago 改进的色调映射算法…

蓝桥杯 Excel地址

Excel地址 题目描述 Excel 单元格的地址表示很有趣&#xff0c;它使用字母来表示列号。 比如&#xff0c; A 表示第 1 列&#xff0c; B 表示第 2 列&#xff0c; Z 表示第 26 列&#xff0c; AA 表示第 27 列&#xff0c; AB 表示第 28 列&#xff0c; BA 表示第 53 列&#x…

JS禁止web页面调试

前言 由于前端在页面渲染的过程中 会调用很多后端的接口&#xff0c;而有些接口是不希望别人看到的&#xff0c;所以前端调用后端接口的行为动作就需要做一个隐藏。 禁用右键菜单 document.oncontextmenu function() {console.log("禁用右键菜单");return false;…

实例详细演示在Pytest中如何忽略警告

关注开源优测不迷路 大数据测试过程、策略及挑战 测试框架原理&#xff0c;构建成功的基石 在自动化测试工作之前&#xff0c;你应该知道的10条建议 在自动化测试中&#xff0c;重要的不是工具 当你尝试运行Pytest代码时&#xff0c;那些不相关的警告突然弹出&#xff0c;是不是…

OpenGL ES -> GLSurfaceView纹理贴图VBO(Vertex Buffer Object)方法实现

贴图 XML文件 <?xml version"1.0" encoding"utf-8"?> <com.example.myapplication.MyGLSurfaceViewxmlns:android"http://schemas.android.com/apk/res/android"android:layout_width"match_parent"android:layout_height…

IDEA中Git版本回退终极指南:Reset与Revert双方案详解

目录 前言一、版本回退前置知识二、Reset方案&#xff1a;整体改写历史1、IDEA图形化操作&#xff08;推荐&#xff09;1.1、查看提交历史1.2、选择目标版本1.3、选择回退模式1.3.1、Soft&#xff08;推荐&#xff09;1.3.2、Mixed1.3.3、Hard&#xff08;慎用&#xff09;1.3.…

面试题02.02.返回倒数第k个节点

实现一种算法&#xff0c;找出单向链表中倒数第 k 个节点。返回该节点的值。 注意&#xff1a;本题相对原题稍作改动 示例&#xff1a; 输入&#xff1a; 1->2->3->4->5 和 k 2 输出&#xff1a; 4 说明&#xff1a; 给定的 k 保证是有效的。 题解&#xff…

【经验分享】Ubuntu20.04编译RK3568 AI模型报错问题(已解决)

【经验分享】Ubuntu20.04编译RK3568 AI模型报错问题&#xff08;已解决&#xff09; 前言问题现象问题分析解决方案总结 前言 这里使用的是Rockchip提供的rknn_model_zoo&#xff0c;https://github.com/airockchip/rknn_model_zoo/tree/main 此解决方案适用于Rockchip芯片在U…

Python的那些事第四十一篇:简化数据库交互的利器Django ORM

Django ORM:简化数据库交互的利器 摘要 随着互联网技术的飞速发展,Web开发越来越受到重视。Django作为一款流行的Python Web框架,以其高效、安全、可扩展等特点受到了广大开发者的喜爱。其中,Django ORM(对象关系映射)是Django框架的核心组件之一,它为开发者提供了一种…

Swagger UI界面的使用

访问地址 一般格式&#xff1a;http://xxxx:端口号/上下文路径/swagger-ui/index.html 首先保证当前项目已经集成Swagger的功能 上下文路径&#xff1a;指的配置文件中的&#xff1a;server.servlet.context-path的值 刚进入界面&#xff0c;找到自己的服务接口&#xff0c;开…

WPS工具栏添加Mathtype加载项

问题描述&#xff1a; 分别安装好WPS和MathType之后&#xff0c;WPS工具栏没直接显示MathType工具&#xff0c;或者是前期使用正常&#xff0c;由于WPS更新之后MathType工具消失&#xff0c;如下图 解决办法 将文件“MathType Commands 2016.dotm”和“MathPage.wll”从Matht…

部署RabbitMQ集群详细教程

部署RabbitMQ集群详细教程 下面是一份在 Ubuntu 环境下部署 RabbitMQ 集群的详细步骤说明&#xff0c;涉及主机名设置、Erlang & RabbitMQ 安装、管理插件启用、集群通信 Cookie 配置、节点加入集群、镜像队列策略设置以及集群验证等。为了演示方便&#xff0c;以下示例假…

三维数据可视化与表面重建:Marching Cubes算法的原理与应用

1. 引言 随着现代医学影像技术的飞速发展&#xff0c;三维数据的可视化与重建已成为医学研究、临床诊断和手术规划的重要工具。在众多三维重建算法中&#xff0c;Marching Cubes算法因其高效、稳定的特性成为从离散数据场中提取等值面的经典方法。本报告将深入探讨Marching Cu…

IDEA 2024.1.7 Java EE 无框架配置servlet

1、创建一个目录&#xff08;文件夹&#xff09;lib来放置我们的库 2、将tomcat目录下的lib文件夹中的servlet-api.jar文件复制到刚创建的lib文件夹下。 3、把刚才复制到lib下的servlet-api.jar添加为库 4、在src下新建一个package&#xff1a;com.demo&#xff0c;然后创…

【文生图】windows 部署stable-diffusion-webui

windows 部署stable-diffusion-webui AUTOMATIC1111 stable-diffusion-webui Detailed feature showcase with images: 带图片的详细功能展示: Original txt2img and img2img modes 原始的 txt2img 和 img2img 模式 One click install and run script (but you still must i…

【TCP/IP协议栈】【传输层】端口号、套接字、多路复用/分解、网络字节序

参考资料&#xff1a; 前言&#xff1a; 总结&#xff1a; 【计算机网络】套接字&#xff08;应用层和传输层之间的接口&#xff09; 套接字是一个通用的通信接口抽象不仅限于TCP/IP协议族作为应用层和传输层之间的桥梁支持多种通信方式和协议族 套接字定义 在 TCP 或者 UDP…

【AI大模型】DeepSeek + Kimi 高效制作PPT实战详解

目录 一、前言 二、传统 PPT 制作问题 2.1 传统方式制作 PPT 2.2 AI 大模型辅助制作 PPT 2.3 适用场景对比分析 2.4 最佳实践与推荐 三、DeepSeek Kimi 高效制作PPT操作实践 3.1 Kimi 简介 3.2 DeepSeek Kimi 制作PPT优势 3.2.1 DeepSeek 优势 3.2.2 Kimi 制作PPT优…

【哇! C++】类和对象(三) - 构造函数和析构函数

目录 一、构造函数 1.1 构造函数的引入 1.2 构造函数的定义和语法 1.2.1 无参构造函数&#xff1a; 1.2.2 带参构造函数 1.3 构造函数的特性 1.4 默认构造函数 二、析构函数 2.1 析构函数的概念 2.2 特性 如果一个类中什么成员都没有&#xff0c;简称为空类。 空类中…