ray.rllib-入门实践-11: 自定义模型/网络

在ray.rllib中定义和使用自己的模型, 分为以下三个步骤:

1. 定义自己的模型。

2. 向ray注册自定义的模型

3. 在config中配置使用自定义的模型

环境配置:

        torch==2.5.1

        ray==2.10.0

        ray[rllib]==2.10.0

        ray[tune]==2.10.0

        ray[serve]==2.10.0

        numpy==1.23.0

        python==3.9.18

一、 定义自己的模型 

需要继承自 TFModel 或 TorchModelV2, 并重写需要自定义的方法, 其代码框架如下:

import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2class My_Model(TorchModelV2, nn.Module): ## 重构以下函数, 函数接口不能变。def __init__(self, obs_space, action_space, num_outputs, model_config, name, *, custom_arg1, custom_arg2): ...def forward(self, input_dict, state, seq_lens): ...def value_function(self): ...

示例如下:

## 1. 定义自己的模型
import numpy as np 
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import gymnasium as gym 
from gymnasium import spaces  
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDictclass My_Model(TorchModelV2, nn.Module):def __init__(self, obs_space:gym.spaces.Space, action_space:gym.spaces.Space, num_outputs:int, model_config:ModelConfigDict,  ## PPOConfig.training(model = ModelConfigDict), 调用的是config.model中的参数name:str,*, custom_arg1, custom_arg2):TorchModelV2.__init__(self, obs_space, action_space, num_outputs,model_config,name)nn.Module.__init__(self)## 测试 custom_arg1 , custom_arg2 传递进来的是什么数值print(f"===========================  custom_arg1 = {custom_arg1}, custom_arg2 = {custom_arg2}")## 定义网络层obs_dim = int(np.product(obs_space.shape))action_dim = int(np.product(action_space.shape))## shareNetself.shared_fc = nn.Linear(obs_dim,128)## actorNetself.actorNet = nn.Linear(128, action_dim)## criticNetself.criticNet = nn.Linear(128,1)self._feature = None def forward(self, input_dict, state, seq_lens):obs = input_dict["obs"].float()self._feature = self.shared_fc.forward(obs)action_logits = self.actorNet.forward(self._feature)return action_logits, state def value_function(self):value = self.criticNet.forward(self._feature).squeeze(1)return value 

        在rllib中,每个算法的所有网络都被汇集到同一个 ModelV2 类下,供算法调用。actor 网络和critic网络可以在外面定义,也可以在model内部直接定义。 model的forward用于返回actor网络的输出, value_function函数用于返回critic网络的输出。 网络结构和网络层共享可以自定义设置。输入输出接口,需要与上面保持严格一致。

二、 向ray注册自定义模型

        ray.rllib.model.ModelCatalog 类,用于向ray注册自定义的model, 还可以用于获取env的 preprocessors 和 action distributions。

import ray 
from ray.rllib.models import ModelCatalog # ModelCatalog 类: 用于注册 models, 获取env的 preprocessors 和 action distributions。 ModelCatalog.register_custom_model(model_name="my_torch_model", model_class = My_Model)

三、 在算法中配置并使用自定义的模型

主要是在 config.training() 模块中的 model 子模块中传入两个配置信息:

        1)"custom_model":"my_torch_model" ,                      
         2)"custom_model_config": {"custom_arg1": 1, "custom_arg2": 2,}})  

两个关键字固定不变,填入自己注册的模型名和对应的模型参数即可。

可以有以下三种配置代码的编写方式:

配置方法1:

## 3. 在训练中使用自定义模型
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") ## 配置使用自定义的模型
config = config.training(model= {"custom_model":"my_torch_model" ,                      "custom_model_config": {"custom_arg1": 1, "custom_arg2": 2,}})  
## 主要在上面两行配置使用自己的模型
##    配置 model 的 "custom_model" 项,用于指定rllib算法所使用的模型
##    配置 model 的 "custom_model_config" 项,用于传入自定义的网络参数,供自定义的model使用。
##    这两个关键词不可更改。algo = config.build()
## 4. 执行训练
result = algo.train()
print(pretty_print(result))

与以上配置内容一样,还可以用以下两种配置写法:

配置方法2:

