PPO算法逐行代码详解

前言:本文会从理论部分、代码部分、实践部分三方面进行PPO算法的介绍。其中理论部分会介绍PPO算法的推导流程,代码部分会给出PPO算法的各部分的代码以及简略介绍,实践部分则会通过debug代码调试的方式从头到尾的带大家看清楚应用PPO算法在cartpole环境上进行训练的整体流程,进而帮助大家将理论与代码实践相结合,更好的理解PPO算法。

文章目录

      • 1. 理论部分
      • 2. 代码部分
        • 2.1 神经网络的定义
        • 2.2 PPO算法的定义
        • 2.3 on policy算法的训练代码
      • 3. 实践部分

1. 理论部分

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

2. 代码部分

这里使用的是《动手学强化学习》中提供的代码,我将这本书中的代码整理到了github上,并且方便使用pycharm进行运行和调试。代码地址:https://github.com/zxs-000202/dsx-rl

代码核心部分整体上可以分为三部分,分别是关于神经网络的类的定义(PolicyNet,ValueNet),关于PPO算法的类的定义(PPO),on policy算法训练流程的定义(train_on_policy_agent)。

2.1 神经网络的定义

策略网络actor采用两层全连接层,第一层采用relu激活函数,第二层采用softmax函数。

class PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)

价值网络critic采用两层全连接层,第一层和第二层均采用relu激活函数,第二层最后输出的维度是1,表示t时刻某个状态s的价值V

class ValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)
2.2 PPO算法的定义

这部分是PPO算法最核心的部分。本部分包括神经网络的初始化、相关参数的定义,如何根据状态s选择动作,以及网络参数是如何更新的。整体的代码如下:

class PPO:''' PPO算法,采用截断方式 '''def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,lmbda, epochs, eps, gamma, device):self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.critic = ValueNet(state_dim, hidden_dim).to(device)self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=critic_lr)self.gamma = gammaself.lmbda = lmbdaself.epochs = epochs  # 一条序列的数据用来训练轮数self.eps = eps  # PPO中截断范围的参数self.device = devicedef take_action(self, state):state = torch.tensor([state], dtype=torch.float).to(self.device)probs = self.actor(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def update(self, transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)td_target = rewards + self.gamma * self.critic(next_states) * (1 -dones)td_delta = td_target - self.critic(states)advantage = rl_utils.compute_advantage(self.gamma, self.lmbda,td_delta.cpu()).to(self.device)old_log_probs = torch.log(self.actor(states).gather(1,actions)).detach()for _ in range(self.epochs):log_probs = torch.log(self.actor(states).gather(1, actions))ratio = torch.exp(log_probs - old_log_probs)surr1 = ratio * advantagesurr2 = torch.clamp(ratio, 1 - self.eps,1 + self.eps) * advantage  # 截断actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO损失函数critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()actor_loss.backward()critic_loss.backward()self.actor_optimizer.step()self.critic_optimizer.step()
2.3 on policy算法的训练代码

我们知道on policy算法与off policy算法的区别就在与进行采样的网络和用来参数更新的训练网络是否是一个网络。PPO是一种on policy的算法,on policy算法要求训练的网络参数一更新就需要重新进行采样然后训练。但PPO有点特殊,它是利用重要性采样方法来实现数据的多次利用,提高了数据的利用效率。

从下面的代码中可以看到每当一个episode的数据采样完毕之后都会执行agent.update(transition_dict)。在update中会将当前采样得到的数据(也就是当前episode的每个t时刻的 [ 当前状态,动作,奖励,是否结束,下一个状态 ] 这个五元组)通过重要性采样的方法进行多次的神经网络参数的更新,这个具体过程我会在第三部分实践部分进行详细说明。

def train_on_policy_agent(env, agent, num_episodes):return_list = []for i in range(10):with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes / 10)):episode_return = 0transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}state = env.reset()done = Falsewhile not done:action = agent.take_action(state)next_state, reward, done, _ = env.step(action)transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)agent.update(transition_dict)if (i_episode + 1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode + 1),'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)return return_list

