ray.rllib-入门实践-12:自定义policy

        在本博客开始之前,先厘清一下几个概念之间的区别与联系:env,  agent,  model, algorithm, policy. 

        强化学习由两部分组成: 环境(env)和智能体(agent)。环境(env)提供观测值和奖励; agent读取观测值,输出动作或决策。agent是algorithm的类对象。 policy是algorithm的子类, 比如ppo, dqn等。因此,自定义policy本质上是自定义algorithm.  algorithm 主要由两部分组成: 网络结构(model)和损失函数(loss)。 网络结构(model)的自定义由上一个博客ray.rllib-入门实践-11: 自定义模型/网络 进行了介绍:在alrorithm外创建新的model类, 通过 AlgorithmConfig类传入algorithm。因此, c从实际操作上, 自定义algorithm就变成了自定义algorithm 的 loss.

        因此,本博客所提到的自定义policy, 本质上就是继承一个Algorithm, 并修改它的loss函数。

        与之前介绍的自定义env, 自定义model一样, 自定义policy也包含三个步骤:

        1. 继承某个Policy, 创建一个新Policy类, 修改它的损失函数。

        2. 把自己的Policy封装为一个Algorithm, 使ray可识别

        3. 配置使用自己的Policy.

环境配置:

        torch==2.5.1

        ray==2.10.0

        ray[rllib]==2.10.0

        ray[tune]==2.10.0

        ray[serve]==2.10.0

        numpy==1.23.0

        python==3.9.18

一、 自定义 policy

import torch 
import gymnasium as gym 
from gymnasium import spaces
from ray.rllib.utils.annotations import override
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.algorithms.ppo import PPO, PPOConfig, PPOTorchPolicy
from typing import Dict, List, Type, Union
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.policy.sample_batch import SampleBatch## 1. 自定义 policy, 主要是改变 policy 的 loss 的计算  # 神经网络的损失函数
class MY_PPOTorchPolicy(PPOTorchPolicy):"""PyTorch policy class used with PPO."""def __init__(self, observation_space:gym.spaces.Box, action_space:gym.spaces.Box, config:PPOConfig): PPOTorchPolicy.__init__(self,observation_space,action_space,config)## PPOTorchPolicy 内部对 PPOConfig 格式的config 执行了to_dict()操作,后面可以以 dict 的形式使用 config@override(PPOTorchPolicy) def loss(self,model: ModelV2,dist_class: Type[ActionDistribution],train_batch: SampleBatch):## 原始损失original_loss = super().loss(model, dist_class, train_batch) # PPO原来的损失函数, 也可以完全自定义新的loss函数, 但是非常不建议。## 新增自定义损失,这里以正则化损失作为示例addiontial_loss = torch.tensor(0.0) ## 自己定义的lossaddiontial_loss = torch.tensor(0.)for param in model.parameters():addiontial_loss += torch.norm(param)## 得到更新后的损失new_loss = original_loss + 0.01 * addiontial_lossreturn new_loss

二、 把自己的policy封装在一个算法中

## 2. 把自己的 policy 封装在算法中: 
##    继承自PPO, 创建一个新的算法类, 默认调用的是自定义的policy
class MY_PPO(PPO):## 重写 get_default_policy_class 函数, 使其返回自定义的policy def get_default_policy_class(self, config):return MY_PPOTorchPolicy

三、使用自己的策略创建智能体,执行训练

配置方法1:

## 三、使用自己的策略创建智能体,执行训练
## method-1
from ray.tune.logger import pretty_print 
config = PPOConfig(algo_class = MY_PPO) ## 配置使用自己的算法
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 
algo = config.build()
result = algo.train()
print(pretty_print(result))

配置方法2:

## 3. 使用新策略执行训练
## method-2
from ray.tune.logger import pretty_print 
config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 
algo = MY_PPO(config=config,)  ## 在这里使用自己的policy
result = algo.train()
print(pretty_print(result))

四、代码汇总:

import torch 
import gymnasium as gym 
from gymnasium import spaces
from ray.rllib.utils.annotations import override
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.algorithms.ppo import PPO, PPOConfig, PPOTorchPolicy
from typing import Dict, List, Type, Union
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.policy.sample_batch import SampleBatch
from ray.tune.logger import pretty_print ## 1. 自定义 policy, 主要是改变 policy 的 loss 的计算  # 神经网络的损失函数
class MY_PPOTorchPolicy(PPOTorchPolicy):"""PyTorch policy class used with PPO."""def __init__(self, observation_space:gym.spaces.Box, action_space:gym.spaces.Box, config:PPOConfig): PPOTorchPolicy.__init__(self,observation_space,action_space,config)## PPOTorchPolicy 内部对 PPOConfig 格式的config 执行了to_dict()操作,后面可以以 dict 的形式使用 config@override(PPOTorchPolicy) def loss(self,model: ModelV2,dist_class: Type[ActionDistribution],train_batch: SampleBatch):## 原始损失original_loss = super().loss(model, dist_class, train_batch) # PPO原来的损失函数, 也可以完全自定义新的loss函数, 但是非常不建议。## 新增自定义损失,这里以正则化损失作为示例addiontial_loss = torch.tensor(0.0) ## 自己定义的lossaddiontial_loss = torch.tensor(0.)for param in model.parameters():addiontial_loss += torch.norm(param)## 得到更新后的损失new_loss = original_loss + 0.01 * addiontial_lossreturn new_loss## 2. 把自己的 policy 封装在算法中: 
##    继承自PPO, 创建一个新的算法类, 默认调用的是自定义的policy
class MY_PPO(PPO):## 重写 get_default_policy_class 函数, 使其返回自定义的policy def get_default_policy_class(self, config):return MY_PPOTorchPolicy## 三、使用自己的策略创建智能体,执行训练
## method-1
# config = PPOConfig(algo_class = MY_PPO) ## 配置使用自己的算法
# config = config.environment("CartPole-v1")
# config = config.rollouts(num_rollout_workers=2)
# config = config.framework(framework="torch") 
# algo = config.build()
# result = algo.train()
# print(pretty_print(result))## method-2
config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 
algo = MY_PPO(config=config,)  ## 在这里配置使用自己的policy
result = algo.train()
print(pretty_print(result))

五、后记:

        如何既要修改网络结构,又要修改loss函数, 可以结合 上一篇博客 和本博客共同实现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/7376.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

侧边导航(Semi Design)

根据前几次的导航栏设计,从最简单的三行导航栏到后面响应式的导航栏,其实可以在这个的基础上慢慢优化,就可以得到一个日常使用设计的导航栏。设计步骤也和之前的类似。 一、实现步骤 1、先下载安装好npm install douyinfe/semi-icons 2、引…

【中间件快速入门】什么是Redis

现在后端开发会用到各种中间件,一不留神项目可能在哪天就要用到一个我们之前可能听过但是从来没接触过的中间件,这个时候对于开发人员来说,如果你不知道这个中间件的设计逻辑和使用方法,那在后面的开发和维护工作中可能就会比较吃…

将 OneLake 数据索引到 Elasticsearch - 第二部分

作者:来自 Elastic Gustavo Llermaly 及 Jeffrey Rengifo 本文分为两部分,第二部分介绍如何使用自定义连接器将 OneLake 数据索引并搜索到 Elastic 中。 在本文中,我们将利用第 1 部分中学到的知识来创建 OneLake 自定义 Elasticsearch 连接器…

“AI教学实训系统:打造未来教育的超级引擎

嘿,各位教育界的伙伴们,今天我要跟你们聊聊一个绝对能让你们眼前一亮的教学神器——AI教学实训系统。作为资深产品经理,我可是亲眼见证了这款系统如何颠覆传统教学,成为未来教育的超级引擎。 一、什么是AI教学实训系统&#xff1f…

Linux下php8安装phpredis扩展的方法

Linux下php8安装phpredis扩展的方法 下载redis扩展执行安装编辑php.ini文件重启php-fpmphpinfo 查看 下载redis扩展 前提是已经安装好redis服务了 php-redis下载地址 https://github.com/phpredis/phpredis 执行命令 git clone https://github.com/phpredis/phpredis.git执行…

基于SMPL的三维人体重建-深度学习经典方法之VIBE

本文以开源项目VIBE[1-2]为例,介绍下采用深度学习和SMPL模板的从图片进行三维人体重建算法的整体流程。如有错误,欢迎评论指正。 一.算法流程 包含生成器模块和判别器模块,核心贡献就在于引入了GRU模块,使得当前帧包含了先前帧的先…

2.1.3 第一个工程,点灯!

新建工程 点击菜单栏左上角,新建工程或者选择“文件”-“新建工程”,选择工程类型“标准工程”选择设备类型和编程语言,并指定工程文件名及保存路径,如下图所示: 选择工程类型为“标准工程” 选择主模块机型&#x…

CVE-2025-0411 7-zip 漏洞复现

文章目录 免责申明漏洞描述影响版本漏洞poc漏洞复现修复建议 免责申明 本文章仅供学习与交流,请勿用于非法用途,均由使用者本人负责,文章作者不为此承担任何责任 漏洞描述 此漏洞 (CVSS SCORE 7.0) 允许远程攻击者绕…

mysql 学习6 DML语句,对数据库中的表进行 增 删 改 操作

添加数据 我们对 testdatabase 数据中 的 qqemp 这张表进行 增加数据,在这张表 下 打开 命令行 query console 在 软件中就是打开命令行的意思 可以先执行 desc qqemp; 查看一下当前表的结构。 插入一条数据 到qqemp 表,插入时要每个字段都有值 insert…

[特殊字符]【计算机视觉】r=2 采样滤波器全解析 ✨

Hey小伙伴们!今天来给大家分享一个在 计算机视觉 领域中非常有趣但又超级重要的概念——r2 采样滤波器(Sampling Filter with r2)。通过这种滤波器,我们可以在图像降采样的过程中有效地减少混叠效应,提升图像质量。 如…

数据库SQLite和SCADA DIAView应用教程

课程简介 此系列课程大纲主要包含七个课时。主要使用到的开发工具有:SQLite studio 和 SCADA DIAView。详细的可成内容大概如下: 1、SQLite 可视化管理工具SQLite Studio :打开数据库和查询数据;查看视频 2、创建6个变量&#x…

【开源免费】基于Vue和SpringBoot的景区民宿预约系统(附论文)

本文项目编号 T 162 ,文末自助获取源码 \color{red}{T162,文末自助获取源码} T162,文末自助获取源码 目录 一、系统介绍二、数据库设计三、配套教程3.1 启动教程3.2 讲解视频3.3 二次开发教程 四、功能截图五、文案资料5.1 选题背景5.2 国内…

[Spring] Gateway详解

🌸个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 🏵️热门专栏: 🧊 Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 🍕 Collection与…

Pandas基础02(DataFrame创建/索引/切片/属性/方法/层次化索引)

DataFrame数据结构 DataFrame 是一个二维表格的数据结构,类似于数据库中的表格或 Excel 工作表。它由多个 Series 组成,每个 Series 共享相同的索引。DataFrame 可以看作是具有列名和行索引的二维数组。设计初衷是将Series的使用场景从一维拓展到多维。…

矩阵快速幂

矩阵快速幂: 高效计算矩阵的幂次(如A^n)的一种算法,只适用于计算某一项,而不是全部项。 递推公式 如果 n为偶数,则: A^nA^(n/2)A^(n/2) 如果 nnn 为奇数,则: A^nA^(n-1…

复位信号的同步与释放(同步复位、异步复位、异步复位同步释放)

文章目录 背景前言一、复位信号的同步与释放1.1 同步复位1.1.1 综述1.1.2 优缺点 1.2 recovery time和removal time1.3 异步复位1.3.1 综述1.3.2 优缺点 1.4 同步复位 与 异步复位1.5 异步复位、同步释放1.5.1 总述1.5.2 机理1.5.3 复位网络 二、思考与补充2.1 复…

【Git版本控制器--3】Git的远程操作

目录 理解分布式版本控制系统 创建远程仓库 仓库被创建后的配置信息 克隆远程仓库 https克隆仓库 ssh克隆仓库 向远程仓库推送 拉取远程仓库 忽略特殊文件 为什么要忽略特殊文件? 如何配置忽略特殊文件? 配置命令别名 标签管理 理…

ios打包:uuid与udid

ios的uuid与udid混乱的网上信息 新人开发ios,发现uuid和udid在网上有很多帖子里是混淆的,比如百度下,就会说: 在iOS中使用UUID(通用唯一识别码)作为永久签名,通常是指生成一个唯一标识&#xf…

中国认知作战研究中心:从认知战角度分析2007年iPhone发布

中国认知作战研究中心:从认知战角度分析2007年iPhone发布 中国认知作战研究中心:从认知战角度分析2007年iPhone发布 关键词 认知作战,新质生产力,人类命运共同体,认知战,认知域,认知战研究中心,认知战争,认知战战术,认知战战略,认知域作战研究,认知作…

Chameleon(变色龙) 跨平台编译C文件,并一次性生成多个平台的可执行文件

地址:https://github.com/MartinxMax/Chameleon Chameleon 跨平台编译C文件,并一次性生成多个平台的可执行文件。可以通过编译Chameleon自带的.C文件反向Shell生成不同平台攻击载荷。 登录 & 代理设置 按照以下步骤设置 Docker 的代理: 创建配置目…