config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") ## 配置自定义模型
model_config_dict = {}
model_config_dict["custom_model"] = "my_torch_model" 
model_config_dict["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}
config = config.training(model= model_config_dict)  algo = config.build()

 配置方法3(推荐):

config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") ## 配置自定义模型
config.model["custom_model"] = "my_torch_model"
config.model["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}algo = config.build()

 代码汇总:

"""
在ray.rllib中定义和使用自己的模型, 分为以下三个步骤:
1. 定义自己的模型。 需要继承自 TFModel 或 TorchModelV2, 并重写需要自定义的方法import torch.nn as nnfrom ray.rllib.models.torch.torch_modelv2 import TorchModelV2class CustomTorchModel(TorchModelV2, nn.Module): ## 重构以下函数, 函数接口不能变。 def __init__(self, obs_space, action_space, num_outputs, model_config, name, *, custom_arg1, custom_arg2): ...def forward(self, input_dict, state, seq_lens): ...def value_function(self): ...2. 向ray注册自定义的模型from ray.rllib.models import ModelCatalogModelCatalog.register_custom_model("wzg_torch_model", CustomTorchModel)3. 在config中配置使用自定义的模型model_config_dict = {"custom_model":"wzg_torch_model","custom_model_config":{"custom_arg1": 1,"custom_arg2": 2}}config = PPOConfig()# config = config.training(model = model_config_dict)config.model["custom_model"] = "wzg_torch_model"config.model["custom_model_config"] = {"custom_arg1": 1,"custom_arg2": 2}
"""## 1. 定义自己的模型
import numpy as np 
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import gymnasium as gym 
from gymnasium import spaces  
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDictclass My_Model(TorchModelV2, nn.Module):def __init__(self, obs_space:gym.spaces.Space, action_space:gym.spaces.Space, num_outputs:int, model_config:ModelConfigDict,  ## PPOConfig.training(model = ModelConfigDict), 调用的是config.model中的参数name:str,*, custom_arg1, custom_arg2):TorchModelV2.__init__(self, obs_space, action_space, num_outputs,model_config,name)nn.Module.__init__(self)## 测试 custom_arg1 , custom_arg2 传递进来的是什么数值print(f"===========================  custom_arg1 = {custom_arg1}, custom_arg2 = {custom_arg2}")## 定义网络层obs_dim = int(np.product(obs_space.shape))action_dim = int(np.product(action_space.shape))## shareNetself.shared_fc = nn.Linear(obs_dim,128)## actorNetself.actorNet = nn.Linear(128, action_dim)## criticNetself.criticNet = nn.Linear(128,1)self._feature = None def forward(self, input_dict, state, seq_lens):obs = input_dict["obs"].float()self._feature = self.shared_fc.forward(obs)action_logits = self.actorNet.forward(self._feature)return action_logits, state def value_function(self):value = self.criticNet.forward(self._feature).squeeze(1)return value ## 2. 向ray注册自定义模型
import ray 
from ray.rllib.models import ModelCatalog # ModelCatalog 类: 用于注册 models, 获取env的 preprocessors 和 action distributions。 ModelCatalog.register_custom_model(model_name="my_torch_model", model_class = My_Model)
ray.init()## 3. 在训练中使用自定义模型
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") # ## 配置自定义模型:方法 1
# config = config.training(model= {"custom_model":"my_torch_model" ,                      
#                                  "custom_model_config": {"custom_arg1": 1, "custom_arg2": 2,}})  
# ## 配置自定义模型:方法 2
# model_config_dict = {}
# model_config_dict["custom_model"] = "my_torch_model" 
# model_config_dict["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}
# config = config.training(model= model_config_dict) ## 配置自定义模型: 方法 3 (个人更喜欢, 因为嵌套层次少)
config.model["custom_model"] = "my_torch_model"
config.model["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}## 错误方法:
# model_config_dict = {}
# model_config_dict["custom_model"] = "my_torch_model" 
# model_config_dict["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}
# config.model = model_config_dict # 会清空 model 里面的其他默认配置,导致报错algo = config.build()## 4. 执行训练
result = algo.train()
print(pretty_print(result))


 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/7550.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于SpringBoot的网上考试系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏:…

【java数据结构】map和set

【java数据结构】map和set 一、Map和Set的概念以及背景1.1 概念1.2 背景1.3 模型 二、Map2.1 Map说明2.2 Map的常用方法 三、Set3.1 Set说明3.2 Set的常用方法 四、Set和Map的关系 博客最后附有整篇博客的全部代码!!! 一、Map和Set的概念以及…

基于迁移学习的ResNet50模型实现石榴病害数据集多分类图片预测

完整源码项目包获取→点击文章末尾名片! 番石榴病害数据集 背景描述 番石榴 (Psidium guajava) 是南亚的主要作物,尤其是在孟加拉国。它富含维生素 C 和纤维,支持区域经济和营养。不幸的是,番石榴生产受到降…

企业信息化2:行政办公管理系统

总裁办公室作为综合行政管理部门服务于整个公司,工作职责包含从最基础的行政综合到协调督办、对外政务、品牌建设等等,工作量繁多而且琐碎。如何通过信息化来实现标准化和常态化的管理手段,确保总裁办的各项工作有章可循,提高工作…

基于springboot+vue的古城景区管理系统的设计与实现

开发语言:Java框架:springbootJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:…

使用 Elasticsearch 导航检索增强生成图表

作者:来自 Elastic Louis Jourdain 及 Ivan Monnier 了解如何使用知识图谱来增强 RAG 结果,同时在 Elasticsearch 中高效存储图谱。本指南探讨了根据用户查询动态生成知识子图的详细策略。 检索增强生成 (RAG) 通过将大型语言模型 (LLM) 的输出基于事实数…

【数据结构】_以SLTPushBack(尾插)为例理解单链表的二级指针传参

目录 1. 第一版代码 2. 第二版代码 3. 第三版代码 前文已介绍无头单向不循环链表的实现,详见下文: 【数据结构】_不带头非循环单向链表-CSDN博客 但对于部分方法如尾插、头插、任意位置前插入、任意位置前删除的相关实现,其形参均采用了…

ceph新增节点,OSD设备,标签管理(二)

一、访问客户端集群方式 方式一: 使用cephadm shell交互式配置 [rootceph141 ~]# cephadm shell # 注意,此命令会启动一个新的容器,运行玩后会退出! Inferring fsid c153209c-d8a0-11ef-a0ed-bdb84668ed01 Inferring config /var/lib/ce…

Spring Data JPA 实战:构建高性能数据访问层

1 简介 1.1 Spring Data JPA 概述 1.1.1 什么是 Spring Data JPA? Spring Data JPA 是 Spring Data 项目的一部分,旨在简化对基于 JPA 的数据库访问操作。它通过提供一致的编程模型和接口,使得开发者可以更轻松地与关系型数据库进行交互,同时减少了样板代码的编写。Spri…

Git进阶笔记系列(01)Git核心架构原理 | 常用命令实战集合

读书笔记:卓越强迫症强大恐惧症,在亲子家庭、职场关系里尤其是纵向关系模型里,这两种状态很容易无缝衔接。尤其父母对子女、领导对下属,都有望子成龙、强将无弱兵的期望,然而在你的面前,他们才是永远强大的…

基于模糊PID的孵化箱温度控制系统(论文+源码)

1系统方案设计 本课题为基于模糊PID的孵化箱温度控制系统,其以STM32最小系统与模糊PID控制器为控制核心。系统主要包括数据采集模块、处理器模块、电机控制模块。 数据采集模块由温度传感器构成,通过温度传感器感应温度变化,获得待处理的数据…

Arcgis国产化替代:Bigemap Pro正式发布

在数字化时代,数据如同新时代的石油,蕴含着巨大的价值。从商业决策到科研探索,从城市规划到环境监测,海量数据的高效处理、精准分析与直观可视化,已成为各行业突破发展瓶颈、实现转型升级的关键所在。历经十年精心打磨…

ThreeJS示例教程200+【目录】

Three.js 是一个强大的 JavaScript 库,旨在简化在网页上创建和展示3D图形的过程。它基于 WebGL 技术,但提供了比直接使用 WebGL 更易于使用的API,使得开发者无需深入了解 WebGL 的复杂细节就能创建出高质量的3D内容。 由于目前内容还不多,下面的内容暂时做一个占位。 文章目…

opengrok_使用技巧

Searchhttps://xrefandroid.com/android-15.0.0_r1/https://xrefandroid.com/android-15.0.0_r1/ 选择搜索的目录(工程) 手动在下拉框中选择,或者 使用下面三个快捷按钮进行选择或者取消选择。 输入搜索的条件 搜索域说明 域 fullSearc…

无人机如何自主侦察?UEAVAD:基于视觉的无人机主动目标探测与导航数据集

作者:Xinhua Jiang, Tianpeng Liu, Li Liu, Zhen Liu, and Yongxiang Liu 单位:国防科技大学电子科学学院 论文标题:UEVAVD: A Dataset for Developing UAV’s Eye View Active Object Detection 论文链接:https://arxiv.org/p…

【图文详解】lnmp架构搭建Discuz论坛

安装部署LNMP 系统及软件版本信息 软件名称版本nginx1.24.0mysql5.7.41php5.6.27安装nginx 我们对Markdown编辑器进行了一些功能拓展与语法支持,除了标准的Markdown编辑器功能,我们增加了如下几点新功能,帮助你用它写博客: 关闭防火墙 systemctl stop firewalld &&a…

Ansible入门学习之基础元素介绍

一、Ansible目录结构介绍 1.通过rpm -ql ansible获取ansible所有文件存放的目录 有配置文件目录 /etc/ansible/ 执行文件目录 /usr/bin/ 其中 /etc/ansible/ 该文件目录的主要功能是 inventory主机信息配置,ansible工具功能配置。 ansible自身的配置文件…

git Bash通过SSH key 登录github的详细步骤

1 问题 通过在windows 终端中的通过git登录github 不再是通过密码登录了,需要本地生成一个密钥,配置到gihub中才能使用 2 步骤 (1)首先配置用户名和邮箱 git config --global user.name "用户名"git config --global…

矩阵的秩在机器学习中具有广泛的应用

矩阵的秩在机器学习中具有广泛的应用,主要体现在以下几个方面: 一、数据降维与特征提取 主成分分析(PCA): PCA是一种常用的数据降维技术,它通过寻找数据中的主成分(即最大方差方向&#xff09…

Windows Defender添加排除项无权限的解决方法

目录 起因Windows Defender添加排除项无权限通过管理员终端添加排除项管理员身份运行打开PowerShell添加/移除排除项的命令 起因 博主在打软件补丁时,遇到 Windows Defender 一直拦截并删除文件,而在 Windows Defender 中无权限访问排除项。尝试通过管理…