3. 实践部分

该部分带领大家从代码运行流程上从头到尾过一遍。
本次PPO算法训练应用的gym环境是CartPole-v0,如下图。该gym环境的状态空间是四个连续值用来表示杆所处的状态,而动作空间是两个离散值用来表示给杆施加向左或者向右的力(也就是action为0或者1)。
在这里插入图片描述
代码中前面的部分是网络、算法、训练参数的定义,真正开始训练是在return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)。我们在这里打一个断点进行调试。
在这里插入图片描述
比较重要的训练参数是episode为500,epoch为10。表示一共采样500个episode的数据,每个episode数据采样完毕后再update阶段进行10次的梯度下降参数更新。
在这里插入图片描述
执行到这里我们跳到take_action里面看看agent是如何选取动作的。
在这里插入图片描述
可以看到首先将输入的state转换为tensor,然后因为actor网络的最后一层是softmax函数,所以通过actor网络输出两个执行两个动作可能性的大小,然后通过action_dist = torch.distributions.Categorical(probs) action = action_dist.sample()根据可能性大小进行采样最后得到这次选择动作1进行返回。

第一个episode中我们采样到的数据如下:
在这里插入图片描述
可以看到第一个episode执行了41个step之后done了。
接下来我们跳进update函数中看看具体如何用这个episode采样的数据进行神经网络参数的更新。
在这里插入图片描述
首先从transition_dict中将采样的41个step的数据取出来存到tensor中方便之后进行运算。
在这里插入图片描述
接下来计算td_targettd_deltaadvantageold_log_probs
其中td_target表示t时刻的得到的奖励值r加上根据t+1时刻critic估计的状态价值V。用公式表示的话就是:

td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)

在这里插入图片描述
td_delta表示的是td_targett时刻采用critic估计的状态价值之间的差值。用公式表示的话就是:

td_delta = td_target - self.critic(states)

在这里插入图片描述
advantage则表示的是某状态下采取某动作的优势值,也就是下面PPO算法公式中的A

advantage = rl_utils.compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)

在这里插入图片描述
old_log_probs则表示的是旧策略下某个状态下采取某个动作的概率值的对数值。对应下面PPO算法公式中的clip函数中分母。

old_log_probs = torch.log(self.actor(states).gather(1,actions)).detach()

在这里插入图片描述

到这里我们计算出这些值的shape如下图所示。接下来我们就要进入一个循环中,循环epoch=10次进行参数的神经网络参数的更新。
在这里插入图片描述

具体而言首先在每个循环中计算一次当前的actor参数下对应的stateaction的概率值的对数值,也就是下面分式中的分子。
这里用到的高中对数数学公式有:
在这里插入图片描述
在这里插入图片描述
通过上面两个公式可以计算得到
在这里插入图片描述

然后计算这个分式的值ratio
之后结合下面的公式和代码可以看出surr1surr2分别是下图公式中对应的值。
在这里插入图片描述

log_probs = torch.log(self.actor(states).gather(1, actions))
ratio = torch.exp(log_probs - old_log_probs)
surr1 = ratio * advantage
surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage  # 截断

然后计算actorloss值和criticloss

actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO损失函数
critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))

计算得到的数据的shape如下图:
在这里插入图片描述
最后对actorcritic进行梯度下降进行参数更新。

self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
ctor_loss.backward()
critic_loss.backward()
self.actor_optimizer.step()
self.critic_optimizer.step()

至此第一次循环执行完毕,接下来还需要再进行九次这种循环,把当前episode的41个step的数据再进行九次actorcritic参数的更新。

然后update相当于执行完毕了。在这里插入图片描述
接下来一个episode的训练过程相当于结束了,我们一共有500个episode需要训练,再循环499次就完成了整个训练的流程。

还有一部分比较重要的就是通过GAE计算优势函数Advantage的代码。

def compute_advantage(gamma, lmbda, td_delta):td_delta = td_delta.detach().numpy()advantage_list = []advantage = 0.0for delta in td_delta[::-1]:advantage = gamma * lmbda * advantage + deltaadvantage_list.append(advantage)advantage_list.reverse()return torch.tensor(advantage_list, dtype=torch.float)

这部分就留给大家结合第一部分理论部分最后一张图来进行理解了。

参考资料

  1. 磨菇书easy-rl https://datawhalechina.github.io/easy-rl/#/chapter5/chapter5
  2. 《动手学强化学习》 https://hrl.boyuai.com/chapter/2/ppo%E7%AE%97%E6%B3%95

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/156584.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

iMazing2023免费版苹果iPhone手机备份应用软件

iMazing是一款功能强大的苹果手机备份软件,它可通过备份功能将通讯录备份到电脑上,并在电脑端iMazing“通讯录”功能中随时查看和导出联系人信息。它自带Wi-Fi自动备份功能,能够保证通讯录备份数据是一直在动态更新的,防止手机中新…

webdriver.Chrome()没反应

今天学习爬虫安装selenium之后刚开始webdriver.Chrome()正常 后面运行突然卡在这一步了 百度发现是版本不匹配 我们下载旧版本的chrome Download Google Chrome 95.0.4638.69 for Windows - Filehippo.com 禁用chrome的自动更新 打开文件所在位置 点击Google文件夹 右键up…

HDLbits: Lemmings3

Lemmings又多了一种状态:dig,我按照上一篇文章里大神的思路又多加了两种状态:LEFT_DIGGING与RIGHT_DIGGING,写出了如下的代码: module top_module(input clk,input areset, // Freshly brainwashed Lemmings walk …

【Java 进阶篇】JavaScript 与 HTML 的结合方式

JavaScript是一种广泛应用于Web开发中的脚本语言,它与HTML(Hypertext Markup Language)结合使用,使开发人员能够创建交互式和动态的网页。在这篇博客中,我们将深入探讨JavaScript与HTML的结合方式,包括如何…

图像滤波总结

中值滤波器 中值滤波器是一种常用的非线性滤波器,其基本原理是:选择待处理像素的一个邻域中各像素值的中值来代替待处理的像素。主要功能使某像素的灰度值与周围领域内的像素比较接近,从而消除一些孤立的噪声点,所以中值滤波器能够…

超美!ChatGPT DALL-E 3已可用,另外GPT-4可上传图片进行问答

今天,在ChatGPT里使用DALL-E 3的功能终于上线了。以下是截图: 在GPT-4下加了一个菜单入口,名为 DALL-E 3,这也意味着ChatGPT免费账户暂时不能使用这个功能。 我们体验一下这个功能。 技术交流 建了技术交流群!想要进…

解决echarts配置滚动(dataZoom)后导出图片数据不全问题

先展现一个echarts&#xff0c;并配置dataZoom&#xff0c;每页最多10条数据&#xff0c;超出滚动 <div class"echartsBox" id"echartsBox"></div>onMounted(() > {nextTick(() > {var chartDom document.getElementById(echartsBox);…

如何在雷电模拟器上安装Magisk并加载movecert模块抓https包(二)

接来下在PC端安装和配置Charles&#xff0c;方法同下面链接&#xff0c;不再赘述。在模拟器上安装magisk实现Charles抓https包&#xff08;二&#xff09;_小小爬虾的博客-CSDN博客 一、记录下本机IP和代理端口 二、在手机模拟器上设置代理192.168.31.71:8888&#xff0c;设置…

接口自动化测试_L1

目录&#xff1a; 接口自动化测试框架介绍 接口测试场景自动化测试场景接口测试在分层测试中的位置接口自动化测试与 Web/App 自动化测试对比接口自动化测试与 Web/App 自动化测试对比接口测试工具类型为什么推荐 RequestsRequests 优势Requests 环境准备接口请求方法接口请求…

LeetCode【118】杨辉三角

题目&#xff1a; 解析&#xff1a; 该题目解析起来更像是数学推导&#xff0c;找到里面的规律 1、第n行有n个元素 2、第i行第j个元素&#xff0c;为第i-1行&#xff0c;j-1个元素和j个元素的和 3、每行第一个&#xff0c;最后一个元素是1 代码&#xff1a; public List<…

Docker 的数据管理和网络通信

目录 Docker 的数据管理 管理 Docker 容器中数据的方式 端口映射 容器互联&#xff08;使用centos镜像&#xff09; Docker 镜像的创建 Dockerfile 操作常用的指令 编写 Dockerfile 时格式 Dockerfile 案例 Docker 的数据管理 管理 Docker 容器中数据的方式 管理 Doc…

STM32使用HAL库驱动DS3231

1、STM32通讯口配置 启动IIC&#xff0c;默认配置即可。 2、头文件 #ifndef __DS3231_H #define __DS3231_H#include "main.h"#define DS3231_COM_PORT hi2c1 /*通讯端口*//**************************** defines *******************************/ #define DS3231…

cocos2d-x C++与Lua交互

Cocos版本&#xff1a; 3.10 Lua版本&#xff1a; 5.1.4 环境&#xff1a; window Visual Studio 2013 Lua Lua作为一种脚本语言&#xff0c; 它的运行需要有宿主的存在&#xff0c;通过Lua虚拟栈进行数据交互。 它的底层实现是C语言&#xff0c;C语言封装了很多的API接口&a…

ICCV23中的域泛化相关研究

ICCV23中的域泛化相关研究 【OCR】Order-preserving Consistency Regularization for Domain Adaptation and Generalization【iDAG】iDAG: Invariant DAG Searching for Domain Generalization【RIDG】Domain Generalization via Rationale Invariance【3DLabelProp】Domain G…

2023年医药商业行业发展研究报告

第一章 行业概况 1.1 定义 医药商业行业&#xff0c;作为医药领域的重要组成部分&#xff0c;扮演着至关重要的角色。这一行业专注于医药商品的经营与流通&#xff0c;确保药品能够有效、安全地到达消费者手中。随着医药科技的进步和市场需求的增长&#xff0c;医药商业行业在…

Android攻城狮学鸿蒙 -- 点击事件

具体参考&#xff1a;华为官网学习地址 1、点击事件&#xff0c;界面跳转 对于一个按钮设置点击事件&#xff0c;跳转页面。但是onclick中&#xff0c;如果pages前边加上“/”&#xff0c;就没法跳转。但是开发工具加上“/”才会给出提示。不知道是不是开发工具的bug。&#…

C++day6

编程题&#xff1a; 以下是一个简单的比喻&#xff0c;将多态概念与生活中的实际情况相联系&#xff1a; 比喻&#xff1a;动物园的讲解员和动物表演 想象一下你去了一家动物园&#xff0c;看到了许多不同种类的动物&#xff0c;如狮子、大象、猴子等。现在&#xff0c;动物…

ubuntu下查看realsense摄像头查看支持的分辨率和频率

引言&#xff1a; 在实际应用中&#xff0c;摄像头的频率如果过高&#xff0c;可能会造成系统图像处理的压力过大&#xff0c;因此需要选择合适的参数才能达到预期的效果。本文主要探讨设置realsense相关参数 1、打开终端&#xff0c;输入rs-enumerate-devices rs-enumerate-…

Java架构师系统架构设计性能评估

目录 1 导论2 架构评估基础系统性能衡量的基本指标2.1 系统性能的指标2.2 数据库指标2.3 并发用户数2.4 网络延迟2.4 系统吞吐量2.5 资源性能指标 3 架构评估基础服务端性能测试3.1基准测试3.2 负载测试3.3 压力测试3.4 疲劳强度测试3.5 容量测试 1 导论 本章的主要内容是掌握架…

018-第三代软件开发-整体介绍

第三代软件开发-整体介绍 文章目录 第三代软件开发-整体介绍项目介绍整体介绍Qt 属性系统QML 最新软件技术框架 关键字&#xff1a; Qt、 Qml、 属性、 Qml 软件架构 项目介绍 欢迎来到我们的 QML & C 项目&#xff01;这个项目结合了 QML&#xff08;Qt Meta-Object …