ETH开源PPO算法学习

前言

项目地址:https://github.com/leggedrobotics/rsl_rl

项目简介:快速简单的强化学习算法实现,设计为完全在 GPU 上运行。这段代码是 NVIDIA Isaac GYM 提供的 rl-pytorch 的进化版。

下载源码,查看目录,整个项目模块化得非常好,每个部分各司其职。下面我们自底向上地进行讲解加粗的部分。

rsl_rl/
│ __init__.py

├─algorithms/
│ │ __init__.py
│ │ ppo.py # PPO算法的实现
│ │
├─env/
│ │ __init__.py
│ │ vec_env.py # 实现并行处理多个环境的向量化环境
│ │
├─modules/
│ │ __init__.py
│ │ actor_critic.py # 定义 Actor-Critic 网络结构
│ │ actor_critic_recurrent.py # 定义包含循环层的 Actor-Critic 网络
│ │ normalizer.py # 数据正规化工具,有助于训练过程的稳定性
│ │
├─runners/
│ │ __init__.py
│ │ on_policy_runner.py # 实现用于执行 on-policy 算法训练循环的运行器
│ │
├─storage/
│ │ __init__.py
│ │ rollout_storage.py # 存储和管理策略 rollout 数据的工具
│ │
└─utils/
│ __init__.py
│ neptune_utils.py # 用于与 Neptune.ai 集成的工具
│ utils.py # 通用实用工具函数
│ wandb_utils.py # 用于与 Weights & Biases 集成的工具

rollout 数据储存和管理(rollout_storage.py)

定义了一个名为 RolloutStorage 的类,用于存储和管理在强化学习训练过程中从环境中收集到的数据(称为rollouts)。

  • 定义Transition

用于存储单个时间步的所有相关数据,包括观察值、动作、奖励、完成标志(dones)、值函数估计、动作的对数概率、动作的均值和标准差,以及可能的隐藏状态(对于使用循环网络的情况)。

  • 特权观察值(Privileged Observations)

除了self.observations外还有self.privileged_observations的使用,在强化学习中是指那些在训练期间可用但在实际部署或测试时不可用的额外信息。这些信息通常提供了环境的内部状态或其他有助于学习的提示,但在现实世界应用中可能难以获得或完全不可用。在训练期间使用特权观察值的一种常见方法是通过教师-学生架构(我们常常也称作特权学习),其中一个拥有全部信息的教师模型(可以访问特权观察值)来指导一个学生模型(只能访问普通观察值)。学生模型的目标是模仿教师模型的决策,尽管它没有直接访问特权信息。

  • 奖励和优势的计算
    def compute_returns(self, last_values, gamma, lam):advantage = 0for step in reversed(range(self.num_transitions_per_env)):if step == self.num_transitions_per_env - 1:next_values = last_valueselse:next_values = self.values[step + 1]next_is_not_terminal = 1.0 - self.dones[step].float()delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]advantage = delta + next_is_not_terminal * gamma * lam * advantageself.returns[step] = advantage + self.values[step]# Compute and normalize the advantagesself.advantages = self.returns - self.valuesself.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)

这段代码实现的是在强化学习中计算回报(returns)和优势(advantages)的逻辑,具体是使用了一种称为广义优势估算(Generalized Advantage Estimation, GAE)的方法。GAE是一种权衡偏差和方差以及平滑回报信号的技术,由以下几个数学公式定义:

  1. TD残差(Temporal Difference Residual):
    δ t = R t + γ V ( S t + 1 ) ( 1 − d o n e t ) − V ( S t ) \delta_t = R_t + \gamma V(S_{t+1}) (1 - done_t) - V(S_t) δt=Rt+γV(St+1)(1donet)V(St)
    其中, δ t \delta_t δt是时刻 t t t的TD残差, R t R_t Rt是奖励, γ \gamma γ是折扣因子, V ( S t ) V(S_t) V(St)是状态 S t S_t St的价值函数估计, d o n e t done_t donet是表示当前状态是否为终止状态的指示函数(如果当前状态为终止状态,则 d o n e t = 1 done_t = 1 donet=1;否则, d o n e t = 0 done_t = 0 donet=0)。如果 d o n e t = 1 done_t = 1 donet=1,那么 γ V ( S t + 1 ) \gamma V(S_{t+1}) γV(St+1)项将为 0,因为终止状态之后没有未来回报。

  2. GAE优势估计:
    A t G A E ( γ , λ ) = ∑ l = 0 ∞ ( γ λ ) l δ t + l A_t^{GAE(\gamma, \lambda)} = \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l} AtGAE(γ,λ)=l=0(γλ)lδt+l
    在代码中,这个无限求和是通过迭代地计算来近似的,具体的迭代公式为:
    A t = δ t + ( γ λ ) A t + 1 ( 1 − d o n e t ) A_t = \delta_t + (\gamma \lambda) A_{t+1} (1 - done_t) At=δt+(γλ)At+1(1donet)
    其中, A t A_t At是时刻 t t t的优势估计, λ \lambda λ是用来平衡TD估计和蒙特卡罗估计之间权重的参数。

  3. 回报的计算:
    G t = A t + V ( S t ) G_t = A_t + V(S_t) Gt=At+V(St)
    其中, G t G_t Gt是时刻 t t t的回报估计。

