DP3 无论是 train 还是 eval 均使用了 Hydra 这一个python 库,这就有些代码在看的时候难以理解其通讯逻辑,例如:
@hydra.main(version_base=None,config_path=str(pathlib.Path(__file__).parent.joinpath('diffusion_policy_3d', 'config'))
)
Hydra 是一个用于管理复杂配置的开源框架,特别适用于需要动态配置的 Python 项目。它的主要功能是允许用户通过多种方式(如配置文件、命令行参数、环境变量等)管理和合成配置
Hydra 最初由 Facebook 开发,已成为了处理大规模机器学习、深度学习和其他数据科学项目中配置的一个主流工具
而 DP3 中就是使用 Hydra 管理 yaml 配置文件
所以本文讲解一下 Hydra 基础知识和使用教程
官方网站:Hydra
官方教程地址:Getting started
Github地址:hydra
目录
1. 安装
2. 配置文件
3. 应用文件
4. 补充示例
1. 安装
pip install hydra-core --upgrade
2. 配置文件
Hydra 使用 YAML 语言书写配置文件。YAML(YAML Ain't Markup Language)是一种简洁的数据序列化格式,常用于配置文件、数据交换、日志记录等场景
通常把需要的配置写在 config.yaml 中,DP3 的配置文件在如下位置
3. 应用文件
编写 application:my_app.py,模板如下:
import hydra
from omegaconf import DictConfig, OmegaConf@hydra.main(version_base=None, config_path="conf", config_name="config")
def my_app(cfg : DictConfig) -> None:print(OmegaConf.to_yaml(cfg))if __name__ == "__main__":my_app()
这下就能明白 DP3 eval.py 文件的内容了
import os
import hydra
import torch
import dill
from omegaconf import OmegaConf
import pathlib
from train import TrainDP3WorkspaceOmegaConf.register_new_resolver("eval", eval, replace=True)@hydra.main(version_base=None,config_path=str(pathlib.Path(__file__).parent.joinpath('diffusion_policy_3d', 'config'))
)
def main(cfg):workspace = TrainDP3Workspace(cfg)workspace.eval()if __name__ == "__main__":main()
OmegaConf 更多相关内容可以参考 Usage
在运行 application 时候,config.yaml 会自动加载
也可以在 application 中通过命令行覆盖 config.yaml 中的值,例如 DP3 命令示例:
bash scripts/train_policy.sh dp3 adroit_hammer 0112 0 0
4. 补充示例
首先,config.yaml 处于文件夹 config 中,config 应与 application 处于同一级
其次,有时候需要使用不同模型并设置不同的超参数,此时可以在文件夹 config 中为每个模型编写一个配置文件,被称为一个配置组(configuration group),DP3样例如下: