Hands on RL 之 Deep Deterministic Policy Gradient(DDPG)

Hands on RL 之 Deep Deterministic Policy Gradient(DDPG)

文章目录

  • Hands on RL 之 Deep Deterministic Policy Gradient(DDPG)
    • 1. 理论部分
      • 1.1 回顾 Deterministic Policy Gradient(DPG)
      • 1.2 Neural Network Difference
      • 1.3 Why is off-policy?
      • 1.4 Soft target update
      • 1.5 Maintain Exploration
      • 1.6 Other Techniques
      • 1.7 Pesudocode
    • 2. 代码实践
    • Reference

1. 理论部分

1.1 回顾 Deterministic Policy Gradient(DPG)

在介绍DDPG之前,我们先回顾一下DPG中最重要的结论,

Deterministic Policy Gradient Theorem即确定性策略梯度定理

∇ θ J ( μ θ ) = ∫ S ρ μ ( s ) ∇ θ μ θ ( s ) ∇ a Q μ ( s , a ) ∣ a = μ θ ( s ) d s = E s ∼ ρ μ [ ∇ θ μ θ ( s ) ∇ a Q μ ( s , a ) ∣ a = μ θ ( s ) ] \begin{aligned} \nabla_\theta J(\mu_\theta) & = \int_{\mathcal{S}} \rho^\mu(s) \nabla_\theta \mu_\theta(s) \nabla_a Q^\mu(s,a)|_{a=\mu_\theta(s)} \mathrm{d}s \\ & = \mathbb{E}_{s\sim\rho^\mu} \Big[ \nabla_\theta \mu_\theta(s) \nabla_a Q^\mu(s,a)|_{a=\mu_\theta(s)} \Big] \end{aligned} θJ(μθ)=Sρμ(s)θμθ(s)aQμ(s,a)a=μθ(s)ds=Esρμ[θμθ(s)aQμ(s,a)a=μθ(s)]

其中, a = μ θ ( s ) a=\mu_\theta(s) a=μθ(s)表示确定性的策略是从状态空间到动作空间的映射 μ θ : S → A \mu_\theta: \mathcal{S}\to\mathcal{A} μθ:SA,网络的参数为 θ \theta θ s ∼ ρ μ s\sim\rho^\mu sρμ表示状态 s s s符合在策略 μ \mu μ下的状态访问分布。如何推导的,这里不详细阐述。(可以参考Deterministic policy gradient algorithms)

接下来逐点介绍DDPG相较于DPG的改进

1.2 Neural Network Difference

​ DDPG在相较于传统的AC算法在网络结构上也有很大不同,首先看看传统算法的网络结构

Image

然后再看看DDPG的网络结构

Image

为什么DDPG会是这样的网络结构呢,这是因为DDPG中的actor输出的是确定性动作,而不是动作的概率分布,因此确定性的动作是连续的可以看作动作空间的维度为无穷,如果采用AC中critic的结构,我们无法通过遍历所有动作来取出某个特定动作对应的Q-value。因此DDPG中将actor的输出作为critic的输入,再联合状态输入,就能直接获得所采取动作 a = μ ( s t ) a=\mu(st) a=μ(st)的Q-value。

1.3 Why is off-policy?

​ 首先为什么DDPG或者说DPG是off-policy的?我们回顾stochastic policy π θ ( a ∣ s ) \pi_\theta(a|s) πθ(as)定义下的Q-value
Q π ( s t , a t ) = E r t , s t + 1 ∼ E [ r ( s t , a t ) + γ E a t + 1 ∼ π [ Q π ( s t + 1 , a t + 1 ) ] ] Q^\pi(s_t,a_t) = \mathbb{E}_{r_t, s_{t+1}\sim E}[r(s_t,a_t) + \gamma \mathbb{E}_{a_{t+1}\sim\pi}[Q^\pi(s_{t+1}, a_{t+1})]] Qπ(st,at)=Ert,st+1E[r(st,at)+γEat+1π[Qπ(st+1,at+1)]]
其中, E E E表示的是环境,即状态 s ∼ E s\sim E sE状态符合环境本身的分布。当我们使用确定性策略的时候 a = μ θ ( s ) a=\mu_\theta(s) a=μθ(s),那么inner expectation就自动被抵消掉了
Q π ( s t , a t ) = E r t , s t + 1 ∼ E [ r ( s t , a t ) + γ Q π ( s t + 1 , a t + 1 = μ ( s t + 1 ) ) ] Q^\pi(s_t,a_t) = \mathbb{E}_{r_t, s_{t+1}\sim E}[r(s_t,a_t) + \gamma Q^\pi(s_{t+1}, a_{t+1}=\mu(s_{t+1}))] Qπ(st,at)=Ert,st+1E[r(st,at)+γQπ(st+1,at+1=μ(st+1))]
这就意味着Q-value不再依赖于动作的访问分布,即没有了 a t + 1 ∼ π a_{t+1}\sim\pi at+1π。那么我们就可以通过行为策略behavior policy β \beta β产生的结果来计算该值,这让off-policy成为可能。

​ 实际上Q-value不再依赖于动作的访问分布,那么确定性梯度定理可以写作
∇ θ J ( μ θ ) ≈ E s ∼ ρ β [ ∇ θ μ θ ( s ) ∇ a Q μ ( s , a ) ∣ a = μ θ ( s ) ] \textcolor{red}{\nabla_\theta J(\mu_\theta) \approx \mathbb{E}_{s\sim\rho^\beta} \Big[ \nabla_\theta \mu_\theta(s) \nabla_a Q^\mu(s,a)|_{a=\mu_\theta(s)} \Big]} θJ(μθ)Esρβ[θμθ(s)aQμ(s,a)a=μθ(s)]
可以写作依赖于behavior policy β \beta β产生的状态访问分布的期望,这就是一种off-policy的形式。

1.4 Soft target update

​ 在DDPG中维护了四个神经网络,分别是policy network, target policy network, action value network, target action value network。使用了DQN中的将目标网络和训练网络分离的思想,并且采用soft更新的方式,能够更有效维护训练中的稳定性。soft更新方式如下
θ − ← τ θ + ( 1 − τ ) θ − \theta^- \leftarrow \tau \theta + (1-\tau)\theta^- θτθ+(1τ)θ
其中, θ − \theta^- θ表示目标网络参数, θ \theta θ表示训练网络参数, τ ≪ 1 \tau \ll 1 τ1 τ \tau τ是软更新参数。

1.5 Maintain Exploration

​ 确定性的策略是不具有探索性的,为了保持策略的探索性,我们可以在策略网络的输出中增加高斯噪声,让输出的动作值有些许偏差来增加网络的探索性。用数学的方式来表示即是
μ ′ ( s t ) = μ θ ( s t ) + N \mu^\prime(s_t) = \mu_\theta(s_t) + \mathcal{N} μ(st)=μθ(st)+N
其中 μ ′ \mu^\prime μ表示探索性的策略, N \mathcal{N} N表示高斯噪声。

1.6 Other Techniques

​ DDPG还集成了一些别的算法的常用技巧,比如Replay Buffer来产生independent and identically distribution的样本,使用了Batch Normalization来预处理数据。

1.7 Pesudocode

伪代码如下

Image

2. 代码实践

我们采用gym中的Pendulum-v1作为本次实验的环境,Pendulum-v1是典型的确定性连续动作空间环境,整体的代码如下

import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import random
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import collections# Policy Network
class PolicyNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim, action_bound):super(PolicyNet, self).__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, action_dim)self.action_bound = action_bounddef forward(self, observation):x = F.relu(self.fc1(observation))x = F.tanh(self.fc2(x))return x * self.action_bound# Q Value Network
class QValueNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(QValueNet, self).__init__()self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, hidden_dim)self.fc_out = nn.Linear(hidden_dim, 1)def forward(self, x, a):cat = torch.cat([x, a], dim=1)    # 拼接状态和动作x = F.relu(self.fc1(cat))x = F.relu(self.fc2(x))return self.fc_out(x)# Deep Deterministic Policy Gradient
class DDPG():def __init__(self, state_dim, hidden_dim, action_dim, action_bound, actor_lr, critic_lr, sigma, tau, gamma, device):self.actor = PolicyNet(state_dim, hidden_dim, action_dim, action_bound).to(device)self.critic = QValueNet(state_dim, hidden_dim, action_dim).to(device)self.target_actor = PolicyNet(state_dim, hidden_dim, action_dim, action_bound).to(device)self.target_critic = QValueNet(state_dim, hidden_dim, action_dim).to(device)# initialize target actor network with same parametersself.target_actor.load_state_dict(self.actor.state_dict())# initialize target critic network with same parametersself.target_critic.load_state_dict(self.critic.state_dict())self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)self.gamma = gammaself.sigma = sigma  # 高斯噪声的标准差,均值直接设置为0self.action_dim = action_dimself.device = deviceself.tau = taudef take_action(self, state):state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)action = self.actor(state).item()# add noise to increase exploratoryaction = action + self.sigma * np.random.randn(self.action_dim)return actiondef soft_update(self, net, target_net):# implement soft update rulefor param_target, param in zip(target_net.parameters(), net.parameters()):param_target.data.copy_(param_target.data * (1.0-self.tau) + param.data * self.tau)def update(self, transition_dict):states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1,1).to(self.device)actions = torch.tensor(transition_dict['actions'], dtype=torch.float).view(-1,1).to(self.device)next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1,1).to(self.device)next_q_values = self.target_critic(next_states, self.target_actor(next_states))td_targets = rewards + self.gamma * next_q_values * (1-dones)critic_loss = torch.mean(F.mse_loss(self.critic(states, actions), td_targets))self.critic_optimizer.zero_grad()critic_loss.backward()self.critic_optimizer.step()actor_loss = torch.mean( - self.critic(states, self.actor(states)))self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# soft update actor net and critic netself.soft_update(self.actor, self.target_actor)self.soft_update(self.critic, self.target_critic)class ReplayBuffer():def __init__(self, capacity):self.buffer = collections.deque(maxlen=capacity)def add(self, s, a, r, s_, d):self.buffer.append((s,a,r,s_,d))def sample(self, batch_size):transitions = random.sample(self.buffer, batch_size)states, actions, rewards, next_states, dones = zip(*transitions)return np.array(states), actions, np.array(rewards), np.array(next_states), donesdef size(self):return len(self.buffer)def train_off_policy_agent(env, agent, num_episodes, replay_buffer, minimal_size, batch_size, render, seed_number):return_list = []for i in range(10):with tqdm(total=int(num_episodes/10), desc='Iteration %d'%(i+1)) as pbar:for i_episode in range(int(num_episodes/10)):observation, _ = env.reset(seed=seed_number)done = Falseepisode_return = 0while not done:if render:env.render()action = agent.take_action(observation)observation_, reward, terminated, truncated, _ = env.step(action)done = terminated or truncatedreplay_buffer.add(observation, action, reward, observation_, done)# swap statesobservation = observation_episode_return += rewardif replay_buffer.size() > minimal_size:b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)transition_dict = {'states': b_s,'actions': b_a,'rewards': b_r,'next_states': b_ns,'dones': b_d}agent.update(transition_dict)return_list.append(episode_return)if(i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d'%(num_episodes/10 * i + i_episode + 1),'return': "%.3f"%(np.mean(return_list[-10:]))})pbar.update(1)env.close()return return_listdef moving_average(a, window_size):cumulative_sum = np.cumsum(np.insert(a, 0, 0)) middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_sizer = np.arange(1, window_size-1, 2)begin = np.cumsum(a[:window_size-1])[::2] / rend = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]return np.concatenate((begin, middle, end))def plot_curve(return_list, mv_return, algorithm_name, env_name):episodes_list = list(range(len(return_list)))plt.plot(episodes_list, return_list, c='gray', alpha=0.6)plt.plot(episodes_list, mv_return)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('{} on {}'.format(algorithm_name, env_name))plt.show()if __name__ == "__main__":# reproducibleseed_number = 0random.seed(seed_number)np.random.seed(seed_number)torch.manual_seed(seed_number)num_episodes = 250     # episodes lengthhidden_dim = 128        # hidden layers dimensiongamma = 0.98            # discounted rateactor_lr = 1e-3         # lr of actorcritic_lr = 1e-3        # lr of critictau = 0.005             # soft update parametersigma = 0.01            # std variance of guassian noisebuffer_size = 10000minimal_size = 1000batch_size = 64device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')env_name = 'Pendulum-v1'render = Falseif render:env = gym.make(id=env_name, render_mode='human')else:env = gym.make(id=env_name)state_dim = env.observation_space.shape[0]action_dim = env.action_space.shape[0]  action_bound = env.action_space.high[0]replay_buffer = ReplayBuffer(buffer_size)        agent = DDPG(state_dim, hidden_dim, action_dim, action_bound, actor_lr, critic_lr, sigma, tau, gamma, device)return_list = train_off_policy_agent(env, agent, num_episodes, replay_buffer, minimal_size, batch_size, render, seed_number)mv_return = moving_average(return_list, 9)plot_curve(return_list, mv_return, 'DDPG', env_name)

DDPG训练的回报曲线如图所示

Image

Reference

Tutorial: Hands on RL

Paper: Continuous control with deep reinforcement learning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/91683.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue数组变更方法和替换方法

一、可以引起UI界面变化 Vue 将被侦听的数组的变更方法进行了包裹,所以它们也将会触发视图更新。这些被包裹过的方法包括: push()pop()shift()unshift()splice()sort()reverse() 以上七个数组都会改变原数组,下面来分别讲解它们的区别&…

Electron 应用实现截图并编辑功能

Electron 应用实现截图并编辑功能 Electron 应用如何实现截屏功能,有两种思路,作为一个框架是否可以通过框架实现截屏,另一种就是 javaScript 结合 html 中画布功能实现截屏。 在初步思考之后,本文优先探索使用 Electron 实现截屏…

C++11实用技术(四)for循环该怎么写

普通用法 在C遍历stl容器的方法通常是&#xff1a; #include <iostream> #include <vector>int main() {std::vector<int> arr {1, 2, 3};for (auto it arr.begin(); it ! arr.end(); it){std::cout << *it << std::endl;}return 0; }上述代…

Word 2019打开.doc文档后图片和公式不显示(呈现为白框)的解决办法

Word 2019打开.doc文档后图片和公式不显示&#xff08;呈现为白框&#xff09;的解决办法 目录 Word 2019打开.doc文档后图片和公式不显示&#xff08;呈现为白框&#xff09;的解决办法一、问题描述二、解决方法1.打开 WORD 2019&#xff0c;点击菜单中的“文件”&#xff1b;…

【笔试题心得】排序算法总结整理

排序算法汇总 常用十大排序算法_calm_G的博客-CSDN博客 以下动图参考 十大经典排序算法 Python 版实现&#xff08;附动图演示&#xff09; - 知乎 冒泡排序 排序过程如下图所示&#xff1a; 比较相邻的元素。如果第一个比第二个大&#xff0c;就交换他们两个。对每一对相邻…

杨氏矩阵!!!!

杨氏矩阵&#x1f438; &#x1f4d5;题目要求&#xff1a; 杨氏矩阵 题目内容&#x1f4da;&#xff1a; 有一个数字矩阵&#xff0c;矩阵的每行从左到右是递增的&#xff0c;矩阵从上到下是递增的&#xff0c;请编写程序在这样的矩阵中查找某个数字是否存在。 &#x1f9e0;题…

Microsoft ISA服务器配置及日志分析

Microsoft ISA 分析器工具&#xff0c;可分析 Microsoft ISA 服务器&#xff08;或 Forefront 威胁管理网关服务器&#xff09;的日志并生成安全和流量报告。支持来自 Microsoft ISA 服务器组件的以下日志&#xff1a; 数据包过滤器ISA 服务器防火墙服务ISA 服务器网络代理服务…

24届近3年青岛理工大学自动化考研院校分析

今天给大家带来的是青岛理工大学控制考研分析 满满干货&#xff5e;还不快快点赞收藏 一、青岛理工大学 学校简介 青岛理工大学是一所以工为主&#xff0c;土木建筑、机械制造、环境能源学科特色鲜明&#xff0c;理工经管文法艺等学科协调发展的多科性大学。是国家首批地方…

安达发APS|APS排产软件之计划甘特图

在当今全球化和竞争激烈的市场环境下&#xff0c;制造业企业面临着巨大的压力&#xff0c;如何在保证产品质量、降低成本以及满足客户需求的同时&#xff0c;提高生产效率和竞争力成为企业需要迫切解决的问题。在这个背景下&#xff0c;生产计划的制定和执行显得尤为重要。然而…

LCS最大公共子序列 与 LIS最大递增子序列

LCS Largest Common Subsequence 最大公共子序列 /* Input s1 s2//两个字符串Output length//长度 ans//具体字母 */ #include<iostream> using namespace std; int main() {string s1,s2;cin>>s1>>s2;int dp[100][100]{0};//dp[i][j]表示s1取前i位&#x…

(二)结构型模式:5、装饰器模式(Decorator Pattern)(C++实例)

目录 1、装饰器模式&#xff08;Decorator Pattern&#xff09;含义 2、装饰器模式的UML图学习 3、装饰器模式的应用场景 4、装饰器模式的优缺点 5、C实现装饰器模式的简单实例 1、装饰器模式&#xff08;Decorator Pattern&#xff09;含义 装饰模式&#xff08;Decorato…

【软件工程】面向对象方法-RUP

RUP&#xff08;Rational Unified Process&#xff0c;统一软件开发过程&#xff09;。 RUP特点 以用况驱动的&#xff0c;以体系结构为中心的&#xff0c;迭代增量式开发 用况驱动 用况是能够向用户提供有价值结果的系统中的一种功能用况获取的是功能需求 在系统的生存周期中…

前后端分离------后端创建笔记(05)用户列表查询接口(上)

本文章转载于【SpringBootVue】全网最简单但实用的前后端分离项目实战笔记 - 前端_大菜007的博客-CSDN博客 仅用于学习和讨论&#xff0c;如有侵权请联系 源码&#xff1a;https://gitee.com/green_vegetables/x-admin-project.git 素材&#xff1a;https://pan.baidu.com/s/…

css3 瀑布流布局遇见截断下一列展示后半截现象

css3 瀑布流布局遇见截断下一列展示后半截现象 注&#xff1a;css3实现瀑布流布局简直不要太香&#xff5e;&#xff5e;&#xff5e;&#xff5e;&#xff5e; 场景-在uniapp项目中 当瀑布流布局column-grap:10px 相邻两列之间的间隙为10px&#xff0c;column-count:2,2列展…

数据结构入门指南:二叉树

目录 文章目录 前言 1. 树的概念及结构 1.1 树的概念 1.2 树的基础概念 1.3 树的表示 1.4 树的应用 2. 二叉树 2.1 二叉树的概念 2.2 二叉树的遍历 前言 在计算机科学中&#xff0c;数据结构是解决问题的关键。而二叉树作为最基本、最常用的数据结构之一&#xff0c;不仅在算法…

LC-相交链表(解法2)

LC-相交链表&#xff08;解法2&#xff09; 链接&#xff1a;https://leetcode.cn/problems/intersection-of-two-linked-lists/description/ 描述&#xff1a;给你两个单链表的头节点 headA 和 headB &#xff0c;请你找出并返回两个单链表相交的起始节点。如果两个链表不存在…

ABAP Der Open SQL command is too big.

ABAP Der Open SQL command is too big. DBSQL_STMNT_TOO_LARGE CX_SY_OPEN_SQL_DB 应该是选择条件中 维护的条件值条数太多了

CSS:background 复合属性详解(用法 + 例子 + 效果)

目录 background 复合属性background-color 背景颜色&#xff08;纯&#xff09;background-image 背景图片 或者 渐变颜色background-repeat 背景是否重复background-size 设置图片大小background-position 设置背景图片显示位置background-attachment 设置背景图片是否随页面…

Windows下升级jdk1.8小版本

1.首先下载要升级jdk最新版本&#xff0c;下载地址&#xff1a;Java Downloads | Oracle 中国 2.下载完毕之后&#xff0c;直接双击下载完毕后的文件&#xff0c;进行安装。 3.安装完毕后&#xff0c;调整环境变量至新安装的jdk位置 4.此时&#xff0c;idea启动项目有可能会出…

如何给 Keycloak 用户加上“部门”、“电话”等自定义属性

Keycloak 是一款开源的用户认证和授权软件。在默认安装情况下&#xff0c;它只给新创建的用户提供了 email 属性&#xff0c;但是在许多应用场景中&#xff0c;客户都会要求给新创建的用户增加诸如“部门”、“电话”等自定义属性。 本文会介绍如何给 keycloak 中新创建的用户…