代码中使用的变量名与数学符号的对应关系:

变量名数学符号含义
rewards[step] R t R_t Rt时刻 t t t的奖励
gamma γ \gamma γ折扣因子,用于计算未来奖励的现值
values[step] V ( S t ) V(S_t) V(St)状态 S t S_t St在当前策略下的价值函数估计
dones[step] d o n e t done_t donet指示当前状态 S t S_t St是否为终止状态的标志(1 表示终止,0 表示非终止)
delta δ t \delta_t δt时刻 t t t的 TD 残差
advantage A t A_t At时刻 t t t的优势估计,根据 GAE 方法计算
lam λ \lambda λ用于 GAE 计算中平衡 TD 估计和蒙特卡罗估计之间权重的参数
returns[step] G t G_t Gt时刻 t t t的回报估计
advantages A t n o r m A_t^{norm} Atnorm标准化后的优势估计
mu_A, sigma_A μ A \mu_A μA, σ A \sigma_A σA优势估计的平均值和标准差
epsilon ϵ \epsilon ϵ避免除零错误而加的小常数,通常取值为 1e-8

代码中的循环从最后一个转换开始向前迭代,使用以上的数学公式来计算每一步的优势和回报。最后,它还对优势进行了标准化处理,即从每个优势中减去所有优势的平均值,并除以标准差,以减少训练期间的方差并加速收敛。标准化公式如下:
A t n o r m = A t − μ A σ A + ϵ A_t^{norm} = \frac{A_t - \mu_A}{\sigma_A + \epsilon} Atnorm=σA+ϵAtμA
其中, μ A \mu_A μA是优势的平均值, σ A \sigma_A σA是优势的标准差, ϵ \epsilon ϵ​ 是为了防止除以零而加的一个小常数(在代码中为 1e-8)。

  • 轨迹的平均长度

类中并没有显式存储轨迹的长度,轨迹长度隐含在self.dones之中。代码中使用的方法是:将每个环境中最后一步置为‘1’,然后flatten(展开)、拼接所有环境中的dones得到flat_dones,差分数组中为‘1’位置的索引得到智能体在每个环境中的步数,即轨迹长度。这个统计量有助于了解训练过程中智能体的表现。

  • mini-batch迭代器

mini_batch_generator 函数通过在多个训练周期(num_epochs)内,从经验回放缓冲区中随机选择小批量数据(包括观察值 observations、动作 actions、奖励 rewards 等)来生成小批量数据集。该函数利用 torch.randperm 生成随机索引 indices 来随机化数据抽样,进而支持基于批处理的学习方法,如梯度下降。通过每次只处理必要的数据量,该生成器在优化模型参数的同时,也优化了内存使用,确保了训练过程的高效性和灵活性。

(未完待续)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/265632.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis 16种妙用

1、缓存 2、数据共享分布式 3、分布式锁 4、全局ID 5、计数器 6、限流 7、位统计 8、购物车 9、用户消息时间线timeline 10、消息队列 11、抽奖 12、点赞、签到、打卡 13、商品标签 14、商品筛选 15、用户关注、推荐模型 16、排行榜 1、缓存 String类型 例如:热点…

提示emp.dll丢失怎么办?emp.dll的多种解决方法

在计算机游戏运行过程中,如果系统提示“emp.dll文件丢失”,这可能会引发一系列影响游戏正常运行的问题。由于emp.dll是动态链接库文件中的一种,它通常承载着游戏核心功能的重要模块或组件,一旦缺失,意味着相关功能将无…

Flask——基于python完整实现客户端和服务器后端流式请求及响应

文章目录 本地客户端Flask服务器后端客户端/服务器端流式接收[打字机]效果 看了很多相关博客,但是都没有本地客户端和服务器后端的完整代码示例,有的也只说了如何流式获取后端结果,基本没有讲两端如何同时实现流式输入输出,特此整…

Linux:Ansible的常用模块

模块帮助 ansible-doc -l 列出ansible的模块 ansible-doc 模块名称 # 查看指定模块的教程 ansible-doc command 查看command模块的教程 退出教程时候建议不要使用ctrlc 停止,某些shell工具会出现错误 command ansible默认的模块,执行命令,注意&#x…

Docker数据卷-自定义镜像

一.数据卷 1.1数据卷的基本使用 数据卷是一个特殊的目录,用于在Docker容器中持久化和共享数据。 数据卷的主要特点包括: 数据持久性:数据卷允许您在容器的生命周期之外保持数据的持久性。即使容器被删除,数据卷中的数据依然存在&…

小程序框架(概念、工作原理、发展及应用)

引言 移动应用的普及使得用户对于轻量级、即时可用的应用程序需求越来越迫切。在这个背景下,小程序应运而生,成为一种无需下载安装、即点即用的应用形式,为用户提供了更便捷的体验。小程序的快速发展离不开强大的开发支持,而小程…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的活体人脸检测系统(Python+PySide6界面+训练代码)

摘要:本篇博客详细讲述了如何利用深度学习构建一个活体人脸检测系统,并且提供了完整的实现代码。该系统基于强大的YOLOv8算法,并进行了与前代算法YOLOv7、YOLOv6、YOLOv5的细致对比,展示了其在图像、视频、实时视频流和批量文件处…

如何做代币分析:以 TRX 币为例

作者:lesleyfootprint.network 编译:cicifootprint.network 数据源:TRX 代币仪表板 (仅包括以太坊数据) 在加密货币和数字资产领域,代币分析起着至关重要的作用。代币分析指的是深入研究与代币相关的数据…

c# .net8 香橙派orangepi + hc-04蓝牙 实例

这些使用c# .net8开发,硬件 香橙派 orangepi 3lts和 hc-04蓝牙 使用场景:可以通过这个功能,手机连接orangepi进行wifi等参数配置 硬件: 1、带USB口的linux开发板orangepi 2、USB 转TTL 中转接蓝牙(HC-04) 某宝上买…

网站三合一缩略图片介绍展示源码

网站三合一缩略图片介绍展示源码,PHP源码,运行需要php环境支持,效果截图如下 蓝奏云下载:https://wfr.lanzout.com/ihY8y1pgim6j

45、WEB攻防——通用漏洞PHP反序列化POP链构造魔术方法原生类

文章目录 序列化:将java、php等代码中的对象转化为数组或字符串等格式。代表函数serialize(),将一个对象转换成一个字符;反序列化:将数组或字符串等格式还成对象。代表函数unserialize(),将字符串还原成一个对象。 P…

【算法与数据结构】复杂度深度解析(超详解)

文章目录 📝算法效率🌠 算法的复杂度🌠 时间复杂度的概念🌉大O的渐进表示法。 🌠常见复杂度🌠常见时间复杂度计算举例🌉常数阶O(1)🌉对数阶 O(logN)🌉线性阶 O(N)&#x…

论文设计任务书学习文档|基于智能搜索引擎的图书管理系统的设计与实现

文章目录 论文(设计)题目:基于智能搜索引擎的图书管理系统的设计与实现1、论文(设计)的主要任务及目标2、论文(设计)的主要内容3、论文(设计)的基本要求4、进度安排论文(设计)题目:基于智能搜索引擎的图书管理系统的设计与实现 1、论文(设计)的主要任务及目标 …

Android Activity启动模式

文章目录 Android Activity启动模式概述四种启动模式Intent标记二者区别 Android Activity启动模式 概述 Activity 的管理方式是任务栈。栈是先进后出的结构。 四种启动模式 启动模式说明适用场景standard标准模式默认模式,每次启动Activity都会创建一个新的Act…

【解读】工信部数据安全能力提升实施方案

近日,工信部印发《工业领域数据安全能力提升实施方案(2024-2026年)》,提出到2026年底,我国工业领域数据安全保障体系基本建立,基本实现各工业行业规上企业数据安全要求宣贯全覆盖。数据安全保护意识普遍提高…

【两颗二叉树】【递归遍历】【▲队列层序遍历】Leetcode 617. 合并二叉树

【两颗二叉树】【递归遍历】【▲队列层序遍历】Leetcode 617. 合并二叉树 解法1 深度优先 递归 前序解法2 采用队列进行层序遍历 挺巧妙的可以再看 ---------------🎈🎈题目链接🎈🎈------------------- 解法1 深度优先 递归 前…

动态规划(题目提升)

[NOIP2012 普及组] 摆花 方法一:记忆化搜索 何为记忆化搜素:就是使用递归函数对每次得到的结果进行保存,下次遇到就直接输出即可 那么这个题目使用递归(DFS)是怎样的? 首先我们需要搞清楚几个坑点&#x…

centos离线使用源码安装Postgre SQL

1源代码包获取 https://www.postgresql.org/ftp/source/ 这里我以下载14.5版本的tar.gz为例 2源码包上传服务器 将下载的源码上传至目标服务器,这里我上传到:/usr/local/postgres14.5/src,我们可以创建这个目录,然后在上传到…

Kali Linux下载与安装

目录 1 kali官网下载镜像文件 2 VMware打开kali linux文件 3 启动kali-linux-2023.4操作系统 1 kali官网下载镜像文件 kali官网:https://www.kali.org/get-kali/#kali-platforms 进入kali官网主页后看到如图所示界面,左边“Installer Images”界面是…

【MATLAB】tvf_emd_ MFE_SVM_LSTM 神经网络时序预测算法

有意向获取代码,请转文末观看代码获取方式~也可转原文链接获取~ 1 基本定义 TVF-EMD_MFE_SVM_LSTM 神经网络时序预测算法是一种结合了变分模态分解(TVF-EMD)、多尺度特征提取(MFE)、聚类后展开支持向量机(…